diff --git a/go/mysql/client.go b/go/mysql/client.go index cba9db8f7c5..242f798fcf5 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -242,7 +242,7 @@ func setCollationForConnection(c *Conn, params *ConnParams) error { // getHandshakeCharacterSet returns the collation ID of DefaultCollation in an // 8 bits integer which will be used to feed the handshake protocol's packet. func getHandshakeCharacterSet() (uint8, error) { - coll := collations.Default().LookupByName(DefaultCollation) + coll := collations.Local().LookupByName(DefaultCollation) if coll == nil { // theoretically, this should never happen from an end user perspective return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot resolve collation ID for collation: '%s'", DefaultCollation) diff --git a/go/mysql/collations/8bit.go b/go/mysql/collations/8bit.go index f84a5adfdc4..07123af1147 100644 --- a/go/mysql/collations/8bit.go +++ b/go/mysql/collations/8bit.go @@ -224,12 +224,14 @@ func weightStringPadingSimple(padChar byte, dst []byte, numCodepoints int, padTo return dst } +const CollationBinaryID ID = 63 + type Collation_binary struct{} func (c *Collation_binary) Init() {} func (c *Collation_binary) ID() ID { - return 63 + return CollationBinaryID } func (c *Collation_binary) Name() string { diff --git a/go/mysql/collations/cached_size.go b/go/mysql/collations/cached_size.go new file mode 100644 index 00000000000..41b04328bf1 --- /dev/null +++ b/go/mysql/collations/cached_size.go @@ -0,0 +1,71 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Code generated by Sizegen. DO NOT EDIT. + +package collations + +import hack "vitess.io/vitess/go/hack" + +type cachedObject interface { + CachedSize(alloc bool) int64 +} + +func (cached *eightbitWildcard) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(32) + } + // field pattern []int16 + { + size += hack.RuntimeAllocSize(int64(cap(cached.pattern)) * int64(2)) + } + return size +} +func (cached *fastMatcher) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field pattern []byte + { + size += hack.RuntimeAllocSize(int64(cap(cached.pattern))) + } + return size +} +func (cached *unicodeWildcard) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field charset vitess.io/vitess/go/mysql/collations/internal/charset/types.Charset + if cc, ok := cached.charset.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field pattern []rune + { + size += hack.RuntimeAllocSize(int64(cap(cached.pattern)) * int64(4)) + } + return size +} diff --git a/go/mysql/collations/coercion.go b/go/mysql/collations/coercion.go index 36bf7e714f6..227389a52ba 100644 --- a/go/mysql/collations/coercion.go +++ b/go/mysql/collations/coercion.go @@ -24,8 +24,8 @@ import ( ) func init() { - if unsafe.Sizeof(TypedCollationID{}) != 4 { - panic("TypedCollationID should fit in an int32") + if unsafe.Sizeof(TypedCollation{}) != 4 { + panic("TypedCollation should fit in an int32") } } @@ -95,82 +95,66 @@ const ( RepertoireUnicode ) -// Coercion is a function that will transform either the left or right -// arguments of the function into the same character set. The `dst` argument +// Coercion is a function that will transform either the given argument +// arguments of the function into a specific character set. The `dst` argument // will be used as the destination of the coerced argument, but it can be nil. -// The function returns the given left and right arguments: one of the arguments -// will be the same value that was passed in, while the other will be the -// same value but transcoded into a different character set, depending on which -// of the arguments is supposed to be coerced. -// If the contents of the argument that must be transcoded cannot be mapped -// to the target charset, an error will be returned. -type Coercion func(dst, left, right []byte) ([]byte, []byte, error) +type Coercion func(dst, in []byte) ([]byte, error) // TypedCollation is the Collation of a SQL expression, including its coercibility // and repertoire. type TypedCollation struct { - Collation Collation - Coercibility Coercibility - Repertoire Repertoire -} - -// TypedCollationID is like TypedCollation but the actual collation is stored with its -// Collation ID, so the total size of the struct is 4 bytes. This is useful for type -// processing in the AST. -type TypedCollationID struct { Collation ID Coercibility Coercibility Repertoire Repertoire } -func (env *Environment) TypedCollation(tid TypedCollationID) *TypedCollation { - return &TypedCollation{ - Collation: env.LookupByID(tid.Collation), - Coercibility: tid.Coercibility, - Repertoire: tid.Repertoire, - } +func (tc TypedCollation) Valid() bool { + return tc.Collation != Unknown } -func checkCompatibleCollations(left, right *TypedCollation) bool { - leftCS := left.Collation.Charset() - rightCS := right.Collation.Charset() +func checkCompatibleCollations( + left Collation, leftCoercibility Coercibility, leftRepertoire Repertoire, + right Collation, rightCoercibility Coercibility, rightRepertoire Repertoire, +) bool { + leftCS := left.Charset() + rightCS := right.Charset() switch leftCS.(type) { case charset.Charset_utf8mb4: - if left.Coercibility <= right.Coercibility { + if leftCoercibility <= rightCoercibility { return true } case charset.Charset_utf32: switch { - case left.Coercibility < right.Coercibility: + case leftCoercibility < rightCoercibility: return true - case left.Coercibility == right.Coercibility: + case leftCoercibility == rightCoercibility: if !charset.IsUnicode(rightCS) { return true } - if !left.Collation.IsBinary() { + if !left.IsBinary() { return true } } case charset.Charset_utf8, charset.Charset_ucs2, charset.Charset_utf16, charset.Charset_utf16le: switch { - case left.Coercibility < right.Coercibility: + case leftCoercibility < rightCoercibility: return true - case left.Coercibility == right.Coercibility: + case leftCoercibility == rightCoercibility: if !charset.IsUnicode(rightCS) { return true } } } - if right.Repertoire == RepertoireASCII { + if rightRepertoire == RepertoireASCII { switch { - case left.Coercibility < right.Coercibility: + case leftCoercibility < rightCoercibility: return true - case left.Coercibility == right.Coercibility: - if left.Repertoire == RepertoireUnicode { + case leftCoercibility == rightCoercibility: + if leftRepertoire == RepertoireUnicode { return true } } @@ -179,10 +163,6 @@ func checkCompatibleCollations(left, right *TypedCollation) bool { return false } -func noCoercion(_, left, right []byte) ([]byte, []byte, error) { - return left, right, nil -} - // CoercionOptions is used to configure how aggressive the algorithm can be // when merging two different collations by transcoding them. type CoercionOptions struct { @@ -222,71 +202,75 @@ type CoercionOptions struct { // // If the collations for both sides of the expression are not compatible, an error // will be returned and the returned TypedCollation and Coercion will be nil. -func (env *Environment) MergeCollations(left, right *TypedCollation, opt CoercionOptions) (*TypedCollation, Coercion, error) { - leftCS := left.Collation.Charset() - rightCS := right.Collation.Charset() +func (env *Environment) MergeCollations(left, right TypedCollation, opt CoercionOptions) (TypedCollation, Coercion, Coercion, error) { + leftColl := env.LookupByID(left.Collation) + rightColl := env.LookupByID(right.Collation) + if leftColl == nil || rightColl == nil { + return TypedCollation{}, nil, nil, fmt.Errorf("unsupported TypeCollationID: %v / %v", left.Collation, right.Collation) + } + leftCS := leftColl.Charset() + rightCS := rightColl.Charset() if leftCS.Name() == rightCS.Name() { switch { case left.Coercibility < right.Coercibility: left.Repertoire |= right.Repertoire - return left, noCoercion, nil + return left, nil, nil, nil case left.Coercibility > right.Coercibility: right.Repertoire |= left.Repertoire - return right, noCoercion, nil + return right, nil, nil, nil - case left.Collation.ID() == right.Collation.ID(): + case left.Collation == right.Collation: left.Repertoire |= right.Repertoire - return left, noCoercion, nil + return left, nil, nil, nil } if left.Coercibility == CoerceExplicit { goto cannotCoerce } - leftCsBin := left.Collation.IsBinary() - rightCsBin := right.Collation.IsBinary() + leftCsBin := leftColl.IsBinary() + rightCsBin := rightColl.IsBinary() switch { case leftCsBin && rightCsBin: left.Coercibility = CoerceNone - return left, noCoercion, nil + return left, nil, nil, nil case leftCsBin: - return left, noCoercion, nil + return left, nil, nil, nil case rightCsBin: - return right, noCoercion, nil + return right, nil, nil, nil } defaults := env.byCharset[leftCS.Name()] - defaults.Binary.Init() - return &TypedCollation{ - Collation: defaults.Binary, + return TypedCollation{ + Collation: defaults.Binary.ID(), Coercibility: CoerceNone, Repertoire: left.Repertoire | right.Repertoire, - }, noCoercion, nil + }, nil, nil, nil } - if _, leftIsBinary := left.Collation.(*Collation_binary); leftIsBinary { + if _, leftIsBinary := leftColl.(*Collation_binary); leftIsBinary { if left.Coercibility <= right.Coercibility { - return left, noCoercion, nil + return left, nil, nil, nil } - return right, noCoercion, nil + return right, nil, nil, nil } - if _, rightIsBinary := right.Collation.(*Collation_binary); rightIsBinary { + if _, rightIsBinary := rightColl.(*Collation_binary); rightIsBinary { if left.Coercibility >= right.Coercibility { - return right, noCoercion, nil + return right, nil, nil, nil } - return left, noCoercion, nil + return left, nil, nil, nil } if opt.ConvertToSuperset { - if checkCompatibleCollations(left, right) { + if checkCompatibleCollations(leftColl, left.Coercibility, left.Repertoire, rightColl, right.Coercibility, right.Repertoire) { goto coerceToLeft } - if checkCompatibleCollations(right, left) { + if checkCompatibleCollations(rightColl, right.Coercibility, right.Repertoire, leftColl, left.Coercibility, left.Repertoire) { goto coerceToRight } } @@ -301,18 +285,28 @@ func (env *Environment) MergeCollations(left, right *TypedCollation, opt Coercio } cannotCoerce: - return nil, nil, fmt.Errorf("Illegal mix of collations (%s,%s) and (%s,%s)", - left.Collation.Name(), left.Coercibility, right.Collation.Name(), right.Coercibility) + return TypedCollation{}, nil, nil, fmt.Errorf("Illegal mix of collations (%s,%s) and (%s,%s)", + leftColl.Name(), left.Coercibility, rightColl.Name(), right.Coercibility) coerceToLeft: - return left, func(dst, left, right []byte) ([]byte, []byte, error) { - trans, err := charset.Convert(dst, leftCS, right, rightCS) - return left, trans, err - }, nil + return left, nil, + func(dst, in []byte) ([]byte, error) { + return charset.Convert(dst, leftCS, in, rightCS) + }, nil coerceToRight: - return right, func(dst, left, right []byte) ([]byte, []byte, error) { - trans, err := charset.Convert(dst, rightCS, left, leftCS) - return trans, right, err - }, nil + return right, + func(dst, in []byte) ([]byte, error) { + return charset.Convert(dst, rightCS, in, leftCS) + }, nil, nil +} + +func (env *Environment) EnsureCollate(fromID, toID ID) error { + // these two lookups should never fail + from := env.LookupByID(fromID) + to := env.LookupByID(toID) + if from.Charset().Name() != to.Charset().Name() { + return fmt.Errorf("COLLATION '%s' is not valid for CHARACTER SET '%s'", to.Name(), from.Charset().Name()) + } + return nil } diff --git a/go/mysql/collations/env.go b/go/mysql/collations/env.go index 6309559b6ae..ff36c9a4333 100644 --- a/go/mysql/collations/env.go +++ b/go/mysql/collations/env.go @@ -20,6 +20,8 @@ import ( "fmt" "strings" "sync" + + "vitess.io/vitess/go/vt/servenv" ) type colldefaults struct { @@ -229,8 +231,18 @@ func makeEnv(version collver) *Environment { return env } -// Default is the default collation Environment for Vitess. This is set to -// the collation set and defaults available in MySQL 8.0 -func Default() *Environment { - return fetchCacheEnvironment(collverMySQL80) +var defaultEnv *Environment +var defaultEnvInit sync.Once + +// Local is the default collation Environment for Vitess. This depends +// on the value of the `mysql_server_version` flag passed to this Vitess process. +func Local() *Environment { + defaultEnvInit.Do(func() { + if *servenv.MySQLServerVersion == "" { + defaultEnv = fetchCacheEnvironment(collverMySQL80) + } else { + defaultEnv = NewEnvironment(*servenv.MySQLServerVersion) + } + }) + return defaultEnv } diff --git a/go/mysql/collations/integration/coercion_test.go b/go/mysql/collations/integration/coercion_test.go index 0f4fa69ce82..fd1f8ef0998 100644 --- a/go/mysql/collations/integration/coercion_test.go +++ b/go/mysql/collations/integration/coercion_test.go @@ -41,7 +41,7 @@ type RemoteCoercionResult struct { type RemoteCoercionTest interface { Expression() string - Test(t *testing.T, remote *RemoteCoercionResult, localCollation *collations.TypedCollation, localCoercion collations.Coercion) + Test(t *testing.T, remote *RemoteCoercionResult, local collations.TypedCollation, coerce1, coerce2 collations.Coercion) } type testConcat struct { @@ -55,17 +55,24 @@ func (tc *testConcat) Expression() string { ) } -func (tc *testConcat) Test(t *testing.T, remote *RemoteCoercionResult, local *collations.TypedCollation, coercion collations.Coercion) { - if local.Collation.Name() != remote.Collation.Name() { - t.Errorf("bad collation resolved: local is %s, remote is %s", local.Collation.Name(), remote.Collation.Name()) +func (tc *testConcat) Test(t *testing.T, remote *RemoteCoercionResult, local collations.TypedCollation, coercion1, coercion2 collations.Coercion) { + localCollation := defaultenv.LookupByID(local.Collation) + if localCollation.Name() != remote.Collation.Name() { + t.Errorf("bad collation resolved: local is %s, remote is %s", localCollation.Name(), remote.Collation.Name()) } if local.Coercibility != remote.Coercibility { t.Errorf("bad coercibility resolved: local is %d, remote is %d", local.Coercibility, remote.Coercibility) } - leftText, rightText, err := coercion(nil, tc.left.Text, tc.right.Text) + leftText, err := coercion1(nil, tc.left.Text) if err != nil { - t.Errorf("failed to transcode left/right: %v", err) + t.Errorf("failed to transcode left: %v", err) + return + } + + rightText, err := coercion2(nil, tc.right.Text) + if err != nil { + t.Errorf("failed to transcode right: %v", err) return } @@ -77,7 +84,7 @@ func (tc *testConcat) Test(t *testing.T, remote *RemoteCoercionResult, local *co t.Errorf("failed to concatenate text;\n\tCONCAT(%v COLLATE %s, %v COLLATE %s) = \n\tCONCAT(%v, %v) COLLATE %s = \n\t\t%v\n\n\texpected: %v", tc.left.Text, tc.left.Collation.Name(), tc.right.Text, tc.right.Collation.Name(), - leftText, rightText, local.Collation.Name(), + leftText, rightText, localCollation.Name(), concat.Bytes(), remote.Expr.ToBytes(), ) } @@ -94,18 +101,25 @@ func (tc *testComparison) Expression() string { ) } -func (tc *testComparison) Test(t *testing.T, remote *RemoteCoercionResult, localCollation *collations.TypedCollation, localCoercion collations.Coercion) { - leftText, rightText, err := localCoercion(nil, tc.left.Text, tc.right.Text) +func (tc *testComparison) Test(t *testing.T, remote *RemoteCoercionResult, local collations.TypedCollation, coerce1, coerce2 collations.Coercion) { + leftText, err := coerce1(nil, tc.left.Text) if err != nil { - t.Errorf("failed to transcode left/right: %v", err) + t.Errorf("failed to transcode left: %v", err) + return + } + + rightText, err := coerce2(nil, tc.right.Text) + if err != nil { + t.Errorf("failed to transcode right: %v", err) return } remoteEquals := remote.Expr.ToBytes()[0] == '1' - localEquals := localCollation.Collation.Collate(leftText, rightText, false) == 0 + localCollation := defaultenv.LookupByID(local.Collation) + localEquals := localCollation.Collate(leftText, rightText, false) == 0 if remoteEquals != localEquals { t.Errorf("failed to collate %#v = %#v with collation %s (expected %v, got %v)", - leftText, rightText, localCollation.Collation.Name(), remoteEquals, localEquals) + leftText, rightText, localCollation.Name(), remoteEquals, localEquals) } } @@ -146,22 +160,30 @@ func TestComparisonSemantics(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, collA := range testInputs { for _, collB := range testInputs { - left := &collations.TypedCollation{ - Collation: collA.Collation, + left := collations.TypedCollation{ + Collation: collA.Collation.ID(), Coercibility: 0, Repertoire: collations.RepertoireASCII, } - right := &collations.TypedCollation{ - Collation: collB.Collation, + right := collations.TypedCollation{ + Collation: collB.Collation.ID(), Coercibility: 0, Repertoire: collations.RepertoireASCII, } - resultLocal, coercionLocal, errLocal := defaultenv.MergeCollations(left, right, + resultLocal, coercionLocal1, coercionLocal2, errLocal := defaultenv.MergeCollations(left, right, collations.CoercionOptions{ ConvertToSuperset: true, ConvertWithCoercion: true, }) + // for strings that do not coerce, replace with a no-op coercion function + if coercionLocal1 == nil { + coercionLocal1 = func(_, in []byte) ([]byte, error) { return in, nil } + } + if coercionLocal2 == nil { + coercionLocal2 = func(_, in []byte) ([]byte, error) { return in, nil } + } + remoteTest := tc.make(collA, collB) expr := remoteTest.Expression() query := fmt.Sprintf("SELECT CAST((%s) AS BINARY), COLLATION(%s), COERCIBILITY(%s)", expr, expr, expr) @@ -192,7 +214,7 @@ func TestComparisonSemantics(t *testing.T) { Expr: resultRemote.Rows[0][0], Collation: remoteCollation, Coercibility: collations.Coercibility(remoteCI), - }, resultLocal, coercionLocal) + }, resultLocal, coercionLocal1, coercionLocal2) } } }) diff --git a/go/mysql/collations/integration/main_test.go b/go/mysql/collations/integration/main_test.go index 030368a1ec8..0a1a3b1528a 100644 --- a/go/mysql/collations/integration/main_test.go +++ b/go/mysql/collations/integration/main_test.go @@ -38,7 +38,7 @@ var ( var waitmysql = flag.Bool("waitmysql", false, "") -var defaultenv = collations.Default() +var defaultenv = collations.Local() func mysqlconn(t *testing.T) *mysql.Conn { conn, err := mysql.Connect(context.Background(), &connParams) diff --git a/go/mysql/collations/tools/maketestdata/maketestdata.go b/go/mysql/collations/tools/maketestdata/maketestdata.go index 1665de68a20..8d4d7c27bf9 100644 --- a/go/mysql/collations/tools/maketestdata/maketestdata.go +++ b/go/mysql/collations/tools/maketestdata/maketestdata.go @@ -159,7 +159,7 @@ func colldump(collation string, input []byte) []byte { } func main() { - var defaults = collations.Default() + var defaults = collations.Local() var collationsForLanguage = make(map[testutil.Lang][]collations.Collation) var allcollations = defaults.AllCollations() for lang := range testutil.KnownLanguages { diff --git a/go/sqltypes/cached_size.go b/go/sqltypes/cached_size.go index f0eb3a6d920..3c84ebbe525 100644 --- a/go/sqltypes/cached_size.go +++ b/go/sqltypes/cached_size.go @@ -82,6 +82,8 @@ func (cached *Value) CachedSize(alloc bool) int64 { size += int64(32) } // field val []byte - size += hack.RuntimeAllocSize(int64(cap(cached.val))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.val))) + } return size } diff --git a/go/tools/sizegen/sizegen.go b/go/tools/sizegen/sizegen.go index 97ce6ffe1cf..61a90d73e43 100644 --- a/go/tools/sizegen/sizegen.go +++ b/go/tools/sizegen/sizegen.go @@ -101,14 +101,14 @@ func isPod(tt types.Type) bool { } } return true - + case *types.Named: + return isPod(tt.Underlying()) case *types.Basic: switch tt.Kind() { case types.String, types.UnsafePointer: return false } return true - default: return false } @@ -333,31 +333,45 @@ func (sizegen *sizegen) sizeStmtForType(fieldName *jen.Statement, field types.Ty case *types.Slice: elemT := node.Elem() elemSize := sizegen.sizes.Sizeof(elemT) + var cond *jen.Statement + var stmt []jen.Code + var flag codeFlag + + if alloc { + cond = jen.If(fieldName.Clone().Op("!=").Nil()) + fieldName = jen.Op("*").Add(fieldName) + stmt = append(stmt, jen.Id("size").Op("+=").Lit(hack.RuntimeAllocSize(8*3))) + } switch elemSize { case 0: return nil, 0 case 1: - return jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Cap(fieldName)))), 0 + stmt = append(stmt, jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Cap(fieldName))))) default: - stmt, flag := sizegen.sizeStmtForType(jen.Id("elem"), elemT, false) - return jen.BlockFunc(func(b *jen.Group) { - b.Add( - jen.Id("size"). - Op("+="). - Do(mallocsize(jen.Int64().Call(jen.Cap(fieldName)). - Op("*"). - Lit(sizegen.sizes.Sizeof(elemT))), - )) - - if stmt != nil { - b.Add(jen.For(jen.List(jen.Id("_"), jen.Id("elem")).Op(":=").Range().Add(fieldName)).Block(stmt)) - } - }), flag + var nested jen.Code + nested, flag = sizegen.sizeStmtForType(jen.Id("elem"), elemT, false) + + stmt = append(stmt, + jen.Id("size"). + Op("+="). + Do(mallocsize(jen.Int64().Call(jen.Cap(fieldName)). + Op("*"). + Lit(sizegen.sizes.Sizeof(elemT))), + )) + + if nested != nil { + stmt = append(stmt, jen.For(jen.List(jen.Id("_"), jen.Id("elem")).Op(":=").Range().Add(fieldName)).Block(nested)) + } } + if cond != nil { + return cond.Block(stmt...), flag + } + return jen.Block(stmt...), flag + case *types.Map: keySize, keyFlag := sizegen.sizeStmtForType(jen.Id("k"), node.Key(), false) valSize, valFlag := sizegen.sizeStmtForType(jen.Id("v"), node.Elem(), false) @@ -436,6 +450,11 @@ func (sizegen *sizegen) sizeStmtForType(fieldName *jen.Statement, field types.Ty return nil, 0 } return jen.Id("size").Op("+=").Do(mallocsize(jen.Lit(sizegen.sizes.Sizeof(node)))), 0 + + case *types.Signature: + // assume that function pointers do not allocate (although they might, if they're closures) + return nil, 0 + default: log.Printf("unhandled type: %T", node) return nil, 0 diff --git a/go/trace/trace.go b/go/trace/trace.go index 181d3964e57..6234ff349b2 100644 --- a/go/trace/trace.go +++ b/go/trace/trace.go @@ -20,16 +20,15 @@ limitations under the License. package trace import ( + "context" "flag" + "fmt" "io" "strings" - "context" - "google.golang.org/grpc" "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" ) @@ -67,8 +66,8 @@ func NewFromString(inCtx context.Context, parent, label string) (Span, context.C // AnnotateSQL annotates information about a sql query in the span. This is done in a way // so as to not leak personally identifying information (PII), or sensitive personal information (SPI) -func AnnotateSQL(span Span, sql string) { - span.Annotate("sql-statement-type", sqlparser.Preview(sql).String()) +func AnnotateSQL(span Span, strippedSQL fmt.Stringer) { + span.Annotate("sql-statement-type", strippedSQL.String()) } // FromContext returns the Span from a Context if present. The bool return diff --git a/go/trace/trace_test.go b/go/trace/trace_test.go index c5e30fa2333..fe1d36d7dc9 100644 --- a/go/trace/trace_test.go +++ b/go/trace/trace_test.go @@ -17,13 +17,12 @@ limitations under the License. package trace import ( + "context" "fmt" "io" "strings" "testing" - "context" - "google.golang.org/grpc" ) @@ -64,31 +63,6 @@ func TestRegisterService(t *testing.T) { } } -func TestProtectPII(t *testing.T) { - // set up fake tracer that we can assert on - fakeName := "test" - var tracer *fakeTracer - tracingBackendFactories[fakeName] = func(serviceName string) (tracingService, io.Closer, error) { - tracer = &fakeTracer{name: serviceName} - return tracer, tracer, nil - } - - tracingServer = &fakeName - - serviceName := "vtservice" - closer := StartTracing(serviceName) - _, ok := closer.(*fakeTracer) - if !ok { - t.Fatalf("did not get the expected tracer") - } - - span, _ := NewSpan(context.Background(), "span-name") - AnnotateSQL(span, "SELECT * FROM Tabble WHERE name = 'SECRET_INFORMATION'") - span.Finish() - - tracer.assertNoSpanWith(t, "SECRET_INFORMATION") -} - type fakeTracer struct { name string log []string diff --git a/go/vt/proto/query/cached_size.go b/go/vt/proto/query/cached_size.go index 77de4dde69c..735bd555e55 100644 --- a/go/vt/proto/query/cached_size.go +++ b/go/vt/proto/query/cached_size.go @@ -28,9 +28,13 @@ func (cached *BindVariable) CachedSize(alloc bool) int64 { size += int64(96) } // field unknownFields []byte - size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + } // field Value []byte - size += hack.RuntimeAllocSize(int64(cap(cached.Value))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.Value))) + } // field Values []*vitess.io/vitess/go/vt/proto/query.Value { size += hack.RuntimeAllocSize(int64(cap(cached.Values)) * int64(8)) @@ -49,7 +53,9 @@ func (cached *Field) CachedSize(alloc bool) int64 { size += int64(160) } // field unknownFields []byte - size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + } // field Name string size += hack.RuntimeAllocSize(int64(len(cached.Name))) // field Table string @@ -73,7 +79,9 @@ func (cached *QueryWarning) CachedSize(alloc bool) int64 { size += int64(64) } // field unknownFields []byte - size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + } // field Message string size += hack.RuntimeAllocSize(int64(len(cached.Message))) return size @@ -87,7 +95,9 @@ func (cached *Target) CachedSize(alloc bool) int64 { size += int64(96) } // field unknownFields []byte - size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + } // field Keyspace string size += hack.RuntimeAllocSize(int64(len(cached.Keyspace))) // field Shard string @@ -105,8 +115,12 @@ func (cached *Value) CachedSize(alloc bool) int64 { size += int64(80) } // field unknownFields []byte - size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + } // field Value []byte - size += hack.RuntimeAllocSize(int64(cap(cached.Value))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.Value))) + } return size } diff --git a/go/vt/proto/topodata/cached_size.go b/go/vt/proto/topodata/cached_size.go index 7358cc3fca3..92da50b703e 100644 --- a/go/vt/proto/topodata/cached_size.go +++ b/go/vt/proto/topodata/cached_size.go @@ -28,10 +28,16 @@ func (cached *KeyRange) CachedSize(alloc bool) int64 { size += int64(96) } // field unknownFields []byte - size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.unknownFields))) + } // field Start []byte - size += hack.RuntimeAllocSize(int64(cap(cached.Start))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.Start))) + } // field End []byte - size += hack.RuntimeAllocSize(int64(cap(cached.End))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.End))) + } return size } diff --git a/go/vt/servenv/buildinfo.go b/go/vt/servenv/buildinfo.go index a6128d98428..ae0f100bdd8 100644 --- a/go/vt/servenv/buildinfo.go +++ b/go/vt/servenv/buildinfo.go @@ -24,7 +24,6 @@ import ( "time" "vitess.io/vitess/go/stats" - "vitess.io/vitess/go/vt/sqlparser" ) var ( @@ -87,8 +86,8 @@ func (v *versionInfo) String() string { } func (v *versionInfo) MySQLVersion() string { - if *sqlparser.MySQLServerVersion != "" { - return *sqlparser.MySQLServerVersion + if *MySQLServerVersion != "" { + return *MySQLServerVersion } return "5.7.9-vitess-" + v.version } diff --git a/go/vt/servenv/version.go b/go/vt/servenv/version.go index b172dd23a41..b76ec11feec 100644 --- a/go/vt/servenv/version.go +++ b/go/vt/servenv/version.go @@ -1,3 +1,10 @@ package servenv +import "flag" + const versionName = "13.0.0-SNAPSHOT" + +// MySQLServerVersion is what Vitess will present as it's version during the connection handshake, +// and as the value to the @@version system variable. If nothing is provided, Vitess will report itself as +// a specific MySQL version with the vitess version appended to it +var MySQLServerVersion = flag.String("mysql_server_version", "", "MySQL server version to advertise.") diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 88ae4c3702d..86da5c07dae 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -1936,8 +1936,8 @@ type ( // CollateExpr represents dynamic collate operator. CollateExpr struct { - Expr Expr - Charset string + Expr Expr + Collation string } // FuncExpr represents a function call. diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 554efa19752..8e5f2ad8f01 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1256,7 +1256,7 @@ func EqualsRefOfCollateExpr(a, b *CollateExpr) bool { if a == nil || b == nil { return false } - return a.Charset == b.Charset && + return a.Collation == b.Collation && EqualsExpr(a.Expr, b.Expr) } diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 1ee01cb86f0..16e8e73f34b 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1220,7 +1220,7 @@ func (node *CurTimeFuncExpr) Format(buf *TrackedBuffer) { // Format formats the node. func (node *CollateExpr) Format(buf *TrackedBuffer) { - buf.astPrintf(node, "%v collate %s", node.Expr, node.Charset) + buf.astPrintf(node, "%v collate %s", node.Expr, node.Collation) } // Format formats the node. diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 0f2e50399c6..e08aae1cd66 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1627,7 +1627,7 @@ func (node *CurTimeFuncExpr) formatFast(buf *TrackedBuffer) { func (node *CollateExpr) formatFast(buf *TrackedBuffer) { buf.printExpr(node, node.Expr, true) buf.WriteString(" collate ") - buf.WriteString(node.Charset) + buf.WriteString(node.Collation) } // formatFast formats the node. diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index ff2007ca9b2..bf2542eefd7 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -490,8 +490,8 @@ func (cached *CollateExpr) CachedSize(alloc bool) int64 { if cc, ok := cached.Expr.(cachedObject); ok { size += cc.CachedSize(true) } - // field Charset string - size += hack.RuntimeAllocSize(int64(len(cached.Charset))) + // field Collation string + size += hack.RuntimeAllocSize(int64(len(cached.Collation))) return size } func (cached *ColumnDefinition) CachedSize(alloc bool) int64 { diff --git a/go/vt/sqlparser/parser.go b/go/vt/sqlparser/parser.go index b45ffdf44e3..5a317288383 100644 --- a/go/vt/sqlparser/parser.go +++ b/go/vt/sqlparser/parser.go @@ -25,15 +25,12 @@ import ( "sync" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) -// MySQLServerVersion is what Vitess will present as it's version during the connection handshake, -// and as the value to the @@version system variable. If nothing is provided, Vitess will report itself as -// a specific MySQL version with the vitess version appended to it -var MySQLServerVersion = flag.String("mysql_server_version", "", "MySQL server version to advertise.") var versionFlagSync sync.Once // parserPool is a pool for parser objects. @@ -110,8 +107,8 @@ func Parse2(sql string) (Statement, BindVars, error) { func checkParserVersionFlag() { if flag.Parsed() { versionFlagSync.Do(func() { - if *MySQLServerVersion != "" { - convVersion, err := convertMySQLVersionToCommentVersion(*MySQLServerVersion) + if *servenv.MySQLServerVersion != "" { + convVersion, err := convertMySQLVersionToCommentVersion(*servenv.MySQLServerVersion) if err != nil { log.Error(err) } else { diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index cc5e31642e6..36d47ac1fdd 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -12072,7 +12072,7 @@ yydefault: var yyLOCAL Expr //line sql.y:4121 { - yyLOCAL = &CollateExpr{Expr: yyDollar[1].exprUnion(), Charset: yyDollar[3].str} + yyLOCAL = &CollateExpr{Expr: yyDollar[1].exprUnion(), Collation: yyDollar[3].str} } yyVAL.union = yyLOCAL case 801: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 30b1128f8f9..a1d88c40064 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -4119,7 +4119,7 @@ function_call_keyword } | simple_expr COLLATE charset %prec UNARY { - $$ = &CollateExpr{Expr: $1, Charset: $3} + $$ = &CollateExpr{Expr: $1, Collation: $3} } | literal_or_null { diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 210a2db812c..28ecb36e91b 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -167,12 +167,16 @@ func (cached *Distinct) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(16) + size += int64(48) } // field Source vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Source.(cachedObject); ok { size += cc.CachedSize(true) } + // field ColCollations []vitess.io/vitess/go/mysql/collations.ID + { + size += hack.RuntimeAllocSize(int64(cap(cached.ColCollations)) * int64(2)) + } return size } func (cached *Filter) CachedSize(alloc bool) int64 { @@ -183,7 +187,7 @@ func (cached *Filter) CachedSize(alloc bool) int64 { if alloc { size += int64(48) } - // field ASTPred vitess.io/vitess/go/vt/vtgate/evalengine.Expr + // field Predicate vitess.io/vitess/go/vt/vtgate/evalengine.Expr if cc, ok := cached.Predicate.(cachedObject); ok { size += cc.CachedSize(true) } @@ -245,6 +249,32 @@ func (cached *GroupByParams) CachedSize(alloc bool) int64 { } return size } +func (cached *HashJoin) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(112) + } + // field Left vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Left.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Right vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Right.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Cols []int + { + size += hack.RuntimeAllocSize(int64(cap(cached.Cols)) * int64(8)) + } + // field ASTPred vitess.io/vitess/go/vt/sqlparser.Expr + if cc, ok := cached.ASTPred.(cachedObject); ok { + size += cc.CachedSize(true) + } + return size +} func (cached *Insert) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -289,7 +319,7 @@ func (cached *Join) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(80) + size += int64(96) } // field Left vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Left.(cachedObject); ok { @@ -317,6 +347,10 @@ func (cached *Join) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(k))) } } + // field ASTPred vitess.io/vitess/go/vt/sqlparser.Expr + if cc, ok := cached.ASTPred.(cachedObject); ok { + size += cc.CachedSize(true) + } return size } func (cached *Limit) CachedSize(alloc bool) int64 { @@ -388,9 +422,6 @@ func (cached *MemorySort) CachedSize(alloc bool) int64 { // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams { size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(36)) - for _, elem := range cached.OrderBy { - size += elem.CachedSize(false) - } } // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Input.(cachedObject); ok { @@ -418,9 +449,6 @@ func (cached *MergeSort) CachedSize(alloc bool) int64 { // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams { size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(36)) - for _, elem := range cached.OrderBy { - size += elem.CachedSize(false) - } } return size } @@ -448,16 +476,6 @@ func (cached *OnlineDDL) CachedSize(alloc bool) int64 { } return size } -func (cached *OrderByParams) CachedSize(alloc bool) int64 { - if cached == nil { - return int64(0) - } - size := int64(0) - if alloc { - size += int64(48) - } - return size -} func (cached *OrderedAggregate) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -658,9 +676,6 @@ func (cached *Route) CachedSize(alloc bool) int64 { // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams { size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(36)) - for _, elem := range cached.OrderBy { - size += elem.CachedSize(false) - } } // field SysTableTableSchema []vitess.io/vitess/go/vt/vtgate/evalengine.Expr { diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 1c45ac77db7..6ac82414057 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -234,7 +234,7 @@ func (d *Distinct) description() PrimitiveDescription { other = map[string]interface{}{} var colls []string for _, collation := range d.ColCollations { - coll := collations.Default().LookupByID(collation) + coll := collations.Local().LookupByID(collation) if coll == nil { colls = append(colls, "UNKNOWN") } else { diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index fda989531e2..77a3a3c1c22 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -56,7 +56,7 @@ type noopVCursor struct { } // ConnCollation implements VCursor -func (t *noopVCursor) ConnCollation() collations.Collation { +func (t *noopVCursor) ConnCollation() collations.ID { panic("implement me") } diff --git a/go/vt/vtgate/engine/filter.go b/go/vt/vtgate/engine/filter.go index 27480e17c4f..920b8a0bfa8 100644 --- a/go/vt/vtgate/engine/filter.go +++ b/go/vt/vtgate/engine/filter.go @@ -57,7 +57,7 @@ func (f *Filter) TryExecute(vcursor VCursor, bindVars map[string]*querypb.BindVa if err != nil { return nil, err } - env := evalengine.ExpressionEnv{ + env := &evalengine.ExpressionEnv{ BindVars: bindVars, } var rows [][]sqltypes.Value @@ -81,7 +81,7 @@ func (f *Filter) TryExecute(vcursor VCursor, bindVars map[string]*querypb.BindVa // TryStreamExecute satisfies the Primitive interface. func (f *Filter) TryStreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - env := evalengine.ExpressionEnv{ + env := &evalengine.ExpressionEnv{ BindVars: bindVars, } filter := func(results *sqltypes.Result) error { diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index e0f8f7530e9..c76f3cf6631 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -245,7 +245,7 @@ func (hj *HashJoin) description() PrimitiveDescription { "Predicate": sqlparser.String(hj.ASTPred), "ComparisonType": hj.ComparisonType.String(), } - coll := collations.Default().LookupByID(hj.Collation) + coll := collations.Local().LookupByID(hj.Collation) if coll != nil { other["Collation"] = coll.Name() } diff --git a/go/vt/vtgate/engine/memory_sort_test.go b/go/vt/vtgate/engine/memory_sort_test.go index 13ba37c041b..c6129032236 100644 --- a/go/vt/vtgate/engine/memory_sort_test.go +++ b/go/vt/vtgate/engine/memory_sort_test.go @@ -223,7 +223,7 @@ func TestMemorySortStreamExecuteCollation(t *testing.T) { )}, } - collationID, _ := collations.Default().LookupID("utf8mb4_hu_0900_ai_ci") + collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") ms := &MemorySort{ OrderBy: []OrderByParams{{ Col: 0, @@ -313,7 +313,7 @@ func TestMemorySortExecuteCollation(t *testing.T) { )}, } - collationID, _ := collations.Default().LookupID("utf8mb4_hu_0900_ai_ci") + collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") ms := &MemorySort{ OrderBy: []OrderByParams{{ Col: 0, diff --git a/go/vt/vtgate/engine/merge_sort_test.go b/go/vt/vtgate/engine/merge_sort_test.go index b9483afce9c..2155d1fec4b 100644 --- a/go/vt/vtgate/engine/merge_sort_test.go +++ b/go/vt/vtgate/engine/merge_sort_test.go @@ -176,7 +176,7 @@ func TestMergeSortCollation(t *testing.T) { ), }} - collationID, _ := collations.Default().LookupID("utf8mb4_hu_0900_ai_ci") + collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") orderBy := []OrderByParams{{ Col: 0, CollationID: collationID, diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index 9828cfc1cb7..c6f9ab42797 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -1076,7 +1076,7 @@ func TestOrderedAggregateCollate(t *testing.T) { )}, } - collationID, _ := collations.Default().LookupID("utf8mb4_0900_ai_ci") + collationID, _ := collations.Local().LookupID("utf8mb4_0900_ai_ci") oa := &OrderedAggregate{ Aggregates: []*AggregateParams{{ Opcode: AggregateCount, @@ -1118,7 +1118,7 @@ func TestOrderedAggregateCollateAS(t *testing.T) { )}, } - collationID, _ := collations.Default().LookupID("utf8mb4_0900_as_ci") + collationID, _ := collations.Local().LookupID("utf8mb4_0900_as_ci") oa := &OrderedAggregate{ Aggregates: []*AggregateParams{{ Opcode: AggregateCount, @@ -1162,7 +1162,7 @@ func TestOrderedAggregateCollateKS(t *testing.T) { )}, } - collationID, _ := collations.Default().LookupID("utf8mb4_ja_0900_as_cs_ks") + collationID, _ := collations.Local().LookupID("utf8mb4_ja_0900_as_cs_ks") oa := &OrderedAggregate{ Aggregates: []*AggregateParams{{ Opcode: AggregateCount, diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index bebe886e097..4c067b64f7d 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -94,7 +94,7 @@ type ( Session() SessionActions - ConnCollation() collations.Collation + ConnCollation() collations.ID ExecuteLock(rs *srvtopo.ResolvedShard, query *querypb.BoundQuery) (*sqltypes.Result, error) diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index d749e31139c..470137ce8c8 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -38,7 +38,7 @@ func (p *Projection) TryExecute(vcursor VCursor, bindVars map[string]*querypb.Bi return nil, err } - env := evalengine.ExpressionEnv{ + env := &evalengine.ExpressionEnv{ BindVars: bindVars, } @@ -71,7 +71,7 @@ func (p *Projection) TryStreamExecute(vcursor VCursor, bindVars map[string]*quer return err } - env := evalengine.ExpressionEnv{ + env := &evalengine.ExpressionEnv{ BindVars: bindVars, } @@ -111,7 +111,7 @@ func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*querypb.Bin } func (p *Projection) addFields(qr *sqltypes.Result, bindVars map[string]*querypb.BindVariable) error { - env := evalengine.ExpressionEnv{BindVars: bindVars} + env := &evalengine.ExpressionEnv{BindVars: bindVars} for i, col := range p.Cols { q, err := p.Exprs[i].Type(env) if err != nil { @@ -134,7 +134,7 @@ func (p *Projection) Inputs() []Primitive { func (p *Projection) description() PrimitiveDescription { var exprs []string for _, e := range p.Exprs { - exprs = append(exprs, e.String()) + exprs = append(exprs, evalengine.FormatExpr(e)) } return PrimitiveDescription{ OperatorType: "Projection", diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 72b686f676e..7b6f37374fa 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -154,7 +154,7 @@ func (obp OrderByParams) String() string { val += " ASC" } if obp.CollationID != collations.Unknown { - collation := collations.Default().LookupByID(obp.CollationID) + collation := collations.Local().LookupByID(obp.CollationID) val += " COLLATE " + collation.Name() } return val @@ -463,7 +463,7 @@ func (route *Route) routeInfoSchemaQuery(vcursor VCursor, bindVars map[string]*q return defaultRoute() } - env := evalengine.ExpressionEnv{ + env := &evalengine.ExpressionEnv{ BindVars: bindVars, Row: []sqltypes.Value{}, } @@ -804,7 +804,7 @@ func (route *Route) description() PrimitiveDescription { if idx != 0 { sysTabSchema += ", " } - sysTabSchema += tableSchema.String() + sysTabSchema += evalengine.FormatExpr(tableSchema) } sysTabSchema += "]" other["SysTableTableSchema"] = sysTabSchema @@ -812,7 +812,7 @@ func (route *Route) description() PrimitiveDescription { if len(route.SysTableTableName) != 0 { var sysTableName []string for k, v := range route.SysTableTableName { - sysTableName = append(sysTableName, k+":"+v.String()) + sysTableName = append(sysTableName, k+":"+evalengine.FormatExpr(v)) } sort.Strings(sysTableName) other["SysTableTableName"] = "[" + strings.Join(sysTableName, ", ") + "]" diff --git a/go/vt/vtgate/engine/route_test.go b/go/vt/vtgate/engine/route_test.go index 3bfe7ef80f0..54cf19b5cd4 100644 --- a/go/vt/vtgate/engine/route_test.go +++ b/go/vt/vtgate/engine/route_test.go @@ -83,7 +83,7 @@ func TestSelectInformationSchemaWithTableAndSchemaWithRoutedTables(t *testing.T) stringListToExprList := func(in []string) []evalengine.Expr { var schema []evalengine.Expr for _, s := range in { - schema = append(schema, evalengine.NewLiteralString([]byte(s), 0)) + schema = append(schema, evalengine.NewLiteralString([]byte(s), collations.TypedCollation{})) } return schema } @@ -98,7 +98,7 @@ func TestSelectInformationSchemaWithTableAndSchemaWithRoutedTables(t *testing.T) tests := []testCase{{ testName: "both schema and table predicates - routed table", tableSchema: []string{"schema"}, - tableName: map[string]evalengine.Expr{"table_name": evalengine.NewLiteralString([]byte("table"), 0)}, + tableName: map[string]evalengine.Expr{"table_name": evalengine.NewLiteralString([]byte("table"), collations.TypedCollation{})}, routed: true, expectedLog: []string{ "FindTable(`schema`.`table`)", @@ -107,7 +107,7 @@ func TestSelectInformationSchemaWithTableAndSchemaWithRoutedTables(t *testing.T) }, { testName: "both schema and table predicates - not routed", tableSchema: []string{"schema"}, - tableName: map[string]evalengine.Expr{"table_name": evalengine.NewLiteralString([]byte("table"), 0)}, + tableName: map[string]evalengine.Expr{"table_name": evalengine.NewLiteralString([]byte("table"), collations.TypedCollation{})}, routed: false, expectedLog: []string{ "FindTable(`schema`.`table`)", @@ -116,7 +116,7 @@ func TestSelectInformationSchemaWithTableAndSchemaWithRoutedTables(t *testing.T) }, { testName: "multiple schema and table predicates", tableSchema: []string{"schema", "schema", "schema"}, - tableName: map[string]evalengine.Expr{"t1": evalengine.NewLiteralString([]byte("table"), 0), "t2": evalengine.NewLiteralString([]byte("table"), 0), "t3": evalengine.NewLiteralString([]byte("table"), 0)}, + tableName: map[string]evalengine.Expr{"t1": evalengine.NewLiteralString([]byte("table"), collations.TypedCollation{}), "t2": evalengine.NewLiteralString([]byte("table"), collations.TypedCollation{}), "t3": evalengine.NewLiteralString([]byte("table"), collations.TypedCollation{})}, routed: false, expectedLog: []string{ "FindTable(`schema`.`table`)", @@ -126,7 +126,7 @@ func TestSelectInformationSchemaWithTableAndSchemaWithRoutedTables(t *testing.T) "ExecuteMultiShard schema.1: dummy_select {__replacevtschemaname: type:INT64 value:\"1\" t1: type:VARBINARY value:\"table\" t2: type:VARBINARY value:\"table\" t3: type:VARBINARY value:\"table\"} false false"}, }, { testName: "table name predicate - routed table", - tableName: map[string]evalengine.Expr{"table_name": evalengine.NewLiteralString([]byte("tableName"), 0)}, + tableName: map[string]evalengine.Expr{"table_name": evalengine.NewLiteralString([]byte("tableName"), collations.TypedCollation{})}, routed: true, expectedLog: []string{ "FindTable(tableName)", @@ -134,7 +134,7 @@ func TestSelectInformationSchemaWithTableAndSchemaWithRoutedTables(t *testing.T) "ExecuteMultiShard routedKeyspace.1: dummy_select {table_name: type:VARBINARY value:\"routedTable\"} false false"}, }, { testName: "table name predicate - not routed", - tableName: map[string]evalengine.Expr{"table_name": evalengine.NewLiteralString([]byte("tableName"), 0)}, + tableName: map[string]evalengine.Expr{"table_name": evalengine.NewLiteralString([]byte("tableName"), collations.TypedCollation{})}, routed: false, expectedLog: []string{ "FindTable(tableName)", @@ -993,7 +993,7 @@ func TestRouteSortCollation(t *testing.T) { "dummy_select_field", ) - collationID, _ := collations.Default().LookupID("utf8mb4_hu_0900_ai_ci") + collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") sel.OrderBy = []OrderByParams{{ Col: 0, diff --git a/go/vt/vtgate/engine/set.go b/go/vt/vtgate/engine/set.go index 430c0367a44..0c24504c410 100644 --- a/go/vt/vtgate/engine/set.go +++ b/go/vt/vtgate/engine/set.go @@ -52,7 +52,7 @@ type ( // SetOp is an interface that different type of set operations implements. SetOp interface { - Execute(vcursor VCursor, env evalengine.ExpressionEnv) error + Execute(vcursor VCursor, env *evalengine.ExpressionEnv) error VariableName() string } @@ -118,7 +118,7 @@ func (s *Set) TryExecute(vcursor VCursor, bindVars map[string]*querypb.BindVaria if len(input.Rows) != 1 { return nil, vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "should get a single row") } - env := evalengine.ExpressionEnv{ + env := &evalengine.ExpressionEnv{ BindVars: bindVars, Row: input.Rows[0], } @@ -171,7 +171,7 @@ func (u *UserDefinedVariable) MarshalJSON() ([]byte, error) { }{ Type: "UserDefinedVariable", Name: u.Name, - Expr: u.Expr.String(), + Expr: evalengine.FormatExpr(u.Expr), }) } @@ -182,7 +182,7 @@ func (u *UserDefinedVariable) VariableName() string { } // Execute implements the SetOp interface method. -func (u *UserDefinedVariable) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { +func (u *UserDefinedVariable) Execute(vcursor VCursor, env *evalengine.ExpressionEnv) error { value, err := u.Expr.Evaluate(env) if err != nil { return err @@ -210,7 +210,7 @@ func (svi *SysVarIgnore) VariableName() string { } // Execute implements the SetOp interface method. -func (svi *SysVarIgnore) Execute(VCursor, evalengine.ExpressionEnv) error { +func (svi *SysVarIgnore) Execute(VCursor, *evalengine.ExpressionEnv) error { log.Infof("Ignored inapplicable SET %v = %v", svi.Name, svi.Expr) return nil } @@ -235,7 +235,7 @@ func (svci *SysVarCheckAndIgnore) VariableName() string { } // Execute implements the SetOp interface method -func (svci *SysVarCheckAndIgnore) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { +func (svci *SysVarCheckAndIgnore) Execute(vcursor VCursor, env *evalengine.ExpressionEnv) error { rss, _, err := vcursor.ResolveDestinations(svci.Keyspace.Name, nil, []key.Destination{svci.TargetDestination}) if err != nil { return err @@ -279,7 +279,7 @@ func (svs *SysVarReservedConn) VariableName() string { } // Execute implements the SetOp interface method -func (svs *SysVarReservedConn) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { +func (svs *SysVarReservedConn) Execute(vcursor VCursor, env *evalengine.ExpressionEnv) error { // For those running on advanced vitess settings. if svs.TargetDestination != nil { rss, _, err := vcursor.ResolveDestinations(svs.Keyspace.Name, nil, []key.Destination{svs.TargetDestination}) @@ -313,7 +313,7 @@ func (svs *SysVarReservedConn) Execute(vcursor VCursor, env evalengine.Expressio return vterrors.Aggregate(errs) } -func (svs *SysVarReservedConn) execSetStatement(vcursor VCursor, rss []*srvtopo.ResolvedShard, env evalengine.ExpressionEnv) error { +func (svs *SysVarReservedConn) execSetStatement(vcursor VCursor, rss []*srvtopo.ResolvedShard, env *evalengine.ExpressionEnv) error { queries := make([]*querypb.BoundQuery, len(rss)) for i := 0; i < len(rss); i++ { queries[i] = &querypb.BoundQuery{ @@ -325,7 +325,7 @@ func (svs *SysVarReservedConn) execSetStatement(vcursor VCursor, rss []*srvtopo. return vterrors.Aggregate(errs) } -func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res evalengine.ExpressionEnv) (bool, error) { +func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res *evalengine.ExpressionEnv) (bool, error) { sysVarExprValidationQuery := fmt.Sprintf("select %s from dual where @@%s != %s", svs.Expr, svs.Name, svs.Expr) if svs.Name == "sql_mode" { sysVarExprValidationQuery = fmt.Sprintf("select @@%s orig, %s new", svs.Name, svs.Expr) @@ -412,12 +412,12 @@ func (svss *SysVarSetAware) MarshalJSON() ([]byte, error) { }{ Type: "SysVarAware", Name: svss.Name, - Expr: svss.Expr.String(), + Expr: evalengine.FormatExpr(svss.Expr), }) } // Execute implements the SetOp interface method -func (svss *SysVarSetAware) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { +func (svss *SysVarSetAware) Execute(vcursor VCursor, env *evalengine.ExpressionEnv) error { var err error switch svss.Name { case sysvars.Autocommit.Name: @@ -512,7 +512,7 @@ func (svss *SysVarSetAware) Execute(vcursor VCursor, env evalengine.ExpressionEn return err } -func (svss *SysVarSetAware) evalAsInt64(env evalengine.ExpressionEnv) (int64, error) { +func (svss *SysVarSetAware) evalAsInt64(env *evalengine.ExpressionEnv) (int64, error) { value, err := svss.Expr.Evaluate(env) if err != nil { return 0, err @@ -529,7 +529,7 @@ func (svss *SysVarSetAware) evalAsInt64(env evalengine.ExpressionEnv) (int64, er return intValue, nil } -func (svss *SysVarSetAware) evalAsFloat(env evalengine.ExpressionEnv) (float64, error) { +func (svss *SysVarSetAware) evalAsFloat(env *evalengine.ExpressionEnv) (float64, error) { value, err := svss.Expr.Evaluate(env) if err != nil { return 0, err @@ -543,7 +543,7 @@ func (svss *SysVarSetAware) evalAsFloat(env evalengine.ExpressionEnv) (float64, return floatValue, nil } -func (svss *SysVarSetAware) evalAsString(env evalengine.ExpressionEnv) (string, error) { +func (svss *SysVarSetAware) evalAsString(env *evalengine.ExpressionEnv) (string, error) { value, err := svss.Expr.Evaluate(env) if err != nil { return "", err @@ -556,7 +556,7 @@ func (svss *SysVarSetAware) evalAsString(env evalengine.ExpressionEnv) (string, return v.ToString(), nil } -func (svss *SysVarSetAware) setBoolSysVar(env evalengine.ExpressionEnv, setter func(bool) error) error { +func (svss *SysVarSetAware) setBoolSysVar(env *evalengine.ExpressionEnv, setter func(bool) error) error { value, err := svss.Expr.Evaluate(env) if err != nil { return err diff --git a/go/vt/vtgate/engine/set_test.go b/go/vt/vtgate/engine/set_test.go index 1464b0cee42..1d4a05aed7d 100644 --- a/go/vt/vtgate/engine/set_test.go +++ b/go/vt/vtgate/engine/set_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -98,7 +99,7 @@ func TestSetTable(t *testing.T) { setOps: []SetOp{ &UserDefinedVariable{ Name: "x", - Expr: evalengine.NewColumn(0, 0), + Expr: evalengine.NewColumn(0, collations.TypedCollation{}), }, }, qr: []*sqltypes.Result{sqltypes.MakeTestResult( diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index 17ca0840428..ed73e997520 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -243,7 +243,7 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, err ID: collationID, } } - collation := collations.Default().LookupByID(collationID) + collation := collations.Local().LookupByID(collationID) if collation == nil { return 0, UnsupportedCollationError{ ID: collationID, @@ -267,32 +267,37 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, err // HashCode is a type alias to the code easier to read type HashCode = uintptr -// NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same -// for two values that are considered equal by `NullsafeCompare`. -func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType querypb.Type) (HashCode, error) { - castValue, err := castTo(v, coerceType) - if err != nil { - return 0, err - } +func (er EvalResult) nullSafeHashcode() (HashCode, error) { switch { - case sqltypes.IsNull(castValue.typ): - return HashCode(math.MaxInt64), nil - case sqltypes.IsNumber(castValue.typ): - return numericalHashCode(castValue), nil - case sqltypes.IsText(castValue.typ): - coll := collations.Default().LookupByID(collation) + case sqltypes.IsNull(er.typ): + return HashCode(math.MaxUint64), nil + case sqltypes.IsNumber(er.typ): + return numericalHashCode(er), nil + case sqltypes.IsText(er.typ): + coll := collations.Local().LookupByID(er.collation.Collation) if coll == nil { return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "text type with an unknown/unsupported collation cannot be hashed") } - return coll.Hash(castValue.bytes, 0), nil - case sqltypes.IsDate(castValue.typ): - time, err := parseDate(castValue) + return coll.Hash(er.bytes, 0), nil + case sqltypes.IsDate(er.typ): + time, err := parseDate(er) if err != nil { return 0, err } return uintptr(time.UnixNano()), nil } - return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v", castValue.typ) + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v", er.typ) +} + +// NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same +// for two values that are considered equal by `NullsafeCompare`. +func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType querypb.Type) (HashCode, error) { + castValue, err := castTo(v, coerceType) + if err != nil { + return 0, err + } + castValue.collation.Collation = collation + return castValue.nullSafeHashcode() } func castTo(v sqltypes.Value, typ querypb.Type) (EvalResult, error) { @@ -306,22 +311,22 @@ func castTo(v sqltypes.Value, typ querypb.Type) (EvalResult, error) { if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) } - return EvalResult{fval: float64(ival), typ: sqltypes.Float64}, nil + return EvalResult{numval: math.Float64bits(float64(ival)), typ: sqltypes.Float64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(v.RawStr(), 10, 64) if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) } - return EvalResult{fval: float64(uval), typ: sqltypes.Float64}, nil + return EvalResult{numval: math.Float64bits(float64(uval)), typ: sqltypes.Float64}, nil case v.IsFloat() || v.Type() == sqltypes.Decimal: fval, err := strconv.ParseFloat(v.RawStr(), 64) if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) } - return EvalResult{fval: fval, typ: sqltypes.Float64}, nil + return EvalResult{numval: math.Float64bits(fval), typ: sqltypes.Float64}, nil case v.IsText() || v.IsBinary(): fval := parseStringToFloat(v.RawStr()) - return EvalResult{fval: fval, typ: sqltypes.Float64}, nil + return EvalResult{numval: math.Float64bits(fval), typ: sqltypes.Float64}, nil default: return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a float: %v", v) } @@ -333,13 +338,13 @@ func castTo(v sqltypes.Value, typ querypb.Type) (EvalResult, error) { if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) } - return EvalResult{ival: ival, typ: sqltypes.Int64}, nil + return EvalResult{numval: uint64(ival), typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(v.RawStr(), 10, 64) if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) } - return EvalResult{ival: int64(uval), typ: sqltypes.Int64}, nil + return EvalResult{numval: uval, typ: sqltypes.Int64}, nil default: return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a signed int: %v", v) } @@ -351,13 +356,13 @@ func castTo(v sqltypes.Value, typ querypb.Type) (EvalResult, error) { if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) } - return EvalResult{uval: uint64(uval), typ: sqltypes.Uint64}, nil + return EvalResult{numval: uint64(uval), typ: sqltypes.Uint64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(v.RawStr(), 10, 64) if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) } - return EvalResult{uval: uval, typ: sqltypes.Uint64}, nil + return EvalResult{numval: uval, typ: sqltypes.Uint64}, nil default: return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a unsigned int: %v", v) } @@ -462,16 +467,16 @@ func addNumeric(v1, v2 EvalResult) EvalResult { v1, v2 = makeNumericAndPrioritize(v1, v2) switch v1.typ { case sqltypes.Int64: - return intPlusInt(v1.ival, v2.ival) + return intPlusInt(v1.numval, v2.numval) case sqltypes.Uint64: switch v2.typ { case sqltypes.Int64: - return uintPlusInt(v1.uval, v2.ival) + return uintPlusInt(v1.numval, v2.numval) case sqltypes.Uint64: - return uintPlusUint(v1.uval, v2.uval) + return uintPlusUint(v1.numval, v2.numval) } case sqltypes.Float64: - return floatPlusAny(v1.fval, v2) + return floatPlusAny(math.Float64frombits(v1.numval), v2) } panic("unreachable") } @@ -480,16 +485,16 @@ func addNumericWithError(v1, v2 EvalResult) (EvalResult, error) { v1, v2 = makeNumericAndPrioritize(v1, v2) switch v1.typ { case sqltypes.Int64: - return intPlusIntWithError(v1.ival, v2.ival) + return intPlusIntWithError(v1.numval, v2.numval) case sqltypes.Uint64: switch v2.typ { case sqltypes.Int64: - return uintPlusIntWithError(v1.uval, v2.ival) + return uintPlusIntWithError(v1.numval, v2.numval) case sqltypes.Uint64: - return uintPlusUintWithError(v1.uval, v2.uval) + return uintPlusUintWithError(v1.numval, v2.numval) } case sqltypes.Float64, sqltypes.Decimal: - return floatPlusAny(v1.fval, v2), nil + return floatPlusAny(math.Float64frombits(v1.numval), v2), nil } return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) @@ -502,23 +507,23 @@ func subtractNumericWithError(i1, i2 EvalResult) (EvalResult, error) { case sqltypes.Int64: switch v2.typ { case sqltypes.Int64: - return intMinusIntWithError(v1.ival, v2.ival) + return intMinusIntWithError(v1.numval, v2.numval) case sqltypes.Uint64: - return intMinusUintWithError(v1.ival, v2.uval) + return intMinusUintWithError(v1.numval, v2.numval) case sqltypes.Float64: - return anyMinusFloat(v1, v2.fval), nil + return anyMinusFloat(v1, math.Float64frombits(v2.numval)), nil } case sqltypes.Uint64: switch v2.typ { case sqltypes.Int64: - return uintMinusIntWithError(v1.uval, v2.ival) + return uintMinusIntWithError(v1.numval, v2.numval) case sqltypes.Uint64: - return uintMinusUintWithError(v1.uval, v2.uval) + return uintMinusUintWithError(v1.numval, v2.numval) case sqltypes.Float64: - return anyMinusFloat(v1, v2.fval), nil + return anyMinusFloat(v1, math.Float64frombits(v2.numval)), nil } case sqltypes.Float64: - return floatMinusAny(v1.fval, v2), nil + return floatMinusAny(math.Float64frombits(v1.numval), v2), nil } return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) } @@ -527,16 +532,16 @@ func multiplyNumericWithError(v1, v2 EvalResult) (EvalResult, error) { v1, v2 = makeNumericAndPrioritize(v1, v2) switch v1.typ { case sqltypes.Int64: - return intTimesIntWithError(v1.ival, v2.ival) + return intTimesIntWithError(v1.numval, v2.numval) case sqltypes.Uint64: switch v2.typ { case sqltypes.Int64: - return uintTimesIntWithError(v1.uval, v2.ival) + return uintTimesIntWithError(v1.numval, v2.numval) case sqltypes.Uint64: - return uintTimesUintWithError(v1.uval, v2.uval) + return uintTimesUintWithError(v1.numval, v2.numval) } case sqltypes.Float64: - return floatTimesAny(v1.fval, v2), nil + return floatTimesAny(math.Float64frombits(v1.numval), v2), nil } return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) @@ -547,13 +552,13 @@ func divideNumericWithError(i1, i2 EvalResult) (EvalResult, error) { v2 := makeNumeric(i2) switch v1.typ { case sqltypes.Int64: - return floatDivideAnyWithError(float64(v1.ival), v2) + return floatDivideAnyWithError(float64(int64(v1.numval)), v2) case sqltypes.Uint64: - return floatDivideAnyWithError(float64(v1.uval), v2) + return floatDivideAnyWithError(float64(v1.numval), v2) case sqltypes.Float64: - return floatDivideAnyWithError(v1.fval, v2) + return floatDivideAnyWithError(math.Float64frombits(v1.numval), v2) } return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) } @@ -578,15 +583,15 @@ func makeNumericAndPrioritize(i1, i2 EvalResult) (EvalResult, EvalResult) { func makeFloat(v EvalResult) EvalResult { if sqltypes.IsIntegral(v.typ) { - return EvalResult{fval: float64(v.ival), typ: sqltypes.Float64} + return EvalResult{numval: math.Float64bits(float64(int64(v.numval))), typ: sqltypes.Float64} } if sqltypes.IsFloat(v.typ) { return v } if fval, err := strconv.ParseFloat(string(v.bytes), 64); err == nil { - return EvalResult{fval: fval, typ: sqltypes.Float64} + return EvalResult{numval: math.Float64bits(fval), typ: sqltypes.Float64} } - return EvalResult{ival: 0, typ: sqltypes.Int64} + return EvalResult{numval: 0, typ: sqltypes.Int64} } func makeNumeric(v EvalResult) EvalResult { @@ -594,15 +599,16 @@ func makeNumeric(v EvalResult) EvalResult { return v } if ival, err := strconv.ParseInt(string(v.bytes), 10, 64); err == nil { - return EvalResult{ival: ival, typ: sqltypes.Int64} + return EvalResult{numval: uint64(ival), typ: sqltypes.Int64} } if fval, err := strconv.ParseFloat(string(v.bytes), 64); err == nil { - return EvalResult{fval: fval, typ: sqltypes.Float64} + return EvalResult{numval: math.Float64bits(fval), typ: sqltypes.Float64} } - return EvalResult{ival: 0, typ: sqltypes.Int64} + return EvalResult{numval: 0, typ: sqltypes.Int64} } -func intPlusInt(v1, v2 int64) EvalResult { +func intPlusInt(v1u, v2u uint64) EvalResult { + v1, v2 := int64(v1u), int64(v2u) result := v1 + v2 if v1 > 0 && v2 > 0 && result < 0 { goto overflow @@ -610,71 +616,78 @@ func intPlusInt(v1, v2 int64) EvalResult { if v1 < 0 && v2 < 0 && result > 0 { goto overflow } - return EvalResult{typ: sqltypes.Int64, ival: result} + return EvalResult{typ: sqltypes.Int64, numval: uint64(result)} overflow: - return EvalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} + return EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(float64(v1) + float64(v2))} } -func intPlusIntWithError(v1, v2 int64) (EvalResult, error) { +func intPlusIntWithError(v1u, v2u uint64) (EvalResult, error) { + v1, v2 := int64(v1u), int64(v2u) result := v1 + v2 if (result > v1) != (v2 > 0) { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT", "+") } - return EvalResult{typ: sqltypes.Int64, ival: result}, nil + return EvalResult{typ: sqltypes.Int64, numval: uint64(result)}, nil } -func intMinusIntWithError(v1, v2 int64) (EvalResult, error) { +func intMinusIntWithError(v1u, v2u uint64) (EvalResult, error) { + v1, v2 := int64(v1u), int64(v2u) result := v1 - v2 if (result < v1) != (v2 > 0) { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT", "-") } - return EvalResult{typ: sqltypes.Int64, ival: result}, nil + return EvalResult{typ: sqltypes.Int64, numval: uint64(result)}, nil } -func intTimesIntWithError(v1, v2 int64) (EvalResult, error) { +func intTimesIntWithError(v1u, v2u uint64) (EvalResult, error) { + v1, v2 := int64(v1u), int64(v2u) result := v1 * v2 if v1 != 0 && result/v1 != v2 { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT", "*") } - return EvalResult{typ: sqltypes.Int64, ival: result}, nil + return EvalResult{typ: sqltypes.Int64, numval: uint64(result)}, nil } -func intMinusUintWithError(v1 int64, v2 uint64) (EvalResult, error) { +func intMinusUintWithError(v1u uint64, v2 uint64) (EvalResult, error) { + v1 := int64(v1u) if v1 < 0 || v1 < int64(v2) { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") } - return uintMinusUintWithError(uint64(v1), v2) + return uintMinusUintWithError(v1u, v2) } -func uintPlusInt(v1 uint64, v2 int64) EvalResult { - return uintPlusUint(v1, uint64(v2)) +func uintPlusInt(v1 uint64, v2 uint64) EvalResult { + return uintPlusUint(v1, v2) } -func uintPlusIntWithError(v1 uint64, v2 int64) (EvalResult, error) { +func uintPlusIntWithError(v1 uint64, v2u uint64) (EvalResult, error) { + v2 := int64(v2u) result := v1 + uint64(v2) if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") } // convert to int -> uint is because for numeric operators (such as + or -) // where one of the operands is an unsigned integer, the result is unsigned by default. - return EvalResult{typ: sqltypes.Uint64, uval: result}, nil + return EvalResult{typ: sqltypes.Uint64, numval: result}, nil } -func uintMinusIntWithError(v1 uint64, v2 int64) (EvalResult, error) { +func uintMinusIntWithError(v1 uint64, v2u uint64) (EvalResult, error) { + v2 := int64(v2u) if int64(v1) < v2 && v2 > 0 { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") } // uint - (- int) = uint + int if v2 < 0 { - return uintPlusIntWithError(v1, -v2) + return uintPlusIntWithError(v1, uint64(-v2)) } return uintMinusUintWithError(v1, uint64(v2)) } -func uintTimesIntWithError(v1 uint64, v2 int64) (EvalResult, error) { +func uintTimesIntWithError(v1 uint64, v2u uint64) (EvalResult, error) { + v2 := int64(v2u) if v2 < 0 || int64(v1) < 0 { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") } @@ -684,9 +697,9 @@ func uintTimesIntWithError(v1 uint64, v2 int64) (EvalResult, error) { func uintPlusUint(v1, v2 uint64) EvalResult { result := v1 + v2 if result < v2 { - return EvalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} + return EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(float64(v1) + float64(v2))} } - return EvalResult{typ: sqltypes.Uint64, uval: result} + return EvalResult{typ: sqltypes.Uint64, numval: result} } func uintPlusUintWithError(v1, v2 uint64) (EvalResult, error) { @@ -694,7 +707,7 @@ func uintPlusUintWithError(v1, v2 uint64) (EvalResult, error) { if result < v1 || result < v2 { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") } - return EvalResult{typ: sqltypes.Uint64, uval: result}, nil + return EvalResult{typ: sqltypes.Uint64, numval: result}, nil } func uintMinusUintWithError(v1, v2 uint64) (EvalResult, error) { @@ -703,7 +716,7 @@ func uintMinusUintWithError(v1, v2 uint64) (EvalResult, error) { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") } - return EvalResult{typ: sqltypes.Uint64, uval: result}, nil + return EvalResult{typ: sqltypes.Uint64, numval: result}, nil } func uintTimesUintWithError(v1, v2 uint64) (EvalResult, error) { @@ -711,65 +724,47 @@ func uintTimesUintWithError(v1, v2 uint64) (EvalResult, error) { if result < v2 || result < v1 { return EvalResult{}, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") } - return EvalResult{typ: sqltypes.Uint64, uval: result}, nil + return EvalResult{typ: sqltypes.Uint64, numval: result}, nil } -func floatPlusAny(v1 float64, v2 EvalResult) EvalResult { +func coerceToFloat(v2 EvalResult) float64 { switch v2.typ { case sqltypes.Int64: - v2.fval = float64(v2.ival) + return float64(int64(v2.numval)) case sqltypes.Uint64: - v2.fval = float64(v2.uval) + return float64(v2.numval) + default: + return math.Float64frombits(v2.numval) } - return EvalResult{typ: sqltypes.Float64, fval: v1 + v2.fval} +} + +func floatPlusAny(v1 float64, v2 EvalResult) EvalResult { + return EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(v1 + coerceToFloat(v2))} } func floatMinusAny(v1 float64, v2 EvalResult) EvalResult { - switch v2.typ { - case sqltypes.Int64: - v2.fval = float64(v2.ival) - case sqltypes.Uint64: - v2.fval = float64(v2.uval) - } - return EvalResult{typ: sqltypes.Float64, fval: v1 - v2.fval} + return EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(v1 - coerceToFloat(v2))} } func floatTimesAny(v1 float64, v2 EvalResult) EvalResult { - switch v2.typ { - case sqltypes.Int64: - v2.fval = float64(v2.ival) - case sqltypes.Uint64: - v2.fval = float64(v2.uval) - } - return EvalResult{typ: sqltypes.Float64, fval: v1 * v2.fval} + return EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(v1 * coerceToFloat(v2))} } func floatDivideAnyWithError(v1 float64, v2 EvalResult) (EvalResult, error) { - switch v2.typ { - case sqltypes.Int64: - v2.fval = float64(v2.ival) - case sqltypes.Uint64: - v2.fval = float64(v2.uval) - } - result := v1 / v2.fval - divisorLessThanOne := v2.fval < 1 - resultMismatch := v2.fval*result != v1 + v2f := coerceToFloat(v2) + result := v1 / v2f + divisorLessThanOne := v2f < 1 + resultMismatch := v2f*result != v1 if divisorLessThanOne && resultMismatch { - return EvalResult{}, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in %v / %v", "BIGINT", v1, v2.fval) + return EvalResult{}, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in %v / %v", "BIGINT", v1, v2f) } - return EvalResult{typ: sqltypes.Float64, fval: v1 / v2.fval}, nil + return EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(result)}, nil } func anyMinusFloat(v1 EvalResult, v2 float64) EvalResult { - switch v1.typ { - case sqltypes.Int64: - v1.fval = float64(v1.ival) - case sqltypes.Uint64: - v1.fval = float64(v1.uval) - } - return EvalResult{typ: sqltypes.Float64, fval: v1.fval - v2} + return EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(coerceToFloat(v1) - v2)} } func parseStringToFloat(str string) float64 { diff --git a/go/vt/vtgate/evalengine/arithmetic_test.go b/go/vt/vtgate/evalengine/arithmetic_test.go index 679dd61c2d4..f2e06a0c8d7 100644 --- a/go/vt/vtgate/evalengine/arithmetic_test.go +++ b/go/vt/vtgate/evalengine/arithmetic_test.go @@ -511,7 +511,7 @@ func TestNullSafeAdd(t *testing.T) { } func TestNullsafeCompare(t *testing.T) { - collation := collations.Default().LookupByName("utf8mb4_general_ci").ID() + collation := collations.Local().LookupByName("utf8mb4_general_ci").ID() tcases := []struct { v1, v2 sqltypes.Value out int @@ -613,7 +613,7 @@ func TestNullsafeCompare(t *testing.T) { } func getCollationID(collation string) collations.ID { - id, _ := collations.Default().LookupID(collation) + id, _ := collations.Local().LookupID(collation) return id } @@ -1015,21 +1015,21 @@ func TestNewNumeric(t *testing.T) { err error }{{ v: NewInt64(1), - out: EvalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, numval: 1}, }, { v: NewUint64(1), - out: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + out: EvalResult{typ: querypb.Type_UINT64, numval: 1}, }, { v: NewFloat64(1), - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(1.0)}, }, { // For non-number type, Int64 is the default. v: TestValue(querypb.Type_VARCHAR, "1"), - out: EvalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, numval: 1}, }, { // If Int64 can't work, we use Float64. v: TestValue(querypb.Type_VARCHAR, "1.2"), - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(1.2)}, }, { // Only valid Int64 allowed if type is Int64. v: TestValue(querypb.Type_INT64, "1.2"), @@ -1044,7 +1044,7 @@ func TestNewNumeric(t *testing.T) { err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseFloat: parsing \"abcd\": invalid syntax"), }, { v: TestValue(querypb.Type_VARCHAR, "abcd"), - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 0}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(0)}, }} for _, tcase := range tcases { got, err := newEvalResult(tcase.v) @@ -1066,21 +1066,21 @@ func TestNewIntegralNumeric(t *testing.T) { err error }{{ v: NewInt64(1), - out: EvalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, numval: 1}, }, { v: NewUint64(1), - out: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + out: EvalResult{typ: querypb.Type_UINT64, numval: 1}, }, { v: NewFloat64(1), - out: EvalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, numval: 1}, }, { // For non-number type, Int64 is the default. v: TestValue(querypb.Type_VARCHAR, "1"), - out: EvalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, numval: 1}, }, { // If Int64 can't work, we use Uint64. v: TestValue(querypb.Type_VARCHAR, "18446744073709551615"), - out: EvalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, + out: EvalResult{typ: querypb.Type_UINT64, numval: 18446744073709551615}, }, { // Only valid Int64 allowed if type is Int64. v: TestValue(querypb.Type_INT64, "1.2"), @@ -1112,48 +1112,48 @@ func TestAddNumeric(t *testing.T) { out EvalResult err error }{{ - v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, - v2: EvalResult{typ: querypb.Type_INT64, ival: 2}, - out: EvalResult{typ: querypb.Type_INT64, ival: 3}, + v1: EvalResult{typ: querypb.Type_INT64, numval: 1}, + v2: EvalResult{typ: querypb.Type_INT64, numval: 2}, + out: EvalResult{typ: querypb.Type_INT64, numval: 3}, }, { - v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, - v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, - out: EvalResult{typ: querypb.Type_UINT64, uval: 3}, + v1: EvalResult{typ: querypb.Type_INT64, numval: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, numval: 2}, + out: EvalResult{typ: querypb.Type_UINT64, numval: 3}, }, { - v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, - v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 3}, + v1: EvalResult{typ: querypb.Type_INT64, numval: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(2)}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(3)}, }, { - v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, - out: EvalResult{typ: querypb.Type_UINT64, uval: 3}, + v1: EvalResult{typ: querypb.Type_UINT64, numval: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, numval: 2}, + out: EvalResult{typ: querypb.Type_UINT64, numval: 3}, }, { - v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 3}, + v1: EvalResult{typ: querypb.Type_UINT64, numval: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(2)}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(3)}, }, { - v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, - v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 3}, + v1: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(1)}, + v2: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(2)}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(3)}, }, { // Int64 overflow. - v1: EvalResult{typ: querypb.Type_INT64, ival: 9223372036854775807}, - v2: EvalResult{typ: querypb.Type_INT64, ival: 2}, - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 9223372036854775809}, + v1: EvalResult{typ: querypb.Type_INT64, numval: 9223372036854775807}, + v2: EvalResult{typ: querypb.Type_INT64, numval: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(9223372036854775809)}, }, { // Int64 underflow. - v1: EvalResult{typ: querypb.Type_INT64, ival: -9223372036854775807}, - v2: EvalResult{typ: querypb.Type_INT64, ival: -2}, - out: EvalResult{typ: querypb.Type_FLOAT64, fval: -9223372036854775809}, + v1: EvalResult{typ: querypb.Type_INT64, numval: castuint64(-9223372036854775807)}, + v2: EvalResult{typ: querypb.Type_INT64, numval: castuint64(-2)}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(-9223372036854775809.0)}, }, { - v1: EvalResult{typ: querypb.Type_INT64, ival: -1}, - v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, + v1: EvalResult{typ: querypb.Type_INT64, numval: castuint64(-1)}, + v2: EvalResult{typ: querypb.Type_UINT64, numval: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(18446744073709551617.0)}, }, { // Uint64 overflow. - v1: EvalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, - v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, - out: EvalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, + v1: EvalResult{typ: querypb.Type_UINT64, numval: 18446744073709551615}, + v2: EvalResult{typ: querypb.Type_UINT64, numval: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(18446744073709551617.0)}, }} for _, tcase := range tcases { got := addNumeric(tcase.v1, tcase.v2) @@ -1162,10 +1162,14 @@ func TestAddNumeric(t *testing.T) { } } +func castuint64(i int64) uint64 { + return uint64(i) +} + func TestPrioritize(t *testing.T) { - ival := EvalResult{typ: querypb.Type_INT64, ival: -1} - uval := EvalResult{typ: querypb.Type_UINT64, uval: 1} - fval := EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2} + ival := EvalResult{typ: querypb.Type_INT64, numval: castuint64(-1)} + uval := EvalResult{typ: querypb.Type_UINT64, numval: 1} + fval := EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(1.2)} textIntval := EvalResult{typ: querypb.Type_VARBINARY, bytes: []byte("-1")} textFloatval := EvalResult{typ: querypb.Type_VARBINARY, bytes: []byte("1.2")} @@ -1230,52 +1234,52 @@ func TestToSqlValue(t *testing.T) { err error }{{ typ: querypb.Type_INT64, - v: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v: EvalResult{typ: querypb.Type_INT64, numval: 1}, out: NewInt64(1), }, { typ: querypb.Type_INT64, - v: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v: EvalResult{typ: querypb.Type_UINT64, numval: 1}, out: NewInt64(1), }, { typ: querypb.Type_INT64, - v: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + v: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(1.2e-16)}, out: NewInt64(0), }, { typ: querypb.Type_UINT64, - v: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v: EvalResult{typ: querypb.Type_INT64, numval: 1}, out: NewUint64(1), }, { typ: querypb.Type_UINT64, - v: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v: EvalResult{typ: querypb.Type_UINT64, numval: 1}, out: NewUint64(1), }, { typ: querypb.Type_UINT64, - v: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + v: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(1.2e-16)}, out: NewUint64(0), }, { typ: querypb.Type_FLOAT64, - v: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v: EvalResult{typ: querypb.Type_INT64, numval: 1}, out: TestValue(querypb.Type_FLOAT64, "1"), }, { typ: querypb.Type_FLOAT64, - v: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v: EvalResult{typ: querypb.Type_UINT64, numval: 1}, out: TestValue(querypb.Type_FLOAT64, "1"), }, { typ: querypb.Type_FLOAT64, - v: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + v: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(1.2e-16)}, out: TestValue(querypb.Type_FLOAT64, "1.2e-16"), }, { typ: querypb.Type_DECIMAL, - v: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v: EvalResult{typ: querypb.Type_INT64, numval: 1}, out: TestValue(querypb.Type_DECIMAL, "1"), }, { typ: querypb.Type_DECIMAL, - v: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v: EvalResult{typ: querypb.Type_UINT64, numval: 1}, out: TestValue(querypb.Type_DECIMAL, "1"), }, { // For float, we should not use scientific notation. typ: querypb.Type_DECIMAL, - v: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + v: EvalResult{typ: querypb.Type_FLOAT64, numval: math.Float64bits(1.2e-16)}, out: TestValue(querypb.Type_DECIMAL, "0.00000000000000012"), }} for _, tcase := range tcases { @@ -1289,17 +1293,17 @@ func TestToSqlValue(t *testing.T) { func TestCompareNumeric(t *testing.T) { values := []EvalResult{ - {typ: querypb.Type_INT64, ival: 1}, - {typ: querypb.Type_INT64, ival: -1}, - {typ: querypb.Type_INT64, ival: 0}, - {typ: querypb.Type_INT64, ival: 2}, - {typ: querypb.Type_UINT64, uval: 1}, - {typ: querypb.Type_UINT64, uval: 0}, - {typ: querypb.Type_UINT64, uval: 2}, - {typ: querypb.Type_FLOAT64, fval: 1}, - {typ: querypb.Type_FLOAT64, fval: -1}, - {typ: querypb.Type_FLOAT64, fval: 0}, - {typ: querypb.Type_FLOAT64, fval: 2}, + {typ: querypb.Type_INT64, numval: 1}, + {typ: querypb.Type_INT64, numval: castuint64(-1)}, + {typ: querypb.Type_INT64, numval: 0}, + {typ: querypb.Type_INT64, numval: 2}, + {typ: querypb.Type_UINT64, numval: 1}, + {typ: querypb.Type_UINT64, numval: 0}, + {typ: querypb.Type_UINT64, numval: 2}, + {typ: querypb.Type_FLOAT64, numval: math.Float64bits(1.0)}, + {typ: querypb.Type_FLOAT64, numval: math.Float64bits(-1.0)}, + {typ: querypb.Type_FLOAT64, numval: math.Float64bits(0.0)}, + {typ: querypb.Type_FLOAT64, numval: math.Float64bits(2.0)}, } // cmpResults is a 2D array with the comparison expectations if we compare all values with each other @@ -1618,8 +1622,8 @@ func BenchmarkAddGoInterface(b *testing.B) { } func BenchmarkAddGoNonInterface(b *testing.B) { - v1 := EvalResult{typ: querypb.Type_INT64, ival: 1} - v2 := EvalResult{typ: querypb.Type_INT64, ival: 12} + v1 := EvalResult{typ: querypb.Type_INT64, numval: 1} + v2 := EvalResult{typ: querypb.Type_INT64, numval: 12} for i := 0; i < b.N; i++ { if v1.typ != querypb.Type_INT64 { b.Error("type assertion failed") @@ -1627,7 +1631,7 @@ func BenchmarkAddGoNonInterface(b *testing.B) { if v2.typ != querypb.Type_INT64 { b.Error("type assertion failed") } - v1 = EvalResult{typ: querypb.Type_INT64, ival: v1.ival + v2.ival} + v1 = EvalResult{typ: querypb.Type_INT64, numval: uint64(int64(v1.numval) + int64(v2.numval))} } } diff --git a/go/vt/vtgate/evalengine/binary.go b/go/vt/vtgate/evalengine/binary.go index 8637f7b302f..6316055016a 100644 --- a/go/vt/vtgate/evalengine/binary.go +++ b/go/vt/vtgate/evalengine/binary.go @@ -17,6 +17,7 @@ limitations under the License. package evalengine import ( + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -46,8 +47,12 @@ var _ BinaryOp = (*Subtraction)(nil) var _ BinaryOp = (*Multiplication)(nil) var _ BinaryOp = (*Division)(nil) +func (b *BinaryExpr) Collation() collations.TypedCollation { + return collationNumeric +} + // Evaluate implements the Expr interface -func (b *BinaryExpr) Evaluate(env ExpressionEnv) (EvalResult, error) { +func (b *BinaryExpr) Evaluate(env *ExpressionEnv) (EvalResult, error) { lVal, err := b.Left.Evaluate(env) if err != nil { return EvalResult{}, err @@ -60,7 +65,7 @@ func (b *BinaryExpr) Evaluate(env ExpressionEnv) (EvalResult, error) { } // Type implements the Expr interface -func (b *BinaryExpr) Type(env ExpressionEnv) (querypb.Type, error) { +func (b *BinaryExpr) Type(env *ExpressionEnv) (querypb.Type, error) { ltype, err := b.Left.Type(env) if err != nil { return 0, err @@ -73,11 +78,6 @@ func (b *BinaryExpr) Type(env ExpressionEnv) (querypb.Type, error) { return b.Op.Type(typ), nil } -// String implements the Expr interface -func (b *BinaryExpr) String() string { - return b.Left.String() + " " + b.Op.String() + " " + b.Right.String() -} - // Evaluate implements the BinaryOp interface func (a *Addition) Evaluate(left, right EvalResult) (EvalResult, error) { return addNumericWithError(left, right) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index ae7469a0d15..af872b84816 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -57,6 +57,20 @@ func (cached *BindVariable) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.Key))) return size } +func (cached *CollateExpr) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field Expr vitess.io/vitess/go/vt/vtgate/evalengine.Expr + if cc, ok := cached.Expr.(cachedObject); ok { + size += cc.CachedSize(true) + } + return size +} func (cached *Column) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -73,7 +87,7 @@ func (cached *ComparisonExpr) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(80) } // field Op vitess.io/vitess/go/vt/vtgate/evalengine.ComparisonOp if cc, ok := cached.Op.(cachedObject); ok { @@ -89,16 +103,62 @@ func (cached *ComparisonExpr) CachedSize(alloc bool) int64 { } return size } +func (cached *EqualOp) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field Operator string + size += hack.RuntimeAllocSize(int64(len(cached.Operator))) + return size +} func (cached *EvalResult) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) } size := int64(0) if alloc { - size += int64(64) + size += int64(48) } // field bytes []byte - size += hack.RuntimeAllocSize(int64(cap(cached.bytes))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.bytes))) + } + // field tuple *[]vitess.io/vitess/go/vt/vtgate/evalengine.EvalResult + if cached.tuple != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.tuple)) * int64(48)) + for _, elem := range *cached.tuple { + size += elem.CachedSize(false) + } + } + return size +} +func (cached *InOp) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(8) + } + return size +} +func (cached *LikeOp) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field Match vitess.io/vitess/go/mysql/collations.WildcardPattern + if cc, ok := cached.Match.(cachedObject); ok { + size += cc.CachedSize(true) + } return size } func (cached *Literal) CachedSize(alloc bool) int64 { @@ -107,9 +167,19 @@ func (cached *Literal) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(64) + size += int64(48) } // field Val vitess.io/vitess/go/vt/vtgate/evalengine.EvalResult size += cached.Val.CachedSize(false) return size } +func (cached *RegexpOp) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(8) + } + return size +} diff --git a/go/vt/vtgate/evalengine/casting.go b/go/vt/vtgate/evalengine/casting.go index 0ea9bf80bc7..97add6ba2e9 100644 --- a/go/vt/vtgate/evalengine/casting.go +++ b/go/vt/vtgate/evalengine/casting.go @@ -40,9 +40,9 @@ func (e *EvalResult) ToBooleanStrict() (bool, error) { switch e.typ { case sqltypes.Int8, sqltypes.Int16, sqltypes.Int32, sqltypes.Int64: - return intToBool(int(e.ival)) + return intToBool(int(e.numval)) case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64: - return intToBool(int(e.uval)) + return intToBool(int(e.numval)) case sqltypes.VarBinary: lower := strings.ToLower(string(e.bytes)) switch lower { diff --git a/go/vt/vtgate/evalengine/casting_test.go b/go/vt/vtgate/evalengine/casting_test.go index befd4fadf34..12a241c3aed 100644 --- a/go/vt/vtgate/evalengine/casting_test.go +++ b/go/vt/vtgate/evalengine/casting_test.go @@ -18,6 +18,7 @@ package evalengine import ( "fmt" + "math" "testing" "github.com/stretchr/testify/require" @@ -27,36 +28,36 @@ import ( func TestEvalResultToBooleanStrict(t *testing.T) { trueValues := []*EvalResult{{ - typ: sqltypes.Int64, - ival: 1, + typ: sqltypes.Int64, + numval: 1, }, { - typ: sqltypes.Uint64, - uval: 1, + typ: sqltypes.Uint64, + numval: 1, }, { - typ: sqltypes.Int8, - ival: 1, + typ: sqltypes.Int8, + numval: 1, }} falseValues := []*EvalResult{{ - typ: sqltypes.Int64, - ival: 0, + typ: sqltypes.Int64, + numval: 0, }, { - typ: sqltypes.Uint64, - uval: 0, + typ: sqltypes.Uint64, + numval: 0, }, { - typ: sqltypes.Int8, - uval: 0, + typ: sqltypes.Int8, + numval: 0, }} invalid := []*EvalResult{{ typ: sqltypes.VarChar, bytes: []byte("foobar"), }, { - typ: sqltypes.Float32, - fval: 1, + typ: sqltypes.Float32, + numval: math.Float64bits(1.0), }, { - typ: sqltypes.Int64, - ival: 12, + typ: sqltypes.Int64, + numval: 12, }} for _, res := range trueValues { diff --git a/go/vt/vtgate/evalengine/comparisons.go b/go/vt/vtgate/evalengine/comparisons.go index 7be6653fe15..0ca57b0e636 100644 --- a/go/vt/vtgate/evalengine/comparisons.go +++ b/go/vt/vtgate/evalengine/comparisons.go @@ -17,6 +17,7 @@ limitations under the License. package evalengine import ( + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -33,54 +34,64 @@ type ( } ComparisonExpr struct { - Op ComparisonOp - Left, Right Expr + Op ComparisonOp + Left, Right Expr + CoerceLeft, CoerceRight collations.Coercion + TypedCollation collations.TypedCollation + } + + EqualOp struct { + Operator string + Compare func(cmp int) bool } - EqualOp struct{} - NotEqualOp struct{} NullSafeEqualOp struct{} - LessThanOp struct{} - LessEqualOp struct{} - GreaterThanOp struct{} - GreaterEqualOp struct{} - InOp struct{} - NotInOp struct{} - LikeOp struct{} - NotLikeOp struct{} - RegexpOp struct{} - NotRegexpOp struct{} + + InOp struct { + Negate bool + Hashed map[uintptr]int + } + LikeOp struct { + Negate bool + Match collations.WildcardPattern + } + RegexpOp struct { + Negate bool + } ) var ( - resultTrue = EvalResult{typ: sqltypes.Int32, ival: 1} - resultFalse = EvalResult{typ: sqltypes.Int32, ival: 0} + resultTrue = EvalResult{typ: sqltypes.Int32, numval: 1} + resultFalse = EvalResult{typ: sqltypes.Int32, numval: 0} resultNull = EvalResult{typ: sqltypes.Null} ) var _ ComparisonOp = (*EqualOp)(nil) -var _ ComparisonOp = (*NotEqualOp)(nil) -var _ ComparisonOp = (*NullSafeEqualOp)(nil) -var _ ComparisonOp = (*LessThanOp)(nil) -var _ ComparisonOp = (*LessEqualOp)(nil) -var _ ComparisonOp = (*GreaterThanOp)(nil) -var _ ComparisonOp = (*GreaterEqualOp)(nil) var _ ComparisonOp = (*InOp)(nil) -var _ ComparisonOp = (*NotInOp)(nil) var _ ComparisonOp = (*LikeOp)(nil) -var _ ComparisonOp = (*NotLikeOp)(nil) var _ ComparisonOp = (*RegexpOp)(nil) -var _ ComparisonOp = (*NotRegexpOp)(nil) -func (c *ComparisonExpr) evaluateComparisonExprs(env ExpressionEnv) (EvalResult, EvalResult, error) { +func (c *ComparisonExpr) Collation() collations.TypedCollation { + return c.TypedCollation +} + +func (c *ComparisonExpr) evaluateComparisonExprs(env *ExpressionEnv) (EvalResult, EvalResult, error) { var lVal, rVal EvalResult var err error if lVal, err = c.Left.Evaluate(env); err != nil { return EvalResult{}, EvalResult{}, err } + if sqltypes.IsText(lVal.typ) && c.CoerceLeft != nil { + lVal.bytes, _ = c.CoerceLeft(nil, lVal.bytes) + lVal.collation = c.TypedCollation + } if rVal, err = c.Right.Evaluate(env); err != nil { return EvalResult{}, EvalResult{}, err } + if sqltypes.IsText(rVal.typ) && c.CoerceRight != nil { + rVal.bytes, _ = c.CoerceRight(nil, rVal.bytes) + rVal.collation = c.TypedCollation + } return lVal, rVal, nil } @@ -123,9 +134,19 @@ func evalResultsAreDateAndNumeric(l, r EvalResult) bool { return sqltypes.IsDate(l.typ) && sqltypes.IsNumber(r.typ) || sqltypes.IsNumber(l.typ) && sqltypes.IsDate(r.typ) } +func nullSafeCoerceAndCompare(lVal, rVal EvalResult) (comp int, isNull bool, err error) { + if lVal.collation.Collation != rVal.collation.Collation { + lVal, rVal, err = mergeCollations(lVal, rVal) + if err != nil { + return 0, false, err + } + } + return nullSafeCompare(lVal, rVal) +} + // For more details on comparison expression evaluation and type conversion: // - https://dev.mysql.com/doc/refman/8.0/en/type-conversion.html -func nullSafeExecuteComparison(lVal, rVal EvalResult) (comp int, isNull bool, err error) { +func nullSafeCompare(lVal, rVal EvalResult) (comp int, isNull bool, err error) { lVal = foldSingleLenTuples(lVal) rVal = foldSingleLenTuples(rVal) if hasNullEvalResult(lVal, rVal) { @@ -133,8 +154,8 @@ func nullSafeExecuteComparison(lVal, rVal EvalResult) (comp int, isNull bool, er } switch { case evalResultsAreStrings(lVal, rVal): - comp, err = compareStrings(lVal, rVal) - return comp, false, err + comp = compareStrings(lVal, rVal) + return comp, false, nil case evalResultsAreSameNumericType(lVal, rVal), needsDecimalHandling(lVal, rVal): comp, err = compareNumeric(lVal, rVal) @@ -159,7 +180,7 @@ func nullSafeExecuteComparison(lVal, rVal EvalResult) (comp int, isNull bool, er case lVal.typ == querypb.Type_TUPLE && rVal.typ == querypb.Type_TUPLE: return compareTuples(lVal, rVal) case lVal.typ == querypb.Type_TUPLE: - return 0, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.OperandColumns, "Operand should contain %d column(s)", len(lVal.tupleResults)) + return 0, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.OperandColumns, "Operand should contain %d column(s)", len(*lVal.tuple)) case rVal.typ == querypb.Type_TUPLE: return 0, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.OperandColumns, "Operand should contain 1 column(s)") @@ -177,14 +198,21 @@ func nullSafeExecuteComparison(lVal, rVal EvalResult) (comp int, isNull bool, er } func foldSingleLenTuples(val EvalResult) EvalResult { - if val.typ == querypb.Type_TUPLE && len(val.tupleResults) == 1 { - val = val.tupleResults[0] + if val.typ == querypb.Type_TUPLE && len(*val.tuple) == 1 { + return (*val.tuple)[0] } return val } +func boolResult(result, negate bool) EvalResult { + if result == !negate { + return resultTrue + } + return resultFalse +} + // Evaluate implements the Expr interface -func (c *ComparisonExpr) Evaluate(env ExpressionEnv) (EvalResult, error) { +func (c *ComparisonExpr) Evaluate(env *ExpressionEnv) (EvalResult, error) { if c.Op == nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "a comparison expression needs a comparison operator") } @@ -198,25 +226,21 @@ func (c *ComparisonExpr) Evaluate(env ExpressionEnv) (EvalResult, error) { } // Type implements the Expr interface -func (c *ComparisonExpr) Type(ExpressionEnv) (querypb.Type, error) { +func (c *ComparisonExpr) Type(*ExpressionEnv) (querypb.Type, error) { return querypb.Type_INT32, nil } -// String implements the Expr interface -func (c *ComparisonExpr) String() string { - return c.Left.String() + " " + c.Op.String() + " " + c.Right.String() -} - // Evaluate implements the ComparisonOp interface func (e *EqualOp) Evaluate(left, right EvalResult) (EvalResult, error) { - numeric, isNull, err := nullSafeExecuteComparison(left, right) + // No need to coerce here because the caller ComparisonExpr.Evaluate has coerced for us + numeric, isNull, err := nullSafeCompare(left, right) if err != nil { return EvalResult{}, err } if isNull { return resultNull, err } - if numeric == 0 { + if e.Compare(numeric) { return resultTrue, nil } return resultFalse, nil @@ -229,32 +253,7 @@ func (e *EqualOp) Type() querypb.Type { // String implements the ComparisonOp interface func (e *EqualOp) String() string { - return "=" -} - -// Evaluate implements the ComparisonOp interface -func (n *NotEqualOp) Evaluate(left, right EvalResult) (EvalResult, error) { - numeric, isNull, err := nullSafeExecuteComparison(left, right) - if err != nil { - return EvalResult{}, err - } - if isNull { - return resultNull, err - } - if numeric != 0 { - return resultTrue, nil - } - return resultFalse, nil -} - -// Type implements the ComparisonOp interface -func (n *NotEqualOp) Type() querypb.Type { - return querypb.Type_INT32 -} - -// String implements the ComparisonOp interface -func (n *NotEqualOp) String() string { - return "!=" + return e.Operator } // Evaluate implements the ComparisonOp interface @@ -272,126 +271,51 @@ func (n *NullSafeEqualOp) String() string { return "<=>" } -// Evaluate implements the ComparisonOp interface -func (l *LessThanOp) Evaluate(left, right EvalResult) (EvalResult, error) { - numeric, isNull, err := nullSafeExecuteComparison(left, right) - if err != nil { - return EvalResult{}, err - } - if isNull { - return resultNull, err - } - if numeric < 0 { - return resultTrue, nil - } - return resultFalse, nil -} - -// Type implements the ComparisonOp interface -func (l *LessThanOp) Type() querypb.Type { - return querypb.Type_INT32 -} - -// String implements the ComparisonOp interface -func (l *LessThanOp) String() string { - return "<" -} - -// Evaluate implements the ComparisonOp interface -func (l *LessEqualOp) Evaluate(left, right EvalResult) (EvalResult, error) { - numeric, isNull, err := nullSafeExecuteComparison(left, right) - if err != nil { - return EvalResult{}, err - } - if isNull { - return resultNull, err - } - if numeric <= 0 { - return resultTrue, nil - } - return resultFalse, nil -} - -// Type implements the ComparisonOp interface -func (l *LessEqualOp) Type() querypb.Type { - return querypb.Type_INT32 -} - -// String implements the ComparisonOp interface -func (l *LessEqualOp) String() string { - return "<=" -} - -// Evaluate implements the ComparisonOp interface -func (g *GreaterThanOp) Evaluate(left, right EvalResult) (EvalResult, error) { - numeric, isNull, err := nullSafeExecuteComparison(left, right) - if err != nil { - return EvalResult{}, err - } - if isNull { - return resultNull, err - } - if numeric > 0 { - return resultTrue, nil - } - return resultFalse, nil -} - -// Type implements the ComparisonOp interface -func (g *GreaterThanOp) Type() querypb.Type { - return querypb.Type_INT32 -} - -// String implements the ComparisonOp interface -func (g *GreaterThanOp) String() string { - return ">" -} - -// Evaluate implements the ComparisonOp interface -func (g *GreaterEqualOp) Evaluate(left, right EvalResult) (EvalResult, error) { - numeric, isNull, err := nullSafeExecuteComparison(left, right) - if err != nil { - return EvalResult{}, err - } - if isNull { - return resultNull, err - } - if numeric >= 0 { - return resultTrue, nil - } - return resultFalse, nil -} - -// Type implements the ComparisonOp interface -func (g *GreaterEqualOp) Type() querypb.Type { - return querypb.Type_INT32 -} - -// String implements the ComparisonOp interface -func (g *GreaterEqualOp) String() string { - return ">=" -} - // Evaluate implements the ComparisonOp interface func (i *InOp) Evaluate(left, right EvalResult) (EvalResult, error) { if right.typ != querypb.Type_TUPLE { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple") } - returnValue := resultFalse - for _, result := range right.tupleResults { - res, err := (&EqualOp{}).Evaluate(left, result) + + var foundNull, found bool + + if i.Hashed != nil { + hash, err := left.nullSafeHashcode() if err != nil { return EvalResult{}, err } - if res.typ == querypb.Type_NULL_TYPE { - returnValue = resultNull - continue + if idx, ok := i.Hashed[hash]; ok { + var numeric int + numeric, foundNull, err = nullSafeCoerceAndCompare(left, (*right.tuple)[idx]) + if err != nil { + return EvalResult{}, err + } + found = numeric == 0 } - if sqltypes.IsIntegral(res.typ) && res.ival == 1 { - return resultTrue, nil + } else { + for _, rtuple := range *right.tuple { + numeric, isNull, err := nullSafeCoerceAndCompare(left, rtuple) + if err != nil { + return EvalResult{}, err + } + if isNull { + foundNull = true + continue + } + if numeric == 0 { + found = true + break + } } } - return returnValue, nil + + if found { + return boolResult(found, i.Negate), nil + } + if foundNull { + return resultNull, nil + } + return boolResult(found, i.Negate), nil } // Type implements the ComparisonOp interface @@ -401,29 +325,25 @@ func (i *InOp) Type() querypb.Type { // String implements the ComparisonOp interface func (i *InOp) String() string { + if i.Negate { + return "not in" + } return "in" } -// Evaluate implements the ComparisonOp interface -func (n *NotInOp) Evaluate(left, right EvalResult) (EvalResult, error) { - res, err := (&InOp{}).Evaluate(left, right) - res.ival = 1 - res.ival - return res, err -} - -// Type implements the ComparisonOp interface -func (n *NotInOp) Type() querypb.Type { - return querypb.Type_INT32 -} - -// String implements the ComparisonOp interface -func (n *NotInOp) String() string { - return "not in" -} - -// Evaluate implements the ComparisonOp interface func (l *LikeOp) Evaluate(left, right EvalResult) (EvalResult, error) { - panic("implement me") + if left.collation.Collation != right.collation.Collation { + panic("LikeOp: did not coerce") + } + var matched bool + if l.Match != nil { + matched = l.Match.Match(left.bytes) + } else { + coll := collations.Local().LookupByID(left.collation.Collation) + wc := coll.Wildcard(right.bytes, 0, 0, 0) + matched = wc.Match(left.bytes) + } + return boolResult(matched, l.Negate), nil } // Type implements the ComparisonOp interface @@ -433,25 +353,12 @@ func (l *LikeOp) Type() querypb.Type { // String implements the ComparisonOp interface func (l *LikeOp) String() string { + if l.Negate { + return "not like" + } return "like" } -// Evaluate implements the ComparisonOp interface -func (n *NotLikeOp) Evaluate(left, right EvalResult) (EvalResult, error) { - panic("implement me") -} - -// Type implements the ComparisonOp interface -func (n *NotLikeOp) Type() querypb.Type { - return querypb.Type_INT32 -} - -// String implements the ComparisonOp interface -func (n *NotLikeOp) String() string { - return "not like" -} - -// Evaluate implements the ComparisonOp interface func (r *RegexpOp) Evaluate(left, right EvalResult) (EvalResult, error) { panic("implement me") } @@ -465,18 +372,3 @@ func (r *RegexpOp) Type() querypb.Type { func (r *RegexpOp) String() string { return "regexp" } - -// Evaluate implements the ComparisonOp interface -func (n *NotRegexpOp) Evaluate(left, right EvalResult) (EvalResult, error) { - panic("implement me") -} - -// Type implements the ComparisonOp interface -func (n *NotRegexpOp) Type() querypb.Type { - return querypb.Type_INT32 -} - -// String implements the ComparisonOp interface -func (n *NotRegexpOp) String() string { - return "not regexp" -} diff --git a/go/vt/vtgate/evalengine/comparisons_test.go b/go/vt/vtgate/evalengine/comparisons_test.go index 24d389486eb..5292fdaf913 100644 --- a/go/vt/vtgate/evalengine/comparisons_test.go +++ b/go/vt/vtgate/evalengine/comparisons_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/sqltypes" @@ -44,14 +45,18 @@ var ( T = true F = false - defaultCollation = collations.Default().LookupByName("utf8mb4_bin").ID() + defaultCollation = collations.TypedCollation{ + Collation: collations.Local().LookupByName("utf8mb4_bin").ID(), + Coercibility: collations.CoerceImplicit, + Repertoire: collations.RepertoireASCII, + } ) func (tc testCase) run(t *testing.T) { if tc.bv == nil { tc.bv = map[string]*querypb.BindVariable{} } - env := ExpressionEnv{ + env := &ExpressionEnv{ BindVars: tc.bv, Row: tc.row, } @@ -64,9 +69,9 @@ func (tc testCase) run(t *testing.T) { if tc.err == "" { require.NoError(t, err) if tc.out != nil && *tc.out { - require.EqualValues(t, 1, got.ival) + require.EqualValues(t, 1, got.numval) } else if tc.out != nil && !*tc.out { - require.EqualValues(t, 0, got.ival) + require.EqualValues(t, 0, got.numval) } else { require.EqualValues(t, sqltypes.Null, got.typ) } @@ -75,93 +80,95 @@ func (tc testCase) run(t *testing.T) { } } +var cmpop = translateComparisonOperator + // This test tests the comparison of two integers func TestCompareIntegers(t *testing.T) { tests := []testCase{ { name: "integers are equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(0, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewInt64(18)}, }, { name: "integers are equal (2)", v1: NewLiteralInt(56), v2: NewLiteralInt(56), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), }, { name: "integers are not equal (1)", v1: NewLiteralInt(56), v2: NewLiteralInt(10), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), }, { name: "integers are not equal (2)", v1: NewLiteralInt(56), v2: NewLiteralInt(10), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), }, { name: "integers are not equal (3)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewInt64(18), sqltypes.NewInt64(98)}, }, { name: "unsigned integers are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(0, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewUint64(18)}, }, { name: "unsigned integer and integer are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewUint64(18), sqltypes.NewInt64(18)}, }, { name: "unsigned integer and integer are not equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewUint64(18), sqltypes.NewInt64(42)}, }, { name: "integer is less than integer", v1: NewLiteralInt(3549), v2: NewLiteralInt(8072), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), }, { name: "integer is not less than integer", v1: NewLiteralInt(3549), v2: NewLiteralInt(21), - out: &F, op: &LessThanOp{}, + out: &F, op: cmpop(sqlparser.LessThanOp), }, { name: "integer is less-equal to integer (1)", v1: NewLiteralInt(3549), v2: NewLiteralInt(9863), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), }, { name: "integer is less-equal to integer (2)", v1: NewLiteralInt(3549), v2: NewLiteralInt(3549), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), }, { name: "integer is greater than integer", v1: NewLiteralInt(9809), v2: NewLiteralInt(9800), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), }, { name: "integer is not greater than integer", v1: NewLiteralInt(549), v2: NewLiteralInt(21579), - out: &F, op: &GreaterThanOp{}, + out: &F, op: cmpop(sqlparser.GreaterThanOp), }, { name: "integer is greater-equal to integer (1)", v1: NewLiteralInt(987), v2: NewLiteralInt(15), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), }, { name: "integer is greater-equal to integer (2)", v1: NewLiteralInt(3549), v2: NewLiteralInt(3549), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), }, } @@ -178,69 +185,69 @@ func TestCompareFloats(t *testing.T) { { name: "floats are equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(0, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewFloat64(18)}, }, { name: "floats are equal (2)", v1: NewLiteralFloat(3549.9), v2: NewLiteralFloat(3549.9), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), }, { name: "floats are not equal (1)", v1: NewLiteralFloat(7858.016), v2: NewLiteralFloat(8943298.56), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), }, { name: "floats are not equal (2)", v1: NewLiteralFloat(351049.65), v2: NewLiteralFloat(62508.99), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), }, { name: "floats are not equal (3)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewFloat64(16516.84), sqltypes.NewFloat64(219541.01)}, }, { name: "float is less than float", v1: NewLiteralFloat(3549.9), v2: NewLiteralFloat(8072), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), }, { name: "float is not less than float", v1: NewLiteralFloat(3549.9), v2: NewLiteralFloat(21.564), - out: &F, op: &LessThanOp{}, + out: &F, op: cmpop(sqlparser.LessThanOp), }, { name: "float is less-equal to float (1)", v1: NewLiteralFloat(3549.9), v2: NewLiteralFloat(9863), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), }, { name: "float is less-equal to float (2)", v1: NewLiteralFloat(3549.9), v2: NewLiteralFloat(3549.9), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), }, { name: "float is greater than float", v1: NewLiteralFloat(9808.549), v2: NewLiteralFloat(9808.540), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), }, { name: "float is not greater than float", v1: NewLiteralFloat(549.02), v2: NewLiteralFloat(21579.64), - out: &F, op: &GreaterThanOp{}, + out: &F, op: cmpop(sqlparser.GreaterThanOp), }, { name: "float is greater-equal to float (1)", v1: NewLiteralFloat(987.30), v2: NewLiteralFloat(15.5), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), }, { name: "float is greater-equal to float (2)", v1: NewLiteralFloat(3549.9), v2: NewLiteralFloat(3549.9), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), }, } @@ -257,37 +264,37 @@ func TestCompareDecimals(t *testing.T) { { name: "decimals are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(0, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("12.9019")}, }, { name: "decimals are not equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("12.9019"), sqltypes.NewDecimal("489.156849")}, }, { name: "decimal is greater than decimal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDecimal("192.129"), sqltypes.NewDecimal("192.128")}, }, { name: "decimal is not greater than decimal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &GreaterThanOp{}, + out: &F, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDecimal("192.128"), sqltypes.NewDecimal("192.129")}, }, { name: "decimal is less than decimal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewDecimal("192.128"), sqltypes.NewDecimal("192.129")}, }, { name: "decimal is not less than decimal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &LessThanOp{}, + out: &F, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewDecimal("192.129"), sqltypes.NewDecimal("192.128")}, }, } @@ -305,121 +312,121 @@ func TestCompareNumerics(t *testing.T) { { name: "decimal and float are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewFloat64(189.6), sqltypes.NewDecimal("189.6")}, }, { name: "decimal and float with negative values are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewFloat64(-98.1839), sqltypes.NewDecimal("-98.1839")}, }, { name: "decimal and float with negative values are not equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewFloat64(-98.9381), sqltypes.NewDecimal("-98.1839")}, }, { name: "decimal and float with negative values are not equal (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewFloat64(-98.9381), sqltypes.NewDecimal("-98.1839")}, }, { name: "decimal and integer are equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewInt64(8979), sqltypes.NewDecimal("8979")}, }, { name: "decimal and integer are equal (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("8979.0000"), sqltypes.NewInt64(8979)}, }, { name: "decimal and unsigned integer are equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewUint64(901), sqltypes.NewDecimal("901")}, }, { name: "decimal and unsigned integer are equal (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("901.00"), sqltypes.NewUint64(901)}, }, { name: "decimal and unsigned integer are not equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("192.129"), sqltypes.NewUint64(192)}, }, { name: "decimal and unsigned integer are not equal (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("192.129"), sqltypes.NewUint64(192)}, }, { name: "decimal is greater than integer", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDecimal("1.01"), sqltypes.NewInt64(1)}, }, { name: "decimal is greater-equal to integer", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("1.00"), sqltypes.NewInt64(1)}, }, { name: "decimal is less than integer", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewDecimal(".99"), sqltypes.NewInt64(1)}, }, { name: "decimal is less-equal to integer", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("1.00"), sqltypes.NewInt64(1)}, }, { name: "decimal is greater than float", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDecimal("849.896"), sqltypes.NewFloat64(86.568)}, }, { name: "decimal is not greater than float", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &GreaterThanOp{}, + out: &F, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDecimal("15.23"), sqltypes.NewFloat64(8689.5)}, }, { name: "decimal is greater-equal to float (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("65"), sqltypes.NewFloat64(65)}, }, { name: "decimal is greater-equal to float (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("65"), sqltypes.NewFloat64(60)}, }, { name: "decimal is less than float", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewDecimal("0.998"), sqltypes.NewFloat64(0.999)}, }, { name: "decimal is less-equal to float", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewDecimal("1.000101"), sqltypes.NewFloat64(1.00101)}, }, } @@ -437,73 +444,73 @@ func TestCompareDatetime(t *testing.T) { { name: "datetimes are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(0, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-22 12:00:00")}, }, { name: "datetimes are not equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-22 12:00:00"), sqltypes.NewDatetime("2020-10-22 12:00:00")}, }, { name: "datetimes are not equal (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-22 12:00:00"), sqltypes.NewDatetime("2021-10-22 10:23:56")}, }, { name: "datetimes are not equal (3)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-01 00:00:00"), sqltypes.NewDatetime("2021-02-01 00:00:00")}, }, { name: "datetime is greater than datetime", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-01 13:10:02")}, }, { name: "datetime is not greater than datetime", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &GreaterThanOp{}, + out: &F, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-01 13:10:02"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, { name: "datetime is less than datetime", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-01 13:10:02"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, { name: "datetime is not less than datetime", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &LessThanOp{}, + out: &F, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-01 13:10:02")}, }, { name: "datetime is greater-equal to datetime (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, { name: "datetime is greater-equal to datetime (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-01 13:10:02")}, }, { name: "datetime is less-equal to datetime (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, { name: "datetime is less-equal to datetime (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-01 13:10:02"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, } @@ -521,73 +528,73 @@ func TestCompareTimestamp(t *testing.T) { { name: "timestamps are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(0, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-22 12:00:00")}, }, { name: "timestamps are not equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-22 12:00:00"), sqltypes.NewTimestamp("2020-10-22 12:00:00")}, }, { name: "timestamps are not equal (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-22 12:00:00"), sqltypes.NewTimestamp("2021-10-22 10:23:56")}, }, { name: "timestamps are not equal (3)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-01 00:00:00"), sqltypes.NewTimestamp("2021-02-01 00:00:00")}, }, { name: "timestamp is greater than timestamp", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-01 13:10:02")}, }, { name: "timestamp is not greater than timestamp", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &GreaterThanOp{}, + out: &F, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-01 13:10:02"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, { name: "timestamp is less than timestamp", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-01 13:10:02"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, { name: "timestamp is not less than timestamp", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &LessThanOp{}, + out: &F, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-01 13:10:02")}, }, { name: "timestamp is greater-equal to timestamp (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, { name: "timestamp is greater-equal to timestamp (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-01 13:10:02")}, }, { name: "timestamp is less-equal to timestamp (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, { name: "timestamp is less-equal to timestamp (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-01 13:10:02"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, } @@ -605,67 +612,67 @@ func TestCompareDate(t *testing.T) { { name: "dates are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(0, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-22")}, }, { name: "dates are not equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewDate("2020-10-21")}, }, { name: "dates are not equal (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-01"), sqltypes.NewDate("2021-02-01")}, }, { name: "date is greater than date", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-01")}, }, { name: "date is not greater than date", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &GreaterThanOp{}, + out: &F, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-01"), sqltypes.NewDate("2021-10-30")}, }, { name: "date is less than date", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-01"), sqltypes.NewDate("2021-10-30")}, }, { name: "date is not less than date", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &LessThanOp{}, + out: &F, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-01")}, }, { name: "date is greater-equal to date (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-30")}, }, { name: "date is greater-equal to date (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-01")}, }, { name: "date is less-equal to date (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-30")}, }, { name: "date is less-equal to date (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-01"), sqltypes.NewDate("2021-10-30")}, }, } @@ -683,67 +690,67 @@ func TestCompareTime(t *testing.T) { { name: "times are equal", v1: NewColumn(0, defaultCollation), v2: NewColumn(0, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewTime("12:00:00")}, }, { name: "times are not equal (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &EqualOp{}, + out: &F, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewTime("12:00:00"), sqltypes.NewTime("10:23:56")}, }, { name: "times are not equal (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewTime("00:00:00"), sqltypes.NewTime("10:15:00")}, }, { name: "time is greater than time", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterThanOp{}, + out: &T, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewTime("18:14:35"), sqltypes.NewTime("13:01:38")}, }, { name: "time is not greater than time", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &GreaterThanOp{}, + out: &F, op: cmpop(sqlparser.GreaterThanOp), row: []sqltypes.Value{sqltypes.NewTime("02:46:02"), sqltypes.NewTime("10:42:50")}, }, { name: "time is less than time", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessThanOp{}, + out: &T, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewTime("04:30:00"), sqltypes.NewTime("09:23:48")}, }, { name: "time is not less than time", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &F, op: &LessThanOp{}, + out: &F, op: cmpop(sqlparser.LessThanOp), row: []sqltypes.Value{sqltypes.NewTime("15:21:00"), sqltypes.NewTime("10:00:00")}, }, { name: "time is greater-equal to time (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewTime("10:42:50"), sqltypes.NewTime("10:42:50")}, }, { name: "time is greater-equal to time (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &GreaterEqualOp{}, + out: &T, op: cmpop(sqlparser.GreaterEqualOp), row: []sqltypes.Value{sqltypes.NewTime("19:42:50"), sqltypes.NewTime("13:10:02")}, }, { name: "time is less-equal to time (1)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewTime("10:42:50"), sqltypes.NewTime("10:42:50")}, }, { name: "time is less-equal to time (2)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &LessEqualOp{}, + out: &T, op: cmpop(sqlparser.LessEqualOp), row: []sqltypes.Value{sqltypes.NewTime("10:10:02"), sqltypes.NewTime("10:42:50")}, }, } @@ -761,13 +768,13 @@ func TestCompareDates(t *testing.T) { { name: "date equal datetime", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewDatetime("2021-10-22 00:00:00")}, }, { name: "date equal datetime through bind variables", v1: NewBindVar("k1", defaultCollation), v2: NewBindVar("k2", defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), bv: map[string]*querypb.BindVariable{ "k1": {Type: sqltypes.Date, Value: []byte("2021-10-22")}, "k2": {Type: sqltypes.Datetime, Value: []byte("2021-10-22 00:00:00")}, @@ -776,7 +783,7 @@ func TestCompareDates(t *testing.T) { { name: "date not equal datetime through bind variables", v1: NewBindVar("k1", defaultCollation), v2: NewBindVar("k2", defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), bv: map[string]*querypb.BindVariable{ "k1": {Type: sqltypes.Date, Value: []byte("2021-02-20")}, "k2": {Type: sqltypes.Datetime, Value: []byte("2021-10-22 00:00:00")}, @@ -785,73 +792,73 @@ func TestCompareDates(t *testing.T) { { name: "date not equal datetime", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewDatetime("2021-10-20 00:06:00")}, }, { name: "date equal timestamp", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewTimestamp("2021-10-22 00:00:00")}, }, { name: "date not equal timestamp", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewTimestamp("2021-10-22 16:00:00")}, }, { name: "date equal time", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewDate(time.Now().Format("2006-01-02")), sqltypes.NewTime("00:00:00")}, }, { name: "date not equal time", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewDate(time.Now().Format("2006-01-02")), sqltypes.NewTime("12:00:00")}, }, { name: "string equal datetime", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewVarChar("2021-10-22"), sqltypes.NewDatetime("2021-10-22 00:00:00")}, }, { name: "string equal timestamp", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewVarChar("2021-10-22 00:00:00"), sqltypes.NewTimestamp("2021-10-22 00:00:00")}, }, { name: "string not equal timestamp", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewVarChar("2021-10-22 06:00:30"), sqltypes.NewTimestamp("2021-10-20 15:02:10")}, }, { name: "string equal time", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewVarChar("00:05:12"), sqltypes.NewTime("00:05:12")}, }, { name: "string equal date", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewVarChar("2021-02-22"), sqltypes.NewDate("2021-02-22")}, }, { name: "string not equal date (1, date on the RHS)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewVarChar("2021-02-20"), sqltypes.NewDate("2021-03-30")}, }, { name: "string not equal date (2, date on the LHS)", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &NotEqualOp{}, + out: &T, op: cmpop(sqlparser.NotEqualOp), row: []sqltypes.Value{sqltypes.NewDate("2021-03-30"), sqltypes.NewVarChar("2021-02-20")}, }, } @@ -869,22 +876,15 @@ func TestCompareStrings(t *testing.T) { { name: "string equal string", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewVarChar("toto"), sqltypes.NewVarChar("toto")}, }, { name: "string equal number", v1: NewColumn(0, defaultCollation), v2: NewColumn(1, defaultCollation), - out: &T, op: &EqualOp{}, + out: &T, op: cmpop(sqlparser.EqualOp), row: []sqltypes.Value{sqltypes.NewVarChar("1"), sqltypes.NewInt64(1)}, }, - { - name: "string equal string unknown collation", - v1: NewColumn(0, collations.Unknown), v2: NewColumn(1, collations.Unknown), - op: &EqualOp{}, - err: "cannot compare strings with an unknown collation", - row: []sqltypes.Value{sqltypes.NewVarChar("1"), sqltypes.NewVarChar("1")}, - }, } for i, tcase := range tests { @@ -899,47 +899,47 @@ func TestInOp(t *testing.T) { tests := []testCase{ { name: "integer In tuple", - v1: NewLiteralInt(52), v2: Tuple{NewLiteralInt(52), NewLiteralInt(54)}, + v1: NewLiteralInt(52), v2: TupleExpr{NewLiteralInt(52), NewLiteralInt(54)}, out: &T, op: &InOp{}, }, { name: "integer not In tuple", - v1: NewLiteralInt(51), v2: Tuple{NewLiteralInt(52), NewLiteralInt(54)}, + v1: NewLiteralInt(51), v2: TupleExpr{NewLiteralInt(52), NewLiteralInt(54)}, out: &F, op: &InOp{}, }, { name: "integer In tuple - single value", - v1: NewLiteralInt(52), v2: Tuple{NewLiteralInt(52)}, + v1: NewLiteralInt(52), v2: TupleExpr{NewLiteralInt(52)}, out: &T, op: &InOp{}, }, { name: "integer not In tuple - single value", - v1: NewLiteralInt(51), v2: Tuple{NewLiteralInt(52)}, + v1: NewLiteralInt(51), v2: TupleExpr{NewLiteralInt(52)}, out: &F, op: &InOp{}, }, { name: "integer not In tuple - no value", - v1: NewLiteralInt(51), v2: Tuple{}, + v1: NewLiteralInt(51), v2: TupleExpr{}, out: &F, op: &InOp{}, }, { name: "integer not In tuple - null value", - v1: NewLiteralInt(51), v2: Tuple{Null{}}, + v1: NewLiteralInt(51), v2: TupleExpr{NewLiteralNull()}, out: nil, op: &InOp{}, }, { name: "integer not In tuple but with Null inside", - v1: NewLiteralInt(52), v2: Tuple{Null{}, NewLiteralInt(51), NewLiteralInt(54), Null{}}, + v1: NewLiteralInt(52), v2: TupleExpr{NewLiteralNull(), NewLiteralInt(51), NewLiteralInt(54), NewLiteralNull()}, out: nil, op: &InOp{}, }, { name: "integer In tuple with null inside", - v1: NewLiteralInt(52), v2: Tuple{Null{}, NewLiteralInt(52), NewLiteralInt(54)}, + v1: NewLiteralInt(52), v2: TupleExpr{NewLiteralNull(), NewLiteralInt(52), NewLiteralInt(54)}, out: &T, op: &InOp{}, }, { name: "Null In tuple", - v1: Null{}, v2: Tuple{Null{}, NewLiteralInt(52), NewLiteralInt(54)}, + v1: NewLiteralNull(), v2: TupleExpr{NewLiteralNull(), NewLiteralInt(52), NewLiteralInt(54)}, out: nil, op: &InOp{}, }, @@ -957,49 +957,49 @@ func TestNotInOp(t *testing.T) { tests := []testCase{ { name: "integer In tuple", - v1: NewLiteralInt(52), v2: Tuple{NewLiteralInt(52), NewLiteralInt(54)}, + v1: NewLiteralInt(52), v2: TupleExpr{NewLiteralInt(52), NewLiteralInt(54)}, out: &F, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, { name: "integer not In tuple", - v1: NewLiteralInt(51), v2: Tuple{NewLiteralInt(52), NewLiteralInt(54)}, + v1: NewLiteralInt(51), v2: TupleExpr{NewLiteralInt(52), NewLiteralInt(54)}, out: &T, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, { name: "integer In tuple - single value", - v1: NewLiteralInt(52), v2: Tuple{NewLiteralInt(52)}, + v1: NewLiteralInt(52), v2: TupleExpr{NewLiteralInt(52)}, out: &F, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, { name: "integer not In tuple - single value", - v1: NewLiteralInt(51), v2: Tuple{NewLiteralInt(52)}, + v1: NewLiteralInt(51), v2: TupleExpr{NewLiteralInt(52)}, out: &T, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, { name: "integer not In tuple - no value", - v1: NewLiteralInt(51), v2: Tuple{}, + v1: NewLiteralInt(51), v2: TupleExpr{}, out: &T, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, { name: "integer not In tuple - null value", - v1: NewLiteralInt(51), v2: Tuple{Null{}}, + v1: NewLiteralInt(51), v2: TupleExpr{NewLiteralNull()}, out: nil, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, { name: "integer not In tuple but with Null inside", - v1: NewLiteralInt(52), v2: Tuple{Null{}, NewLiteralInt(51), NewLiteralInt(54), Null{}}, + v1: NewLiteralInt(52), v2: TupleExpr{NewLiteralNull(), NewLiteralInt(51), NewLiteralInt(54), NewLiteralNull()}, out: nil, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, { name: "integer In tuple with null inside", - v1: NewLiteralInt(52), v2: Tuple{Null{}, NewLiteralInt(52), NewLiteralInt(54)}, + v1: NewLiteralInt(52), v2: TupleExpr{NewLiteralNull(), NewLiteralInt(52), NewLiteralInt(54)}, out: &F, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, { name: "Null In tuple", - v1: Null{}, v2: Tuple{Null{}, NewLiteralInt(52), NewLiteralInt(54)}, + v1: NewLiteralNull(), v2: TupleExpr{NewLiteralNull(), NewLiteralInt(52), NewLiteralInt(54)}, out: nil, - op: &NotInOp{}, + op: &InOp{Negate: true}, }, } diff --git a/go/vt/vtgate/evalengine/convert.go b/go/vt/vtgate/evalengine/convert.go index 9447f104726..ce4d479ab27 100644 --- a/go/vt/vtgate/evalengine/convert.go +++ b/go/vt/vtgate/evalengine/convert.go @@ -18,6 +18,7 @@ package evalengine import ( "vitess.io/vitess/go/mysql/collations" + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -37,46 +38,209 @@ var ErrConvertExprNotSupported = "expr cannot be converted, not supported" func translateComparisonOperator(op sqlparser.ComparisonExprOperator) ComparisonOp { switch op { case sqlparser.EqualOp: - return &EqualOp{} + return &EqualOp{"=", func(cmp int) bool { return cmp == 0 }} case sqlparser.LessThanOp: - return &LessThanOp{} + return &EqualOp{"<", func(cmp int) bool { return cmp < 0 }} case sqlparser.GreaterThanOp: - return &GreaterThanOp{} + return &EqualOp{">", func(cmp int) bool { return cmp > 0 }} case sqlparser.LessEqualOp: - return &LessEqualOp{} + return &EqualOp{"<=", func(cmp int) bool { return cmp <= 0 }} case sqlparser.GreaterEqualOp: - return &GreaterEqualOp{} + return &EqualOp{">=", func(cmp int) bool { return cmp >= 0 }} case sqlparser.NotEqualOp: - return &NotEqualOp{} + return &EqualOp{"!=", func(cmp int) bool { return cmp != 0 }} case sqlparser.NullSafeEqualOp: return &NullSafeEqualOp{} case sqlparser.InOp: return &InOp{} case sqlparser.NotInOp: - return &NotInOp{} + return &InOp{Negate: true} case sqlparser.LikeOp: return &LikeOp{} case sqlparser.NotLikeOp: - return &NotLikeOp{} + return &LikeOp{Negate: true} case sqlparser.RegexpOp: return &RegexpOp{} case sqlparser.NotRegexpOp: - return &NotRegexpOp{} + return &RegexpOp{Negate: true} default: return nil } } -func getCollation(expr sqlparser.Expr, lookup ConverterLookup) collations.ID { - collation := collations.Unknown +func getCollation(expr sqlparser.Expr, lookup ConverterLookup) collations.TypedCollation { + collation := collations.TypedCollation{ + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + } if lookup != nil { - collation = lookup.CollationIDLookup(expr) + collation.Collation = lookup.CollationIDLookup(expr) + } else { + sysdefault, _ := collations.Local().ResolveCollation("", "") + collation.Collation = sysdefault.ID() } return collation } +func ConvertEx(e sqlparser.Expr, lookup ConverterLookup, simplify bool) (Expr, error) { + expr, err := convertExpr(e, lookup) + if err != nil { + return nil, err + } + if simplify { + expr, err = simplifyExpr(expr) + } + return expr, err +} + // Convert converts between AST expressions and executable expressions func Convert(e sqlparser.Expr, lookup ConverterLookup) (Expr, error) { + return ConvertEx(e, lookup, true) +} + +func simplifyExpr(e Expr) (Expr, error) { + var err error + + switch node := e.(type) { + case *ComparisonExpr: + node.Left, err = simplifyExpr(node.Left) + if err != nil { + return nil, err + } + node.Right, err = simplifyExpr(node.Right) + if err != nil { + return nil, err + } + lit1, _ := node.Left.(*Literal) + lit2, _ := node.Right.(*Literal) + if lit1 != nil && lit2 != nil { + res, err := node.Evaluate(nil) + if err != nil { + return nil, err + } + return &Literal{Val: res}, nil + } + + if lit1 != nil && node.CoerceLeft != nil { + lit1.Val.bytes, _ = node.CoerceLeft(nil, lit1.Val.bytes) + lit1.Val.collation = node.TypedCollation + node.CoerceLeft = nil + } + if lit2 != nil && node.CoerceRight != nil { + lit2.Val.bytes, _ = node.CoerceRight(nil, lit2.Val.bytes) + lit2.Val.collation = node.TypedCollation + node.CoerceRight = nil + } + + switch op := node.Op.(type) { + case *LikeOp: + if lit2 != nil { + coll := collations.Local().LookupByID(node.TypedCollation.Collation) + op.Match = coll.Wildcard(lit2.Val.bytes, 0, 0, 0) + } + + case *InOp: + if tuple, ok := node.Right.(TupleExpr); ok { + var ( + collation collations.ID + typ querypb.Type + optimize = true + literalTuple = true + ) + + for i, expr := range tuple { + if lit, ok := expr.(*Literal); ok { + thisColl := lit.Val.collation.Collation + thisTyp := lit.Val.typ + if i == 0 { + collation = thisColl + typ = thisTyp + continue + } + if collation == thisColl && typ == thisTyp { + continue + } + optimize = false + continue + } + literalTuple = false + break + } + + if lit1 != nil && literalTuple { + res, err := node.Evaluate(nil) + if err != nil { + return nil, err + } + return &Literal{Val: res}, nil + } + + if optimize && literalTuple { + op.Hashed = make(map[HashCode]int) + for i, expr := range tuple { + lit := expr.(*Literal) + hash, err := lit.Val.nullSafeHashcode() + if err != nil { + op.Hashed = nil + break + } + if collidx, collision := op.Hashed[hash]; collision { + cmp, _, err := nullSafeCompare(lit.Val, tuple[collidx].(*Literal).Val) + if cmp != 0 || err != nil { + op.Hashed = nil + break + } + continue + } + op.Hashed[hash] = i + } + } + } + } + + case *BinaryExpr: + node.Left, err = simplifyExpr(node.Left) + if err != nil { + return nil, err + } + node.Right, err = simplifyExpr(node.Right) + if err != nil { + return nil, err + } + _, lit1 := node.Left.(*Literal) + _, lit2 := node.Right.(*Literal) + if lit1 && lit2 { + res, err := node.Evaluate(nil) + if err != nil { + return nil, err + } + return &Literal{Val: res}, nil + } + + case *CollateExpr: + lit, _ := node.Expr.(*Literal) + if lit != nil { + res, err := node.Evaluate(nil) + if err != nil { + return nil, err + } + return &Literal{Val: res}, nil + } + + case TupleExpr: + var err error + for i, expr := range node { + expr, err = simplifyExpr(expr) + if err != nil { + return nil, err + } + node[i] = expr + } + } + return e, nil +} + +func convertExpr(e sqlparser.Expr, lookup ConverterLookup) (Expr, error) { switch node := e.(type) { case *sqlparser.ColName: if lookup == nil { @@ -89,24 +253,37 @@ func Convert(e sqlparser.Expr, lookup ConverterLookup) (Expr, error) { collation := getCollation(node, lookup) return NewColumn(idx, collation), nil case *sqlparser.ComparisonExpr: - left, err := Convert(node.Left, lookup) + left, err := convertExpr(node.Left, lookup) if err != nil { return nil, err } - right, err := Convert(node.Right, lookup) + right, err := convertExpr(node.Right, lookup) if err != nil { return nil, err } - return &ComparisonExpr{ + comp := &ComparisonExpr{ Op: translateComparisonOperator(node.Operator), Left: left, Right: right, - }, nil - case sqlparser.Argument: - collation := collations.Unknown - if lookup != nil { - collation = lookup.CollationIDLookup(e) } + + leftColl := left.Collation() + rightColl := right.Collation() + if leftColl.Valid() && rightColl.Valid() { + env := collations.Local() + comp.TypedCollation, comp.CoerceLeft, comp.CoerceRight, err = + env.MergeCollations(leftColl, rightColl, collations.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) + if err != nil { + return nil, err + } + } + + return comp, nil + case sqlparser.Argument: + collation := getCollation(e, lookup) return NewBindVar(string(node), collation), nil case *sqlparser.Literal: switch node.Type { @@ -122,9 +299,9 @@ func Convert(e sqlparser.Expr, lookup ConverterLookup) (Expr, error) { } case sqlparser.BoolVal: if node { - return NewLiteralIntFromBytes([]byte("1")) + return NewLiteralInt(1), nil } - return NewLiteralIntFromBytes([]byte("0")) + return NewLiteralInt(0), nil case *sqlparser.BinaryExpr: var op BinaryOp switch node.Operator { @@ -139,11 +316,11 @@ func Convert(e sqlparser.Expr, lookup ConverterLookup) (Expr, error) { default: return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%s: %T", ErrConvertExprNotSupported, e) } - left, err := Convert(node.Left, lookup) + left, err := convertExpr(node.Left, lookup) if err != nil { return nil, err } - right, err := Convert(node.Right, lookup) + right, err := convertExpr(node.Right, lookup) if err != nil { return nil, err } @@ -153,17 +330,34 @@ func Convert(e sqlparser.Expr, lookup ConverterLookup) (Expr, error) { Right: right, }, nil case sqlparser.ValTuple: - var res Tuple + var exprs TupleExpr for _, expr := range node { - convertedExpr, err := Convert(expr, lookup) + convertedExpr, err := convertExpr(expr, lookup) if err != nil { return nil, err } - res = append(res, convertedExpr) + exprs = append(exprs, convertedExpr) } - return res, nil + return exprs, nil case *sqlparser.NullVal: - return Null{}, nil + return NewLiteralNull(), nil + case *sqlparser.CollateExpr: + expr, err := convertExpr(node.Expr, lookup) + if err != nil { + return nil, err + } + coll := collations.Local().LookupByName(node.Collation) + if coll == nil { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Unknown collation: '%s'", node.Collation) + } + return &CollateExpr{ + Expr: expr, + TypedCollation: collations.TypedCollation{ + Collation: coll.ID(), + Coercibility: collations.CoerceExplicit, + Repertoire: collations.RepertoireUnicode, + }, + }, nil } return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%s: %T", ErrConvertExprNotSupported, e) } diff --git a/go/vt/vtgate/evalengine/convert_test.go b/go/vt/vtgate/evalengine/convert_test.go index b350d155312..a97313656da 100644 --- a/go/vt/vtgate/evalengine/convert_test.go +++ b/go/vt/vtgate/evalengine/convert_test.go @@ -17,8 +17,10 @@ limitations under the License. package evalengine import ( + "strings" "testing" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/sqltypes" @@ -34,6 +36,98 @@ These tests should in theory live in the sqltypes package but they live here so exercise both expression conversion and evaluation in the same test file */ +type dummyCollation collations.ID + +func (d dummyCollation) ColumnLookup(_ *sqlparser.ColName) (int, error) { + panic("not supported") +} + +func (d dummyCollation) CollationIDLookup(_ sqlparser.Expr) collations.ID { + return collations.ID(d) +} + +func TestConvertSimplification(t *testing.T) { + type ast struct { + literal, err string + } + ok := func(in string) ast { + return ast{literal: in} + } + err := func(in string) ast { + return ast{err: in} + } + + var testCases = []struct { + expression string + converted ast + simplified ast + }{ + {"42", ok("INT64(42)"), ok("INT64(42)")}, + {"1 + (1 + 1) * 8", ok("INT64(1) + ((INT64(1) + INT64(1)) * INT64(8))"), ok("INT64(17)")}, + {"1.0 + (1 + 1) * 8.0", ok("FLOAT64(1) + ((INT64(1) + INT64(1)) * FLOAT64(8))"), ok("FLOAT64(17)")}, + {"'pokemon' LIKE 'poke%'", ok("VARBINARY(\"pokemon\") like VARBINARY(\"poke%\")"), ok("INT32(1)")}, + { + "'foo' COLLATE utf8mb4_general_ci IN ('bar' COLLATE latin1_swedish_ci, 'baz')", + ok(`VARBINARY("foo") COLLATE utf8mb4_general_ci in (VARBINARY("bar") COLLATE latin1_swedish_ci, VARBINARY("baz"))`), + err("COLLATION 'latin1_swedish_ci' is not valid for CHARACTER SET 'utf8mb4'"), + }, + {`"pokemon" in ("bulbasaur", "venusaur", "charizard")`, + ok(`VARBINARY("pokemon") in (VARBINARY("bulbasaur"), VARBINARY("venusaur"), VARBINARY("charizard"))`), + ok("INT32(0)"), + }, + {`"pokemon" in ("bulbasaur", "venusaur", "pokemon")`, + ok(`VARBINARY("pokemon") in (VARBINARY("bulbasaur"), VARBINARY("venusaur"), VARBINARY("pokemon"))`), + ok("INT32(1)"), + }, + {`"pokemon" in ("bulbasaur", "venusaur", "pokemon", NULL)`, + ok(`VARBINARY("pokemon") in (VARBINARY("bulbasaur"), VARBINARY("venusaur"), VARBINARY("pokemon"), NULL)`), + ok(`INT32(1)`), + }, + {`"pokemon" in ("bulbasaur", "venusaur", NULL)`, + ok(`VARBINARY("pokemon") in (VARBINARY("bulbasaur"), VARBINARY("venusaur"), NULL)`), + ok(`NULL`), + }, + } + + for _, tc := range testCases { + t.Run(tc.expression, func(t *testing.T) { + stmt, err := sqlparser.Parse("select " + tc.expression) + if err != nil { + t.Fatal(err) + } + + astExpr := stmt.(*sqlparser.Select).SelectExprs[0].(*sqlparser.AliasedExpr).Expr + converted, err := ConvertEx(astExpr, dummyCollation(45), false) + if err != nil { + if tc.converted.err == "" { + t.Fatalf("failed to Convert (simplify=false): %v", err) + } + if !strings.Contains(err.Error(), tc.converted.err) { + t.Fatalf("wrong Convert error (simplify=false): %q (expected %q)", err, tc.converted.err) + } + return + } + if FormatExpr(converted) != tc.converted.literal { + t.Errorf("mismatch (simplify=false): got %s, expected %s", FormatExpr(converted), tc.converted.literal) + } + + simplified, err := ConvertEx(astExpr, dummyCollation(45), true) + if err != nil { + if tc.simplified.err == "" { + t.Fatalf("failed to Convert (simplify=true): %v", err) + } + if !strings.Contains(err.Error(), tc.simplified.err) { + t.Fatalf("wrong Convert error (simplify=true): %q (expected %q)", err, tc.simplified.err) + } + return + } + if FormatExpr(simplified) != tc.simplified.literal { + t.Errorf("mismatch (simplify=true): got %s, expected %s", FormatExpr(simplified), tc.simplified.literal) + } + }) + } +} + func TestEvaluate(t *testing.T) { type testCase struct { expression string @@ -117,10 +211,10 @@ func TestEvaluate(t *testing.T) { stmt, err := sqlparser.Parse("select " + test.expression) require.NoError(t, err) astExpr := stmt.(*sqlparser.Select).SelectExprs[0].(*sqlparser.AliasedExpr).Expr - sqltypesExpr, err := Convert(astExpr, nil) + sqltypesExpr, err := Convert(astExpr, dummyCollation(45)) require.Nil(t, err) require.NotNil(t, sqltypesExpr) - env := ExpressionEnv{ + env := &ExpressionEnv{ BindVars: map[string]*querypb.BindVariable{ "exp": sqltypes.Int64BindVariable(66), "string_bind_variable": sqltypes.StringBindVariable("bar"), diff --git a/go/vt/vtgate/evalengine/evalengine.go b/go/vt/vtgate/evalengine/evalengine.go index f14b17fdb29..f05e7cc7780 100644 --- a/go/vt/vtgate/evalengine/evalengine.go +++ b/go/vt/vtgate/evalengine/evalengine.go @@ -18,14 +18,13 @@ package evalengine import ( "math" + "strconv" "time" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" - "strconv" - querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -67,12 +66,12 @@ func ToUint64(v sqltypes.Value) (uint64, error) { } switch num.typ { case sqltypes.Int64: - if num.ival < 0 { - return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: %d", num.ival) + if num.numval > math.MaxInt64 { + return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: %d", int64(num.numval)) } - return uint64(num.ival), nil + return num.numval, nil case sqltypes.Uint64: - return num.uval, nil + return num.numval, nil } panic("unreachable") } @@ -85,11 +84,11 @@ func ToInt64(v sqltypes.Value) (int64, error) { } switch num.typ { case sqltypes.Int64: - return num.ival, nil + return int64(num.numval), nil case sqltypes.Uint64: - ival := int64(num.uval) + ival := int64(num.numval) if ival < 0 { - return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: %d", num.uval) + return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: %d", num.numval) } return ival, nil } @@ -104,11 +103,11 @@ func ToFloat64(v sqltypes.Value) (float64, error) { } switch num.typ { case sqltypes.Int64: - return float64(num.ival), nil + return float64(int64(num.numval)), nil case sqltypes.Uint64: - return float64(num.uval), nil + return float64(num.numval), nil case sqltypes.Float64: - return num.fval, nil + return math.Float64frombits(num.numval), nil } if sqltypes.IsText(num.typ) || sqltypes.IsBinary(num.typ) { @@ -155,13 +154,13 @@ func newEvalResult(v sqltypes.Value) (EvalResult, error) { if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return EvalResult{ival: ival, typ: sqltypes.Int64}, nil + return EvalResult{numval: uint64(ival), typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(string(raw), 10, 64) if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return EvalResult{uval: uval, typ: sqltypes.Uint64}, nil + return EvalResult{numval: uval, typ: sqltypes.Uint64}, nil case v.IsFloat() || v.Type() == sqltypes.Decimal: fval, err := strconv.ParseFloat(string(raw), 64) if err != nil { @@ -171,7 +170,7 @@ func newEvalResult(v sqltypes.Value) (EvalResult, error) { if v.Type() == sqltypes.Decimal { typ = sqltypes.Decimal } - return EvalResult{fval: fval, typ: typ}, nil + return EvalResult{numval: math.Float64bits(fval), typ: typ}, nil default: return EvalResult{typ: v.Type(), bytes: raw}, nil } @@ -186,21 +185,21 @@ func newIntegralNumeric(v sqltypes.Value) (EvalResult, error) { if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return EvalResult{ival: ival, typ: sqltypes.Int64}, nil + return EvalResult{numval: uint64(ival), typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(str, 10, 64) if err != nil { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return EvalResult{uval: uval, typ: sqltypes.Uint64}, nil + return EvalResult{numval: uval, typ: sqltypes.Uint64}, nil } // For other types, do best effort. if ival, err := strconv.ParseInt(str, 10, 64); err == nil { - return EvalResult{ival: ival, typ: sqltypes.Int64}, nil + return EvalResult{numval: uint64(ival), typ: sqltypes.Int64}, nil } if uval, err := strconv.ParseUint(str, 10, 64); err == nil { - return EvalResult{uval: uval, typ: sqltypes.Uint64}, nil + return EvalResult{numval: uval, typ: sqltypes.Uint64}, nil } return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) } @@ -210,33 +209,31 @@ func (v EvalResult) toSQLValue(resultType querypb.Type) sqltypes.Value { case sqltypes.IsSigned(resultType): switch v.typ { case sqltypes.Int64, sqltypes.Int32: - return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.ival), 10)) + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.numval), 10)) case sqltypes.Uint64, sqltypes.Uint32: - return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.uval), 10)) + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.numval), 10)) case sqltypes.Float64, sqltypes.Float32: - return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.fval), 10)) + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(math.Float64frombits(v.numval)), 10)) } case sqltypes.IsUnsigned(resultType): switch v.typ { - case sqltypes.Uint64, sqltypes.Uint32: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.uval), 10)) - case sqltypes.Int64, sqltypes.Int32: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.ival), 10)) + case sqltypes.Uint64, sqltypes.Uint32, sqltypes.Int64, sqltypes.Int32: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.numval), 10)) case sqltypes.Float64, sqltypes.Float32: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.fval), 10)) + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(math.Float64frombits(v.numval)), 10)) } case sqltypes.IsFloat(resultType) || resultType == sqltypes.Decimal: switch v.typ { case sqltypes.Int64, sqltypes.Int32: - return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.ival), 10)) + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.numval), 10)) case sqltypes.Uint64, sqltypes.Uint32: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.uval), 10)) + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.numval), 10)) case sqltypes.Float64, sqltypes.Float32: format := byte('g') if resultType == sqltypes.Decimal { format = 'f' } - return sqltypes.MakeTrusted(resultType, strconv.AppendFloat(nil, float64(v.fval), format, -1, 64)) + return sqltypes.MakeTrusted(resultType, strconv.AppendFloat(nil, math.Float64frombits(v.numval), format, -1, 64)) } default: return sqltypes.MakeTrusted(resultType, v.bytes) @@ -245,15 +242,7 @@ func (v EvalResult) toSQLValue(resultType querypb.Type) sqltypes.Value { } func numericalHashCode(v EvalResult) HashCode { - switch { - case sqltypes.IsSigned(v.typ): - return HashCode(v.ival) - case sqltypes.IsUnsigned(v.typ): - return HashCode(v.uval) - case sqltypes.IsFloat(v.typ) || v.typ == sqltypes.Decimal: - return HashCode(math.Float64bits(v.fval)) - } - panic("BUG: this is not a numerical value") + return HashCode(v.numval) } func compareNumeric(v1, v2 EvalResult) (int, error) { @@ -263,44 +252,44 @@ func compareNumeric(v1, v2 EvalResult) (int, error) { case sqltypes.Int64: switch v2.typ { case sqltypes.Uint64: - if v1.ival < 0 { + if v1.numval > math.MaxInt64 { return -1, nil } - v1 = EvalResult{typ: sqltypes.Uint64, uval: uint64(v1.ival)} + v1 = EvalResult{typ: sqltypes.Uint64, numval: uint64(v1.numval)} case sqltypes.Float64, sqltypes.Decimal: - v1 = EvalResult{typ: v2.typ, fval: float64(v1.ival)} + v1 = EvalResult{typ: v2.typ, numval: math.Float64bits(float64(int64(v1.numval)))} } case sqltypes.Uint64: switch v2.typ { case sqltypes.Int64: - if v2.ival < 0 { + if v2.numval > math.MaxInt64 { return 1, nil } - v2 = EvalResult{typ: sqltypes.Uint64, uval: uint64(v2.ival)} + v2 = EvalResult{typ: sqltypes.Uint64, numval: uint64(v2.numval)} case sqltypes.Float64, sqltypes.Decimal: - v1 = EvalResult{typ: v2.typ, fval: float64(v1.uval)} + v1 = EvalResult{typ: v2.typ, numval: math.Float64bits(float64(v1.numval))} } case sqltypes.Float64: switch v2.typ { case sqltypes.Int64: - v2 = EvalResult{typ: sqltypes.Float64, fval: float64(v2.ival)} + v2 = EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(float64(int64(v2.numval)))} case sqltypes.Uint64: - if v1.fval < 0 { + if math.Float64frombits(v1.numval) < 0 { return -1, nil } - v2 = EvalResult{typ: sqltypes.Float64, fval: float64(v2.uval)} + v2 = EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(float64(v2.numval))} case sqltypes.Decimal: v2.typ = sqltypes.Float64 } case sqltypes.Decimal: switch v2.typ { case sqltypes.Int64: - v2 = EvalResult{typ: sqltypes.Decimal, fval: float64(v2.ival)} + v2 = EvalResult{typ: sqltypes.Decimal, numval: math.Float64bits(float64(int64(v2.numval)))} case sqltypes.Uint64: - if v1.fval < 0 { + if math.Float64frombits(v1.numval) < 0 { return -1, nil } - v2 = EvalResult{typ: sqltypes.Decimal, fval: float64(v2.uval)} + v2 = EvalResult{typ: sqltypes.Decimal, numval: math.Float64bits(float64(v2.numval))} case sqltypes.Float64: v1.typ = sqltypes.Float64 } @@ -309,24 +298,26 @@ func compareNumeric(v1, v2 EvalResult) (int, error) { // Both values are of the same type. switch v1.typ { case sqltypes.Int64: + v1v, v2v := int64(v1.numval), int64(v2.numval) switch { - case v1.ival == v2.ival: + case v1v == v2v: return 0, nil - case v1.ival < v2.ival: + case v1v < v2v: return -1, nil } case sqltypes.Uint64: switch { - case v1.uval == v2.uval: + case v1.numval == v2.numval: return 0, nil - case v1.uval < v2.uval: + case v1.numval < v2.numval: return -1, nil } case sqltypes.Float64, sqltypes.Decimal: + v1v, v2v := math.Float64frombits(v1.numval), math.Float64frombits(v2.numval) switch { - case v1.fval == v2.fval: + case v1v == v2v: return 0, nil - case v1.fval < v2.fval: + case v1v < v2v: return -1, nil } } @@ -416,14 +407,37 @@ func compareDateAndString(l, r EvalResult) (int, error) { return compareGoTimes(lTime, rTime) } +func mergeCollations(left, right EvalResult) (EvalResult, EvalResult, error) { + if !sqltypes.IsText(left.typ) || !sqltypes.IsText(right.typ) { + return left, right, nil + } + env := collations.Local() + tc, coerceLeft, coerceRight, err := env.MergeCollations(left.collation, right.collation, collations.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) + if err != nil { + return EvalResult{}, EvalResult{}, err + } + if coerceLeft != nil { + left.bytes, _ = coerceLeft(nil, left.bytes) + } + if coerceRight != nil { + right.bytes, _ = coerceRight(nil, right.bytes) + } + left.collation = tc + right.collation = tc + return left, right, nil +} + func compareTuples(lVal EvalResult, rVal EvalResult) (int, bool, error) { - if len(lVal.tupleResults) != len(rVal.tupleResults) { - return 0, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.OperandColumns, "Operand should contain %d column(s)", len(lVal.tupleResults)) + if len(*lVal.tuple) != len(*rVal.tuple) { + return 0, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.OperandColumns, "Operand should contain %d column(s)", len(*lVal.tuple)) } hasSeenNull := false - for idx, lResult := range lVal.tupleResults { - rResult := rVal.tupleResults[idx] - res, isNull, err := nullSafeExecuteComparison(lResult, rResult) + for idx, lResult := range *lVal.tuple { + rResult := (*rVal.tuple)[idx] + res, isNull, err := nullSafeCoerceAndCompare(lResult, rResult) if isNull { hasSeenNull = true } @@ -446,19 +460,10 @@ func compareGoTimes(lTime, rTime time.Time) (int, error) { // More on string collations coercibility on MySQL documentation: // - https://dev.mysql.com/doc/refman/8.0/en/charset-collation-coercibility.html -func compareStrings(l, r EvalResult) (int, error) { - // If one of the strings has an unknown collation we fail, though such error should - // already be handled before the execution by the planner. - if l.collation == collations.Unknown || r.collation == collations.Unknown { - return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot compare strings with an unknown collation") - } - - // We cannot compare different collations for now, so we fail - // TODO: support multiple collations comparison - if r.collation != l.collation { - return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot compare strings with different collations") +func compareStrings(l, r EvalResult) int { + if l.collation.Collation != r.collation.Collation { + panic("compareStrings: did not coerce") } - - collation := collations.Default().LookupByID(l.collation) - return collation.Collate(l.bytes, r.bytes, false), nil + collation := collations.Local().LookupByID(l.collation.Collation) + return collation.Collate(l.bytes, r.bytes, false) } diff --git a/go/vt/vtgate/evalengine/expressions.go b/go/vt/vtgate/evalengine/expressions.go index 83dfa590f0f..734e9077ee4 100644 --- a/go/vt/vtgate/evalengine/expressions.go +++ b/go/vt/vtgate/evalengine/expressions.go @@ -18,8 +18,10 @@ package evalengine import ( "fmt" + "math" "strconv" "strings" + "unicode/utf8" "vitess.io/vitess/go/mysql/collations" @@ -32,13 +34,11 @@ import ( type ( EvalResult struct { - typ querypb.Type - ival int64 - uval uint64 - fval float64 - bytes []byte - collation collations.ID - tupleResults []EvalResult + typ querypb.Type + collation collations.TypedCollation + numval uint64 + bytes []byte + tuple *[]EvalResult } // ExpressionEnv contains the environment that the expression @@ -50,41 +50,120 @@ type ( // Expr is the interface that all evaluating expressions must implement Expr interface { - Evaluate(env ExpressionEnv) (EvalResult, error) - Type(env ExpressionEnv) (querypb.Type, error) - String() string + Evaluate(env *ExpressionEnv) (EvalResult, error) + Type(env *ExpressionEnv) (querypb.Type, error) + Collation() collations.TypedCollation + format(buf *strings.Builder, wrap bool) } - // Expressions - Null struct{} Literal struct { - Val EvalResult - Collation collations.ID + Val EvalResult } BindVariable struct { Key string - Collation collations.ID + collation collations.TypedCollation } Column struct { Offset int - Collation collations.ID + collation collations.TypedCollation + } + TupleExpr []Expr + CollateExpr struct { + Expr Expr + TypedCollation collations.TypedCollation } - Tuple []Expr ) -var _ Expr = (*Null)(nil) +func (t TupleExpr) Collation() collations.TypedCollation { + // a Tuple does not have a collation, but an individual collation for every element of the tuple + return collations.TypedCollation{} +} + +func FormatExpr(expr Expr) string { + var bld strings.Builder + expr.format(&bld, false) + return bld.String() +} + +func (l *Literal) format(w *strings.Builder, _ bool) { + w.WriteString(l.Val.Value().String()) +} +func (bv *BindVariable) format(w *strings.Builder, _ bool) { + w.WriteByte(':') + w.WriteString(bv.Key) +} +func (c *Column) format(w *strings.Builder, _ bool) { + fmt.Fprintf(w, "[COLUMN %d]", c.Offset) +} +func (b *BinaryExpr) format(w *strings.Builder, wrap bool) { + if wrap { + w.WriteByte('(') + } + + b.Left.format(w, true) + w.WriteString(" ") + w.WriteString(b.Op.String()) + w.WriteString(" ") + b.Right.format(w, true) + + if wrap { + w.WriteByte(')') + } +} +func (c *ComparisonExpr) format(w *strings.Builder, wrap bool) { + if wrap { + w.WriteByte('(') + } + + c.Left.format(w, true) + w.WriteString(" ") + w.WriteString(c.Op.String()) + w.WriteString(" ") + c.Right.format(w, true) + + if wrap { + w.WriteByte(')') + } +} +func (t TupleExpr) format(w *strings.Builder, wrap bool) { + w.WriteByte('(') + for i, expr := range t { + if i > 0 { + w.WriteString(", ") + } + expr.format(w, wrap) + } + w.WriteByte(')') +} +func (c *CollateExpr) format(w *strings.Builder, wrap bool) { + c.Expr.format(w, wrap) + coll := collations.Local().LookupByID(c.TypedCollation.Collation) + fmt.Fprintf(w, " COLLATE %s", coll.Name()) +} + var _ Expr = (*Literal)(nil) var _ Expr = (*BindVariable)(nil) var _ Expr = (*Column)(nil) var _ Expr = (*BinaryExpr)(nil) var _ Expr = (*ComparisonExpr)(nil) -var _ Expr = (Tuple)(nil) +var _ Expr = (TupleExpr)(nil) +var _ Expr = (*CollateExpr)(nil) // Value allows for retrieval of the value we expose for public consumption func (e EvalResult) Value() sqltypes.Value { return e.toSQLValue(e.typ) } +var collationNull = collations.TypedCollation{ + Collation: collations.CollationBinaryID, + Coercibility: collations.CoerceIgnorable, + Repertoire: collations.RepertoireASCII, +} + +func NewLiteralNull() Expr { + return &Literal{Val: EvalResult{typ: querypb.Type_NULL_TYPE, collation: collationNull}} +} + // NewLiteralIntFromBytes returns a literal expression func NewLiteralIntFromBytes(val []byte) (Expr, error) { ival, err := strconv.ParseInt(string(val), 10, 64) @@ -94,14 +173,20 @@ func NewLiteralIntFromBytes(val []byte) (Expr, error) { return NewLiteralInt(ival), nil } +var collationNumeric = collations.TypedCollation{ + Collation: collations.CollationBinaryID, + Coercibility: collations.CoerceNumeric, + Repertoire: collations.RepertoireASCII, +} + // NewLiteralInt returns a literal expression func NewLiteralInt(i int64) Expr { - return &Literal{Val: EvalResult{typ: sqltypes.Int64, ival: i}} + return &Literal{Val: EvalResult{typ: sqltypes.Int64, numval: uint64(i), collation: collationNumeric}} } // NewLiteralFloat returns a literal expression func NewLiteralFloat(val float64) Expr { - return &Literal{Val: EvalResult{typ: sqltypes.Float64, fval: val}} + return &Literal{Val: EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(val), collation: collationNumeric}} } // NewLiteralFloatFromBytes returns a float literal expression from a slice of bytes @@ -110,68 +195,63 @@ func NewLiteralFloatFromBytes(val []byte) (Expr, error) { if err != nil { return nil, err } - return &Literal{Val: EvalResult{typ: sqltypes.Float64, fval: fval}}, nil + return &Literal{Val: EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(fval)}}, nil } // NewLiteralString returns a literal expression -func NewLiteralString(val []byte, collation collations.ID) Expr { - return &Literal{Val: EvalResult{typ: sqltypes.VarBinary, bytes: val}, Collation: collation} +func NewLiteralString(val []byte, collation collations.TypedCollation) Expr { + collation.Repertoire = collations.RepertoireASCII + for _, b := range val { + if b >= utf8.RuneSelf { + collation.Repertoire = collations.RepertoireUnicode + break + } + } + return &Literal{Val: EvalResult{typ: sqltypes.VarBinary, bytes: val, collation: collation}} } // NewBindVar returns a bind variable -func NewBindVar(key string, collation collations.ID) Expr { +func NewBindVar(key string, collation collations.TypedCollation) Expr { return &BindVariable{ Key: key, - Collation: collation, + collation: collation, } } // NewColumn returns a bind variable -func NewColumn(offset int, collation collations.ID) Expr { +func NewColumn(offset int, collation collations.TypedCollation) Expr { return &Column{ Offset: offset, - Collation: collation, + collation: collation, } } // Evaluate implements the Expr interface -func (n Null) Evaluate(ExpressionEnv) (EvalResult, error) { - return EvalResult{}, nil -} - -// Type implements the Expr interface -func (n Null) Type(ExpressionEnv) (querypb.Type, error) { - return querypb.Type_NULL_TYPE, nil -} - -// String implements the Expr interface -func (n Null) String() string { - return "null" +func (l *Literal) Evaluate(*ExpressionEnv) (EvalResult, error) { + return l.Val, nil } -// Evaluate implements the Expr interface -func (l *Literal) Evaluate(ExpressionEnv) (EvalResult, error) { - eval := l.Val - eval.collation = l.Collation - return eval, nil +func (l *Literal) Collation() collations.TypedCollation { + return l.Val.collation } -// Evaluate implements the Expr interface -func (t Tuple) Evaluate(env ExpressionEnv) (EvalResult, error) { - var res EvalResult - res.typ = querypb.Type_TUPLE +func (t TupleExpr) Evaluate(env *ExpressionEnv) (EvalResult, error) { + var tup []EvalResult for _, expr := range t { evalRes, err := expr.Evaluate(env) if err != nil { return EvalResult{}, err } - res.tupleResults = append(res.tupleResults, evalRes) + tup = append(tup, evalRes) } - return res, nil + return EvalResult{ + typ: querypb.Type_TUPLE, + tuple: &tup, + }, nil } // Evaluate implements the Expr interface -func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { +func (b *BindVariable) Evaluate(env *ExpressionEnv) (EvalResult, error) { val, ok := env.BindVars[b.Key] if !ok { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Bind variable not found") @@ -180,20 +260,28 @@ func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { if err != nil { return EvalResult{}, err } - eval.collation = b.Collation + eval.collation = b.collation return eval, nil } +func (b *BindVariable) Collation() collations.TypedCollation { + return b.collation +} + // Evaluate implements the Expr interface -func (c *Column) Evaluate(env ExpressionEnv) (EvalResult, error) { +func (c *Column) Evaluate(env *ExpressionEnv) (EvalResult, error) { value := env.Row[c.Offset] numeric, err := newEvalResult(value) - numeric.collation = c.Collation + numeric.collation = c.collation return numeric, err } +func (c *Column) Collation() collations.TypedCollation { + return c.collation +} + // Type implements the Expr interface -func (b *BindVariable) Type(env ExpressionEnv) (querypb.Type, error) { +func (b *BindVariable) Type(env *ExpressionEnv) (querypb.Type, error) { e := env.BindVars v, found := e[b.Key] if !found { @@ -203,44 +291,19 @@ func (b *BindVariable) Type(env ExpressionEnv) (querypb.Type, error) { } // Type implements the Expr interface -func (l *Literal) Type(ExpressionEnv) (querypb.Type, error) { +func (l *Literal) Type(*ExpressionEnv) (querypb.Type, error) { return l.Val.typ, nil } // Type implements the Expr interface -func (t Tuple) Type(env ExpressionEnv) (querypb.Type, error) { +func (t TupleExpr) Type(*ExpressionEnv) (querypb.Type, error) { return querypb.Type_TUPLE, nil } -// Type implements the Expr interface -func (c *Column) Type(ExpressionEnv) (querypb.Type, error) { +func (c *Column) Type(*ExpressionEnv) (querypb.Type, error) { return sqltypes.Float64, nil } -// String implements the Expr interface -func (b *BindVariable) String() string { - return ":" + b.Key -} - -// String implements the Expr interface -func (l *Literal) String() string { - return l.Val.Value().String() -} - -// String implements the Expr interface -func (t Tuple) String() string { - var stringSlice []string - for _, expr := range t { - stringSlice = append(stringSlice, expr.String()) - } - return "(" + strings.Join(stringSlice, ",") + ")" -} - -// String implements the Expr interface -func (c *Column) String() string { - return fmt.Sprintf("column %d from the input", c.Offset) -} - func mergeNumericalTypes(ltype, rtype querypb.Type) querypb.Type { switch ltype { case sqltypes.Int64: @@ -262,25 +325,25 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { if err != nil { ival = 0 } - return EvalResult{typ: sqltypes.Int64, ival: ival}, nil + return EvalResult{typ: sqltypes.Int64, numval: uint64(ival)}, nil case sqltypes.Int32: ival, err := strconv.ParseInt(string(val.Value), 10, 32) if err != nil { ival = 0 } - return EvalResult{typ: sqltypes.Int32, ival: ival}, nil + return EvalResult{typ: sqltypes.Int32, numval: uint64(ival)}, nil case sqltypes.Uint64: uval, err := strconv.ParseUint(string(val.Value), 10, 64) if err != nil { uval = 0 } - return EvalResult{typ: sqltypes.Uint64, uval: uval}, nil + return EvalResult{typ: sqltypes.Uint64, numval: uval}, nil case sqltypes.Float64: fval, err := strconv.ParseFloat(string(val.Value), 64) if err != nil { fval = 0 } - return EvalResult{typ: sqltypes.Float64, fval: fval}, nil + return EvalResult{typ: sqltypes.Float64, numval: math.Float64bits(fval)}, nil case sqltypes.VarChar, sqltypes.Text, sqltypes.VarBinary: return EvalResult{typ: sqltypes.VarBinary, bytes: val.Value}, nil case sqltypes.Time, sqltypes.Datetime, sqltypes.Timestamp, sqltypes.Date: @@ -293,5 +356,25 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { // debugString prints the entire EvalResult in a debug format func (e *EvalResult) debugString() string { - return fmt.Sprintf("(%s) %d %d %f %s", querypb.Type_name[int32(e.typ)], e.ival, e.uval, e.fval, string(e.bytes)) + return fmt.Sprintf("(%s) 0x%08x %s", querypb.Type_name[int32(e.typ)], e.numval, e.bytes) +} + +func (c *CollateExpr) Evaluate(env *ExpressionEnv) (EvalResult, error) { + res, err := c.Expr.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + if err := collations.Local().EnsureCollate(res.collation.Collation, c.TypedCollation.Collation); err != nil { + return EvalResult{}, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, err.Error()) + } + res.collation = c.TypedCollation + return res, nil +} + +func (c *CollateExpr) Type(env *ExpressionEnv) (querypb.Type, error) { + return c.Expr.Type(env) +} + +func (c *CollateExpr) Collation() collations.TypedCollation { + return c.TypedCollation } diff --git a/go/vt/vtgate/evalengine/hash_code_test.go b/go/vt/vtgate/evalengine/hash_code_test.go index 21df5bffba7..e36b31bcfec 100644 --- a/go/vt/vtgate/evalengine/hash_code_test.go +++ b/go/vt/vtgate/evalengine/hash_code_test.go @@ -33,28 +33,26 @@ import ( func TestHashCodesRandom(t *testing.T) { tested := 0 equal := 0 - collation := collations.Default().LookupByName("utf8mb4_general_ci").ID() + collation := collations.Local().LookupByName("utf8mb4_general_ci").ID() endTime := time.Now().Add(1 * time.Second) for time.Now().Before(endTime) { - t.Run(fmt.Sprintf("test %d", tested), func(t *testing.T) { - tested++ - v1, v2 := randomValues() - cmp, err := NullsafeCompare(v1, v2, collation) - require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) - typ, err := CoerceTo(v1.Type(), v2.Type()) - require.NoError(t, err) - - hash1, err := NullsafeHashcode(v1, collation, typ) - require.NoError(t, err) - hash2, err := NullsafeHashcode(v2, collation, typ) - require.NoError(t, err) - if cmp == 0 { - equal++ - require.Equalf(t, hash1, hash2, "values %s and %s are considered equal but produce different hash codes: %d & %d", v1.String(), v2.String(), hash1, hash2) - } - }) + tested++ + v1, v2 := randomValues() + cmp, err := NullsafeCompare(v1, v2, collation) + require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) + typ, err := CoerceTo(v1.Type(), v2.Type()) + require.NoError(t, err) + + hash1, err := NullsafeHashcode(v1, collation, typ) + require.NoError(t, err) + hash2, err := NullsafeHashcode(v2, collation, typ) + require.NoError(t, err) + if cmp == 0 { + equal++ + require.Equalf(t, hash1, hash2, "values %s and %s are considered equal but produce different hash codes: %d & %d", v1.String(), v2.String(), hash1, hash2) + } } - fmt.Printf("tested %d values, with %d equalities found\n", tested, equal) + t.Logf("tested %d values, with %d equalities found\n", tested, equal) } func randomValues() (sqltypes.Value, sqltypes.Value) { diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 878285dec69..cf19734d8ab 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -162,7 +162,7 @@ func NewExecutor(ctx context.Context, serv srvtopo.Server, cell string, resolver func (e *Executor) Execute(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) { span, ctx := trace.NewSpan(ctx, "executor.Execute") span.Annotate("method", method) - trace.AnnotateSQL(span, sql) + trace.AnnotateSQL(span, sqlparser.Preview(sql)) defer span.Finish() logStats := NewLogStats(ctx, method, sql, bindVars) @@ -218,7 +218,7 @@ func (e *Executor) StreamExecute( ) error { span, ctx := trace.NewSpan(ctx, "executor.StreamExecute") span.Annotate("method", method) - trace.AnnotateSQL(span, sql) + trace.AnnotateSQL(span, sqlparser.Preview(sql)) defer span.Finish() logStats := NewLogStats(ctx, method, sql, bindVars) diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 468face549c..90098f0c28e 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -20,6 +20,7 @@ import ( "errors" "sort" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/vtgate/semantics" @@ -53,6 +54,7 @@ type ContextVSchema interface { GetSemTable() *semantics.SemTable Planner() PlannerVersion SetPlannerVersion(pv PlannerVersion) + ConnCollation() collations.ID // ErrorIfShardedF will return an error if the keyspace is sharded, // and produce a warning if the vtgate if configured to do so diff --git a/go/vt/vtgate/planbuilder/collations_test.go b/go/vt/vtgate/planbuilder/collations_test.go index ea3b0d6fe00..e86cf6aaff5 100644 --- a/go/vt/vtgate/planbuilder/collations_test.go +++ b/go/vt/vtgate/planbuilder/collations_test.go @@ -65,7 +65,7 @@ func (tc *collationTestCase) addCollationsToSchema(vschema *vschemaWrapper) { func TestOrderedAggregateCollations(t *testing.T) { collid := func(collname string) collations.ID { - return collations.Default().LookupByName(collname).ID() + return collations.Local().LookupByName(collname).ID() } testCases := []collationTestCase{ { diff --git a/go/vt/vtgate/planbuilder/expression_converter.go b/go/vt/vtgate/planbuilder/expression_converter.go index 1bcde04bbc2..88528a85161 100644 --- a/go/vt/vtgate/planbuilder/expression_converter.go +++ b/go/vt/vtgate/planbuilder/expression_converter.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -62,7 +63,8 @@ func booleanValues(astExpr sqlparser.Expr) evalengine.Expr { func identifierAsStringValue(astExpr sqlparser.Expr) evalengine.Expr { colName, isColName := astExpr.(*sqlparser.ColName) if isColName { - return evalengine.NewLiteralString([]byte(colName.Name.Lowered()), 0) + // TODO@collations: proper collation for column name + return evalengine.NewLiteralString([]byte(colName.Name.Lowered()), collations.TypedCollation{}) } return nil } diff --git a/go/vt/vtgate/planbuilder/expression_converter_test.go b/go/vt/vtgate/planbuilder/expression_converter_test.go index 86ad0061c2a..b1d4e9b851a 100644 --- a/go/vt/vtgate/planbuilder/expression_converter_test.go +++ b/go/vt/vtgate/planbuilder/expression_converter_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/evalengine" ) @@ -40,7 +41,7 @@ func TestConversion(t *testing.T) { expressionsOut: e(evalengine.NewLiteralInt(1)), }, { expressionsIn: "@@foo", - expressionsOut: e(evalengine.NewColumn(0, 0)), + expressionsOut: e(evalengine.NewColumn(0, collations.TypedCollation{})), }} for _, tc := range queries { diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 4a6cd8afafc..b06cfd193d3 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -28,6 +28,7 @@ import ( "strings" "testing" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/semantics" "github.com/google/go-cmp/cmp" @@ -478,6 +479,10 @@ type vschemaWrapper struct { version PlannerVersion } +func (vw *vschemaWrapper) ConnCollation() collations.ID { + return collations.Unknown +} + func (vw *vschemaWrapper) PlannerWarning(_ string) { } diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index a490bc052e4..74182ce5dae 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -21,6 +21,7 @@ import ( "sort" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/planbuilder/abstract" @@ -1031,14 +1032,14 @@ func canSelectDBAMerge(a, b *route) bool { // Inner might end up throwing an error at runtime, but if it doesn't then it is safe to merge. for _, aExpr := range a.eroute.SysTableTableSchema { for _, bExpr := range b.eroute.SysTableTableSchema { - if aExpr.String() == bExpr.String() { + if evalengine.FormatExpr(aExpr) == evalengine.FormatExpr(bExpr) { return true } } } for _, aExpr := range a.eroute.SysTableTableName { for _, bExpr := range b.eroute.SysTableTableName { - if aExpr.String() == bExpr.String() { + if evalengine.FormatExpr(aExpr) == evalengine.FormatExpr(bExpr) { return true } } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 94a003819ac..11c8c742254 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -1659,7 +1659,7 @@ Gen4 plan same as above "42 + 2" ], "Expressions": [ - "INT64(42) + INT64(2)" + "INT64(44)" ], "Inputs": [ { diff --git a/go/vt/vtgate/planbuilder/testdata/set_cases.txt b/go/vt/vtgate/planbuilder/testdata/set_cases.txt index d691e8162cd..eb8fd6e78c3 100644 --- a/go/vt/vtgate/planbuilder/testdata/set_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/set_cases.txt @@ -88,7 +88,7 @@ Gen4 plan same as above { "Type": "UserDefinedVariable", "Name": "foo", - "Expr": "column 0 from the input" + "Expr": "[COLUMN 0]" } ], "Inputs": [ diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 15f725e1b1c..f81c005ee84 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -17,6 +17,7 @@ limitations under the License. package vtgate import ( + "context" "flag" "fmt" "net" @@ -32,8 +33,6 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" - "context" - "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/trace" @@ -155,7 +154,7 @@ func startSpanTestable(ctx context.Context, query, label string, match := r.FindStringSubmatch(comments.Leading) span, ctx := getSpan(ctx, match, newSpan, label, newSpanFromString) - trace.AnnotateSQL(span, query) + trace.AnnotateSQL(span, sqlparser.Preview(query)) return span, ctx, nil } @@ -431,8 +430,8 @@ func initMySQLProtocol() { if err != nil { log.Exitf("mysql.NewListener failed: %v", err) } - if *sqlparser.MySQLServerVersion != "" { - mysqlListener.ServerVersion = *sqlparser.MySQLServerVersion + if *servenv.MySQLServerVersion != "" { + mysqlListener.ServerVersion = *servenv.MySQLServerVersion } if *mysqlSslCert != "" && *mysqlSslKey != "" { tlsVersion, err := vttls.TLSVersionToNumber(*mysqlTLSMinVersion) diff --git a/go/vt/vtgate/semantics/FakeSI.go b/go/vt/vtgate/semantics/FakeSI.go index 7ecf958a580..1c2071f26ce 100644 --- a/go/vt/vtgate/semantics/FakeSI.go +++ b/go/vt/vtgate/semantics/FakeSI.go @@ -17,6 +17,7 @@ limitations under the License. package semantics import ( + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/sqlparser" @@ -39,3 +40,7 @@ func (s *FakeSI) FindTableOrVindex(tablename sqlparser.TableName) (*vindexes.Tab } return nil, s.VindexTables[sqlparser.String(tablename)], "", 0, nil, nil } + +func (FakeSI) ConnCollation() collations.ID { + return 45 +} diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index ad752071da0..e45f30dc03a 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -17,6 +17,7 @@ limitations under the License. package semantics import ( + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/vindexes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -68,12 +69,12 @@ func Analyze(statement sqlparser.SelectStatement, currentDb string, si SchemaInf } // Creation of the semantic table - semTable := analyzer.newSemTable(statement) + semTable := analyzer.newSemTable(statement, si.ConnCollation()) return semTable, nil } -func (a analyzer) newSemTable(statement sqlparser.SelectStatement) *SemTable { +func (a analyzer) newSemTable(statement sqlparser.SelectStatement, coll collations.ID) *SemTable { return &SemTable{ Recursive: a.binder.recursive, Direct: a.binder.direct, @@ -86,6 +87,7 @@ func (a analyzer) newSemTable(statement sqlparser.SelectStatement) *SemTable { SubqueryMap: a.binder.subqueryMap, SubqueryRef: a.binder.subqueryRef, ColumnEqualities: map[columnName][]sqlparser.Expr{}, + DefaultCollation: coll, } } diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index 2c05b380880..d228c77a7d3 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -106,7 +106,7 @@ func vindexTableToColumnInfo(tbl *vindexes.Table) []ColumnInfo { for _, col := range tbl.Columns { var collation collations.ID if sqltypes.IsText(col.Type) { - collation, _ = collations.Default().LookupID(col.CollationName) + collation, _ = collations.Local().LookupID(col.CollationName) } cols = append(cols, ColumnInfo{ Name: col.Name.String(), diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 90995e0815d..07033d8ba7d 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -92,6 +92,10 @@ type ( // if a == b and b == c then a == c ColumnEqualities map[columnName][]sqlparser.Expr + // DefaultCollation is the default collation for this query, which is usually + // inherited from the connection's default collation. + DefaultCollation collations.ID + Warning string } @@ -103,6 +107,7 @@ type ( // SchemaInformation is used tp provide table information from Vschema. SchemaInformation interface { FindTableOrVindex(tablename sqlparser.TableName) (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) + ConnCollation() collations.ID } ) @@ -214,7 +219,7 @@ func (st *SemTable) CollationFor(e sqlparser.Expr) collations.ID { if found { return typ.Collation } - return collations.Unknown + return st.DefaultCollation } // dependencies return the table dependencies of the expression. This method finds table dependencies recursively diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index a8f98a681cc..6f8656e16f7 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -83,18 +83,17 @@ type VSchemaOperator interface { // vcursorImpl implements the VCursor functionality used by dependent // packages to call back into VTGate. type vcursorImpl struct { - ctx context.Context - safeSession *SafeSession - keyspace string - tabletType topodatapb.TabletType - destination key.Destination - marginComments sqlparser.MarginComments - executor iExecute - resolver *srvtopo.Resolver - topoServer *topo.Server - logStats *LogStats - collationEnvironment *collations.Environment - collation collations.ID + ctx context.Context + safeSession *SafeSession + keyspace string + tabletType topodatapb.TabletType + destination key.Destination + marginComments sqlparser.MarginComments + executor iExecute + resolver *srvtopo.Resolver + topoServer *topo.Server + logStats *LogStats + collation collations.ID // rollbackOnPartialExec is set to true if any DML was successfully // executed. If there was a subsequent failure, the transaction @@ -145,7 +144,7 @@ func newVCursorImpl( } // we only support collations for the new TabletGateway implementation - collationEnv := collations.NewEnvironment(*sqlparser.MySQLServerVersion) + collationEnv := collations.Local() var connCollation collations.ID if executor != nil { if gw, isTabletGw := executor.resolver.resolver.GetGateway().(*TabletGateway); isTabletGw { @@ -161,27 +160,26 @@ func newVCursorImpl( } return &vcursorImpl{ - ctx: ctx, - safeSession: safeSession, - keyspace: keyspace, - tabletType: tabletType, - destination: destination, - marginComments: marginComments, - executor: executor, - logStats: logStats, - collationEnvironment: collationEnv, - collation: connCollation, - resolver: resolver, - vschema: vschema, - vm: vm, - topoServer: ts, - warnShardedOnly: warnShardedOnly, + ctx: ctx, + safeSession: safeSession, + keyspace: keyspace, + tabletType: tabletType, + destination: destination, + marginComments: marginComments, + executor: executor, + logStats: logStats, + collation: connCollation, + resolver: resolver, + vschema: vschema, + vm: vm, + topoServer: ts, + warnShardedOnly: warnShardedOnly, }, nil } // ConnCollation returns the collation of this session -func (vc *vcursorImpl) ConnCollation() collations.Collation { - return vc.collationEnvironment.LookupByID(vc.collation) +func (vc *vcursorImpl) ConnCollation() collations.ID { + return vc.collation } // Context returns the current Context. diff --git a/go/vt/vtgate/vindexes/cached_size.go b/go/vt/vtgate/vindexes/cached_size.go index 931cf2be124..9c8fe3da6aa 100644 --- a/go/vt/vtgate/vindexes/cached_size.go +++ b/go/vt/vtgate/vindexes/cached_size.go @@ -404,7 +404,9 @@ func (cached *Table) CachedSize(alloc bool) int64 { } } // field Pinned []byte - size += hack.RuntimeAllocSize(int64(cap(cached.Pinned))) + { + size += hack.RuntimeAllocSize(int64(cap(cached.Pinned))) + } return size } func (cached *UnicodeLooseMD5) CachedSize(alloc bool) int64 { diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index 2a32ec4718b..f785a142603 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -17,6 +17,7 @@ limitations under the License. package connpool import ( + "context" "fmt" "strings" "sync" @@ -26,8 +27,6 @@ import ( "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" - "context" - "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/sync2" @@ -195,7 +194,7 @@ func (dbc *DBConn) FetchNext(ctx context.Context, maxrows int, wantfields bool) // Stream executes the query and streams the results. func (dbc *DBConn) Stream(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int, includedFields querypb.ExecuteOptions_IncludedFields) error { span, ctx := trace.NewSpan(ctx, "DBConn.Stream") - trace.AnnotateSQL(span, query) + trace.AnnotateSQL(span, sqlparser.Preview(query)) defer span.Finish() resultSent := false diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 63b7ce346a8..70734664d94 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -928,7 +928,7 @@ func (qre *QueryExecutor) execStatefulConn(conn *StatefulConnection, sql string, func (qre *QueryExecutor) execStreamSQL(conn *connpool.DBConn, sql string, callback func(*sqltypes.Result) error) error { span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.execStreamSQL") - trace.AnnotateSQL(span, sql) + trace.AnnotateSQL(span, sqlparser.Preview(sql)) callBackClosingSpan := func(result *sqltypes.Result) error { defer span.Finish() return callback(result) diff --git a/go/vt/vttablet/tabletserver/schema/cached_size.go b/go/vt/vttablet/tabletserver/schema/cached_size.go index 9c1d38cbc84..63779b0ea34 100644 --- a/go/vt/vttablet/tabletserver/schema/cached_size.go +++ b/go/vt/vttablet/tabletserver/schema/cached_size.go @@ -36,16 +36,6 @@ func (cached *MessageInfo) CachedSize(alloc bool) int64 { } return size } -func (cached *SequenceInfo) CachedSize(alloc bool) int64 { - if cached == nil { - return int64(0) - } - size := int64(0) - if alloc { - size += int64(24) - } - return size -} func (cached *Table) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -68,7 +58,9 @@ func (cached *Table) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(cap(cached.PKColumns)) * int64(8)) } // field SequenceInfo *vitess.io/vitess/go/vt/vttablet/tabletserver/schema.SequenceInfo - size += cached.SequenceInfo.CachedSize(true) + if cached.SequenceInfo != nil { + size += hack.RuntimeAllocSize(int64(24)) + } // field MessageInfo *vitess.io/vitess/go/vt/vttablet/tabletserver/schema.MessageInfo size += cached.MessageInfo.CachedSize(true) return size diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index e19b141f54f..97457fd0550 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -685,7 +685,7 @@ func (tsv *TabletServer) ReadTransaction(ctx context.Context, target *querypb.Ta // Execute executes the query and returns the result as response. func (tsv *TabletServer) Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (result *sqltypes.Result, err error) { span, ctx := trace.NewSpan(ctx, "TabletServer.Execute") - trace.AnnotateSQL(span, sql) + trace.AnnotateSQL(span, sqlparser.Preview(sql)) defer span.Finish() if transactionID != 0 && reservedID != 0 && transactionID != reservedID { @@ -1365,7 +1365,7 @@ func (tsv *TabletServer) execRequest( if options != nil { span.Annotate("isolation-level", options.TransactionIsolation) } - trace.AnnotateSQL(span, sql) + trace.AnnotateSQL(span, sqlparser.Preview(sql)) if target != nil { span.Annotate("cell", target.Cell) span.Annotate("shard", target.Shard) diff --git a/go/vt/vttest/vtprocess.go b/go/vt/vttest/vtprocess.go index 30c061280ab..15a31d40bb7 100644 --- a/go/vt/vttest/vtprocess.go +++ b/go/vt/vttest/vtprocess.go @@ -27,8 +27,6 @@ import ( "syscall" "time" - "vitess.io/vitess/go/vt/sqlparser" - "google.golang.org/protobuf/encoding/prototext" "vitess.io/vitess/go/vt/log" @@ -261,8 +259,8 @@ func VtcomboProcess(env Environment, args *Config, mysql MySQLManager) *VtProces if args.VSchemaDDLAuthorizedUsers != "" { vt.ExtraArgs = append(vt.ExtraArgs, []string{"-vschema_ddl_authorized_users", args.VSchemaDDLAuthorizedUsers}...) } - if *sqlparser.MySQLServerVersion != "" { - vt.ExtraArgs = append(vt.ExtraArgs, "-mysql_server_version", *sqlparser.MySQLServerVersion) + if *servenv.MySQLServerVersion != "" { + vt.ExtraArgs = append(vt.ExtraArgs, "-mysql_server_version", *servenv.MySQLServerVersion) } if socket != "" { diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 60027fc656c..c3a23d355d5 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -1148,8 +1148,8 @@ type contextVCursor struct { ctx context.Context } -func (vc *contextVCursor) ConnCollation() collations.Collation { - panic("implement me") +func (vc *contextVCursor) ConnCollation() collations.ID { + return collations.CollationBinaryID } func (vc *contextVCursor) ExecutePrimitive(primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {