Skip to content

Commit

Permalink
Improve and Fix Distinct Aggregation planner (#13466)
Browse files Browse the repository at this point in the history
Co-authored-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
harshit-gangal and systay authored Jul 13, 2023
1 parent 4dd8022 commit 6d4d00a
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 46 deletions.
30 changes: 30 additions & 0 deletions go/test/endtoend/utils/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,33 @@ func (mcmp *MySQLCompare) ExecAndIgnore(query string) (*sqltypes.Result, error)
_, _ = mcmp.MySQLConn.ExecuteFetch(query, 1000, true)
return mcmp.VtConn.ExecuteFetch(query, 1000, true)
}

func (mcmp *MySQLCompare) Run(query string, f func(mcmp *MySQLCompare)) {
mcmp.t.Run(query, func(t *testing.T) {
inner := &MySQLCompare{
t: t,
MySQLConn: mcmp.MySQLConn,
VtConn: mcmp.VtConn,
}
f(inner)
})
}

// ExecAllowError executes the query against both Vitess and MySQL.
// If there is no error, it compares the result
// Return any Vitess execution error without comparing the results.
func (mcmp *MySQLCompare) ExecAllowError(query string) (*sqltypes.Result, error) {
mcmp.t.Helper()
vtQr, vtErr := mcmp.VtConn.ExecuteFetch(query, 1000, true)
if vtErr != nil {
return nil, vtErr
}
mysqlQr, mysqlErr := mcmp.MySQLConn.ExecuteFetch(query, 1000, true)

// Since we allow errors, we don't want to compare results if one of the client failed.
// Vitess and MySQL should always be agreeing whether the query returns an error or not.
if mysqlErr == nil {
vtErr = compareVitessAndMySQLResults(mcmp.t, query, mcmp.VtConn, vtQr, mysqlQr, false)
}
return vtQr, vtErr
}
34 changes: 34 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,37 @@ func compareRow(t *testing.T, mRes *sqltypes.Result, vtRes *sqltypes.Result, grp
require.True(t, foundKey, "mysql and vitess result does not same row: vitess:%v, mysql:%v", vtRes.Rows, mRes.Rows)
}
}

func TestDistinctAggregation(t *testing.T) {
mcmp, closer := start(t)
defer closer()
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1','tata',893), (7,'a1','titi',2380), (8,'b1','tete',12833), (9,'e1','yoyo',783493)")

tcases := []struct {
query string
expectedErr string
}{{
query: `SELECT /*vt+ PLANNER=gen4 */ COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`,
expectedErr: "VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct shardkey) (errno 1235) (sqlstate 42000)",
}, {
query: `SELECT /*vt+ PLANNER=gen4 */ a.t1_id, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.t1_id`,
}, {
query: `SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.value`,
}, {
query: `SELECT /*vt+ PLANNER=gen4 */ count(distinct a.value), SUM(DISTINCT b.t1_id) FROM t1 a, t1 b`,
expectedErr: "VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct b.t1_id) (errno 1235) (sqlstate 42000)",
}, {
query: `SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.t1_id), min(DISTINCT a.t1_id) FROM t1 a, t1 b group by a.value`,
}}

for _, tc := range tcases {
mcmp.Run(tc.query, func(mcmp *utils.MySQLCompare) {
_, err := mcmp.ExecAllowError(tc.query)
if tc.expectedErr == "" {
require.NoError(t, err)
return
}
require.ErrorContains(t, err, tc.expectedErr)
})
}
}
8 changes: 8 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2869,6 +2869,7 @@ type (

DistinctableAggr interface {
IsDistinct() bool
SetDistinct(bool)
}

Count struct {
Expand Down Expand Up @@ -3381,6 +3382,13 @@ func (avg *Avg) IsDistinct() bool { return avg.Distinct }
func (count *Count) IsDistinct() bool { return count.Distinct }
func (grpConcat *GroupConcatExpr) IsDistinct() bool { return grpConcat.Distinct }

func (sum *Sum) SetDistinct(distinct bool) { sum.Distinct = distinct }
func (min *Min) SetDistinct(distinct bool) { min.Distinct = distinct }
func (max *Max) SetDistinct(distinct bool) { max.Distinct = distinct }
func (avg *Avg) SetDistinct(distinct bool) { avg.Distinct = distinct }
func (count *Count) SetDistinct(distinct bool) { count.Distinct = distinct }
func (grpConcat *GroupConcatExpr) SetDistinct(distinct bool) { grpConcat.Distinct = distinct }

func (*Sum) AggrName() string { return "sum" }
func (*Min) AggrName() string { return "min" }
func (*Max) AggrName() string { return "max" }
Expand Down
14 changes: 14 additions & 0 deletions go/vt/sqlparser/ast_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ func (er *astRewriter) rewriteUp(cursor *Cursor) bool {
er.rewriteShowBasic(node)
case *ExistsExpr:
er.existsRewrite(cursor, node)
case DistinctableAggr:
er.rewriteDistinctableAggr(cursor, node)
}
return true
}
Expand Down Expand Up @@ -683,6 +685,18 @@ func (er *astRewriter) existsRewrite(cursor *Cursor, node *ExistsExpr) {
sel.GroupBy = nil
}

// rewriteDistinctableAggr removed Distinct from Max and Min Aggregations as it does not impact the result. But, makes the plan simpler.
func (er *astRewriter) rewriteDistinctableAggr(cursor *Cursor, node DistinctableAggr) {
if !node.IsDistinct() {
return
}
switch aggr := node.(type) {
case *Max, *Min:
aggr.SetDistinct(false)
er.bindVars.NoteRewrite()
}
}

func bindVarExpression(name string) Expr {
return NewArgument(name)
}
Expand Down
3 changes: 3 additions & 0 deletions go/vt/sqlparser/ast_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ func TestRewrites(in *testing.T) {
}, {
in: "SELECT id, name, salary FROM user_details",
expected: "SELECT id, name, salary FROM (select user.id, user.name, user_extra.salary from user join user_extra where user.id = user_extra.user_id) as user_details",
}, {
in: "select max(distinct c1), min(distinct c2), avg(distinct c3), sum(distinct c4), count(distinct c5), group_concat(distinct c6) from tbl",
expected: "select max(c1) as `max(distinct c1)`, min(c2) as `min(distinct c2)`, avg(distinct c3), sum(distinct c4), count(distinct c5), group_concat(distinct c6) from tbl",
}, {
in: "SHOW VARIABLES",
expected: "SHOW VARIABLES",
Expand Down
88 changes: 86 additions & 2 deletions go/vt/vtgate/engine/scalar_aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/utils"
. "vitess.io/vitess/go/vt/vtgate/engine/opcode"
)

Expand Down Expand Up @@ -255,6 +254,91 @@ func TestScalarGroupConcatWithAggrOnEngine(t *testing.T) {
}
}

// TestScalarDistinctAggr tests distinct aggregation on engine.
func TestScalarDistinctAggrOnEngine(t *testing.T) {
fields := sqltypes.MakeTestFields(
"value|value",
"int64|int64",
)

fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"100|100",
"200|200",
"200|200",
"400|400",
"400|400",
"600|600",
)}}

oa := &ScalarAggregate{
Aggregates: []*AggregateParams{
NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)"),
NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct value)"),
},
Input: fp,
}
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
require.NoError(t, err)
require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", qr.Rows))

fp.rewind()
results := &sqltypes.Result{}
err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
if qr.Fields != nil {
results.Fields = qr.Fields
}
results.Rows = append(results.Rows, qr.Rows...)
return nil
})
require.NoError(t, err)
require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", results.Rows))
}

func TestScalarDistinctPushedDown(t *testing.T) {
fields := sqltypes.MakeTestFields(
"count(distinct value)|sum(distinct value)",
"int64|decimal",
)

fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"2|200",
"6|400",
"3|700",
"1|10",
"7|30",
"8|90",
)}}

countAggr := NewAggregateParam(AggregateSum, 0, "count(distinct value)")
countAggr.OrigOpcode = AggregateCountDistinct
sumAggr := NewAggregateParam(AggregateSum, 1, "sum(distinct value)")
sumAggr.OrigOpcode = AggregateSumDistinct
oa := &ScalarAggregate{
Aggregates: []*AggregateParams{
countAggr,
sumAggr,
},
Input: fp,
}
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
require.NoError(t, err)
require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", qr.Rows))

fp.rewind()
results := &sqltypes.Result{}
err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
if qr.Fields != nil {
results.Fields = qr.Fields
}
results.Rows = append(results.Rows, qr.Rows...)
return nil
})
require.NoError(t, err)
require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", results.Rows))
}

// TestScalarGroupConcat tests group_concat with partial aggregation on engine.
func TestScalarGroupConcat(t *testing.T) {
fields := sqltypes.MakeTestFields(
Expand Down
92 changes: 70 additions & 22 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func (a *Aggregator) aggregateTheAggregates() {
func aggregateTheAggregate(a *Aggregator, i int) {
aggr := a.Aggregations[i]
switch aggr.OpCode {
case opcode.AggregateCount, opcode.AggregateCountStar, opcode.AggregateCountDistinct:
// All count variations turn into SUM above the Route.
case opcode.AggregateCount, opcode.AggregateCountStar, opcode.AggregateCountDistinct, opcode.AggregateSumDistinct:
// All count variations turn into SUM above the Route. This is also applied for Sum distinct when it is pushed down.
// Think of it as we are SUMming together a bunch of distributed COUNTs.
aggr.OriginalOpCode, aggr.OpCode = aggr.OpCode, opcode.AggregateSum
a.Aggregations[i] = aggr
Expand Down Expand Up @@ -115,37 +115,72 @@ func pushDownAggregationThroughRoute(

// pushDownAggregations splits aggregations between the original aggregator and the one we are pushing down
func pushDownAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator, aggrBelowRoute *Aggregator) error {
for i, aggregation := range aggregator.Aggregations {
if !aggregation.Distinct || exprHasUniqueVindex(ctx, aggregation.Func.GetArg()) {
aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggregation)
canPushDownDistinctAggr, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator)
if err != nil {
return err
}

distinctAggrGroupByAdded := false

for i, aggr := range aggregator.Aggregations {
if !aggr.Distinct || canPushDownDistinctAggr {
aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggr)
aggregateTheAggregate(aggregator, i)
continue
}
innerExpr := aggregation.Func.GetArg()

if aggregator.DistinctExpr != nil {
if ctx.SemTable.EqualsExpr(aggregator.DistinctExpr, innerExpr) {
// we can handle multiple distinct aggregations, as long as they are aggregating on the same expression
aggrBelowRoute.Columns[aggregation.ColOffset] = aeWrap(innerExpr)
continue
}
return vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(aggregation.Original)))
}

// We handle a distinct aggregation by turning it into a group by and
// doing the aggregating on the vtgate level instead
aggregator.DistinctExpr = innerExpr
aeDistinctExpr := aeWrap(aggregator.DistinctExpr)
aeDistinctExpr := aeWrap(distinctExpr)
aggrBelowRoute.Columns[aggr.ColOffset] = aeDistinctExpr

aggrBelowRoute.Columns[aggregation.ColOffset] = aeDistinctExpr
// We handle a distinct aggregation by turning it into a group by and
// doing the aggregating on the vtgate level instead
// Adding to group by can be done only once even though there are multiple distinct aggregation with same expression.
if !distinctAggrGroupByAdded {
groupBy := NewGroupBy(distinctExpr, distinctExpr, aeDistinctExpr)
groupBy.ColOffset = aggr.ColOffset
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)
distinctAggrGroupByAdded = true
}
}

groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr)
groupBy.ColOffset = aggregation.ColOffset
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)
if !canPushDownDistinctAggr {
aggregator.DistinctExpr = distinctExpr
}

return nil
}

func checkIfWeCanPushDown(ctx *plancontext.PlanningContext, aggregator *Aggregator) (bool, sqlparser.Expr, error) {
canPushDown := true
var distinctExpr sqlparser.Expr
var differentExpr *sqlparser.AliasedExpr

for _, aggr := range aggregator.Aggregations {
if !aggr.Distinct {
continue
}

innerExpr := aggr.Func.GetArg()
if !exprHasUniqueVindex(ctx, innerExpr) {
canPushDown = false
}
if distinctExpr == nil {
distinctExpr = innerExpr
}
if !ctx.SemTable.EqualsExpr(distinctExpr, innerExpr) {
differentExpr = aggr.Original
}
}

if !canPushDown && differentExpr != nil {
return false, nil, vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(differentExpr)))
}

return canPushDown, distinctExpr, nil
}

func pushDownAggregationThroughFilter(
ctx *plancontext.PlanningContext,
aggregator *Aggregator,
Expand Down Expand Up @@ -411,6 +446,18 @@ func splitAggrColumnsToLeftAndRight(
outerJoin: join.LeftJoin,
}

canPushDownDistinctAggr, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator)
if err != nil {
return nil, nil, err
}

// Distinct aggregation cannot be pushed down in the join.
// We keep node of the distinct aggregation expression to be used later for ordering.
if !canPushDownDistinctAggr {
aggregator.DistinctExpr = distinctExpr
return nil, nil, errAbortAggrPushing
}

outer:
// we prefer adding the aggregations in the same order as the columns are declared
for colIdx, col := range aggregator.Columns {
Expand Down Expand Up @@ -509,7 +556,8 @@ func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) er
// this is only used for SHOW GTID queries that will never contain joins
return vterrors.VT13001("cannot do join with vgtid")
case opcode.AggregateSumDistinct, opcode.AggregateCountDistinct:
return errAbortAggrPushing
// we are not going to see values multiple times, so we don't need to multiply with the count(*) from the other side
return ab.handlePushThroughAggregation(ctx, aggr)
default:
return errHorizonNotPlanned()
}
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ type (
Grouping []GroupBy
Aggregations []Aggr

// We support a single distinct aggregation per aggregator. It is stored here
// We support a single distinct aggregation per aggregator. It is stored here.
// When planning the ordering that the OrderedAggregate will require,
// this needs to be the last ORDER BY expression
DistinctExpr sqlparser.Expr

// Pushed will be set to true once this aggregation has been pushed deeper in the tree
Expand Down
Loading

0 comments on commit 6d4d00a

Please sign in to comment.