Skip to content

Commit

Permalink
types: always handle overflow error outside the types package (#47997)
Browse files Browse the repository at this point in the history
close #47517
  • Loading branch information
lcwangchao authored Oct 30, 2023
1 parent f135464 commit 5503eb5
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 83 deletions.
29 changes: 20 additions & 9 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 {
return nil, ErrInvalidJSONForFuncIndex
}
jsonToInt, err := types.ConvertJSONToInt(sc, item, mysql.HasUnsignedFlag(tp.GetFlag()), tp.GetType())
jsonToInt, err := types.ConvertJSONToInt(sc.TypeCtx(), item, mysql.HasUnsignedFlag(tp.GetFlag()), tp.GetType())
err = sc.HandleOverflow(err, err)
if mysql.HasUnsignedFlag(tp.GetFlag()) {
return uint64(jsonToInt), err
}
Expand Down Expand Up @@ -702,7 +703,9 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe
} else {
res = types.NewDecFromUint(uint64(val))
}
res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx)
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), res, b.tp)
err = sc.HandleOverflow(err, err)
return res, isNull, err
}

Expand Down Expand Up @@ -1018,7 +1021,9 @@ func (b *builtinCastRealAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD
return res, false, err
}
}
res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx)
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), res, b.tp)
err = sc.HandleOverflow(err, err)
return res, false, err
}

Expand Down Expand Up @@ -1130,7 +1135,8 @@ func (b *builtinCastDecimalAsDecimalSig) evalDecimal(row chunk.Row) (res *types.
*res = *evalDecimal
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, sc)
res, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), res, b.tp)
err = sc.HandleOverflow(err, err)
return res, false, err
}

Expand Down Expand Up @@ -1454,7 +1460,8 @@ func (b *builtinCastStringAsDecimalSig) evalDecimal(row chunk.Row) (res *types.M
return res, false, err
}
}
res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, sc)
res, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), res, b.tp)
err = sc.HandleOverflow(err, err)
return res, false, err
}

Expand Down Expand Up @@ -1599,7 +1606,8 @@ func (b *builtinCastTimeAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceDecWithSpecifiedTp(val.ToNumber(), b.tp, sc)
res, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), val.ToNumber(), b.tp)
err = sc.HandleOverflow(err, err)
return res, false, err
}

Expand Down Expand Up @@ -1732,7 +1740,8 @@ func (b *builtinCastDurationAsDecimalSig) evalDecimal(row chunk.Row) (res *types
return res, false, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceDecWithSpecifiedTp(val.ToNumber(), b.tp, sc)
res, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), val.ToNumber(), b.tp)
err = sc.HandleOverflow(err, err)
return res, false, err
}

Expand Down Expand Up @@ -1834,7 +1843,8 @@ func (b *builtinCastJSONAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ConvertJSONToInt64(sc, val, mysql.HasUnsignedFlag(b.tp.GetFlag()))
res, err = types.ConvertJSONToInt64(sc.TypeCtx(), val, mysql.HasUnsignedFlag(b.tp.GetFlag()))
err = sc.HandleOverflow(err, err)
return
}

Expand Down Expand Up @@ -1878,7 +1888,8 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD
if err != nil {
return res, false, err
}
res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, sc)
res, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), res, b.tp)
err = sc.HandleOverflow(err, err)
return res, false, err
}

