diff --git a/engine_test.go b/engine_test.go index 100419ed7..81824b0db 100644 --- a/engine_test.go +++ b/engine_test.go @@ -448,6 +448,8 @@ var queries = []struct { {"sql_mode", ""}, {"gtid_mode", int32(0)}, {"collation_database", "utf8_bin"}, + {"ndbinfo_version", ""}, + {"sql_select_limit", math.MaxInt32}, }, }, { @@ -556,6 +558,43 @@ func TestQueries(t *testing.T) { }) } +func TestSessionDefaults(t *testing.T) { + ctx := newCtx() + ctx.Session.Set("auto_increment_increment", sql.Int64, 0) + ctx.Session.Set("max_allowed_packet", sql.Int64, 0) + ctx.Session.Set("sql_select_limit", sql.Int64, 0) + ctx.Session.Set("ndbinfo_version", sql.Text, "non default value") + + q := `SET @@auto_increment_increment=DEFAULT, + @@max_allowed_packet=DEFAULT, + @@sql_select_limit=DEFAULT, + @@ndbinfo_version=DEFAULT` + + e := newEngine(t) + + defaults := sql.DefaultSessionConfig() + t.Run(q, func(t *testing.T) { + require := require.New(t) + _, _, err := e.Query(ctx, q) + require.NoError(err) + + typ, val := ctx.Get("auto_increment_increment") + require.Equal(defaults["auto_increment_increment"].Typ, typ) + require.Equal(defaults["auto_increment_increment"].Value, val) + + typ, val = ctx.Get("max_allowed_packet") + require.Equal(defaults["max_allowed_packet"].Typ, typ) + require.Equal(defaults["max_allowed_packet"].Value, val) + + typ, val = ctx.Get("sql_select_limit") + require.Equal(defaults["sql_select_limit"].Typ, typ) + require.Equal(defaults["sql_select_limit"].Value, val) + + typ, val = ctx.Get("ndbinfo_version") + require.Equal(defaults["ndbinfo_version"].Typ, typ) + require.Equal(defaults["ndbinfo_version"].Value, val) + }) +} func TestWarnings(t *testing.T) { ctx := newCtx() ctx.Session.Warn(&sql.Warning{Code: 1}) diff --git a/sql/expression/default.go b/sql/expression/default.go new file mode 100644 index 000000000..7d1153beb --- /dev/null +++ b/sql/expression/default.go @@ -0,0 +1,60 @@ +package expression + +import ( + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +// DefaultColumn is an default expression of a column that is not yet resolved. +type DefaultColumn struct { + name string +} + +// NewDefaultColumn creates a new NewDefaultColumn expression. +func NewDefaultColumn(name string) *DefaultColumn { + return &DefaultColumn{name: name} +} + +// Children implements the sql.Expression interface. +// The function returns always nil +func (*DefaultColumn) Children() []sql.Expression { + return nil +} + +// Resolved implements the sql.Expression interface. +// The function returns always false +func (*DefaultColumn) Resolved() bool { + return false +} + +// IsNullable implements the sql.Expression interface. +// The function always panics! +func (*DefaultColumn) IsNullable() bool { + panic("default column is a placeholder node, but IsNullable was called") +} + +// Type implements the sql.Expression interface. +// The function always panics! +func (*DefaultColumn) Type() sql.Type { + panic("default column is a placeholder node, but Type was called") +} + +// Name implements the sql.Nameable interface. +func (c *DefaultColumn) Name() string { return c.name } + +// String implements the Stringer +// The function returns column's name (can be an empty string) +func (c *DefaultColumn) String() string { + return c.name +} + +// Eval implements the sql.Expression interface. +// The function always panics! +func (*DefaultColumn) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { + panic("default column is a placeholder node, but Eval was called") +} + +// TransformUp implements the sql.Expression interface. +func (c *DefaultColumn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + n := *c + return f(&n) +} diff --git a/sql/parse/parse.go b/sql/parse/parse.go index dd1867e68..6f5589496 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -161,6 +161,10 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { name := strings.TrimSpace(e.Name.Lowered()) if expr, err = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) { + if _, ok := e.(*expression.DefaultColumn); ok { + return e, nil + } + if !e.Resolved() || e.Type() != sql.Text { return e, nil } @@ -176,13 +180,13 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { } switch strings.ToLower(val) { - case "on": + case sqlparser.KeywordString(sqlparser.ON): return expression.NewLiteral(int64(1), sql.Int64), nil - case "true": + case sqlparser.KeywordString(sqlparser.TRUE): return expression.NewLiteral(true, sql.Boolean), nil - case "off": + case sqlparser.KeywordString(sqlparser.OFF): return expression.NewLiteral(int64(0), sql.Int64), nil - case "false": + case sqlparser.KeywordString(sqlparser.FALSE): return expression.NewLiteral(false, sql.Boolean), nil } @@ -632,6 +636,8 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { switch v := e.(type) { default: return nil, ErrUnsupportedSyntax.New(e) + case *sqlparser.Default: + return expression.NewDefaultColumn(v.ColName), nil case *sqlparser.SubstrExpr: name, err := exprToExpression(v.Name) if err != nil { diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index cce64a83a..d06b256e9 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -729,6 +729,18 @@ var fixtures = map[string]sql.Node{ Value: expression.NewLiteral(int64(700), sql.Int64), }, ), + `SET gtid_mode=DEFAULT`: plan.NewSet( + plan.SetVariable{ + Name: "gtid_mode", + Value: expression.NewDefaultColumn(""), + }, + ), + `SET @@sql_select_limit=default`: plan.NewSet( + plan.SetVariable{ + Name: "@@sql_select_limit", + Value: expression.NewDefaultColumn(""), + }, + ), `/*!40101 SET NAMES utf8 */`: plan.Nothing, `SELECT /*!40101 SET NAMES utf8 */ * FROM foo`: plan.NewProject( []sql.Expression{ diff --git a/sql/plan/set.go b/sql/plan/set.go index ae47d353f..0c852587a 100644 --- a/sql/plan/set.go +++ b/sql/plan/set.go @@ -5,6 +5,7 @@ import ( "strings" "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" "gopkg.in/src-d/go-vitess.v1/vt/sqlparser" ) @@ -28,6 +29,9 @@ func NewSet(vars ...SetVariable) *Set { // Resolved implements the sql.Node interface. func (s *Set) Resolved() bool { for _, v := range s.Variables { + if _, ok := v.Value.(*expression.DefaultColumn); ok { + continue + } if !v.Value.Resolved() { return false } @@ -83,10 +87,11 @@ func (s *Set) RowIter(ctx *sql.Context) (sql.RowIter, error) { globalPrefix = sqlparser.GlobalStr + "." ) for _, v := range s.Variables { - value, err := v.Value.Eval(ctx, nil) - if err != nil { - return nil, err - } + var ( + value interface{} + typ sql.Type + err error + ) name := strings.TrimLeft(v.Name, "@") if strings.HasPrefix(name, sessionPrefix) { @@ -95,7 +100,21 @@ func (s *Set) RowIter(ctx *sql.Context) (sql.RowIter, error) { name = name[len(globalPrefix):] } - ctx.Set(name, v.Value.Type(), value) + if _, ok := v.Value.(*expression.DefaultColumn); ok { + valtyp, ok := sql.DefaultSessionConfig()[name] + if !ok { + continue + } + value, typ = valtyp.Value, valtyp.Typ + } else { + value, err = v.Value.Eval(ctx, nil) + if err != nil { + return nil, err + } + typ = v.Value.Type() + } + + ctx.Set(name, typ, value) } return sql.RowsToRowIter(), nil diff --git a/sql/plan/set_test.go b/sql/plan/set_test.go index 3f8b97034..7241b64ff 100644 --- a/sql/plan/set_test.go +++ b/sql/plan/set_test.go @@ -30,3 +30,44 @@ func TestSet(t *testing.T) { require.Equal(sql.Int64, typ) require.Equal(int64(1), v) } + +func TestSetDesfault(t *testing.T) { + require := require.New(t) + + ctx := sql.NewContext(context.Background(), sql.WithSession(sql.NewBaseSession())) + + s := NewSet( + SetVariable{"auto_increment_increment", expression.NewLiteral(int64(123), sql.Int64)}, + SetVariable{"@@sql_select_limit", expression.NewLiteral(int64(1), sql.Int64)}, + ) + + _, err := s.RowIter(ctx) + require.NoError(err) + + typ, v := ctx.Get("auto_increment_increment") + require.Equal(sql.Int64, typ) + require.Equal(int64(123), v) + + typ, v = ctx.Get("sql_select_limit") + require.Equal(sql.Int64, typ) + require.Equal(int64(1), v) + + s = NewSet( + SetVariable{"auto_increment_increment", expression.NewDefaultColumn("")}, + SetVariable{"@@sql_select_limit", expression.NewDefaultColumn("")}, + ) + + _, err = s.RowIter(ctx) + require.NoError(err) + + defaults := sql.DefaultSessionConfig() + + typ, v = ctx.Get("auto_increment_increment") + require.Equal(defaults["auto_increment_increment"].Typ, typ) + require.Equal(defaults["auto_increment_increment"].Value, v) + + typ, v = ctx.Get("sql_select_limit") + require.Equal(defaults["sql_select_limit"].Typ, typ) + require.Equal(defaults["sql_select_limit"].Value, v) + +} diff --git a/sql/session.go b/sql/session.go index 3753380bd..53ca7d246 100644 --- a/sql/session.go +++ b/sql/session.go @@ -137,7 +137,8 @@ type ( } ) -func defaultSessionConfig() map[string]TypedValue { +// DefaultSessionConfig returns default values for session variables +func DefaultSessionConfig() map[string]TypedValue { return map[string]TypedValue{ "auto_increment_increment": TypedValue{Int64, int64(1)}, "time_zone": TypedValue{Text, time.Local.String()}, @@ -146,6 +147,8 @@ func defaultSessionConfig() map[string]TypedValue { "sql_mode": TypedValue{Text, ""}, "gtid_mode": TypedValue{Int32, int32(0)}, "collation_database": TypedValue{Text, "utf8_bin"}, + "ndbinfo_version": TypedValue{Text, ""}, + "sql_select_limit": TypedValue{Int32, math.MaxInt32}, } } @@ -155,13 +158,13 @@ func NewSession(address string, user string, id uint32) Session { id: id, addr: address, user: user, - config: defaultSessionConfig(), + config: DefaultSessionConfig(), } } // NewBaseSession creates a new empty session. func NewBaseSession() Session { - return &BaseSession{config: defaultSessionConfig()} + return &BaseSession{config: DefaultSessionConfig()} } // Context of the query execution.