Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Defined2014 committed Sep 9, 2024
1 parent ab459b2 commit 08f2f1a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 40 deletions.
41 changes: 41 additions & 0 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,47 @@ func newBaseBuiltinCastFunc(builtinFunc baseBuiltinFunc, inUnion bool) baseBuilt
}
}

// // baseBuiltinCastStringFunc will be used in every struct created by `castAsStringFunctionClass`.
type baseBuiltinCastStringFunc struct {
baseBuiltinFunc

// isExplicitCharSet indicates whether cast function set the charset info explicit.
isExplicitCharSet bool
}

func newBaseBuiltinCastStringFunc(ctx BuildContext, funcName string, args []Expression, tp *types.FieldType, isExplicitCharSet bool) (baseBuiltinCastStringFunc, error) {
bf, err := newBaseBuiltinFunc(ctx, funcName, args, tp)
if err != nil {
return baseBuiltinCastStringFunc{}, err
}
if isExplicitCharSet {
bf.SetCharsetAndCollation(tp.GetCharset(), tp.GetCollate())
bf.setCollator(collate.GetCollator(tp.GetCollate()))
bf.SetCoercibility(CoercibilityExplicit)
if tp.GetCharset() == charset.CharsetASCII {
bf.SetRepertoire(ASCII)
} else {
bf.SetRepertoire(UNICODE)
}
}
return baseBuiltinCastStringFunc{
baseBuiltinFunc: bf,
isExplicitCharSet: isExplicitCharSet,
}, nil
}

func newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(builtinFunc baseBuiltinFunc, isExplicitCharSet bool) baseBuiltinCastStringFunc {
return baseBuiltinCastStringFunc{
baseBuiltinFunc: builtinFunc,
isExplicitCharSet: isExplicitCharSet,
}
}

func (b *baseBuiltinCastStringFunc) cloneFrom(from *baseBuiltinCastStringFunc) {
b.baseBuiltinFunc.cloneFrom(&from.baseBuiltinFunc)
b.isExplicitCharSet = from.isExplicitCharSet
}

// vecBuiltinFunc contains all vectorized methods for a builtin function.
type vecBuiltinFunc interface {
// vectorized returns if this builtin function itself supports vectorized evaluation.
Expand Down
45 changes: 17 additions & 28 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tipb/go-tipb"
)

Expand Down Expand Up @@ -297,20 +296,10 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp)
bf, err := newBaseBuiltinCastStringFunc(ctx, c.funcName, args, c.tp, c.isExplicitCharSet)
if err != nil {
return nil, err
}
if c.isExplicitCharSet {
bf.SetCharsetAndCollation(c.tp.GetCharset(), c.tp.GetCollate())
bf.setCollator(collate.GetCollator(c.tp.GetCollate()))
bf.SetCoercibility(CoercibilityExplicit)
if c.tp.GetCharset() == charset.CharsetASCII {
bf.SetRepertoire(ASCII)
} else {
bf.SetRepertoire(UNICODE)
}
}
if args[0].GetType(ctx.GetEvalCtx()).Hybrid() {
sig = &builtinCastStringAsStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastStringAsString)
Expand Down Expand Up @@ -910,12 +899,12 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(ctx EvalContext, row chunk.Row)
}

type builtinCastIntAsStringSig struct {
baseBuiltinFunc
baseBuiltinCastStringFunc
}

func (b *builtinCastIntAsStringSig) Clone() builtinFunc {
newSig := &builtinCastIntAsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.cloneFrom(&b.baseBuiltinCastStringFunc)
return newSig
}

Expand Down Expand Up @@ -1227,12 +1216,12 @@ func (b *builtinCastRealAsDecimalSig) evalDecimal(ctx EvalContext, row chunk.Row
}

type builtinCastRealAsStringSig struct {
baseBuiltinFunc
baseBuiltinCastStringFunc
}

func (b *builtinCastRealAsStringSig) Clone() builtinFunc {
newSig := &builtinCastRealAsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.cloneFrom(&b.baseBuiltinCastStringFunc)
return newSig
}

Expand Down Expand Up @@ -1382,12 +1371,12 @@ func (b *builtinCastDecimalAsIntSig) evalInt(ctx EvalContext, row chunk.Row) (re
}

type builtinCastDecimalAsStringSig struct {
baseBuiltinFunc
baseBuiltinCastStringFunc
}