Expand Down
34 changes: 18 additions & 16 deletions pkg/expression/builtin_cast_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ func (b *builtinCastTimeAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result
}
*dec = types.MyDecimal{}
times[i].FillNumber(dec)
dec, err = types.ProduceDecWithSpecifiedTp(dec, b.tp, sc)
if err != nil {
dec, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), dec, b.tp)
if err = sc.HandleOverflow(err, err); err != nil {
return err
}
decs[i] = *dec
Expand Down Expand Up @@ -623,8 +623,8 @@ func (b *builtinCastDecimalAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, resu
if !(conditionUnionAndUnsigned && decs[i].IsNegative()) {
*dec = decs[i]
}
dec, err := types.ProduceDecWithSpecifiedTp(dec, b.tp, sc)
if err != nil {
dec, err := types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), dec, b.tp)
if err = sc.HandleOverflow(err, err); err != nil {
return err
}
decs[i] = *dec
Expand Down Expand Up @@ -899,6 +899,7 @@ func (b *builtinCastRealAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result
result.MergeNulls(buf)
bufreal := buf.Float64s()
resdecimal := result.Decimals()
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
Expand All @@ -917,8 +918,8 @@ func (b *builtinCastRealAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result
}
}
}
dec, err := types.ProduceDecWithSpecifiedTp(&resdecimal[i], b.tp, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
dec, err := types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), &resdecimal[i], b.tp)
if err = sc.HandleOverflow(err, err); err != nil {
return err
}
resdecimal[i] = *dec
Expand Down Expand Up @@ -1058,8 +1059,8 @@ func (b *builtinCastDurationAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, res
}
duration.Duration = ds[i]
duration.Fsp = fsp
res, err := types.ProduceDecWithSpecifiedTp(duration.ToNumber(), b.tp, sc)
if err != nil {
res, err := types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), duration.ToNumber(), b.tp)
if err = sc.HandleOverflow(err, err); err != nil {
return err
}
d64s[i] = *res
Expand Down Expand Up @@ -1104,8 +1105,8 @@ func (b *builtinCastIntAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *
dec.FromUint(uint64(nums[i]))
}

