Skip to content

Commit

Permalink
Implement temporal comparisons
Browse files Browse the repository at this point in the history
This is currently missing and leads to incorrect types to be returned
for `LEAST` & `GREATEST` as comparison functions.

There's a little mismatch here in behavior compared to MySQL which I
argue is actually a bug in MySQL. In MySQL, a temporal type always has
the binary collation:

```
mysql> select NOW(6), collation(NOW(6));
+----------------------------+-------------------+
| NOW(6)                     | collation(NOW(6)) |
+----------------------------+-------------------+
| 2025-02-19 15:33:21.732301 | binary            |
+----------------------------+-------------------+
1 row in set (0.00 sec)

```

On MySQL 8.4, this results in:

```
mysql> select GREATEST(NOW(6), NOW(6)), collation(GREATEST(NOW(6), NOW(6)));
+----------------------------+-------------------------------------+
| GREATEST(NOW(6), NOW(6))   | collation(GREATEST(NOW(6), NOW(6))) |
+----------------------------+-------------------------------------+
| 2025-02-19 15:35:00.921308 | latin1_swedish_ci                   |
+----------------------------+-------------------------------------+
1 row in set (0.00 sec)
```

But on MySQL 8.0, it returns:

```
mysql> select GREATEST(NOW(6), NOW(6)), collation(GREATEST(NOW(6), NOW(6)));
+----------------------------+-------------------------------------+
| GREATEST(NOW(6), NOW(6))   | collation(GREATEST(NOW(6), NOW(6))) |
+----------------------------+-------------------------------------+
| 2025-02-19 15:35:00.921308 | utf8mb4_0900_ai_ci                  |
+----------------------------+-------------------------------------+
1 row in set (0.00 sec)
```

Neither of these collations make sense, because it really should not
change the collation and return `binary` still. That is what Vitess
still does with the changes here (hence the addition to the test
framework to allow skipping the collation check).

I'll also report the issue upstream to make it behave correctly there as
well.

Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
dbussink committed Feb 19, 2025
1 parent 2118bc3 commit a693845
Show file tree
Hide file tree
Showing 15 changed files with 911 additions and 385 deletions.
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 34 additions & 1 deletion go/vt/vtgate/evalengine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype {
c.asm.Convert_id(offset)
case sqltypes.Uint64:
c.asm.Convert_ud(offset)
case sqltypes.Datetime, sqltypes.Time:
case sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp:
scale = ct.Size
size = ct.Size + decimalSizeBase
fallthrough
Expand All @@ -345,6 +345,28 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype {
return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Scale: scale, Size: size}
}

func (c *compiler) compileToTemporal(doct ctype, typ sqltypes.Type, offset, prec int) ctype {
switch doct.Type {
case typ:
if int(doct.Size) == prec {
return doct
}
fallthrough
default:
switch typ {
case sqltypes.Date:
c.asm.Convert_xD(offset, c.sqlmode.AllowZeroDate())
case sqltypes.Datetime:
c.asm.Convert_xDT(offset, prec, c.sqlmode.AllowZeroDate())
case sqltypes.Timestamp:
c.asm.Convert_xDTs(offset, prec, c.sqlmode.AllowZeroDate())
case sqltypes.Time:
c.asm.Convert_xT(offset, prec)
}
}
return ctype{Type: typ, Col: collationBinary, Flag: flagNullable}
}