func (b *builtinCastDecimalAsStringSig) Clone() builtinFunc {
newSig := &builtinCastDecimalAsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.cloneFrom(&b.baseBuiltinCastStringFunc)
return newSig
}

Expand Down Expand Up @@ -1492,12 +1481,12 @@ func (b *builtinCastDecimalAsDurationSig) evalDuration(ctx EvalContext, row chun
}

type builtinCastStringAsStringSig struct {
baseBuiltinFunc
baseBuiltinCastStringFunc
}

func (b *builtinCastStringAsStringSig) Clone() builtinFunc {
newSig := &builtinCastStringAsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.cloneFrom(&b.baseBuiltinCastStringFunc)
return newSig
}

Expand Down Expand Up @@ -1809,12 +1798,12 @@ func (b *builtinCastTimeAsDecimalSig) evalDecimal(ctx EvalContext, row chunk.Row
}

type builtinCastTimeAsStringSig struct {
baseBuiltinFunc
baseBuiltinCastStringFunc
}

func (b *builtinCastTimeAsStringSig) Clone() builtinFunc {
newSig := &builtinCastTimeAsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.cloneFrom(&b.baseBuiltinCastStringFunc)
return newSig
}

Expand Down Expand Up @@ -1948,12 +1937,12 @@ func (b *builtinCastDurationAsDecimalSig) evalDecimal(ctx EvalContext, row chunk
}

type builtinCastDurationAsStringSig struct {
baseBuiltinFunc
baseBuiltinCastStringFunc
}

func (b *builtinCastDurationAsStringSig) Clone() builtinFunc {
newSig := &builtinCastDurationAsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.cloneFrom(&b.baseBuiltinCastStringFunc)
return newSig
}

Expand Down Expand Up @@ -2090,12 +2079,12 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(ctx EvalContext, row chunk.Row
}

type builtinCastJSONAsStringSig struct {
baseBuiltinFunc
baseBuiltinCastStringFunc
}

func (b *builtinCastJSONAsStringSig) Clone() builtinFunc {
newSig := &builtinCastJSONAsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.cloneFrom(&b.baseBuiltinCastStringFunc)
return newSig
}

Expand All @@ -2112,12 +2101,12 @@ func (b *builtinCastJSONAsStringSig) evalString(ctx EvalContext, row chunk.Row)
}

type builtinCastVectorFloat32AsStringSig struct {
baseBuiltinFunc
baseBuiltinCastStringFunc
}