dec, err = types.ProduceDecWithSpecifiedTp(dec, b.tp, sc)
if err != nil {
dec, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), dec, b.tp)
if err = sc.HandleOverflow(err, err); err != nil {
return err
}
decs[i] = *dec
Expand Down Expand Up @@ -1254,12 +1255,13 @@ func (b *builtinCastJSONAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.C
result.MergeNulls(buf)
i64s := result.Int64s()
sc := b.ctx.GetSessionVars().StmtCtx
tc := sc.TypeCtx()
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
}
i64s[i], err = types.ConvertJSONToInt64(sc, buf.GetJSON(i), mysql.HasUnsignedFlag(b.tp.GetFlag()))
if err != nil {
i64s[i], err = types.ConvertJSONToInt64(tc, buf.GetJSON(i), mysql.HasUnsignedFlag(b.tp.GetFlag()))
if err = sc.HandleOverflow(err, err); err != nil {
return err
}
}
Expand Down Expand Up @@ -1643,8 +1645,8 @@ func (b *builtinCastJSONAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result
if err != nil {
return err
}
tempres, err = types.ProduceDecWithSpecifiedTp(tempres, b.tp, sc)
if err != nil {
tempres, err = types.ProduceDecWithSpecifiedTp(sc.TypeCtx(), tempres, b.tp)
if err = sc.HandleOverflow(err, err); err != nil {
return err
}
res[i] = *tempres
Expand Down Expand Up @@ -1733,8 +1735,8 @@ func (b *builtinCastStringAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, resul
if err := stmtCtx.HandleTruncate(dec.FromString([]byte(val))); err != nil {
return err
}
dec, err := types.ProduceDecWithSpecifiedTp(dec, b.tp, stmtCtx)
if err != nil {
dec, err := types.ProduceDecWithSpecifiedTp(stmtCtx.TypeCtx(), dec, b.tp)
if err = stmtCtx.HandleOverflow(err, err); err != nil {
return err
}
res[i] = *dec
Expand Down
4 changes: 2 additions & 2 deletions pkg/sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ func TestSetStmtCtxTypeFlags(t *testing.T) {
require.Equal(t, typectx.FlagAllowNegativeToUnsigned|typectx.FlagSkipASCIICheck, sc.TypeFlags())
require.Equal(t, sc.TypeFlags(), sc.TypeFlags())

sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagInvalidDateAsWarning)
require.Equal(t, typectx.FlagSkipASCIICheck|typectx.FlagSkipUTF8Check|typectx.FlagInvalidDateAsWarning, sc.TypeFlags())
sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagTruncateAsWarning)
require.Equal(t, typectx.FlagSkipASCIICheck|typectx.FlagSkipUTF8Check|typectx.FlagTruncateAsWarning, sc.TypeFlags())
require.Equal(t, sc.TypeFlags(), sc.TypeFlags())
}

Expand Down
18 changes: 3 additions & 15 deletions pkg/types/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,10 @@ const (
FlagTruncateAsWarning
// FlagAllowNegativeToUnsigned indicates to allow the casting from negative to unsigned int.
// When this flag is not set by default, casting a negative value to unsigned results an overflow error.
// The overflow will also be controlled by `FlagIgnoreOverflowError` and `FlagOverflowAsWarning`. When any of them is set,
// a zero value is returned instead.
// Whe this flag is set, casting a negative value to unsigned will be allowed. And the negative value will be cast to
// a positive value by adding the max value of the unsigned type.
// Otherwise, a negative value will be cast to the corresponding unsigned value without any error.
// For example, when casting -1 to an unsigned bigint with `FlagAllowNegativeToUnsigned` set,
// we will get `18446744073709551615` which is the biggest unsigned value.
FlagAllowNegativeToUnsigned
// FlagIgnoreOverflowError indicates to ignore the overflow error.
// If this flag is set, `FlagOverflowAsWarning` will be ignored.
FlagIgnoreOverflowError
// FlagOverflowAsWarning indicates to append the overflow error to warnings instead of returning it to user.
FlagOverflowAsWarning
// FlagIgnoreZeroDateErr indicates to ignore the zero-date error.
// See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_zero_date for details about the "zero-date" error.
// If this flag is set, `FlagZeroDateAsWarning` will be ignored.
Expand All @@ -55,18 +49,12 @@ const (
// This flag is the reverse of `NoZeroDate` in #30507. It's set to `true` for most context, and is only set to
// `false` for `alter` (and `create`) statements.
FlagIgnoreZeroDateErr
// FlagZeroDateAsWarning indicates to append the zero-date error to warnings instead of returning it to user.
FlagZeroDateAsWarning
// FlagIgnoreZeroInDateErr indicates to ignore the zero-in-date error.
// See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_zero_in_date for details about the "zero-in-date" error.
FlagIgnoreZeroInDateErr
// FlagZeroInDateAsWarning indicates to append the zero-in-date error to warnings instead of returning it to user.
FlagZeroInDateAsWarning
// FlagIgnoreInvalidDateErr indicates to ignore the invalid-date error.
// See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_allow_invalid_dates for details about the "invalid-date" error.
FlagIgnoreInvalidDateErr
// FlagInvalidDateAsWarning indicates to append the invalid-date error to warnings instead of returning it to user.
FlagInvalidDateAsWarning
// FlagSkipASCIICheck indicates to skip the ASCII check when converting the value to an ASCII string.
FlagSkipASCIICheck
// FlagSkipUTF8Check indicates to skip the UTF8 check when converting the value to an UTF8MB3 string.
Expand Down
35 changes: 16 additions & 19 deletions pkg/types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,71 +575,68 @@ func StrToFloat(ctx Context, str string, isFuncCast bool) (float64, error) {
}

// ConvertJSONToInt64 casts JSON into int64.
func ConvertJSONToInt64(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool) (int64, error) {
return ConvertJSONToInt(sc, j, unsigned, mysql.TypeLonglong)
func ConvertJSONToInt64(ctx Context, j BinaryJSON, unsigned bool) (int64, error) {
return ConvertJSONToInt(ctx, j, unsigned, mysql.TypeLonglong)
}

// ConvertJSONToInt casts JSON into int by type.
func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool, tp byte) (int64, error) {
func ConvertJSONToInt(ctx Context, j BinaryJSON, unsigned bool, tp byte) (int64, error) {
switch j.TypeCode {
case JSONTypeCodeObject, JSONTypeCodeArray, JSONTypeCodeOpaque, JSONTypeCodeDate, JSONTypeCodeDatetime, JSONTypeCodeTimestamp, JSONTypeCodeDuration:
return 0, sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String()))
return 0, ctx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String()))
case JSONTypeCodeLiteral:
switch j.Value[0] {
case JSONLiteralFalse:
return 0, nil
case JSONLiteralNil:
return 0, sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String()))
return 0, ctx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String()))
default:
return 1, nil
}
case JSONTypeCodeInt64:
i := j.GetInt64()
if unsigned {
uBound := IntergerUnsignedUpperBound(tp)
u, err := ConvertIntToUint(sc.TypeFlags(), i, uBound, tp)
return int64(u), sc.HandleOverflow(err, err)
u, err := ConvertIntToUint(ctx.Flags(), i, uBound, tp)
return int64(u), err
}

