You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							246 lines
						
					
					
						
							6.3 KiB
						
					
					
				
			
		
		
	
	
							246 lines
						
					
					
						
							6.3 KiB
						
					
					
				| // Copyright 2015 PingCAP, Inc.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| package optimizer
 | |
| 
 | |
| import (
 | |
| 	"github.com/juju/errors"
 | |
| 	"github.com/pingcap/tidb/ast"
 | |
| 	"github.com/pingcap/tidb/mysql"
 | |
| 	"github.com/pingcap/tidb/parser"
 | |
| 	"github.com/pingcap/tidb/parser/opcode"
 | |
| )
 | |
| 
 | |
| // Validate checkes whether the node is valid.
 | |
| func Validate(node ast.Node, inPrepare bool) error {
 | |
| 	v := validator{inPrepare: inPrepare}
 | |
| 	node.Accept(&v)
 | |
| 	return v.err
 | |
| }
 | |
| 
 | |
| // validator is an ast.Visitor that validates
 | |
| // ast Nodes parsed from parser.
 | |
| type validator struct {
 | |
| 	err           error
 | |
| 	wildCardCount int
 | |
| 	inPrepare     bool
 | |
| 	inAggregate   bool
 | |
| }
 | |
| 
 | |
| func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
 | |
| 	switch in.(type) {
 | |
| 	case *ast.AggregateFuncExpr:
 | |
| 		if v.inAggregate {
 | |
| 			// Aggregate function can not contain aggregate function.
 | |
| 			v.err = ErrInvalidGroupFuncUse
 | |
| 			return in, true
 | |
| 		}
 | |
| 		v.inAggregate = true
 | |
| 	}
 | |
| 	return in, false
 | |
| }
 | |
| 
 | |
| func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) {
 | |
| 	switch x := in.(type) {
 | |
| 	case *ast.AggregateFuncExpr:
 | |
| 		v.inAggregate = false
 | |
| 	case *ast.BetweenExpr:
 | |
| 		v.checkAllOneColumn(x.Expr, x.Left, x.Right)
 | |
| 	case *ast.BinaryOperationExpr:
 | |
| 		v.checkBinaryOperation(x)
 | |
| 	case *ast.ByItem:
 | |
| 		v.checkAllOneColumn(x.Expr)
 | |
| 	case *ast.CreateTableStmt:
 | |
| 		v.checkAutoIncrement(x)
 | |
| 	case *ast.CompareSubqueryExpr:
 | |
| 		v.checkSameColumns(x.L, x.R)
 | |
| 	case *ast.FieldList:
 | |
| 		v.checkFieldList(x)
 | |
| 	case *ast.HavingClause:
 | |
| 		v.checkAllOneColumn(x.Expr)
 | |
| 	case *ast.IsNullExpr:
 | |
| 		v.checkAllOneColumn(x.Expr)
 | |
| 	case *ast.IsTruthExpr:
 | |
| 		v.checkAllOneColumn(x.Expr)
 | |
| 	case *ast.ParamMarkerExpr:
 | |
| 		if !v.inPrepare {
 | |
| 			v.err = parser.ErrSyntax.Gen("syntax error, unexpected '?'")
 | |
| 		}
 | |
| 	case *ast.PatternInExpr:
 | |
| 		v.checkSameColumns(append(x.List, x.Expr)...)
 | |
| 	}
 | |
| 
 | |
| 	return in, v.err == nil
 | |
| }
 | |
| 
 | |
| // checkAllOneColumn checks that all expressions have one column.
 | |
| // Expression may have more than one column when it is a rowExpr or
 | |
| // a Subquery with more than one result fields.
 | |
