Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Add support for set DEFAULT #493

Merged
merged 1 commit into from
Oct 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ var queries = []struct {
{"sql_mode", ""},
{"gtid_mode", int32(0)},
{"collation_database", "utf8_bin"},
{"ndbinfo_version", ""},
{"sql_select_limit", math.MaxInt32},
},
},
{
Expand Down Expand Up @@ -556,6 +558,43 @@ func TestQueries(t *testing.T) {
})
}

func TestSessionDefaults(t *testing.T) {
ctx := newCtx()
ctx.Session.Set("auto_increment_increment", sql.Int64, 0)
ctx.Session.Set("max_allowed_packet", sql.Int64, 0)
ctx.Session.Set("sql_select_limit", sql.Int64, 0)
ctx.Session.Set("ndbinfo_version", sql.Text, "non default value")

q := `SET @@auto_increment_increment=DEFAULT,
@@max_allowed_packet=DEFAULT,
@@sql_select_limit=DEFAULT,
@@ndbinfo_version=DEFAULT`

e := newEngine(t)

defaults := sql.DefaultSessionConfig()
t.Run(q, func(t *testing.T) {
require := require.New(t)
_, _, err := e.Query(ctx, q)
require.NoError(err)

typ, val := ctx.Get("auto_increment_increment")
require.Equal(defaults["auto_increment_increment"].Typ, typ)
require.Equal(defaults["auto_increment_increment"].Value, val)

typ, val = ctx.Get("max_allowed_packet")
require.Equal(defaults["max_allowed_packet"].Typ, typ)
require.Equal(defaults["max_allowed_packet"].Value, val)

typ, val = ctx.Get("sql_select_limit")
require.Equal(defaults["sql_select_limit"].Typ, typ)
require.Equal(defaults["sql_select_limit"].Value, val)

typ, val = ctx.Get("ndbinfo_version")
require.Equal(defaults["ndbinfo_version"].Typ, typ)
require.Equal(defaults["ndbinfo_version"].Value, val)
})
}
func TestWarnings(t *testing.T) {
ctx := newCtx()
ctx.Session.Warn(&sql.Warning{Code: 1})
Expand Down
60 changes: 60 additions & 0 deletions sql/expression/default.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package expression

import (
"gopkg.in/src-d/go-mysql-server.v0/sql"
)

// DefaultColumn is an default expression of a column that is not yet resolved.
type DefaultColumn struct {
name string
}

// NewDefaultColumn creates a new NewDefaultColumn expression.
func NewDefaultColumn(name string) *DefaultColumn {
return &DefaultColumn{name: name}
}

// Children implements the sql.Expression interface.
// The function returns always nil
func (*DefaultColumn) Children() []sql.Expression {
return nil
}

// Resolved implements the sql.Expression interface.
// The function returns always false
func (*DefaultColumn) Resolved() bool {
return false
}

// IsNullable implements the sql.Expression interface.
// The function always panics!
func (*DefaultColumn) IsNullable() bool {
panic("default column is a placeholder node, but IsNullable was called")
}

// Type implements the sql.Expression interface.
// The function always panics!
func (*DefaultColumn) Type() sql.Type {
panic("default column is a placeholder node, but Type was called")
}

// Name implements the sql.Nameable interface.
func (c *DefaultColumn) Name() string { return c.name }

// String implements the Stringer
// The function returns column's name (can be an empty string)
func (c *DefaultColumn) String() string {
return c.name
}

// Eval implements the sql.Expression interface.
// The function always panics!
func (*DefaultColumn) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) {
panic("default column is a placeholder node, but Eval was called")
}

// TransformUp implements the sql.Expression interface.
func (c *DefaultColumn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
n := *c
return f(&n)
}
14 changes: 10 additions & 4 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) {

name := strings.TrimSpace(e.Name.Lowered())
if expr, err = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) {
if _, ok := e.(*expression.DefaultColumn); ok {
return e, nil
}

if !e.Resolved() || e.Type() != sql.Text {
return e, nil
}
Expand All @@ -176,13 +180,13 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) {
}

switch strings.ToLower(val) {
case "on":
case sqlparser.KeywordString(sqlparser.ON):
return expression.NewLiteral(int64(1), sql.Int64), nil
case "true":
case sqlparser.KeywordString(sqlparser.TRUE):
return expression.NewLiteral(true, sql.Boolean), nil
case "off":
case sqlparser.KeywordString(sqlparser.OFF):
return expression.NewLiteral(int64(0), sql.Int64), nil
case "false":
case sqlparser.KeywordString(sqlparser.FALSE):
return expression.NewLiteral(false, sql.Boolean), nil
}

