From 69d26d997997fe25b108b411d525a8efafb77b1a Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Wed, 10 Apr 2019 11:14:06 +0200 Subject: [PATCH] sql: add support for intervals, DATE_SUB and DATE_ADD Fixes #663 This PR introduces a few new features: - Correctly parsing interval expressions. - Allow using the + and - operators to subtract from and add to dates. e.g. `'2019-04-10' - INTERVAL 1 DAY`. - New `DATE_ADD` function, which is essentially the same as `DATE + INTERVAL`. - New `DATE_SUB` function, which is essentially the same as `DATE - INTERVAL`. - Validation rule to ensure intervals are only used in certain specific places, such as DATE_SUB, DATE_ADD, + and -. Using it anywhere else is not valid SQL, but vitess does not catch those errors. Plus, even if it's an expression for convenience, its `Eval` method is a stub and panics, so it should not be used unless it's in an expression that knows how to deal with intervals. Signed-off-by: Miguel Molina --- README.md | 2 + SUPPORTED.md | 7 +- engine_test.go | 16 ++ sql/analyzer/validation_rules.go | 38 ++++ sql/analyzer/validation_rules_test.go | 124 +++++++++++ sql/expression/arithmetic.go | 88 ++++++-- sql/expression/arithmetic_test.go | 38 ++++ sql/expression/function/date.go | 177 ++++++++++++++++ sql/expression/function/date_test.go | 87 ++++++++ sql/expression/function/registry.go | 2 + sql/expression/interval.go | 286 ++++++++++++++++++++++++++ sql/expression/interval_test.go | 285 +++++++++++++++++++++++++ sql/parse/parse.go | 21 ++ sql/parse/parse_test.go | 57 +++++ 14 files changed, 1206 insertions(+), 22 deletions(-) create mode 100644 sql/expression/function/date.go create mode 100644 sql/expression/function/date_test.go create mode 100644 sql/expression/interval.go create mode 100644 sql/expression/interval_test.go diff --git a/README.md b/README.md index bc09e3f53..c2f62cba7 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,8 @@ We support and actively test against certain third-party clients to ensure compa |`CONCAT_WS(sep, ...)`|Concatenate any group of fields into a single string. The first argument is the separator for the rest of the arguments. The separator is added between the strings to be concatenated. The separator can be a string, as can the rest of the arguments. If the separator is NULL, the result is NULL.| |`CONNECTION_ID()`|Return the current connection ID.| |`COUNT(expr)`| Returns a count of the number of non-NULL values of expr in the rows retrieved by a SELECT statement.| +|`DATE_ADD(date, interval)`|Adds the interval to the given date.| +|`DATE_SUB(date, interval)`|Subtracts the interval from the given date.| |`DAY(date)`|Returns the day of the given date.| |`DAYOFWEEK(date)`|Returns the day of the week of the given date.| |`DAYOFYEAR(date)`|Returns the day of the year of the given date.| diff --git a/SUPPORTED.md b/SUPPORTED.md index 84d33a0e0..01f79f7a2 100644 --- a/SUPPORTED.md +++ b/SUPPORTED.md @@ -50,6 +50,7 @@ - USE - SHOW DATABASES - SHOW WARNINGS +- INTERVALS ## Index expressions - CREATE INDEX (an index can be created using either column names or a single arbitrary expression). @@ -67,8 +68,8 @@ - OR ## Arithmetic expressions -- \+ -- \- +- \+ (including between dates and intervals) +- \- (including between dates and intervals) - \* - \\ - << @@ -121,3 +122,5 @@ - SECOND - YEAR - NOW +- DATE_ADD +- DATE_SUB diff --git a/engine_test.go b/engine_test.go index dc9495031..ce7763a14 100644 --- a/engine_test.go +++ b/engine_test.go @@ -877,6 +877,22 @@ var queries = []struct { "SELECT substring(s, 1, 1), count(*) FROM mytable GROUP BY substring(s, 1, 1)", []sql.Row{{"f", int64(1)}, {"s", int64(1)}, {"t", int64(1)}}, }, + { + "SELECT DATE_ADD('2018-05-02', INTERVAL 1 DAY)", + []sql.Row{{time.Date(2018, time.May, 3, 0, 0, 0, 0, time.UTC)}}, + }, + { + "SELECT DATE_SUB('2018-05-02', INTERVAL 1 DAY)", + []sql.Row{{time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC)}}, + }, + { + "SELECT '2018-05-02' + INTERVAL 1 DAY", + []sql.Row{{time.Date(2018, time.May, 3, 0, 0, 0, 0, time.UTC)}}, + }, + { + "SELECT '2018-05-02' - INTERVAL 1 DAY", + []sql.Row{{time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC)}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 500b3608c..c1cbadcc5 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -6,6 +6,7 @@ import ( errors "gopkg.in/src-d/go-errors.v1" "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function" "gopkg.in/src-d/go-mysql-server.v0/sql/plan" ) @@ -17,6 +18,7 @@ const ( validateProjectTuplesRule = "validate_project_tuples" validateIndexCreationRule = "validate_index_creation" validateCaseResultTypesRule = "validate_case_result_types" + validateIntervalUsageRule = "validate_interval_usage" ) var ( @@ -43,6 +45,12 @@ var ( "expecting all case branches to return values of type %s, " + "but found value %q of type %s on %s", ) + // ErrIntervalInvalidUse is returned when an interval expression is not + // correctly used. + ErrIntervalInvalidUse = errors.NewKind( + "invalid use of an interval, which can only be used with DATE_ADD, " + + "DATE_SUB and +/- operators to subtract from or add to a date", + ) ) // DefaultValidationRules to apply while analyzing nodes. @@ -54,6 +62,7 @@ var DefaultValidationRules = []Rule{ {validateProjectTuplesRule, validateProjectTuples}, {validateIndexCreationRule, validateIndexCreation}, {validateCaseResultTypesRule, validateCaseResultTypes}, + {validateIntervalUsageRule, validateIntervalUsage}, } func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -243,6 +252,35 @@ func validateCaseResultTypes(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Nod return n, nil } +func validateIntervalUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + var invalid bool + plan.InspectExpressions(n, func(e sql.Expression) bool { + // If it's already invalid just skip everything else. + if invalid { + return false + } + + switch e := e.(type) { + case *function.DateAdd, *function.DateSub: + return false + case *expression.Arithmetic: + if e.Op == "+" || e.Op == "-" { + return false + } + case *expression.Interval: + invalid = true + } + + return true + }) + + if invalid { + return nil, ErrIntervalInvalidUse.New() + } + + return n, nil +} + func stringContains(strs []string, target string) bool { for _, s := range strs { if s == target { diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 8715e36f9..85993b84e 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -6,6 +6,7 @@ import ( "gopkg.in/src-d/go-mysql-server.v0/mem" "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function" "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation" "gopkg.in/src-d/go-mysql-server.v0/sql/plan" @@ -431,6 +432,129 @@ func TestValidateCaseResultTypes(t *testing.T) { } } +func mustFunc(e sql.Expression, err error) sql.Expression { + if err != nil { + panic(err) + } + return e +} + +func TestValidateIntervalUsage(t *testing.T) { + testCases := []struct { + name string + node sql.Node + ok bool + }{ + { + "date add", + plan.NewProject( + []sql.Expression{ + mustFunc(function.NewDateAdd( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + )), + }, + plan.NewUnresolvedTable("dual", ""), + ), + true, + }, + { + "date sub", + plan.NewProject( + []sql.Expression{ + mustFunc(function.NewDateSub( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + )), + }, + plan.NewUnresolvedTable("dual", ""), + ), + true, + }, + { + "+ op", + plan.NewProject( + []sql.Expression{ + expression.NewPlus( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + true, + }, + { + "- op", + plan.NewProject( + []sql.Expression{ + expression.NewMinus( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + true, + }, + { + "invalid", + plan.NewProject( + []sql.Expression{ + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + false, + }, + { + "alias", + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + false, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + _, err := validateIntervalUsage(sql.NewEmptyContext(), nil, tt.node) + if tt.ok { + require.NoError(err) + } else { + require.Error(err) + require.True(ErrIntervalInvalidUse.Is(err)) + } + }) + } +} + type dummyNode struct{ resolved bool } func (n dummyNode) String() string { return "dummynode" } diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 005bd6260..a0ccfe2e5 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -3,6 +3,7 @@ package expression import ( "fmt" "reflect" + "time" errors "gopkg.in/src-d/go-errors.v1" "gopkg.in/src-d/go-vitess.v1/vt/sqlparser" @@ -21,7 +22,7 @@ var ( // Arithmetic expressions (+, -, *, /, ...) type Arithmetic struct { BinaryExpression - op string + Op string } // NewArithmetic creates a new Arithmetic sql.Expression. @@ -85,13 +86,19 @@ func NewMod(left, right sql.Expression) *Arithmetic { } func (a *Arithmetic) String() string { - return fmt.Sprintf("%s %s %s", a.Left, a.op, a.Right) + return fmt.Sprintf("%s %s %s", a.Left, a.Op, a.Right) } // Type returns the greatest type for given operation. func (a *Arithmetic) Type() sql.Type { - switch a.op { + switch a.Op { case sqlparser.PlusStr, sqlparser.MinusStr, sqlparser.MultStr, sqlparser.DivStr: + _, lok := a.Left.(*Interval) + _, rok := a.Right.(*Interval) + if lok || rok { + return sql.Timestamp + } + if sql.IsInteger(a.Left.Type()) && sql.IsInteger(a.Right.Type()) { if sql.IsUnsigned(a.Left.Type()) && sql.IsUnsigned(a.Right.Type()) { return sql.Uint64 @@ -126,7 +133,7 @@ func (a *Arithmetic) TransformUp(f sql.TransformExprFunc) (sql.Expression, error return nil, err } - return f(NewArithmetic(l, r, a.op)) + return f(NewArithmetic(l, r, a.Op)) } // Eval implements the Expression interface. @@ -141,7 +148,7 @@ func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - switch a.op { + switch a.Op { case sqlparser.PlusStr: return plus(lval, rval) case sqlparser.MinusStr: @@ -166,37 +173,63 @@ func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return mod(lval, rval) } - return nil, errUnableToEval.New(lval, a.op, rval) + return nil, errUnableToEval.New(lval, a.Op, rval) } func (a *Arithmetic) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { - lval, err := a.Left.Eval(ctx, row) - if err != nil { - return nil, nil, err + var lval, rval interface{} + var err error + + if i, ok := a.Left.(*Interval); ok { + lval, err = i.EvalDelta(ctx, row) + if err != nil { + return nil, nil, err + } + } else { + lval, err = a.Left.Eval(ctx, row) + if err != nil { + return nil, nil, err + } } - rval, err := a.Right.Eval(ctx, row) - if err != nil { - return nil, nil, err + if i, ok := a.Right.(*Interval); ok { + rval, err = i.EvalDelta(ctx, row) + if err != nil { + return nil, nil, err + } + } else { + rval, err = a.Right.Eval(ctx, row) + if err != nil { + return nil, nil, err + } } return lval, rval, nil } -func (a *Arithmetic) convertLeftRight(lval interface{}, rval interface{}) (interface{}, interface{}, error) { +func (a *Arithmetic) convertLeftRight(left interface{}, right interface{}) (interface{}, interface{}, error) { + var err error typ := a.Type() - lval64, err := typ.Convert(lval) - if err != nil { - return nil, nil, err + if i, ok := left.(*TimeDelta); ok { + left = i + } else { + left, err = typ.Convert(left) + if err != nil { + return nil, nil, err + } } - rval64, err := typ.Convert(rval) - if err != nil { - return nil, nil, err + if i, ok := right.(*TimeDelta); ok { + right = i + } else { + right, err = typ.Convert(right) + if err != nil { + return nil, nil, err + } } - return lval64, rval64, nil + return left, right, nil } func plus(lval, rval interface{}) (interface{}, error) { @@ -218,6 +251,16 @@ func plus(lval, rval interface{}) (interface{}, error) { case float64: return l + r, nil } + case time.Time: + switch r := rval.(type) { + case *TimeDelta: + return r.Add(l), nil + } + case *TimeDelta: + switch r := rval.(type) { + case time.Time: + return l.Add(r), nil + } } return nil, errUnableToCast.New(lval, rval) @@ -242,6 +285,11 @@ func minus(lval, rval interface{}) (interface{}, error) { case float64: return l - r, nil } + case time.Time: + switch r := rval.(type) { + case *TimeDelta: + return r.Sub(l), nil + } } return nil, errUnableToCast.New(lval, rval) diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index d56047504..44b1e757c 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -2,6 +2,7 @@ package expression import ( "testing" + "time" "github.com/stretchr/testify/require" "gopkg.in/src-d/go-mysql-server.v0/sql" @@ -38,6 +39,29 @@ func TestPlus(t *testing.T) { require.Equal(float64(5), result) } +func TestPlusInterval(t *testing.T) { + require := require.New(t) + + expected := time.Date(2018, time.May, 2, 0, 0, 0, 0, time.UTC) + op := NewPlus( + NewLiteral("2018-05-01", sql.Text), + NewInterval(NewLiteral(int64(1), sql.Int64), "DAY"), + ) + + result, err := op.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(expected, result) + + op = NewPlus( + NewInterval(NewLiteral(int64(1), sql.Int64), "DAY"), + NewLiteral("2018-05-01", sql.Text), + ) + + result, err = op.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(expected, result) +} + func TestMinus(t *testing.T) { var testCases = []struct { name string @@ -69,6 +93,20 @@ func TestMinus(t *testing.T) { require.Equal(float64(0), result) } +func TestMinusInterval(t *testing.T) { + require := require.New(t) + + expected := time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC) + op := NewMinus( + NewLiteral("2018-05-02", sql.Text), + NewInterval(NewLiteral(int64(1), sql.Int64), "DAY"), + ) + + result, err := op.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(expected, result) +} + func TestMult(t *testing.T) { var testCases = []struct { name string diff --git a/sql/expression/function/date.go b/sql/expression/function/date.go new file mode 100644 index 000000000..116eaead6 --- /dev/null +++ b/sql/expression/function/date.go @@ -0,0 +1,177 @@ +package function + +import ( + "fmt" + "time" + + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" +) + +// DateAdd adds an interval to a date. +type DateAdd struct { + Date sql.Expression + Interval *expression.Interval +} + +// NewDateAdd creates a new date add function. +func NewDateAdd(args ...sql.Expression) (sql.Expression, error) { + if len(args) != 2 { + return nil, sql.ErrInvalidArgumentNumber.New("DATE_ADD", 2, len(args)) + } + + i, ok := args[1].(*expression.Interval) + if !ok { + return nil, fmt.Errorf("DATE_ADD expects an interval as second parameter") + } + + return &DateAdd{args[0], i}, nil +} + +// Children implements the sql.Expression interface. +func (d *DateAdd) Children() []sql.Expression { + return []sql.Expression{d.Date, d.Interval} +} + +// Resolved implements the sql.Expression interface. +func (d *DateAdd) Resolved() bool { + return d.Date.Resolved() && d.Interval.Resolved() +} + +// IsNullable implements the sql.Expression interface. +func (d *DateAdd) IsNullable() bool { + return d.Date.IsNullable() || d.Interval.IsNullable() +} + +// Type implements the sql.Expression interface. +func (d *DateAdd) Type() sql.Type { return sql.Date } + +// TransformUp implements the sql.Expression interface. +func (d *DateAdd) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + date, err := d.Date.TransformUp(f) + if err != nil { + return nil, err + } + interval, err := d.Interval.TransformUp(f) + if err != nil { + return nil, err + } + + return &DateAdd{date, interval.(*expression.Interval)}, nil +} + +// Eval implements the sql.Expression interface. +func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + date, err := d.Date.Eval(ctx, row) + if err != nil { + return nil, err + } + + if date == nil { + return nil, nil + } + + date, err = sql.Timestamp.Convert(date) + if err != nil { + return nil, err + } + + delta, err := d.Interval.EvalDelta(ctx, row) + if err != nil { + return nil, err + } + + if delta == nil { + return nil, nil + } + + return delta.Add(date.(time.Time)), nil +} + +func (d *DateAdd) String() string { + return fmt.Sprintf("DATE_ADD(%s, %s)", d.Date, d.Interval) +} + +// DateSub subtracts an interval from a date. +type DateSub struct { + Date sql.Expression + Interval *expression.Interval +} + +// NewDateSub creates a new date add function. +func NewDateSub(args ...sql.Expression) (sql.Expression, error) { + if len(args) != 2 { + return nil, sql.ErrInvalidArgumentNumber.New("DATE_SUB", 2, len(args)) + } + + i, ok := args[1].(*expression.Interval) + if !ok { + return nil, fmt.Errorf("DATE_SUB expects an interval as second parameter") + } + + return &DateSub{args[0], i}, nil +} + +// Children implements the sql.Expression interface. +func (d *DateSub) Children() []sql.Expression { + return []sql.Expression{d.Date, d.Interval} +} + +// Resolved implements the sql.Expression interface. +func (d *DateSub) Resolved() bool { + return d.Date.Resolved() && d.Interval.Resolved() +} + +// IsNullable implements the sql.Expression interface. +func (d *DateSub) IsNullable() bool { + return d.Date.IsNullable() || d.Interval.IsNullable() +} + +// Type implements the sql.Expression interface. +func (d *DateSub) Type() sql.Type { return sql.Date } + +// TransformUp implements the sql.Expression interface. +func (d *DateSub) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + date, err := d.Date.TransformUp(f) + if err != nil { + return nil, err + } + interval, err := d.Interval.TransformUp(f) + if err != nil { + return nil, err + } + + return &DateSub{date, interval.(*expression.Interval)}, nil +} + +// Eval implements the sql.Expression interface. +func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + date, err := d.Date.Eval(ctx, row) + if err != nil { + return nil, err + } + + if date == nil { + return nil, nil + } + + date, err = sql.Timestamp.Convert(date) + if err != nil { + return nil, err + } + + delta, err := d.Interval.EvalDelta(ctx, row) + if err != nil { + return nil, err + } + + if delta == nil { + return nil, nil + } + + return delta.Sub(date.(time.Time)), nil +} + +func (d *DateSub) String() string { + return fmt.Sprintf("DATE_SUB(%s, %s)", d.Date, d.Interval) +} diff --git a/sql/expression/function/date_test.go b/sql/expression/function/date_test.go new file mode 100644 index 000000000..bc1ca61f1 --- /dev/null +++ b/sql/expression/function/date_test.go @@ -0,0 +1,87 @@ +package function + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" +) + +func TestDateAdd(t *testing.T) { + require := require.New(t) + + _, err := NewDateAdd() + require.Error(err) + + _, err = NewDateAdd(expression.NewLiteral("2018-05-02", sql.Text)) + require.Error(err) + + _, err = NewDateAdd( + expression.NewLiteral("2018-05-02", sql.Text), + expression.NewLiteral(int64(1), sql.Int64), + ) + require.Error(err) + + f, err := NewDateAdd( + expression.NewGetField(0, sql.Text, "foo", false), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + ) + require.NoError(err) + + ctx := sql.NewEmptyContext() + expected := time.Date(2018, time.May, 3, 0, 0, 0, 0, time.UTC) + + result, err := f.Eval(ctx, sql.Row{"2018-05-02"}) + require.NoError(err) + require.Equal(expected, result) + + result, err = f.Eval(ctx, sql.Row{nil}) + require.NoError(err) + require.Nil(result) + + _, err = f.Eval(ctx, sql.Row{"asdasdasd"}) + require.Error(err) +} +func TestDateSub(t *testing.T) { + require := require.New(t) + + _, err := NewDateSub() + require.Error(err) + + _, err = NewDateSub(expression.NewLiteral("2018-05-02", sql.Text)) + require.Error(err) + + _, err = NewDateSub( + expression.NewLiteral("2018-05-02", sql.Text), + expression.NewLiteral(int64(1), sql.Int64), + ) + require.Error(err) + + f, err := NewDateSub( + expression.NewGetField(0, sql.Text, "foo", false), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + ) + require.NoError(err) + + ctx := sql.NewEmptyContext() + expected := time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC) + + result, err := f.Eval(ctx, sql.Row{"2018-05-02"}) + require.NoError(err) + require.Equal(expected, result) + + result, err = f.Eval(ctx, sql.Row{nil}) + require.NoError(err) + require.Nil(result) + + _, err = f.Eval(ctx, sql.Row{"asdasdasd"}) + require.Error(err) +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 4248e6874..35964e2d8 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -74,4 +74,6 @@ var Defaults = []sql.Function{ sql.Function2{Name: "ifnull", Fn: NewIfNull}, sql.Function2{Name: "nullif", Fn: NewNullIf}, sql.Function0{Name: "now", Fn: NewNow}, + sql.FunctionN{Name: "date_add", Fn: NewDateAdd}, + sql.FunctionN{Name: "date_sub", Fn: NewDateSub}, } diff --git a/sql/expression/interval.go b/sql/expression/interval.go new file mode 100644 index 000000000..0053b8391 --- /dev/null +++ b/sql/expression/interval.go @@ -0,0 +1,286 @@ +package expression + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" + + errors "gopkg.in/src-d/go-errors.v1" + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +// Interval defines a time duration. +type Interval struct { + UnaryExpression + Unit string +} + +// NewInterval creates a new interval expression. +func NewInterval(child sql.Expression, unit string) *Interval { + return &Interval{UnaryExpression{Child: child}, strings.ToUpper(unit)} +} + +// Type implements the sql.Expression interface. +func (i *Interval) Type() sql.Type { return sql.Uint64 } + +// IsNullable implements the sql.Expression interface. +func (i *Interval) IsNullable() bool { return i.Child.IsNullable() } + +// Eval implements the sql.Expression interface. +func (i *Interval) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + panic("Interval.Eval is just a placeholder method and should not be called directly") +} + +var ( + errInvalidIntervalUnit = errors.NewKind("invalid interval unit: %s") + errInvalidIntervalFormat = errors.NewKind("invalid interval format for %q: %s") +) + +// EvalDelta evaluates the expression returning a TimeDelta. This method should +// be used instead of Eval, as this expression returns a TimeDelta, which is not +// a valid value that can be returned in Eval. +func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error) { + val, err := i.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if val == nil { + return nil, nil + } + + var td TimeDelta + + if r, ok := unitTextFormats[i.Unit]; ok { + val, err = sql.Text.Convert(val) + if err != nil { + return nil, err + } + + text := val.(string) + if !r.MatchString(text) { + return nil, errInvalidIntervalFormat.New(i.Unit, text) + } + + parts := textFormatParts(text, r) + + switch i.Unit { + case "DAY_HOUR": + td.Days = parts[0] + td.Hours = parts[1] + case "DAY_MICROSECOND": + td.Days = parts[0] + td.Hours = parts[1] + td.Minutes = parts[2] + td.Seconds = parts[3] + td.Microseconds = parts[4] + case "DAY_MINUTE": + td.Days = parts[0] + td.Hours = parts[1] + td.Minutes = parts[2] + case "DAY_SECOND": + td.Days = parts[0] + td.Hours = parts[1] + td.Minutes = parts[2] + td.Seconds = parts[3] + case "HOUR_MICROSECOND": + td.Hours = parts[0] + td.Minutes = parts[1] + td.Seconds = parts[2] + td.Microseconds = parts[3] + case "HOUR_SECOND": + td.Hours = parts[0] + td.Minutes = parts[1] + td.Seconds = parts[2] + case "HOUR_MINUTE": + td.Hours = parts[0] + td.Minutes = parts[1] + case "MINUTE_MICROSECOND": + td.Minutes = parts[0] + td.Seconds = parts[1] + td.Microseconds = parts[2] + case "MINUTE_SECOND": + td.Minutes = parts[0] + td.Seconds = parts[1] + case "SECOND_MICROSECOND": + td.Seconds = parts[0] + td.Microseconds = parts[1] + case "YEAR_MONTH": + td.Years = parts[0] + td.Months = parts[1] + default: + return nil, errInvalidIntervalUnit.New(i.Unit) + } + } else { + val, err = sql.Int64.Convert(val) + if err != nil { + return nil, err + } + + num := val.(int64) + + switch i.Unit { + case "DAY": + td.Days = num + case "HOUR": + td.Hours = num + case "MINUTE": + td.Minutes = num + case "SECOND": + td.Seconds = num + case "MICROSECOND": + td.Microseconds = num + case "QUARTER": + td.Months = num * 3 + case "MONTH": + td.Months = num + case "WEEK": + td.Days = num * 7 + case "YEAR": + td.Years = num + default: + return nil, errInvalidIntervalUnit.New(i.Unit) + } + } + + return &td, nil +} + +// TransformUp implements the sql.Expression interface. +func (i *Interval) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + child, err := i.Child.TransformUp(f) + if err != nil { + return nil, err + } + + return NewInterval(child, i.Unit), nil +} + +func (i *Interval) String() string { + return fmt.Sprintf("INTERVAL %s %s", i.Child, i.Unit) +} + +var unitTextFormats = map[string]*regexp.Regexp{ + "DAY_HOUR": regexp.MustCompile(`^(\d+)\s+(\d+)$`), + "DAY_MICROSECOND": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+):(\d+).(\d+)$`), + "DAY_MINUTE": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+)$`), + "DAY_SECOND": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+):(\d+)$`), + "HOUR_MICROSECOND": regexp.MustCompile(`^(\d+):(\d+):(\d+).(\d+)$`), + "HOUR_SECOND": regexp.MustCompile(`^(\d+):(\d+):(\d+)$`), + "HOUR_MINUTE": regexp.MustCompile(`^(\d+):(\d+)$`), + "MINUTE_MICROSECOND": regexp.MustCompile(`^(\d+):(\d+).(\d+)$`), + "MINUTE_SECOND": regexp.MustCompile(`^(\d+):(\d+)$`), + "SECOND_MICROSECOND": regexp.MustCompile(`^(\d+).(\d+)$`), + "YEAR_MONTH": regexp.MustCompile(`^(\d+)-(\d+)$`), +} + +func textFormatParts(text string, r *regexp.Regexp) []int64 { + parts := r.FindStringSubmatch(text) + var result []int64 + for _, p := range parts[1:] { + // It is safe to igore the error here, because at this point we know + // the string matches the regexp, and that means it can't be an + // invalid number. + n, _ := strconv.ParseInt(p, 10, 64) + result = append(result, n) + } + return result +} + +// TimeDelta is the difference between a time and another time. +type TimeDelta struct { + Years int64 + Months int64 + Days int64 + Hours int64 + Minutes int64 + Seconds int64 + Microseconds int64 +} + +// Add returns the given time plus the time delta. +func (td TimeDelta) Add(t time.Time) time.Time { + return td.apply(t, 1) +} + +// Sub returns the given time minus the time delta. +func (td TimeDelta) Sub(t time.Time) time.Time { + return td.apply(t, -1) +} + +const ( + day = 24 * time.Hour + week = 7 * day +) + +func (td TimeDelta) apply(t time.Time, sign int64) time.Time { + y := int64(t.Year()) + mo := int64(t.Month()) + d := t.Day() + h := t.Hour() + min := t.Minute() + s := t.Second() + ns := t.Nanosecond() + + if td.Years != 0 { + y += td.Years * sign + } + + if td.Months != 0 { + m := mo + td.Months*sign + if m < 1 { + mo = 12 + (m % 12) + y += m/12 - 1 + } else if m > 12 { + mo = m % 12 + y += m / 12 + } else { + mo = m + } + + // Due to the operations done before, month may be zero, which means it's + // december. + if mo == 0 { + mo = 12 + } + } + + if days := daysInMonth(time.Month(mo), int(y)); days < d { + d = days + } + + date := time.Date(int(y), time.Month(mo), d, h, min, s, ns, t.Location()) + + if td.Days != 0 { + date = date.Add(time.Duration(td.Days) * day * time.Duration(sign)) + } + + if td.Hours != 0 { + date = date.Add(time.Duration(td.Hours) * time.Hour * time.Duration(sign)) + } + + if td.Minutes != 0 { + date = date.Add(time.Duration(td.Minutes) * time.Minute * time.Duration(sign)) + } + + if td.Seconds != 0 { + date = date.Add(time.Duration(td.Seconds) * time.Second * time.Duration(sign)) + } + + if td.Microseconds != 0 { + date = date.Add(time.Duration(td.Microseconds) * time.Microsecond * time.Duration(sign)) + } + + return date +} + +func daysInMonth(month time.Month, year int) int { + if month == time.December { + return 31 + } + + date := time.Date(year, month+time.Month(1), 1, 0, 0, 0, 0, time.Local) + return date.Add(-1 * day).Day() +} diff --git a/sql/expression/interval_test.go b/sql/expression/interval_test.go new file mode 100644 index 000000000..9c5442d4d --- /dev/null +++ b/sql/expression/interval_test.go @@ -0,0 +1,285 @@ +package expression + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +func TestTimeDelta(t *testing.T) { + leapYear := date(2004, time.February, 29, 0, 0, 0, 0) + testCases := []struct { + name string + delta TimeDelta + date time.Time + output time.Time + }{ + { + "leap year minus one year", + TimeDelta{Years: -1}, + leapYear, + date(2003, time.February, 28, 0, 0, 0, 0), + }, + { + "leap year plus one year", + TimeDelta{Years: 1}, + leapYear, + date(2005, time.February, 28, 0, 0, 0, 0), + }, + { + "plus overflowing months", + TimeDelta{Months: 13}, + leapYear, + date(2005, time.March, 29, 0, 0, 0, 0), + }, + { + "plus overflowing until december", + TimeDelta{Months: 22}, + leapYear, + date(2006, time.December, 29, 0, 0, 0, 0), + }, + { + "minus overflowing months", + TimeDelta{Months: -13}, + leapYear, + date(2003, time.January, 29, 0, 0, 0, 0), + }, + { + "minus overflowing until december", + TimeDelta{Months: -14}, + leapYear, + date(2002, time.December, 29, 0, 0, 0, 0), + }, + { + "minus months", + TimeDelta{Months: -1}, + leapYear, + date(2004, time.January, 29, 0, 0, 0, 0), + }, + { + "plus months", + TimeDelta{Months: 1}, + leapYear, + date(2004, time.March, 29, 0, 0, 0, 0), + }, + { + "minus days", + TimeDelta{Days: -2}, + leapYear, + date(2004, time.February, 27, 0, 0, 0, 0), + }, + { + "plus days", + TimeDelta{Days: 1}, + leapYear, + date(2004, time.March, 1, 0, 0, 0, 0), + }, + { + "minus hours", + TimeDelta{Hours: -2}, + leapYear, + date(2004, time.February, 28, 22, 0, 0, 0), + }, + { + "plus hours", + TimeDelta{Hours: 26}, + leapYear, + date(2004, time.March, 1, 2, 0, 0, 0), + }, + { + "minus minutes", + TimeDelta{Minutes: -2}, + leapYear, + date(2004, time.February, 28, 23, 58, 0, 0), + }, + { + "plus minutes", + TimeDelta{Minutes: 26}, + leapYear, + date(2004, time.February, 29, 0, 26, 0, 0), + }, + { + "minus seconds", + TimeDelta{Seconds: -2}, + leapYear, + date(2004, time.February, 28, 23, 59, 58, 0), + }, + { + "plus seconds", + TimeDelta{Seconds: 26}, + leapYear, + date(2004, time.February, 29, 0, 0, 26, 0), + }, + { + "minus microseconds", + TimeDelta{Microseconds: -2}, + leapYear, + date(2004, time.February, 28, 23, 59, 59, 999998), + }, + { + "plus microseconds", + TimeDelta{Microseconds: 26}, + leapYear, + date(2004, time.February, 29, 0, 0, 0, 26), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + result := tt.delta.Add(tt.date) + require.Equal(t, tt.output, result) + }) + } +} + +func TestIntervalEvalDelta(t *testing.T) { + testCases := []struct { + expr sql.Expression + unit string + row sql.Row + expected TimeDelta + }{ + { + NewGetField(0, sql.Int64, "foo", false), + "DAY", + sql.Row{int64(2)}, + TimeDelta{Days: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "DAY", + nil, + TimeDelta{Days: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "MONTH", + nil, + TimeDelta{Months: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "YEAR", + nil, + TimeDelta{Years: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "QUARTER", + nil, + TimeDelta{Months: 6}, + }, + { + NewLiteral(int64(2), sql.Int64), + "WEEK", + nil, + TimeDelta{Days: 14}, + }, + { + NewLiteral(int64(2), sql.Int64), + "HOUR", + nil, + TimeDelta{Hours: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "MINUTE", + nil, + TimeDelta{Minutes: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "SECOND", + nil, + TimeDelta{Seconds: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "MICROSECOND", + nil, + TimeDelta{Microseconds: 2}, + }, + { + NewLiteral("2 3", sql.Text), + "DAY_HOUR", + nil, + TimeDelta{Days: 2, Hours: 3}, + }, + { + NewLiteral("2 3:04:05.06", sql.Text), + "DAY_MICROSECOND", + nil, + TimeDelta{Days: 2, Hours: 3, Minutes: 4, Seconds: 5, Microseconds: 6}, + }, + { + NewLiteral("2 3:04:05", sql.Text), + "DAY_SECOND", + nil, + TimeDelta{Days: 2, Hours: 3, Minutes: 4, Seconds: 5}, + }, + { + NewLiteral("2 3:04", sql.Text), + "DAY_MINUTE", + nil, + TimeDelta{Days: 2, Hours: 3, Minutes: 4}, + }, + { + NewLiteral("3:04:05.06", sql.Text), + "HOUR_MICROSECOND", + nil, + TimeDelta{Hours: 3, Minutes: 4, Seconds: 5, Microseconds: 6}, + }, + { + NewLiteral("3:04:05", sql.Text), + "HOUR_SECOND", + nil, + TimeDelta{Hours: 3, Minutes: 4, Seconds: 5}, + }, + { + NewLiteral("3:04", sql.Text), + "HOUR_MINUTE", + nil, + TimeDelta{Hours: 3, Minutes: 4}, + }, + { + NewLiteral("04:05.06", sql.Text), + "MINUTE_MICROSECOND", + nil, + TimeDelta{Minutes: 4, Seconds: 5, Microseconds: 6}, + }, + { + NewLiteral("04:05", sql.Text), + "MINUTE_SECOND", + nil, + TimeDelta{Minutes: 4, Seconds: 5}, + }, + { + NewLiteral("04.05", sql.Text), + "SECOND_MICROSECOND", + nil, + TimeDelta{Seconds: 4, Microseconds: 5}, + }, + { + NewLiteral("1-5", sql.Text), + "YEAR_MONTH", + nil, + TimeDelta{Years: 1, Months: 5}, + }, + } + + for _, tt := range testCases { + interval := NewInterval(tt.expr, tt.unit) + t.Run(interval.String(), func(t *testing.T) { + require := require.New(t) + result, err := interval.EvalDelta(sql.NewEmptyContext(), tt.row) + require.NoError(err) + require.Equal(tt.expected, *result) + }) + } +} + +func date(year int, month time.Month, day, hour, min, sec, micro int) time.Time { + return time.Date(year, month, day, hour, min, sec, micro*int(time.Microsecond), time.Local) +} diff --git a/sql/parse/parse.go b/sql/parse/parse.go index f3a80275c..a222be31c 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -826,6 +826,8 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { return nil, ErrUnsupportedSubqueryExpression.New() case *sqlparser.CaseExpr: return caseExprToExpression(v) + case *sqlparser.IntervalExpr: + return intervalExprToExpression(v) } } @@ -1011,6 +1013,16 @@ func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { return nil, err } + _, lok := l.(*expression.Interval) + _, rok := r.(*expression.Interval) + if lok && be.Operator == "-" { + return nil, ErrUnsupportedSyntax.New("subtracting from an interval") + } else if (lok || rok) && be.Operator != "+" && be.Operator != "-" { + return nil, ErrUnsupportedSyntax.New("only + and - can be used to add of subtract intervals from dates") + } else if lok && rok { + return nil, ErrUnsupportedSyntax.New("intervals cannot be added or subtracted from other intervals") + } + return expression.NewArithmetic(l, r, be.Operator), nil default: @@ -1058,6 +1070,15 @@ func caseExprToExpression(e *sqlparser.CaseExpr) (sql.Expression, error) { return expression.NewCase(expr, branches, elseExpr), nil } +func intervalExprToExpression(e *sqlparser.IntervalExpr) (sql.Expression, error) { + expr, err := exprToExpression(e.Expr) + if err != nil { + return nil, err + } + + return expression.NewInterval(expr, e.Unit), nil +} + func removeComments(s string) string { r := bufio.NewReader(strings.NewReader(s)) var result []rune diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 08c35a429..23ec9e3d7 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1010,6 +1010,57 @@ var fixtures = map[string]sql.Node{ `ROLLBACK`: plan.NewRollback(), "SHOW CREATE TABLE `mytable`": plan.NewShowCreateTable("", nil, "mytable"), "SHOW CREATE TABLE `mydb`.`mytable`": plan.NewShowCreateTable("mydb", nil, "mytable"), + `SELECT '2018-05-01' + INTERVAL 1 DAY`: plan.NewProject( + []sql.Expression{expression.NewArithmetic( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + "+", + )}, + plan.NewUnresolvedTable("dual", ""), + ), + `SELECT '2018-05-01' - INTERVAL 1 DAY`: plan.NewProject( + []sql.Expression{expression.NewArithmetic( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + "-", + )}, + plan.NewUnresolvedTable("dual", ""), + ), + `SELECT INTERVAL 1 DAY + '2018-05-01'`: plan.NewProject( + []sql.Expression{expression.NewArithmetic( + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + expression.NewLiteral("2018-05-01", sql.Text), + "+", + )}, + plan.NewUnresolvedTable("dual", ""), + ), + `SELECT '2018-05-01' + INTERVAL 1 DAY + INTERVAL 1 DAY`: plan.NewProject( + []sql.Expression{expression.NewArithmetic( + expression.NewArithmetic( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + "+", + ), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + "+", + )}, + plan.NewUnresolvedTable("dual", ""), + ), } func TestParse(t *testing.T) { @@ -1035,6 +1086,12 @@ var fixturesErrors = map[string]*errors.Kind{ JOIN commit_files JOIN refs `: ErrUnsupportedSyntax, + `SELECT INTERVAL 1 DAY - '2018-05-01'`: ErrUnsupportedSyntax, + `SELECT INTERVAL 1 DAY * '2018-05-01'`: ErrUnsupportedSyntax, + `SELECT '2018-05-01' * INTERVAL 1 DAY`: ErrUnsupportedSyntax, + `SELECT '2018-05-01' / INTERVAL 1 DAY`: ErrUnsupportedSyntax, + `SELECT INTERVAL 1 DAY + INTERVAL 1 DAY`: ErrUnsupportedSyntax, + `SELECT '2018-05-01' + (INTERVAL 1 DAY + INTERVAL 1 DAY)`: ErrUnsupportedSyntax, } func TestParseErrors(t *testing.T) {