Skip to content

Commit

Permalink
planner, executor: support range framed window functions (#9450)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored and zz-jason committed Feb 27, 2019
1 parent 56a79ef commit a59a5f4
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 49 deletions.
15 changes: 14 additions & 1 deletion executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1918,13 +1918,26 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec
windowFunc: agg,
partialResult: agg.AllocPartialResult(),
}
} else {
} else if v.Frame.Type == ast.Rows {
processor = &rowFrameWindowProcessor{
windowFunc: agg,
partialResult: agg.AllocPartialResult(),
start: v.Frame.Start,
end: v.Frame.End,
}
} else {
cmpResult := int64(-1)
if v.OrderBy[0].Desc {
cmpResult = 1
}
processor = &rangeFrameWindowProcessor{
windowFunc: agg,
partialResult: agg.AllocPartialResult(),
start: v.Frame.Start,
end: v.Frame.End,
col: v.OrderBy[0].Col,
expectedCmpResult: cmpResult,
}
}
return &WindowExec{baseExecutor: base,
processor: processor,
Expand Down
95 changes: 95 additions & 0 deletions executor/window.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -305,3 +306,97 @@ func (p *rowFrameWindowProcessor) resetPartialResult() {
p.windowFunc.ResetPartialResult(p.partialResult)
p.curRowIdx = 0
}

type rangeFrameWindowProcessor struct {
windowFunc aggfuncs.AggFunc
partialResult aggfuncs.PartialResult
start *core.FrameBound
end *core.FrameBound
curRowIdx uint64
lastStartOffset uint64
lastEndOffset uint64
col *expression.Column
// expectedCmpResult is used to decide if one value is included in the frame.
expectedCmpResult int64
}

func (p *rangeFrameWindowProcessor) getStartOffset(ctx sessionctx.Context, rows []chunk.Row) (uint64, error) {
if p.start.UnBounded {
return 0, nil
}
numRows := uint64(len(rows))
for ; p.lastStartOffset < numRows; p.lastStartOffset++ {
res, _, err := p.start.CmpFunc(ctx, p.col, p.start.CalcFunc, rows[p.lastStartOffset], rows[p.curRowIdx])
if err != nil {
return 0, err
}
// For asc, break when the current value is greater or equal to the calculated result;
// For desc, break when the current value is less or equal to the calculated result.
if res != p.expectedCmpResult {
break
}
}
return p.lastStartOffset, nil
}

func (p *rangeFrameWindowProcessor) getEndOffset(ctx sessionctx.Context, rows []chunk.Row) (uint64, error) {
numRows := uint64(len(rows))
if p.end.UnBounded {
return numRows, nil
}
for ; p.lastEndOffset < numRows; p.lastEndOffset++ {
res, _, err := p.end.CmpFunc(ctx, p.end.CalcFunc, p.col, rows[p.curRowIdx], rows[p.lastEndOffset])
if err != nil {
return 0, err
}
// For asc, break when the calculated result is greater than the current value.
// For desc, break when the calculated result is less than the current value.
if res == p.expectedCmpResult {
break
}
}
return p.lastEndOffset, nil
}

func (p *rangeFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, rows []chunk.Row, chk *chunk.Chunk, remained int) ([]chunk.Row, error) {
for remained > 0 {
start, err := p.getStartOffset(ctx, rows)
if err != nil {
return nil, err
}
end, err := p.getEndOffset(ctx, rows)
if err != nil {
return nil, err
}
p.curRowIdx++
remained--
if start >= end {
err := p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk)
if err != nil {
return nil, err
}
continue
}
err = p.windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResult)
if err != nil {
return nil, err
}
err = p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk)
if err != nil {
return nil, err
}
p.windowFunc.ResetPartialResult(p.partialResult)
}
return rows, nil
}

func (p *rangeFrameWindowProcessor) consumeGroupRows(ctx sessionctx.Context, rows []chunk.Row) ([]chunk.Row, error) {
return rows, nil
}