| func (v *validator) checkAllOneColumn(exprs ...ast.ExprNode) {
 | |
| 	for _, expr := range exprs {
 | |
| 		switch x := expr.(type) {
 | |
| 		case *ast.RowExpr:
 | |
| 			v.err = ErrOneColumn
 | |
| 		case *ast.SubqueryExpr:
 | |
| 			if len(x.Query.GetResultFields()) != 1 {
 | |
| 				v.err = ErrOneColumn
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func checkAutoIncrementOp(colDef *ast.ColumnDef, num int) (bool, error) {
 | |
| 	var hasAutoIncrement bool
 | |
| 
 | |
| 	if colDef.Options[num].Tp == ast.ColumnOptionAutoIncrement {
 | |
| 		hasAutoIncrement = true
 | |
| 		if len(colDef.Options) == num+1 {
 | |
| 			return hasAutoIncrement, nil
 | |
| 		}
 | |
| 		for _, op := range colDef.Options[num+1:] {
 | |
| 			if op.Tp == ast.ColumnOptionDefaultValue {
 | |
| 				return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	if colDef.Options[num].Tp == ast.ColumnOptionDefaultValue && len(colDef.Options) != num+1 {
 | |
| 		for _, op := range colDef.Options[num+1:] {
 | |
| 			if op.Tp == ast.ColumnOptionAutoIncrement {
 | |
| 				return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return hasAutoIncrement, nil
 | |
| }
 | |
| 
 | |
| func isConstraintKeyTp(constraints []*ast.Constraint, colDef *ast.ColumnDef) bool {
 | |
| 	for _, c := range constraints {
 | |
| 		if len(c.Keys) < 1 {
 | |
| 		}
 | |
| 		// If the constraint as follows: primary key(c1, c2)
 | |
| 		// we only support c1 column can be auto_increment.
 | |
| 		if colDef.Name.Name.L != c.Keys[0].Column.Name.L {
 | |
| 			continue
 | |
| 		}
 | |
| 		switch c.Tp {
 | |
| 		case ast.ConstraintPrimaryKey, ast.ConstraintKey, ast.ConstraintIndex,
 | |
| 			ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey:
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (v *validator) checkAutoIncrement(stmt *ast.CreateTableStmt) {
 | |
| 	var (
 | |
| 		isKey            bool
 | |
| 		count            int
 | |
| 		autoIncrementCol *ast.ColumnDef
 | |
| 	)
 | |
| 
 | |
| 	for _, colDef := range stmt.Cols {
 | |
| 		var hasAutoIncrement bool
 | |
| 		for i, op := range colDef.Options {
 | |
| 			ok, err := checkAutoIncrementOp(colDef, i)
 | |
| 			if err != nil {
 | |
| 				v.err = err
 | |
| 				return
 | |
| 			}
 | |
| 			if ok {
 | |
| 				hasAutoIncrement = true
 | |
| 			}
 | |
| 			switch op.Tp {
 | |
| 			case ast.ColumnOptionPrimaryKey, ast.ColumnOptionUniqKey, ast.ColumnOptionUniqIndex,
 | |
| 				ast.ColumnOptionUniq, ast.ColumnOptionKey, ast.ColumnOptionIndex:
 | |
| 				isKey = true
 | |
| 			}
 | |
| 		}
 | |
| 		if hasAutoIncrement {
 | |
| 			count++
 | |
| 			autoIncrementCol = colDef
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if count < 1 {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if !isKey {
 | |
| 		isKey = isConstraintKeyTp(stmt.Constraints, autoIncrementCol)
 | |
| 	}
 | |
| 	if !isKey || count > 1 {
 | |
| 		v.err = errors.New("Incorrect table definition; there can be only one auto column and it must be defined as a key")
 | |
| 	}
 | |
| 
 | |
| 	switch autoIncrementCol.Tp.Tp {
 | |
| 	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong,
 | |
| 		mysql.TypeFloat, mysql.TypeDouble, mysql.TypeLonglong, mysql.TypeInt24:
 | |
| 	default:
 | |
| 		v.err = errors.Errorf("Incorrect column specifier for column '%s'", autoIncrementCol.Name.Name.O)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (v *validator) checkBinaryOperation(x *ast.BinaryOperationExpr) {
 | |
| 	// row constructor only supports comparison operation.
 | |
| 	switch x.Op {
 | |
| 	case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
 | |
| 		v.checkSameColumns(x.L, x.R)
 | |
| 	default:
 | |
| 		v.checkAllOneColumn(x.L, x.R)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func columnCount(ex ast.ExprNode) int {
 | |
| 	switch x := ex.(type) {
 | |
| 	case *ast.RowExpr:
 | |
| 		return len(x.Values)
 | |
| 	case *ast.SubqueryExpr:
 | |
| 		return len(x.Query.GetResultFields())
 | |
| 	default:
 | |
| 		return 1
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (v *validator) checkSameColumns(exprs ...ast.ExprNode) {
 | |
| 	if len(exprs) == 0 {
 | |
| 		return
 | |
| 	}
 | |
| 	count := columnCount(exprs[0])
 | |
| 	for i := 1; i < len(exprs); i++ {
 | |
| 		if columnCount(exprs[i]) != count {
 | |
| 			v.err = ErrSameColumns
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // checkFieldList checks if there is only one '*' and each field has only one column.
 | |
| func (v *validator) checkFieldList(x *ast.FieldList) {
 | |
| 	var hasWildCard bool
 | |
| 	for _, val := range x.Fields {
 | |
| 		if val.WildCard != nil && val.WildCard.Table.L == "" {
 | |
| 			if hasWildCard {
 | |
| 				v.err = ErrMultiWildCard
 | |
| 				return
 | |
| 			}
 | |
| 			hasWildCard = true
 | |
| 		}
 | |
| 		v.checkAllOneColumn(val.Expr)
 | |
| 		if v.err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 |