Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plan for Set Statement with User Defined Variables #6035

Merged
merged 5 commits into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
return sqltypes.PlanValue{}, err
}
return sqltypes.PlanValue{Value: n}, nil
case FloatVal:
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.Float64, node.Val)}, nil
case StrVal:
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil
case HexVal:
Expand Down
8 changes: 7 additions & 1 deletion go/vt/sqlparser/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,12 @@ func TestNewPlanValue(t *testing.T) {
}, {
in: &NullVal{},
out: sqltypes.PlanValue{},
}, {
in: &SQLVal{
Type: FloatVal,
Val: []byte("2.1"),
},
out: sqltypes.PlanValue{Value: sqltypes.NewFloat64(2.1)},
}}
for _, tc := range tcases {
got, err := NewPlanValue(tc.in)
Expand All @@ -423,7 +429,7 @@ func TestNewPlanValue(t *testing.T) {
t.Error(err)
continue
}
if !reflect.DeepEqual(got, tc.out) {
if !reflect.DeepEqual(tc.out, got) {
t.Errorf("NewPlanValue(%s): %v, want %v", String(tc.in), got, tc.out)
}
}
Expand Down
7 changes: 6 additions & 1 deletion go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ type ColIdent struct {
// last field in the struct.
_ [0]struct{ _ []byte }
val, lowered string
at atCount
at AtCount
}

// TableIdent is a case sensitive SQL identifier. It will be escaped with
Expand Down Expand Up @@ -1750,3 +1750,8 @@ func (node ColIdent) Format(buf *TrackedBuffer) {
func (node TableIdent) Format(buf *TrackedBuffer) {
formatID(buf, node.v, strings.ToLower(node.v), NoAt)
}

// AtCount return the '@' count present in ColIdent Name
func (node ColIdent) AtCount() AtCount {
return node.at
}
11 changes: 6 additions & 5 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ func NewColIdent(str string) ColIdent {
}

// NewColIdentWithAt makes a new ColIdent.
func NewColIdentWithAt(str string, at atCount) ColIdent {
func NewColIdentWithAt(str string, at AtCount) ColIdent {
return ColIdent{
val: str,
at: at,
Expand Down Expand Up @@ -639,7 +639,7 @@ func (node *TableIdent) UnmarshalJSON(b []byte) error {
return nil
}

func containEscapableChars(s string, at atCount) bool {
func containEscapableChars(s string, at AtCount) bool {
isDbSystemVariable := at != NoAt

for i, c := range s {
Expand All @@ -660,7 +660,7 @@ func isKeyword(s string) bool {
return isKeyword
}

func formatID(buf *TrackedBuffer, original, lowered string, at atCount) {
func formatID(buf *TrackedBuffer, original, lowered string, at AtCount) {
if containEscapableChars(original, at) || isKeyword(lowered) {
writeEscapedString(buf, original)
} else {
Expand Down Expand Up @@ -755,11 +755,12 @@ func (node *Union) SetLimit(limit *Limit) {
node.Limit = limit
}

type atCount int
// AtCount represents the '@' count in ColIdent
type AtCount int

const (
// NoAt represents no @
NoAt atCount = iota
NoAt AtCount = iota
// SingleAt represents @
SingleAt
// DoubleAt represnts @@
Expand Down
9 changes: 9 additions & 0 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ var _ VCursor = (*noopVCursor)(nil)
type noopVCursor struct {
}

func (t noopVCursor) SetUDV(key string, value interface{}) error {
panic("implement me")
}

func (t noopVCursor) ExecuteVSchema(keyspace string, vschemaDDL *sqlparser.DDL) error {
panic("implement me")
}
Expand Down Expand Up @@ -123,6 +127,11 @@ type loggingVCursor struct {
log []string
}

func (f *loggingVCursor) SetUDV(key string, value interface{}) error {
f.log = append(f.log, fmt.Sprintf("UDV set with (%s,%v)", key, value))
return nil
}

func (f *loggingVCursor) ExecuteVSchema(keyspace string, vschemaDDL *sqlparser.DDL) error {
panic("implement me")
}
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ type VCursor interface {
// Resolver methods, from key.Destination to srvtopo.ResolvedShard.
// Will replace all of the Topo functions.
ResolveDestinations(keyspace string, ids []*querypb.Value, destinations []key.Destination) ([]*srvtopo.ResolvedShard, [][]*querypb.Value, error)

SetTarget(target string) error
SetUDV(key string, value interface{}) error

ExecuteVSchema(keyspace string, vschemaDDL *sqlparser.DDL) error
}

Expand Down
7 changes: 3 additions & 4 deletions go/vt/vtgate/engine/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package engine
import (
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/vindexes"
Expand Down Expand Up @@ -74,7 +73,7 @@ func (s *Send) GetTableName() string {
}

// Execute implements Primitive interface
func (s *Send) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, _ bool) (*sqltypes.Result, error) {
func (s *Send) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) {
rss, _, err := vcursor.ResolveDestinations(s.Keyspace.Name, nil, []key.Destination{s.TargetDestination})
if err != nil {
return nil, vterrors.Wrap(err, "sendExecute")
Expand Down Expand Up @@ -111,12 +110,12 @@ func (s *Send) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable,
}

// StreamExecute implements Primitive interface
func (s *Send) StreamExecute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error {
func (s *Send) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "not reachable") // TODO: systay - this should work
}

// GetFields implements Primitive interface
func (s *Send) GetFields(vcursor VCursor, bindVars map[string]*query.BindVariable) (*sqltypes.Result, error) {
func (s *Send) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "not reachable")
}

Expand Down
108 changes: 108 additions & 0 deletions go/vt/vtgate/engine/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
Copyright 2020 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

type (
// Set contains the instructions to perform set.
Set struct {
Ops []SetOp
noTxNeeded
noInputs
}

// SetOp is an interface that different type of set operations implements.
SetOp interface {
Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable) error
VariableName() string
}

// UserDefinedVariable implements the SetOp interface to execute user defined variables.
UserDefinedVariable struct {
Name string
PlanValue sqltypes.PlanValue
}
)

var _ Primitive = (*Set)(nil)

//RouteType implements the Primitive interface method.
func (s *Set) RouteType() string {
return "Set"
}

//GetKeyspaceName implements the Primitive interface method.
func (s *Set) GetKeyspaceName() string {
return ""
}

//GetTableName implements the Primitive interface method.
func (s *Set) GetTableName() string {
return ""
}

//Execute implements the Primitive interface method.
func (s *Set) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
for _, setOp := range s.Ops {
err := setOp.Execute(vcursor, bindVars)
if err != nil {
return nil, err
}
}
return &sqltypes.Result{}, nil
}

//StreamExecute implements the Primitive interface method.
func (s *Set) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error {
panic("implement me")
}

//GetFields implements the Primitive interface method.
func (s *Set) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
panic("implement me")
}

func (s *Set) description() PrimitiveDescription {
other := map[string]interface{}{
"Ops": s.Ops,
}
return PrimitiveDescription{
OperatorType: "Set",
Variant: "",
Other: other,
}
}

var _ SetOp = (*UserDefinedVariable)(nil)

//VariableName implements the SetOp interface method.
func (u *UserDefinedVariable) VariableName() string {
return u.Name
}

//Execute implements the SetOp interface method.
func (u *UserDefinedVariable) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable) error {
value, err := u.PlanValue.ResolveValue(bindVars)
if err != nil {
return err
}
return vcursor.SetUDV(u.Name, value)
}
10 changes: 5 additions & 5 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func TestSelectUserDefindVariable(t *testing.T) {
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{{
Sql: "select :__vtudvfoo as `@foo` from dual",
BindVariables: map[string]*querypb.BindVariable{"__vtudvfoo": sqltypes.StringBindVariable("bar")},
BindVariables: map[string]*querypb.BindVariable{"__vtudvfoo": sqltypes.BytesBindVariable([]byte("bar"))},
}}

assert.Equal(t, wantQueries, sbc1.Queries)
Expand Down Expand Up @@ -409,15 +409,15 @@ func TestSelectBindvars(t *testing.T) {
// Test with StringBindVariable
sql = "select id from user where name in (:name1, :name2)"
_, err = executorExec(executor, sql, map[string]*querypb.BindVariable{
"name1": sqltypes.StringBindVariable("foo1"),
"name2": sqltypes.StringBindVariable("foo2"),
"name1": sqltypes.BytesBindVariable([]byte("foo1")),
"name2": sqltypes.BytesBindVariable([]byte("foo2")),
})
require.NoError(t, err)
wantQueries = []*querypb.BoundQuery{{
Sql: "select id from user where name in ::__vals",
BindVariables: map[string]*querypb.BindVariable{
"name1": sqltypes.StringBindVariable("foo1"),
"name2": sqltypes.StringBindVariable("foo2"),
"name1": sqltypes.BytesBindVariable([]byte("foo1")),
"name2": sqltypes.BytesBindVariable([]byte("foo2")),
"__vals": sqltypes.TestBindVariable([]interface{}{"foo1", "foo2"}),
},
}}
Expand Down
29 changes: 7 additions & 22 deletions go/vt/vtgate/executor_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package vtgate
import (
"testing"

"vitess.io/vitess/go/test/utils"

"vitess.io/vitess/go/vt/vterrors"

"context"

"github.com/golang/protobuf/proto"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vtgate/vschemaacl"

Expand All @@ -35,18 +36,6 @@ import (
"github.com/stretchr/testify/require"
)

func createMap(keys []string, values []interface{}) map[string]*querypb.BindVariable {
result := make(map[string]*querypb.BindVariable)
for i, key := range keys {
variable, err := sqltypes.BuildBindVariable(values[i])
if err != nil {
panic(err)
}
result[key] = variable
}
return result
}

func TestExecutorSet(t *testing.T) {
executor, _, _, _ := createExecutorEnv()

Expand Down Expand Up @@ -262,21 +251,17 @@ func TestExecutorSet(t *testing.T) {
in: "set @foo = 2",
out: &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{2}), Autocommit: true},
}, {
in: "set @foo = 2.0, @bar = 'baz'",
out: &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo", "bar"}, []interface{}{2.0, "baz"}), Autocommit: true},
in: "set @foo = 2.1, @bar = 'baz'",
out: &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo", "bar"}, []interface{}{2.1, "baz"}), Autocommit: true},
}}
for _, tcase := range testcases {
t.Run(tcase.in, func(t *testing.T) {
session := NewSafeSession(&vtgatepb.Session{Autocommit: true})
_, err := executor.Execute(context.Background(), "TestExecute", session, tcase.in, nil)
if err != nil {
if err.Error() != tcase.err {
t.Errorf("%s error: %v, want %s", tcase.in, err, tcase.err)
}
return
}
if !proto.Equal(session.Session, tcase.out) {
t.Errorf("%s: %v, want %s", tcase.in, session.Session, tcase.out)
require.EqualError(t, err, tcase.err)
} else {
utils.MustMatch(t, tcase.out, session.Session, "session output was not as expected")
}
})
}
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/plan_executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func TestPlanSelectUserDefindVariable(t *testing.T) {
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{{
Sql: "select :__vtudvfoo as `@foo` from dual",
BindVariables: map[string]*querypb.BindVariable{"__vtudvfoo": sqltypes.StringBindVariable("bar")},
BindVariables: map[string]*querypb.BindVariable{"__vtudvfoo": sqltypes.BytesBindVariable([]byte("bar"))},
}}

assert.Equal(t, wantQueries, sbc1.Queries)
Expand Down Expand Up @@ -408,15 +408,15 @@ func TestPlanSelectBindvars(t *testing.T) {
// Test with StringBindVariable
sql = "select id from user where name in (:name1, :name2)"
_, err = executorExec(executor, sql, map[string]*querypb.BindVariable{
"name1": sqltypes.StringBindVariable("foo1"),
"name2": sqltypes.StringBindVariable("foo2"),
"name1": sqltypes.BytesBindVariable([]byte("foo1")),
"name2": sqltypes.BytesBindVariable([]byte("foo2")),
})
require.NoError(t, err)
wantQueries = []*querypb.BoundQuery{{
Sql: "select id from user where name in ::__vals",
BindVariables: map[string]*querypb.BindVariable{
"name1": sqltypes.StringBindVariable("foo1"),
"name2": sqltypes.StringBindVariable("foo2"),
"name1": sqltypes.BytesBindVariable([]byte("foo1")),
"name2": sqltypes.BytesBindVariable([]byte("foo2")),
"__vals": sqltypes.TestBindVariable([]interface{}{"foo1", "foo2"}),
},
}}
Expand Down
Loading