func (p *rangeFrameWindowProcessor) resetPartialResult() {
p.windowFunc.ResetPartialResult(p.partialResult)
p.curRowIdx = 0
p.lastStartOffset = 0
p.lastEndOffset = 0
}
12 changes: 12 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,16 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("1 5", "4 7", "2 6"))
result = tk.MustQuery("select a, sum(a) over(rows between unbounded preceding and 1 preceding) from t")
result.Check(testkit.Rows("1 <nil>", "4 1", "2 5"))

tk.MustExec("drop table t")
tk.MustExec("create table t(a int, b date)")
tk.MustExec("insert into t values (null,null),(1,20190201),(2,20190202),(3,20190203),(5,20190205)")
result = tk.MustQuery("select a, sum(a) over(order by a range between 1 preceding and 2 following) from t")
result.Check(testkit.Rows("<nil> <nil>", "1 6", "2 6", "3 10", "5 5"))
result = tk.MustQuery("select a, sum(a) over(order by a desc range between 1 preceding and 2 following) from t")
result.Check(testkit.Rows("5 8", "3 6", "2 6", "1 3", "<nil> <nil>"))
result = tk.MustQuery("select a, b, sum(a) over(order by b range between interval 1 day preceding and interval 2 day following) from t")
result.Check(testkit.Rows("<nil> <nil> <nil>", "1 2019-02-01 6", "2 2019-02-02 6", "3 2019-02-03 10", "5 2019-02-05 5"))
result = tk.MustQuery("select a, b, sum(a) over(order by b desc range between interval 1 day preceding and interval 2 day following) from t")
result.Check(testkit.Rows("5 2019-02-05 8", "3 2019-02-03 6", "2 2019-02-02 6", "1 2019-02-01 3", "<nil> <nil> <nil>"))
}
21 changes: 21 additions & 0 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,27 @@ func GetAccurateCmpType(lhs, rhs Expression) types.EvalType {
return cmpType
}

// GetCmpFunction get the compare function according to two arguments.
func GetCmpFunction(lhs, rhs Expression) CompareFunc {
switch GetAccurateCmpType(lhs, rhs) {
case types.ETInt:
return CompareInt
case types.ETReal:
return CompareReal
case types.ETDecimal:
return CompareDecimal
case types.ETString:
return CompareString
case types.ETDuration:
return CompareDuration
case types.ETDatetime, types.ETTimestamp:
return CompareTime
case types.ETJson:
return CompareJSON
}
return nil
}