Expand Down Expand Up @@ -632,6 +636,8 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) {
switch v := e.(type) {
default:
return nil, ErrUnsupportedSyntax.New(e)
case *sqlparser.Default:
return expression.NewDefaultColumn(v.ColName), nil
case *sqlparser.SubstrExpr:
name, err := exprToExpression(v.Name)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,18 @@ var fixtures = map[string]sql.Node{
Value: expression.NewLiteral(int64(700), sql.Int64),
},
),
`SET gtid_mode=DEFAULT`: plan.NewSet(
plan.SetVariable{
Name: "gtid_mode",
Value: expression.NewDefaultColumn(""),
},
),
`SET @@sql_select_limit=default`: plan.NewSet(
plan.SetVariable{
Name: "@@sql_select_limit",
Value: expression.NewDefaultColumn(""),
},
),
`/*!40101 SET NAMES utf8 */`: plan.Nothing,
`SELECT /*!40101 SET NAMES utf8 */ * FROM foo`: plan.NewProject(
[]sql.Expression{
Expand Down
29 changes: 24 additions & 5 deletions sql/plan/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
"gopkg.in/src-d/go-vitess.v1/vt/sqlparser"
)

Expand All @@ -28,6 +29,9 @@ func NewSet(vars ...SetVariable) *Set {
// Resolved implements the sql.Node interface.
func (s *Set) Resolved() bool {
for _, v := range s.Variables {
if _, ok := v.Value.(*expression.DefaultColumn); ok {
continue
}
if !v.Value.Resolved() {
return false
}
Expand Down Expand Up @@ -83,10 +87,11 @@ func (s *Set) RowIter(ctx *sql.Context) (sql.RowIter, error) {
globalPrefix = sqlparser.GlobalStr + "."
)
for _, v := range s.Variables {
value, err := v.Value.Eval(ctx, nil)
if err != nil {
return nil, err
}
var (
value interface{}
typ sql.Type
err error
)

name := strings.TrimLeft(v.Name, "@")
if strings.HasPrefix(name, sessionPrefix) {
Expand All @@ -95,7 +100,21 @@ func (s *Set) RowIter(ctx *sql.Context) (sql.RowIter, error) {
name = name[len(globalPrefix):]
}

ctx.Set(name, v.Value.Type(), value)
if _, ok := v.Value.(*expression.DefaultColumn); ok {
valtyp, ok := sql.DefaultSessionConfig()[name]
if !ok {
continue
}
value, typ = valtyp.Value, valtyp.Typ
} else {
value, err = v.Value.Eval(ctx, nil)
if err != nil {
return nil, err
}
typ = v.Value.Type()
}

ctx.Set(name, typ, value)
}

return sql.RowsToRowIter(), nil
Expand Down
41 changes: 41 additions & 0 deletions sql/plan/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,44 @@ func TestSet(t *testing.T) {
require.Equal(sql.Int64, typ)
require.Equal(int64(1), v)
}

func TestSetDesfault(t *testing.T) {
require := require.New(t)

ctx := sql.NewContext(context.Background(), sql.WithSession(sql.NewBaseSession()))

s := NewSet(
SetVariable{"auto_increment_increment", expression.NewLiteral(int64(123), sql.Int64)},
SetVariable{"@@sql_select_limit", expression.NewLiteral(int64(1), sql.Int64)},
)

_, err := s.RowIter(ctx)
require.NoError(err)

typ, v := ctx.Get("auto_increment_increment")
require.Equal(sql.Int64, typ)
require.Equal(int64(123), v)

typ, v = ctx.Get("sql_select_limit")
require.Equal(sql.Int64, typ)
require.Equal(int64(1), v)

s = NewSet(
SetVariable{"auto_increment_increment", expression.NewDefaultColumn("")},
SetVariable{"@@sql_select_limit", expression.NewDefaultColumn("")},
)

_, err = s.RowIter(ctx)
require.NoError(err)

defaults := sql.DefaultSessionConfig()

typ, v = ctx.Get("auto_increment_increment")
require.Equal(defaults["auto_increment_increment"].Typ, typ)
require.Equal(defaults["auto_increment_increment"].Value, v)

typ, v = ctx.Get("sql_select_limit")
require.Equal(defaults["sql_select_limit"].Typ, typ)
require.Equal(defaults["sql_select_limit"].Value, v)

}
9 changes: 6 additions & 3 deletions sql/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ type (
}
)

func defaultSessionConfig() map[string]TypedValue {
// DefaultSessionConfig returns default values for session variables
func DefaultSessionConfig() map[string]TypedValue {
return map[string]TypedValue{
"auto_increment_increment": TypedValue{Int64, int64(1)},
"time_zone": TypedValue{Text, time.Local.String()},
Expand All @@ -146,6 +147,8 @@ func defaultSessionConfig() map[string]TypedValue {
"sql_mode": TypedValue{Text, ""},
"gtid_mode": TypedValue{Int32, int32(0)},
"collation_database": TypedValue{Text, "utf8_bin"},
"ndbinfo_version": TypedValue{Text, ""},
"sql_select_limit": TypedValue{Int32, math.MaxInt32},
}
}

Expand All @@ -155,13 +158,13 @@ func NewSession(address string, user string, id uint32) Session {
id: id,
addr: address,
user: user,
config: defaultSessionConfig(),
config: DefaultSessionConfig(),
}
}

// NewBaseSession creates a new empty session.
func NewBaseSession() Session {
return &BaseSession{config: defaultSessionConfig()}
return &BaseSession{config: DefaultSessionConfig()}
}

// Context of the query execution.
Expand Down