From d985c8f2c8394ce1419dcc12c0ed8a518c8fb9b6 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 12 Jun 2020 19:56:55 +0530 Subject: [PATCH] fix udv: return null type for non-existing udv Signed-off-by: Harshit Gangal --- .../endtoend/vtgate/setstatement/udv_test.go | 7 +++- go/vt/sqlparser/expression_rewriting.go | 35 +++++++++---------- go/vt/sqlparser/expression_rewriting_test.go | 16 +++++---- go/vt/vtgate/evalengine/expressions.go | 2 ++ go/vt/vtgate/executor.go | 13 ++++--- go/vt/vtgate/executor_select_test.go | 15 ++++++-- 6 files changed, 56 insertions(+), 32 deletions(-) diff --git a/go/test/endtoend/vtgate/setstatement/udv_test.go b/go/test/endtoend/vtgate/setstatement/udv_test.go index a60a6a1646f..5d2779c2fa8 100644 --- a/go/test/endtoend/vtgate/setstatement/udv_test.go +++ b/go/test/endtoend/vtgate/setstatement/udv_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/test/utils" "github.com/google/go-cmp/cmp" @@ -44,6 +46,9 @@ func TestSetUDV(t *testing.T) { } queries := []queriesWithExpectations{{ + query: "select @foo", + expectedRows: "[[NULL]]", rowsAffected: 1, + }, { query: "set @foo = 'abc', @bar = 42, @baz = 30.5, @tablet = concat('foo','bar')", expectedRows: "", rowsAffected: 0, }, { @@ -101,7 +106,7 @@ func TestSetUDV(t *testing.T) { t.Run(fmt.Sprintf("%d-%s", i, q.query), func(t *testing.T) { qr, err := exec(t, conn, q.query) require.NoError(t, err) - require.Equal(t, uint64(q.rowsAffected), qr.RowsAffected, "rows affected wrong for query: %s", q.query) + assert.Equal(t, uint64(q.rowsAffected), qr.RowsAffected, "rows affected wrong for query: %s", q.query) if q.expectedRows != "" { result := fmt.Sprintf("%v", qr.Rows) if diff := cmp.Diff(q.expectedRows, result); diff != "" { diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go index 79df8d1ae23..0276e73c9f9 100644 --- a/go/vt/sqlparser/expression_rewriting.go +++ b/go/vt/sqlparser/expression_rewriting.go @@ -39,7 +39,7 @@ type BindVarNeeds struct { NeedDatabase bool NeedFoundRows bool NeedRowCount bool - NeedUserDefinedVariables bool + NeedUserDefinedVariables []string } // RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries @@ -58,22 +58,20 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) { r := &RewriteASTResult{ AST: out, } - if _, ok := er.bindVars[LastInsertIDName]; ok { - r.NeedLastInsertID = true - } - if _, ok := er.bindVars[DBVarName]; ok { - r.NeedDatabase = true - } - if _, ok := er.bindVars[FoundRowsName]; ok { - r.NeedFoundRows = true - } - if _, ok := er.bindVars[RowCountName]; ok { - r.NeedRowCount = true - } - if _, ok := er.bindVars[UserDefinedVariableName]; ok { - r.NeedUserDefinedVariables = true + for k := range er.bindVars { + switch k { + case LastInsertIDName: + r.NeedLastInsertID = true + case DBVarName: + r.NeedDatabase = true + case FoundRowsName: + r.NeedFoundRows = true + case RowCountName: + r.NeedRowCount = true + default: + r.NeedUserDefinedVariables = append(r.NeedUserDefinedVariables, k) + } } - return r, nil } @@ -159,8 +157,9 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool { er.funcRewrite(cursor, node) case *ColName: if node.Name.at == SingleAt { - cursor.Replace(bindVarExpression(UserDefinedVariableName + strings.ToLower(node.Name.CompliantName()))) - er.needBindVarFor(UserDefinedVariableName) + udv := strings.ToLower(node.Name.CompliantName()) + cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) + er.needBindVarFor(udv) } } return true diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/expression_rewriting_test.go index 4e10c0418fd..110c4883373 100644 --- a/go/vt/sqlparser/expression_rewriting_test.go +++ b/go/vt/sqlparser/expression_rewriting_test.go @@ -23,8 +23,9 @@ import ( ) type myTestCase struct { - in, expected string - liid, db, foundRows, udv, rowCount bool + in, expected string + liid, db, foundRows, rowCount bool + udv int } func TestRewrites(in *testing.T) { @@ -97,17 +98,17 @@ func TestRewrites(in *testing.T) { { in: "select @`x y`", expected: "select :__vtudvx_y as `@``x y``` from dual", - udv: true, + udv: 1, }, { - in: "select id from t where id = @x", - expected: "select id from t where id = :__vtudvx", - db: false, udv: true, + in: "select id from t where id = @x and val = @y", + expected: "select id from t where id = :__vtudvx and val = :__vtudvy", + db: false, udv: 2, }, { in: "insert into t(id) values(@xyx)", expected: "insert into t(id) values(:__vtudvxyx)", - db: false, udv: true, + db: false, udv: 1, }, { in: "select row_count()", @@ -138,6 +139,7 @@ func TestRewrites(in *testing.T) { require.Equal(t, tc.db, result.NeedDatabase, "should need database name") require.Equal(t, tc.foundRows, result.NeedFoundRows, "should need found rows") require.Equal(t, tc.rowCount, result.NeedRowCount, "should need row count") + require.Equal(t, tc.udv, len(result.NeedUserDefinedVariables), "should need row count") }) } } diff --git a/go/vt/vtgate/evalengine/expressions.go b/go/vt/vtgate/evalengine/expressions.go index 9581dd9067e..9679a240312 100644 --- a/go/vt/vtgate/evalengine/expressions.go +++ b/go/vt/vtgate/evalengine/expressions.go @@ -287,6 +287,8 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { return evalResult{typ: sqltypes.Float64, fval: fval}, nil case sqltypes.VarChar, sqltypes.Text, sqltypes.VarBinary: return evalResult{typ: sqltypes.VarBinary, bytes: val.Value}, nil + case sqltypes.Null: + return evalResult{typ: sqltypes.Null}, nil } return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported: %s", val.Type.String()) } diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 42273fdcd9d..a3e5942c6f6 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -245,11 +245,16 @@ func (e *Executor) addNeededBindVars(bindVarNeeds sqlparser.BindVarNeeds, bindVa bindVars[sqlparser.LastInsertIDName] = sqltypes.Uint64BindVariable(session.GetLastInsertId()) } - // todo: do we need to check this map for nil? - if bindVarNeeds.NeedUserDefinedVariables && session.UserDefinedVariables != nil { - for k, v := range session.UserDefinedVariables { - bindVars[sqlparser.UserDefinedVariableName+k] = v + udvMap := session.UserDefinedVariables + if udvMap == nil { + udvMap = map[string]*querypb.BindVariable{} + } + for _, udv := range bindVarNeeds.NeedUserDefinedVariables { + val := udvMap[udv] + if val == nil { + val = sqltypes.NullBindVariable } + bindVars[sqlparser.UserDefinedVariableName+udv] = val } if bindVarNeeds.NeedFoundRows { diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 6f760832710..53d82ca3ed0 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -249,10 +249,22 @@ func TestSelectUserDefindVariable(t *testing.T) { defer QueryLogger.Unsubscribe(logChan) sql := "select @foo" - masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})} result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "@foo", Type: sqltypes.Null}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NULL, + }}, + } + utils.MustMatch(t, result, wantResult, "Mismatch") + + masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})} + result, err = executorExec(executor, sql, map[string]*querypb.BindVariable{}) + require.NoError(t, err) + wantResult = &sqltypes.Result{ Fields: []*querypb.Field{ {Name: "@foo", Type: sqltypes.VarBinary}, }, @@ -260,7 +272,6 @@ func TestSelectUserDefindVariable(t *testing.T) { sqltypes.NewVarBinary("bar"), }}, } - require.NoError(t, err) utils.MustMatch(t, result, wantResult, "Mismatch") }