// isTemporalColumn checks if a expression is a temporal column,
// temporal column indicates time column or duration column.
func isTemporalColumn(expr Expression) bool {
Expand Down
17 changes: 1 addition & 16 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,7 @@ func (p *LogicalJoin) getEnforcedMergeJoin(prop *property.PhysicalProperty) []Ph
func (p *PhysicalMergeJoin) initCompareFuncs() {
p.CompareFuncs = make([]expression.CompareFunc, 0, len(p.LeftKeys))
for i := range p.LeftKeys {
switch expression.GetAccurateCmpType(p.LeftKeys[i], p.RightKeys[i]) {
case types.ETInt:
p.CompareFuncs = append(p.CompareFuncs, expression.CompareInt)
case types.ETReal:
p.CompareFuncs = append(p.CompareFuncs, expression.CompareReal)
case types.ETDecimal:
p.CompareFuncs = append(p.CompareFuncs, expression.CompareDecimal)
case types.ETString:
p.CompareFuncs = append(p.CompareFuncs, expression.CompareString)
case types.ETDuration:
p.CompareFuncs = append(p.CompareFuncs, expression.CompareDuration)
case types.ETDatetime, types.ETTimestamp:
p.CompareFuncs = append(p.CompareFuncs, expression.CompareTime)
case types.ETJson:
p.CompareFuncs = append(p.CompareFuncs, expression.CompareJSON)
}
p.CompareFuncs = append(p.CompareFuncs, expression.GetCmpFunction(p.LeftKeys[i], p.RightKeys[i]))
}
}

Expand Down
14 changes: 10 additions & 4 deletions planner/core/explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,16 @@ func (p *PhysicalWindow) formatFrameBound(buffer *bytes.Buffer, bound *FrameBoun
}
if bound.UnBounded {
buffer.WriteString("unbounded")
} else if bound.DateCalcFunc != nil {
sf := bound.DateCalcFunc.(*expression.ScalarFunction)
// for `interval '2:30' minute_second`.
fmt.Fprintf(buffer, "interval %s %s", sf.GetArgs()[1].ExplainInfo(), sf.GetArgs()[2].ExplainInfo())
} else if bound.CalcFunc != nil {
sf := bound.CalcFunc.(*expression.ScalarFunction)
switch sf.FuncName.L {
case ast.DateAdd, ast.DateSub:
// For `interval '2:30' minute_second`.
fmt.Fprintf(buffer, "interval %s %s", sf.GetArgs()[1].ExplainInfo(), sf.GetArgs()[2].ExplainInfo())
case ast.Plus, ast.Minus:
// For `1 preceding` of range frame.
fmt.Fprintf(buffer, "%s", sf.GetArgs()[1].ExplainInfo())
}
} else {
fmt.Fprintf(buffer, "%d", bound.Num)
}
Expand Down
77 changes: 52 additions & 25 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2785,11 +2785,14 @@ func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFu
func (b *PlanBuilder) buildWindowFunctionFrameBound(spec *ast.WindowSpec, orderByItems []property.Item, boundClause *ast.FrameBound) (*FrameBound, error) {
frameType := spec.Frame.Type
bound := &FrameBound{Type: boundClause.Type, UnBounded: boundClause.UnBounded}
if bound.UnBounded || boundClause.Type == ast.CurrentRow {
if bound.UnBounded {
return bound, nil
}

if frameType == ast.Rows {
if bound.Type == ast.CurrentRow {
return bound, nil
}
// Rows type does not support interval range.
if boundClause.Unit != nil {
return nil, ErrWindowRowsIntervalUse.GenWithStackByArgs(spec.Name)
Expand All @@ -2805,50 +2808,74 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(spec *ast.WindowSpec, orderB
if len(orderByItems) != 1 {
return nil, ErrWindowRangeFrameOrderType.GenWithStackByArgs(spec.Name)
}

if bound.Type == ast.CurrentRow {
bound.CalcFunc = orderByItems[0].Col
bound.CmpFunc = expression.GetCmpFunction(orderByItems[0].Col, orderByItems[0].Col)
return bound, nil
}
col := orderByItems[0].Col
isNumeric, isTemporal := types.IsTypeNumeric(col.RetType.Tp), types.IsTypeTemporal(col.RetType.Tp)
if !isNumeric && !isTemporal {
return nil, ErrWindowRangeFrameOrderType.GenWithStackByArgs(spec.Name)
}
if boundClause.Unit != nil {
// Interval bounds only support order by temporal types.
if isNumeric {
return nil, ErrWindowRangeFrameNumericType.GenWithStackByArgs(spec.Name)
}
// Interval bounds only support order by temporal types.
if boundClause.Unit != nil && isNumeric {
return nil, ErrWindowRangeFrameNumericType.GenWithStackByArgs(spec.Name)
}
// Non-interval bound only support order by numeric types.
if boundClause.Unit == nil && !isNumeric {
return nil, ErrWindowRangeFrameTemporalType.GenWithStackByArgs(spec.Name)
}

// TODO: We also need to raise error for non-deterministic expressions, like rand().
val, err := evalAstExpr(b.ctx, boundClause.Expr)
if err != nil {
return nil, ErrWindowRangeBoundNotConstant.GenWithStackByArgs(spec.Name)
}
expr := expression.Constant{Value: val, RetType: boundClause.Expr.GetType()}
uVal, isNull, err := expr.EvalInt(b.ctx, chunk.Row{})
if uVal < 0 || isNull || err != nil {
return nil, ErrWindowFrameIllegal.GenWithStackByArgs(spec.Name)
}
// TODO: We also need to raise error for non-deterministic expressions, like rand().
val, err := evalAstExpr(b.ctx, boundClause.Expr)
if err != nil {
return nil, ErrWindowRangeBoundNotConstant.GenWithStackByArgs(spec.Name)
}
expr := expression.Constant{Value: val, RetType: boundClause.Expr.GetType()}

// Do not raise warnings for truncate.
oriIgnoreTruncate := b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate
b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate = true
uVal, isNull, err := expr.EvalInt(b.ctx, chunk.Row{})
b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate = oriIgnoreTruncate
if uVal < 0 || isNull || err != nil {
return nil, ErrWindowFrameIllegal.GenWithStackByArgs(spec.Name)
}

desc := orderByItems[0].Desc
if boundClause.Unit != nil {
// It can be guaranteed by the parser.
unitVal := boundClause.Unit.(*driver.ValueExpr)
unit := expression.Constant{Value: unitVal.Datum, RetType: unitVal.GetType()}

// When the order is asc:
// `+` for following, and `-` for the preceding
// When the order is desc, `+` becomes `-` and vice-versa.
funcName := ast.DateAdd
if bound.Type == ast.Preceding {
if (!desc && bound.Type == ast.Preceding) || (desc && bound.Type == ast.Following) {
funcName = ast.DateSub
}
bound.DateCalcFunc, err = expression.NewFunctionBase(b.ctx, funcName, col.RetType, col, &expr, &unit)
bound.CalcFunc, err = expression.NewFunctionBase(b.ctx, funcName, col.RetType, col, &expr, &unit)
if err != nil {
return nil, err
}
bound.CmpFunc = expression.GetCmpFunction(orderByItems[0].Col, bound.CalcFunc)
return bound, nil
}
// Non-interval bound only support order by numeric types.
if isTemporal {
return nil, ErrWindowRangeFrameTemporalType.GenWithStackByArgs(spec.Name)
// When the order is asc:
// `+` for following, and `-` for the preceding
// When the order is desc, `+` becomes `-` and vice-versa.
funcName := ast.Plus
if (!desc && bound.Type == ast.Preceding) || (desc && bound.Type == ast.Following) {
funcName = ast.Minus
}
num, isNull, isExpectedType := getUintFromNode(b.ctx, boundClause.Expr)
if isNull || !isExpectedType {
return nil, ErrWindowFrameIllegal.GenWithStackByArgs(spec.Name)
bound.CalcFunc, err = expression.NewFunctionBase(b.ctx, funcName, col.RetType, col, &expr)
if err != nil {
return nil, err
}
bound.Num = num
bound.CmpFunc = expression.GetCmpFunction(orderByItems[0].Col, bound.CalcFunc)
return bound, nil
}

Expand Down
2 changes: 1 addition & 1 deletion planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2189,7 +2189,7 @@ func (s *testPlanSuite) TestWindowFunction(c *C) {
},
{
sql: "select sum(a) over(order by a range between 1.0 preceding and 1 following) from t",
result: "[planner:3586]Window '<unnamed window>': frame start or end is negative, NULL or of non-integral type",
result: "TableReader(Table(t))->Window(sum(cast(test.t.a)) over(order by test.t.a asc range between 1.0 preceding and 1 following))->Projection",
},
}

Expand Down
8 changes: 6 additions & 2 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,12 @@ type FrameBound struct {
Type ast.BoundType
UnBounded bool
Num uint64
// For `INTERVAL '2:30' MINUTE_SECOND FOLLOWING`, we will build the date_add or date_sub functions.
DateCalcFunc expression.Expression
// CalcFunc is used for range framed windows.
// We will build the date_add or date_sub functions for frames like `INTERVAL '2:30' MINUTE_SECOND FOLLOWING`,
// and plus or minus for frames like `1 preceding`.
CalcFunc expression.Expression
// CmpFunc is used to decide whether one row is included in the current frame.
CmpFunc expression.CompareFunc
}

// LogicalWindow represents a logical window function plan.
Expand Down
14 changes: 14 additions & 0 deletions planner/core/resolve_indices.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,20 @@ func (p *PhysicalWindow) ResolveIndices() (err error) {
return err
}
}
if p.Frame != nil {
if p.Frame.Start.CalcFunc != nil {
p.Frame.Start.CalcFunc, err = p.Frame.Start.CalcFunc.ResolveIndices(p.children[0].Schema())
if err != nil {
return err
}
}
if p.Frame.End.CalcFunc != nil {
p.Frame.End.CalcFunc, err = p.Frame.End.CalcFunc.ResolveIndices(p.children[0].Schema())
if err != nil {
return err
}
}
}
return nil
}

Expand Down

0 comments on commit a59a5f4

Please sign in to comment.