lBound := IntergerSignedLowerBound(tp)
uBound := IntergerSignedUpperBound(tp)
i, err := ConvertIntToInt(i, lBound, uBound, tp)
return i, sc.HandleOverflow(err, err)
return ConvertIntToInt(i, lBound, uBound, tp)
case JSONTypeCodeUint64:
u := j.GetUint64()
if unsigned {
uBound := IntergerUnsignedUpperBound(tp)
u, err := ConvertUintToUint(u, uBound, tp)
return int64(u), sc.HandleOverflow(err, err)
return int64(u), err
}

uBound := IntergerSignedUpperBound(tp)
i, err := ConvertUintToInt(u, uBound, tp)
return i, sc.HandleOverflow(err, err)
return ConvertUintToInt(u, uBound, tp)
case JSONTypeCodeFloat64:
f := j.GetFloat64()
if !unsigned {
lBound := IntergerSignedLowerBound(tp)
uBound := IntergerSignedUpperBound(tp)
u, e := ConvertFloatToInt(f, lBound, uBound, tp)
return u, sc.HandleOverflow(e, e)
return u, e
}
bound := IntergerUnsignedUpperBound(tp)
u, err := ConvertFloatToUint(sc.TypeFlags(), f, bound, tp)
return int64(u), sc.HandleOverflow(err, err)
u, err := ConvertFloatToUint(ctx.Flags(), f, bound, tp)
return int64(u), err
case JSONTypeCodeString:
str := string(hack.String(j.GetString()))
// The behavior of casting json string as an integer is consistent with casting a string as an integer.
// See the `builtinCastStringAsIntSig` in `expression` pkg. The only difference is that this function
// doesn't append any warning. This behavior is compatible with MySQL.
isNegative := len(str) > 1 && str[0] == '-'
if !isNegative {
r, err := StrToUint(sc.TypeCtxOrDefault(), str, false)
return int64(r), sc.HandleOverflow(err, err)
r, err := StrToUint(ctx, str, false)
return int64(r), err
}

r, err := StrToInt(sc.TypeCtxOrDefault(), str, false)
return r, sc.HandleOverflow(err, err)
return StrToInt(ctx, str, false)
}
return 0, errors.New("Unknown type code in JSON")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ func TestConvertJSONToInt(t *testing.T) {
j, err := ParseBinaryJSONFromString(tt.in)
require.NoError(t, err)

casted, err := ConvertJSONToInt64(stmtctx.NewStmtCtx(), j, false)
casted, err := ConvertJSONToInt64(stmtctx.NewStmtCtx().TypeCtx(), j, false)
if tt.err {
require.Error(t, err, tt)
} else {
Expand Down
Loading

0 comments on commit 5503eb5

Please sign in to comment.