From 4bb44d8775a1263b53c0c9126f651a9236a18225 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Sat, 17 Feb 2024 14:19:25 +0100 Subject: [PATCH] evalengine: Implement REPLACE This implements the REPLACE function. This function is documented as case sensitive, but in practice is also byte sensitive. This means that equal characters under collation rules are not replaced. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/cached_size.go | 12 ++ go/vt/vtgate/evalengine/compiler_asm.go | 13 ++ go/vt/vtgate/evalengine/compiler_test.go | 4 + go/vt/vtgate/evalengine/fn_string.go | 171 ++++++++++++++++--- go/vt/vtgate/evalengine/testcases/cases.go | 28 +++ go/vt/vtgate/evalengine/translate_builtin.go | 5 + 6 files changed, 211 insertions(+), 22 deletions(-) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index ce90aadea8a..b386d3dc915 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1463,6 +1463,18 @@ func (cached *builtinRepeat) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinReplace) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinReverse) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index cddf0790ea7..5eeb9a6300d 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -3022,6 +3022,19 @@ func (asm *assembler) Locate2(collation colldata.Collation) { }, "LOCATE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) } +func (asm *assembler) Replace() { + asm.adjustStack(-2) + + asm.emit(func(env *ExpressionEnv) int { + str := env.vm.stack[env.vm.sp-3].(*evalBytes) + from := env.vm.stack[env.vm.sp-2].(*evalBytes) + to := env.vm.stack[env.vm.sp-1].(*evalBytes) + env.vm.sp -= 2 + str.bytes = replace(str.bytes, from.bytes, to.bytes) + return 1 + }, "REPLACE VARCHAR(SP-3), VARCHAR(SP-2) VARCHAR(SP-1)") +} + func (asm *assembler) Strcmp(collation collations.TypedCollation) { asm.adjustStack(-1) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 09e08ad0d48..f101bf61c64 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -627,6 +627,10 @@ func TestCompilerSingle(t *testing.T) { expression: `locate("", "😊😂🤢", 3)`, result: `INT64(3)`, }, + { + expression: `REPLACE('www.mysql.com', '', 'Ww')`, + result: `VARCHAR("www.mysql.com")`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 11a18c95300..e0887037c0a 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -114,6 +114,31 @@ type ( CallExpr collate collations.ID } + + builtinChar struct { + CallExpr + collate collations.ID + } + + builtinRepeat struct { + CallExpr + collate collations.ID + } + + builtinConcat struct { + CallExpr + collate collations.ID + } + + builtinConcatWs struct { + CallExpr + collate collations.ID + } + + builtinReplace struct { + CallExpr + collate collations.ID + } ) var _ IR = (*builtinInsert)(nil) @@ -129,7 +154,15 @@ var _ IR = (*builtinCollation)(nil) var _ IR = (*builtinWeightString)(nil) var _ IR = (*builtinLeftRight)(nil) var _ IR = (*builtinPad)(nil) +var _ IR = (*builtinStrcmp)(nil) var _ IR = (*builtinTrim)(nil) +var _ IR = (*builtinSubstring)(nil) +var _ IR = (*builtinLocate)(nil) +var _ IR = (*builtinChar)(nil) +var _ IR = (*builtinRepeat)(nil) +var _ IR = (*builtinConcat)(nil) +var _ IR = (*builtinConcatWs)(nil) +var _ IR = (*builtinReplace)(nil) func insert(str, newstr *evalBytes, pos, l int) []byte { pos-- @@ -555,11 +588,6 @@ func (call *builtinOrd) compile(c *compiler) (ctype, error) { // - `> max_allowed_packet`, no error and returns `NULL`. const maxRepeatLength = 1073741824 -type builtinRepeat struct { - CallExpr - collate collations.ID -} - func (call *builtinRepeat) eval(env *ExpressionEnv) (eval, error) { arg1, arg2, err := call.arg2(env) if err != nil { @@ -1374,11 +1402,6 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil } -type builtinConcat struct { - CallExpr - collate collations.ID -} - func concatSQLType(arg sqltypes.Type, tt sqltypes.Type) sqltypes.Type { if arg == sqltypes.TypeJSON { return sqltypes.Blob @@ -1507,11 +1530,6 @@ func (call *builtinConcat) compile(c *compiler) (ctype, error) { return ctype{Type: tt, Flag: f, Col: tc}, nil } -type builtinConcatWs struct { - CallExpr - collate collations.ID -} - func (call *builtinConcatWs) eval(env *ExpressionEnv) (eval, error) { var ca collationAggregation tt := sqltypes.VarChar @@ -1643,13 +1661,6 @@ func (call *builtinConcatWs) compile(c *compiler) (ctype, error) { return ctype{Type: tt, Flag: args[0].Flag, Col: tc}, nil } -type builtinChar struct { - CallExpr - collate collations.ID -} - -var _ IR = (*builtinChar)(nil) - func (call *builtinChar) eval(env *ExpressionEnv) (eval, error) { vals := make([]eval, 0, len(call.Arguments)) for _, arg := range call.Arguments { @@ -1726,3 +1737,119 @@ func encodeChar(buf []byte, i uint32) []byte { } return buf } + +func (call *builtinReplace) eval(env *ExpressionEnv) (eval, error) { + str, err := call.Arguments[0].eval(env) + if err != nil || str == nil { + return nil, err + } + + fromStr, err := call.Arguments[1].eval(env) + if err != nil || fromStr == nil { + return nil, err + } + + toStr, err := call.Arguments[2].eval(env) + if err != nil || toStr == nil { + return nil, err + } + + if _, ok := str.(*evalBytes); !ok { + str, err = evalToVarchar(str, call.collate, true) + if err != nil { + return nil, err + } + } + + col := str.(*evalBytes).col + fromStr, err = evalToVarchar(fromStr, col.Collation, true) + if err != nil { + return nil, err + } + + toStr, err = evalToVarchar(toStr, col.Collation, true) + if err != nil { + return nil, err + } + + strBytes := str.(*evalBytes).bytes + fromBytes := fromStr.(*evalBytes).bytes + toBytes := toStr.(*evalBytes).bytes + + out := replace(strBytes, fromBytes, toBytes) + return newEvalRaw(str.SQLType(), out, col), nil +} + +func (call *builtinReplace) compile(c *compiler) (ctype, error) { + str, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + + fromStr, err := call.Arguments[1].compile(c) + if err != nil { + return ctype{}, err + } + + toStr, err := call.Arguments[2].compile(c) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck3(str, fromStr, toStr) + if !str.isTextual() { + c.asm.Convert_xce(3, sqltypes.VarChar, c.collation) + str.Col = collations.TypedCollation{ + Collation: c.collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + fromCharset := colldata.Lookup(fromStr.Col.Collation).Charset() + toCharset := colldata.Lookup(toStr.Col.Collation).Charset() + strCharset := colldata.Lookup(str.Col.Collation).Charset() + if !fromStr.isTextual() || (fromCharset != strCharset && !strCharset.IsSuperset(fromCharset)) { + c.asm.Convert_xce(2, sqltypes.VarChar, str.Col.Collation) + fromStr.Col = collations.TypedCollation{ + Collation: str.Col.Collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + if !toStr.isTextual() || (toCharset != strCharset && !strCharset.IsSuperset(toCharset)) { + c.asm.Convert_xce(1, sqltypes.VarChar, str.Col.Collation) + toStr.Col = collations.TypedCollation{ + Collation: str.Col.Collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + c.asm.Replace() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.VarChar, Col: str.Col, Flag: flagNullable}, nil +} + +func replace(str, from, to []byte) []byte { + if len(from) == 0 { + return str + } + n := bytes.Count(str, from) + if n == 0 { + return str + } + + out := make([]byte, len(str)+n*(len(to)-len(from))) + end := 0 + start := 0 + for i := 0; i < n; i++ { + pos := start + bytes.Index(str[start:], from) + end += copy(out[end:], str[start:pos]) + end += copy(out[end:], to) + start = pos + len(from) + } + end += copy(out[end:], str[start:]) + return out[0:end] +} diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index c53faa7f217..b55f6c2c18d 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -83,6 +83,7 @@ var Cases = []TestCase{ {Run: FnTrim}, {Run: FnSubstr}, {Run: FnLocate}, + {Run: FnReplace}, {Run: FnConcat}, {Run: FnConcatWs}, {Run: FnChar}, @@ -1577,6 +1578,33 @@ func FnLocate(yield Query) { } } +func FnReplace(yield Query) { + cases := []string{ + `REPLACE('www.mysql.com', 'w', 'Ww')`, + // MySQL doesn't do collation matching for replace, only + // byte equivalence, but make sure to check. + `REPLACE('straße', 'ss', 'b')`, + `REPLACE('straße', 'ß', 'b')`, + // From / to strings are converted into the collation of + // the input string. + `REPLACE('fooÿbar', _latin1 0xFF, _latin1 0xFE)`, + // First occurence is replaced + `replace('fff', 'ff', 'gg')`, + } + + for _, q := range cases { + yield(q, nil) + } + + for _, substr := range inputStrings { + for _, str := range inputStrings { + for _, i := range inputStrings { + yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil) + } + } + } +} + func FnConcat(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("CONCAT(%s)", str), nil) diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 71eff66bc2b..cabe406bfb6 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -627,6 +627,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { } call = CallExpr{Arguments: []IR{call.Arguments[1], call.Arguments[0]}, Method: method} return &builtinLocate{CallExpr: call, collate: ast.cfg.Collation}, nil + case "replace": + if len(args) != 3 { + return nil, argError(method) + } + return &builtinReplace{CallExpr: call, collate: ast.cfg.Collation}, nil default: return nil, translateExprNotSupported(fn) }