Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: use a relax json comparison cast rule #37404

Merged
merged 2 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -1340,8 +1340,7 @@ func GetAccurateCmpType(lhs, rhs Expression) types.EvalType {
lhsFieldType, rhsFieldType := lhs.GetType(), rhs.GetType()
lhsEvalType, rhsEvalType := lhsFieldType.EvalType(), rhsFieldType.EvalType()
cmpType := getBaseCmpType(lhsEvalType, rhsEvalType, lhsFieldType, rhsFieldType)
if (lhsEvalType.IsStringKind() && rhsFieldType.GetType() == mysql.TypeJSON) ||
(lhsFieldType.GetType() == mysql.TypeJSON && rhsEvalType.IsStringKind()) {
if (lhsEvalType.IsStringKind() && lhsFieldType.GetType() == mysql.TypeJSON) || (rhsEvalType.IsStringKind() && rhsFieldType.GetType() == mysql.TypeJSON) {
cmpType = types.ETJson
} else if cmpType == types.ETString && (types.IsTypeTime(lhsFieldType.GetType()) || types.IsTypeTime(rhsFieldType.GetType())) {
// date[time] <cmp> date[time]
Expand Down
8 changes: 8 additions & 0 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ func TestCompare(t *testing.T) {
args = bf.getArgs()
require.Equal(t, mysql.TypeDatetime, args[0].GetType().GetType())
require.Equal(t, mysql.TypeDatetime, args[1].GetType().GetType())

// test <json column> <cmp> <const int expression>
jsonCol, intCon := &Column{RetType: types.NewFieldType(mysql.TypeJSON)}, &Constant{RetType: types.NewFieldType(mysql.TypeLong)}
bf, err = funcs[ast.LT].getFunction(ctx, []Expression{jsonCol, intCon})
require.NoError(t, err)
args = bf.getArgs()
require.Equal(t, mysql.TypeJSON, args[0].GetType().GetType())
require.Equal(t, mysql.TypeJSON, args[1].GetType().GetType())
}

func TestCoalesce(t *testing.T) {
Expand Down
16 changes: 7 additions & 9 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -1123,9 +1123,9 @@ func PBToExpr(expr *tipb.Expr, tps []*types.FieldType, sc *stmtctx.StatementCont
case tipb.ExprType_Null:
return &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)}, nil
case tipb.ExprType_Int64:
return convertInt(expr.Val)
return convertInt(expr.Val, expr.FieldType)
case tipb.ExprType_Uint64:
return convertUint(expr.Val)
return convertUint(expr.Val, expr.FieldType)
case tipb.ExprType_String:
return convertString(expr.Val, expr.FieldType)
case tipb.ExprType_Bytes:
Expand Down Expand Up @@ -1215,32 +1215,30 @@ func decodeValueList(data []byte) ([]Expression, error) {
return result, nil
}

func convertInt(val []byte) (*Constant, error) {
func convertInt(val []byte, tp *tipb.FieldType) (*Constant, error) {
var d types.Datum
_, i, err := codec.DecodeInt(val)
if err != nil {
return nil, errors.Errorf("invalid int % x", val)
}
d.SetInt64(i)
return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeLonglong)}, nil
return &Constant{Value: d, RetType: PbTypeToFieldType(tp)}, nil
}

func convertUint(val []byte) (*Constant, error) {
func convertUint(val []byte, tp *tipb.FieldType) (*Constant, error) {
var d types.Datum
_, u, err := codec.DecodeUint(val)
if err != nil {
return nil, errors.Errorf("invalid uint % x", val)
}
d.SetUint64(u)
ftp := types.NewFieldTypeBuilder()
ftp.SetType(mysql.TypeLonglong).SetFlag(mysql.UnsignedFlag)
return &Constant{Value: d, RetType: ftp.BuildP()}, nil
return &Constant{Value: d, RetType: PbTypeToFieldType(tp)}, nil
}

func convertString(val []byte, tp *tipb.FieldType) (*Constant, error) {
var d types.Datum
d.SetBytesAsString(val, collate.ProtoToCollation(tp.Collate), uint32(tp.Flen))
return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeVarString)}, nil
return &Constant{Value: d, RetType: PbTypeToFieldType(tp)}, nil
}

func convertFloat(val []byte, f32 bool) (*Constant, error) {
Expand Down
2 changes: 2 additions & 0 deletions expression/distsql_builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,9 +874,11 @@ func datumExpr(t *testing.T, d types.Datum) *tipb.Expr {
switch d.Kind() {
case types.KindInt64:
expr.Tp = tipb.ExprType_Int64
expr.FieldType = toPBFieldType(types.NewFieldType(mysql.TypeLonglong))
expr.Val = codec.EncodeInt(nil, d.GetInt64())
case types.KindUint64:
expr.Tp = tipb.ExprType_Uint64
expr.FieldType = toPBFieldType(types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetFlag(mysql.UnsignedFlag).BuildP())
expr.Val = codec.EncodeUint(nil, d.GetUint64())
case types.KindString:
expr.Tp = tipb.ExprType_String
Expand Down
11 changes: 11 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7460,3 +7460,14 @@ func TestCastJSONOpaqueValueToNumeric(t *testing.T) {
tk.MustQuery("select cast(json_extract(json_objectagg('a', b'010101'), '$.a') as double);").Check(testkit.Rows("0"))
tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1292 Truncated incorrect FLOAT value: '\"base64:type253:FQ==\"'"))
}

func TestCompareJSONWithOtherType(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create table t(a JSON)")
tk.MustExec("insert into t values ('{}'), ('true'), ('5')")
tk.MustQuery("select * from t where a = TRUE;").Check(testkit.Rows("true"))
tk.MustQuery("select * from t where a < 6;").Check(testkit.Rows("5"))
tk.MustQuery("select * from t where a > 5;").Check(testkit.Rows("{}", "true"))
}