func (b *builtinCastVectorFloat32AsStringSig) Clone() builtinFunc {
newSig := &builtinCastVectorFloat32AsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.cloneFrom(&b.baseBuiltinCastStringFunc)
return newSig
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ func TestCastFuncSig(t *testing.T) {
tp := types.NewFieldType(mysql.TypeVarString)
tp.SetCharset(charset.CharsetBin)
args := []Expression{c.before}
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp)
stringFunc, err := newBaseBuiltinCastStringFunc(ctx, "", args, tp, false)
require.NoError(t, err)
switch i {
case 0:
Expand Down Expand Up @@ -742,7 +742,7 @@ func TestCastFuncSig(t *testing.T) {
tp := types.NewFieldType(mysql.TypeVarString)
tp.SetFlen(c.flen)
tp.SetCharset(charset.CharsetBin)
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp)
stringFunc, err := newBaseBuiltinCastStringFunc(ctx, "", args, tp, false)
require.NoError(t, err)
switch i {
case 0:
Expand Down Expand Up @@ -1099,7 +1099,7 @@ func TestCastFuncSig(t *testing.T) {
// null case
args := []Expression{&Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}}
row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(nil)})
bf, err := newBaseBuiltinFunc(ctx, "", args, types.NewFieldType(mysql.TypeVarString))
bf, err := newBaseBuiltinCastStringFunc(ctx, "", args, types.NewFieldType(mysql.TypeVarString), false)
require.NoError(t, err)
sig = &builtinCastRealAsStringSig{bf}
sRes, err := evalBuiltinFunc(sig, ctx, row.ToRow())
Expand Down
16 changes: 8 additions & 8 deletions pkg/expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie
case tipb.ScalarFuncSig_CastIntAsReal:
f = &builtinCastIntAsRealSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastIntAsString:
f = &builtinCastIntAsStringSig{base}
f = &builtinCastIntAsStringSig{newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(base, false)}
case tipb.ScalarFuncSig_CastIntAsDecimal:
f = &builtinCastIntAsDecimalSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastIntAsTime:
Expand All @@ -62,7 +62,7 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie
case tipb.ScalarFuncSig_CastRealAsReal:
f = &builtinCastRealAsRealSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastRealAsString:
f = &builtinCastRealAsStringSig{base}
f = &builtinCastRealAsStringSig{newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(base, false)}
case tipb.ScalarFuncSig_CastRealAsDecimal:
f = &builtinCastRealAsDecimalSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastRealAsTime:
Expand All @@ -76,7 +76,7 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie
case tipb.ScalarFuncSig_CastDecimalAsReal:
f = &builtinCastDecimalAsRealSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastDecimalAsString:
f = &builtinCastDecimalAsStringSig{base}
f = &builtinCastDecimalAsStringSig{newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(base, false)}
case tipb.ScalarFuncSig_CastDecimalAsDecimal:
f = &builtinCastDecimalAsDecimalSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastDecimalAsTime:
Expand All @@ -90,7 +90,7 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie
case tipb.ScalarFuncSig_CastStringAsReal:
f = &builtinCastStringAsRealSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastStringAsString:
f = &builtinCastStringAsStringSig{base}
f = &builtinCastStringAsStringSig{newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(base, false)}
case tipb.ScalarFuncSig_CastStringAsDecimal:
f = &builtinCastStringAsDecimalSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastStringAsTime:
Expand All @@ -104,7 +104,7 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie
case tipb.ScalarFuncSig_CastTimeAsReal:
f = &builtinCastTimeAsRealSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastTimeAsString:
f = &builtinCastTimeAsStringSig{base}
f = &builtinCastTimeAsStringSig{newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(base, false)}
case tipb.ScalarFuncSig_CastTimeAsDecimal:
f = &builtinCastTimeAsDecimalSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastTimeAsTime:
Expand All @@ -118,7 +118,7 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie
case tipb.ScalarFuncSig_CastDurationAsReal:
f = &builtinCastDurationAsRealSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastDurationAsString:
f = &builtinCastDurationAsStringSig{base}
f = &builtinCastDurationAsStringSig{newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(base, false)}
case tipb.ScalarFuncSig_CastDurationAsDecimal:
f = &builtinCastDurationAsDecimalSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastDurationAsTime:
Expand All @@ -132,7 +132,7 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie
case tipb.ScalarFuncSig_CastJsonAsReal:
f = &builtinCastJSONAsRealSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastJsonAsString:
f = &builtinCastJSONAsStringSig{base}
f = &builtinCastJSONAsStringSig{newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(base, false)}
case tipb.ScalarFuncSig_CastJsonAsDecimal:
f = &builtinCastJSONAsDecimalSig{newBaseBuiltinCastFunc(base, false)}
case tipb.ScalarFuncSig_CastJsonAsTime:
Expand Down Expand Up @@ -1077,7 +1077,7 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie
// TODO: set the `cannotConvertStringAsWarning` accordingly
f = &builtinInternalFromBinarySig{base, false}
case tipb.ScalarFuncSig_CastVectorFloat32AsString:
f = &builtinCastVectorFloat32AsStringSig{base}
f = &builtinCastVectorFloat32AsStringSig{newBaseBuiltinCastStringFuncFromBaseBuiltinFunc(base, false)}
case tipb.ScalarFuncSig_CastVectorFloat32AsVectorFloat32:
f = &builtinCastVectorFloat32AsVectorFloat32Sig{base}
case tipb.ScalarFuncSig_LTVectorFloat32:
Expand Down
6 changes: 5 additions & 1 deletion pkg/expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,11 @@ func ColumnSubstituteImpl(ctx BuildContext, expr Expression, schema *Schema, new
flag := v.RetType.GetFlag()
var e Expression
if v.FuncName.L == ast.Cast {
e = BuildCastFunctionExplicitCharset(ctx, newArg, v.RetType)
if f, ok := v.Function.(*baseBuiltinCastStringFunc); ok && f.isExplicitCharSet {
e = BuildCastFunctionExplicitCharset(ctx, newArg, v.RetType)
} else {
e = BuildCastFunction(ctx, newArg, v.RetType)
}
} else {
// for grouping function recreation, use clone (meta included) instead of newFunction
e = v.Clone()
Expand Down

0 comments on commit 08f2f1a

Please sign in to comment.