func (c *compiler) compileToDate(doct ctype, offset int) ctype {
switch doct.Type {
case sqltypes.Date:
Expand All @@ -366,6 +388,17 @@ func (c *compiler) compileToDateTime(doct ctype, offset, prec int) ctype {
return ctype{Type: sqltypes.Datetime, Size: int32(prec), Col: collationBinary, Flag: flagNullable}
}

func (c *compiler) compileToTimestamp(doct ctype, offset, prec int) ctype {
switch doct.Type {
case sqltypes.Timestamp:
c.asm.Convert_tp(offset, prec)
return doct
default:
c.asm.Convert_xDTs(offset, prec, c.sqlmode.AllowZeroDate())
}
return ctype{Type: sqltypes.Timestamp, Size: int32(prec), Col: collationBinary, Flag: flagNullable}
}

func (c *compiler) compileToTime(doct ctype, offset, prec int) ctype {
switch doct.Type {
case sqltypes.Time:
Expand Down
53 changes: 51 additions & 2 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,11 @@ func (asm *assembler) CmpDates() {
}, "CMP DATE(SP-2), DATE(SP-1)")
}

func (asm *assembler) Collate(col collations.ID) {
func (asm *assembler) Collate(col collations.TypedCollation) {
asm.emit(func(env *ExpressionEnv) int {
a := env.vm.stack[env.vm.sp-1].(*evalBytes)
a.tt = int16(sqltypes.VarChar)
a.col.Collation = col
a.col = col
return 1
}, "COLLATE VARCHAR(SP-1), %d", col)
}
Expand Down Expand Up @@ -1170,6 +1170,21 @@ func (asm *assembler) Convert_xDT(offset, prec int, allowZero bool) {
}, "CONV (SP-%d), DATETIME", offset)
}

func (asm *assembler) Convert_xDTs(offset, prec int, allowZero bool) {
asm.emit(func(env *ExpressionEnv) int {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
dt := evalToTimestamp(env.vm.stack[env.vm.sp-offset], prec, env.now, allowZero)
if dt == nil {
env.vm.stack[env.vm.sp-offset] = nil
} else {
env.vm.stack[env.vm.sp-offset] = dt
}
return 1
}, "CONV (SP-%d), TIMESTAMP", offset)
}

func (asm *assembler) Convert_xT(offset, prec int) {
asm.emit(func(env *ExpressionEnv) int {
t := evalToTime(env.vm.stack[env.vm.sp-offset], prec)
Expand Down Expand Up @@ -2670,6 +2685,40 @@ func (asm *assembler) Fn_MULTICMP_u(args int, lessThan bool) {
}, "FN MULTICMP UINT64(SP-%d)...UINT64(SP-1)", args)
}

func (asm *assembler) Fn_MULTICMP_temporal(args int, lessThan bool) {
asm.adjustStack(-(args - 1))

asm.emit(func(env *ExpressionEnv) int {
var x *evalTemporal
x, _ = env.vm.stack[env.vm.sp-args].(*evalTemporal)
for sp := env.vm.sp - args + 1; sp < env.vm.sp; sp++ {
if env.vm.stack[sp] == nil {
if lessThan {
x = nil
}
continue
}
y := env.vm.stack[sp].(*evalTemporal)
if lessThan == (y.compare(x) < 0) {
x = y
}
}
env.vm.stack[env.vm.sp-args] = x
env.vm.sp -= args - 1
return 1
}, "FN MULTICMP TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args)
}

func (asm *assembler) Fn_MULTICMP_temporal_fallback(f multiComparisonFunc, args int, cmp, prec int) {
asm.adjustStack(-(args - 1))

asm.emit(func(env *ExpressionEnv) int {
env.vm.stack[env.vm.sp-args], env.vm.err = f(env, env.vm.stack[env.vm.sp-args:env.vm.sp], cmp, prec)
env.vm.sp -= args - 1
return 1
}, "FN MULTICMP_FALLBACK TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args)
}

func (asm *assembler) Fn_REPEAT(base sqltypes.Type, fallback sqltypes.Type) {
asm.adjustStack(-1)

Expand Down
29 changes: 29 additions & 0 deletions go/vt/vtgate/evalengine/compiler_asm_push.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,23 @@ func (asm *assembler) PushColumn_datetime(offset int) {
}, "PUSH DATETIME(:%d)", offset)
}

func push_timestamp(env *ExpressionEnv, raw []byte) int {
env.vm.stack[env.vm.sp], env.vm.err = parseTimestamp(raw)
env.vm.sp++
return 1
}

func (asm *assembler) PushColumn_timestamp(offset int) {
asm.adjustStack(1)
asm.emit(func(env *ExpressionEnv) int {
col := env.Row[offset]
if col.IsNull() {
return push_null(env)
}
return push_timestamp(env, col.Raw())
}, "PUSH TIMESTAMP(:%d)", offset)
}

func (asm *assembler) PushBVar_datetime(key string) {
asm.adjustStack(1)
asm.emit(func(env *ExpressionEnv) int {
Expand All @@ -374,6 +391,18 @@ func (asm *assembler) PushBVar_datetime(key string) {
}, "PUSH DATETIME(:%q)", key)
}

func (asm *assembler) PushBVar_timestamp(key string) {
asm.adjustStack(1)
asm.emit(func(env *ExpressionEnv) int {
var bvar *querypb.BindVariable
bvar, env.vm.err = env.lookupBindVar(key)
if env.vm.err != nil {
return 0
}
return push_timestamp(env, bvar.Value)
}, "PUSH TIMESTAMP(:%q)", key)
}

func push_date(env *ExpressionEnv, raw []byte) int {
env.vm.stack[env.vm.sp], env.vm.err = parseDate(raw)
env.vm.sp++
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestCompilerReference(t *testing.T) {
var supported, total int
env := evalengine.EmptyExpressionEnv(venv)

tc.Run(func(query string, row []sqltypes.Value) {
tc.Run(func(query string, row []sqltypes.Value, skipCollationCheck bool) {
env.Row = row
total++
testCompilerCase(t, query, venv, tc.Schema, env)
Expand Down
99 changes: 91 additions & 8 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (e *evalTemporal) ToRawBytes() []byte {
switch e.t {
case sqltypes.Date:
return e.dt.Date.Format()
case sqltypes.Datetime:
case sqltypes.Datetime, sqltypes.Timestamp:
return e.dt.Format(e.prec)
case sqltypes.Time:
return e.dt.Time.Format(e.prec)
Expand All @@ -54,7 +54,7 @@ func (e *evalTemporal) toInt64() int64 {
switch e.SQLType() {
case sqltypes.Date:
return e.dt.Date.FormatInt64()
case sqltypes.Datetime:
case sqltypes.Datetime, sqltypes.Timestamp:
return e.dt.FormatInt64()
case sqltypes.Time:
return e.dt.Time.FormatInt64()
Expand All @@ -67,7 +67,7 @@ func (e *evalTemporal) toFloat() float64 {
switch e.SQLType() {
case sqltypes.Date:
return float64(e.dt.Date.FormatInt64())
case sqltypes.Datetime:
case sqltypes.Datetime, sqltypes.Timestamp:
return e.dt.FormatFloat64()
case sqltypes.Time:
return e.dt.Time.FormatFloat64()
Expand All @@ -80,7 +80,7 @@ func (e *evalTemporal) toDecimal() decimal.Decimal {
switch e.SQLType() {
case sqltypes.Date:
return decimal.NewFromInt(e.dt.Date.FormatInt64())
case sqltypes.Datetime:
case sqltypes.Datetime, sqltypes.Timestamp:
return e.dt.FormatDecimal()
case sqltypes.Time:
return e.dt.Time.FormatDecimal()
Expand All @@ -93,7 +93,7 @@ func (e *evalTemporal) toJSON() *evalJSON {
switch e.SQLType() {
case sqltypes.Date:
return json.NewDate(hack.String(e.dt.Date.Format()))
case sqltypes.Datetime:
case sqltypes.Datetime, sqltypes.Timestamp:
return json.NewDateTime(hack.String(e.dt.Format(datetime.DefaultPrecision)))
case sqltypes.Time:
return json.NewTime(hack.String(e.dt.Time.Format(datetime.DefaultPrecision)))
Expand All @@ -104,7 +104,7 @@ func (e *evalTemporal) toJSON() *evalJSON {

func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal {
switch e.SQLType() {
case sqltypes.Datetime, sqltypes.Date:
case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp:
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)}
case sqltypes.Time:
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)}
Expand All @@ -113,9 +113,23 @@ func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal {
}
}

func (e *evalTemporal) toTimestamp(l int, now time.Time) *evalTemporal {
switch e.SQLType() {
case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp:
return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Round(l), prec: uint8(l)}
case sqltypes.Time:
return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)}
default:
panic("unreachable")
}
}

func (e *evalTemporal) toTime(l int) *evalTemporal {
if l == -1 {
l = int(e.prec)
}
switch e.SQLType() {
case sqltypes.Datetime:
case sqltypes.Datetime, sqltypes.Timestamp:
dt := datetime.DateTime{Time: e.dt.Time.Round(l)}
return &evalTemporal{t: sqltypes.Time, dt: dt, prec: uint8(l)}
case sqltypes.Date:
Expand All @@ -130,7 +144,7 @@ func (e *evalTemporal) toTime(l int) *evalTemporal {

func (e *evalTemporal) toDate(now time.Time) *evalTemporal {
switch e.SQLType() {
case sqltypes.Datetime:
case sqltypes.Datetime, sqltypes.Timestamp:
dt := datetime.DateTime{Date: e.dt.Date}
return &evalTemporal{t: sqltypes.Date, dt: dt}
case sqltypes.Date:
Expand All @@ -148,6 +162,13 @@ func (e *evalTemporal) isZero() bool {
return e.dt.IsZero()
}

func (e *evalTemporal) compare(other *evalTemporal) int {
if other == nil {
return 1
}
return e.dt.Compare(other.dt)
}

func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval {
var tmp *evalTemporal
var ok bool
Expand Down Expand Up @@ -179,6 +200,13 @@ func newEvalDateTime(dt datetime.DateTime, l int, allowZero bool) *evalTemporal
return &evalTemporal{t: sqltypes.Datetime, dt: dt.Round(l), prec: uint8(l)}
}

func newEvalTimestamp(dt datetime.DateTime, l int, allowZero bool) *evalTemporal {
if !allowZero && dt.IsZero() {
return nil
}
return &evalTemporal{t: sqltypes.Timestamp, dt: dt.Round(l), prec: uint8(l)}
}

func newEvalDate(d datetime.Date, allowZero bool) *evalTemporal {
if !allowZero && d.IsZero() {
return nil
Expand Down Expand Up @@ -210,6 +238,14 @@ func parseDateTime(s []byte) (*evalTemporal, error) {
return newEvalDateTime(t, l, true), nil
}

func parseTimestamp(s []byte) (*evalTemporal, error) {
t, l, ok := datetime.ParseDateTime(hack.String(s), -1)
if !ok {
return nil, errIncorrectTemporal("TIMESTAMP", s)
}
return newEvalTimestamp(t, l, true), nil
}

func parseTime(s []byte) (*evalTemporal, error) {
t, l, state := datetime.ParseTime(hack.String(s), -1)
if state != datetime.TimeOK {
Expand Down Expand Up @@ -387,6 +423,53 @@ func evalToDateTime(e eval, l int, now time.Time, allowZero bool) *evalTemporal
return nil
}

func evalToTimestamp(e eval, l int, now time.Time, allowZero bool) *evalTemporal {
switch e := e.(type) {
case *evalTemporal:
return e.toTimestamp(precision(l, int(e.prec)), now)
case *evalBytes:
if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() {
return newEvalTimestamp(t, l, allowZero)
}
if d, _ := datetime.ParseDate(e.string()); !d.IsZero() {
return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero)
}
case *evalInt64:
if t, ok := datetime.ParseDateTimeInt64(e.i); ok {
return newEvalTimestamp(t, precision(l, 0), allowZero)
}
if d, ok := datetime.ParseDateInt64(e.i); ok {
return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero)
}
case *evalUint64:
if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok {
return newEvalTimestamp(t, precision(l, 0), allowZero)
}
if d, ok := datetime.ParseDateInt64(int64(e.u)); ok {
return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero)
}
case *evalFloat:
if t, l, ok := datetime.ParseDateTimeFloat(e.f, l); ok {
return newEvalTimestamp(t, l, allowZero)
}
if d, ok := datetime.ParseDateFloat(e.f); ok {
return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero)
}
case *evalDecimal:
if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, l); ok {
return newEvalTimestamp(t, l, allowZero)
}
if d, ok := datetime.ParseDateDecimal(e.dec); ok {
return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero)
}
case *evalJSON:
if dt, ok := e.DateTime(); ok {
return newEvalTimestamp(dt, precision(l, datetime.DefaultPrecision), allowZero)
}
}
return nil
}

func evalToDate(e eval, now time.Time, allowZero bool) *evalTemporal {
switch e := e.(type) {
case *evalTemporal:
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/evalengine/expr_bvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,10 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) {
c.asm.PushNull()
case tt == sqltypes.TypeJSON:
c.asm.PushBVar_json(bvar.Key)
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp:
case tt == sqltypes.Datetime:
c.asm.PushBVar_datetime(bvar.Key)
case tt == sqltypes.Timestamp:
c.asm.PushBVar_timestamp(bvar.Key)
case tt == sqltypes.Date:
c.asm.PushBVar_date(bvar.Key)
case tt == sqltypes.Time:
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expr_collate.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (expr *CollateExpr) compile(c *compiler) (ctype, error) {
}
fallthrough
case sqltypes.VarBinary:
c.asm.Collate(expr.TypedCollation.Collation)
c.asm.Collate(expr.TypedCollation)
default:
c.asm.Convert_xc(1, sqltypes.VarChar, expr.TypedCollation.Collation, nil)
}
Expand Down
Loading

0 comments on commit a693845

Please sign in to comment.