Skip to content

Commit

Permalink
[fix](nereids)support group_concat with distinct and order by (#38871)
Browse files Browse the repository at this point in the history
## Proposed changes

pick from master #38080

<!--Describe your changes.-->
  • Loading branch information
starocean999 authored Aug 5, 2024
1 parent bf1c7a1 commit 40567b5
Show file tree
Hide file tree
Showing 22 changed files with 327 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,9 @@ public Expr visitIsNull(IsNull isNull, PlanTranslatorContext context) {

@Override
public Expr visitStateCombinator(StateCombinator combinator, PlanTranslatorContext context) {
List<Expr> arguments = combinator.getArguments().stream().map(arg -> arg.accept(this, context))
List<Expr> arguments = combinator.getArguments().stream().map(arg -> arg instanceof OrderExpression
? translateOrderExpression((OrderExpression) arg, context).getExpr()
: arg.accept(this, context))
.collect(Collectors.toList());
return Function.convertToStateCombinator(
new FunctionCallExpr(visitAggregateFunction(combinator.getNestedFunction(), context).getFn(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@
import org.apache.doris.nereids.trees.expressions.WindowFrame;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRange;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRangeDayUnit;
Expand Down Expand Up @@ -2096,11 +2095,10 @@ public Expression visitFunctionCallExpression(DorisParser.FunctionCallExpression
return ParserUtils.withOrigin(ctx, () -> {
String functionName = ctx.functionIdentifier().functionNameIdentifier().getText();
boolean isDistinct = ctx.DISTINCT() != null;
List<Expression> params = visit(ctx.expression(), Expression.class);
List<Expression> params = Lists.newArrayList();
params.addAll(visit(ctx.expression(), Expression.class));
List<OrderKey> orderKeys = visit(ctx.sortItem(), OrderKey.class);
if (!orderKeys.isEmpty()) {
return parseFunctionWithOrderKeys(functionName, isDistinct, params, orderKeys, ctx);
}
params.addAll(orderKeys.stream().map(OrderExpression::new).collect(Collectors.toList()));

List<UnboundStar> unboundStars = ExpressionUtils.collectAll(params, UnboundStar.class::isInstance);
if (!unboundStars.isEmpty()) {
Expand Down Expand Up @@ -3471,23 +3469,6 @@ public StructField visitComplexColType(ComplexColTypeContext ctx) {
return new StructField(ctx.identifier().getText(), typedVisit(ctx.dataType()), true, comment);
}

private Expression parseFunctionWithOrderKeys(String functionName, boolean isDistinct,
List<Expression> params, List<OrderKey> orderKeys, ParserRuleContext ctx) {
if (functionName.equalsIgnoreCase("group_concat")) {
OrderExpression[] orderExpressions = orderKeys.stream()
.map(OrderExpression::new)
.toArray(OrderExpression[]::new);
if (params.size() == 1) {
return new GroupConcat(isDistinct, params.get(0), orderExpressions);
} else if (params.size() == 2) {
return new GroupConcat(isDistinct, params.get(0), params.get(1), orderExpressions);
} else {
throw new ParseException("group_concat requires one or two parameters: " + params, ctx);
}
}
throw new ParseException("Unsupported function with order expressions" + ctx.getText(), ctx);
}

private String parseConstant(ConstantContext context) {
Object constant = visit(context);
if (constant instanceof Literal && ((Literal) constant).isStringLikeLiteral()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,15 @@ public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan>
distinctChildColumns, ShuffleType.REQUIRE);
if ((!groupByColumns.isEmpty() && distributionSpecHash.satisfy(groupByRequire))
|| (groupByColumns.isEmpty() && distributionSpecHash.satisfy(distinctChildRequire))) {
return false;
if (!agg.mustUseMultiDistinctAgg()) {
return false;
}
}
}
// if distinct without group by key, we prefer three or four stage distinct agg
// because the second phase of multi-distinct only have one instance, and it is slow generally.
if (agg.getOutputExpressions().size() == 1 && agg.getGroupByExpressions().isEmpty()) {
if (agg.getOutputExpressions().size() == 1 && agg.getGroupByExpressions().isEmpty()
&& !agg.mustUseMultiDistinctAgg()) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
Expand Down Expand Up @@ -148,7 +149,7 @@ private void checkAggregate(LogicalAggregate<? extends Plan> aggregate) {
continue;
}
for (int i = 1; i < func.arity(); i++) {
if (!func.child(i).getInputSlots().isEmpty()) {
if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) {
// think about group_concat(distinct col_1, ',')
distinctMultiColumns = true;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ public Expression visitUnboundFunction(UnboundFunction unboundFunction, Expressi
}

Pair<? extends Expression, ? extends BoundFunction> buildResult = builder.build(functionName, arguments);
buildResult.second.checkOrderExprIsValid();
Optional<SqlCacheContext> sqlCacheContext = Optional.empty();
if (wantToParseSqlFromSqlCache) {
StatementContext statementContext = context.cascadesContext.getStatementContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
Expand All @@ -54,6 +56,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* normalize aggregate's group keys and AggregateFunction's child to SlotReference
Expand Down Expand Up @@ -170,6 +173,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi
// should not push down literal under aggregate
// e.g. group_concat(distinct xxx, ','), the ',' literal show stay in aggregate
.filter(arg -> !(arg instanceof Literal))
.flatMap(arg -> arg instanceof OrderExpression ? arg.getInputSlots().stream() : Stream.of(arg))
.collect(
Collectors.groupingBy(
child -> !(child instanceof SlotReference),
Expand Down Expand Up @@ -255,7 +259,15 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi
normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs)
);
// create new agg node
ImmutableList<NamedExpression> normalizedAggOutput = normalizedAggOutputBuilder.build();
ImmutableList<NamedExpression> aggOutput = normalizedAggOutputBuilder.build();
ImmutableList.Builder<NamedExpression> newAggOutputBuilder
= ImmutableList.builderWithExpectedSize(aggOutput.size());
for (NamedExpression output : aggOutput) {
Expression rewrittenExpr = output.rewriteDownShortCircuit(
e -> e instanceof MultiDistinction ? ((MultiDistinction) e).withMustUseMultiDistinctAgg(true) : e);
newAggOutputBuilder.add((NamedExpression) rewrittenExpr);
}
ImmutableList<NamedExpression> normalizedAggOutput = newAggOutputBuilder.build();
LogicalAggregate<?> newAggregate =
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
Expand Down Expand Up @@ -296,6 +297,7 @@ && couldConvertToMulti(agg))
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
/*
Expand All @@ -319,6 +321,7 @@ && couldConvertToMulti(agg))
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
groupByAndDistinct -> RequireProperties.of(
Expand Down Expand Up @@ -1940,7 +1943,7 @@ private boolean couldConvertToMulti(LogicalAggregate<? extends Plan> aggregate)
}
for (int i = 1; i < func.arity(); i++) {
// think about group_concat(distinct col_1, ',')
if (!func.child(i).getInputSlots().isEmpty()) {
if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ public WindowExpression withFunction(Expression function) {
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys));
}

public WindowExpression withFunctionPartitionKeysOrderKeys(Expression function,
List<Expression> partitionKeys, List<OrderExpression> orderKeys) {
return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeys, frame))
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys));
}

@Override
public boolean nullable() {
return function.nullable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ public Pair<BoundFunction, AggregateFunction> build(String name, List<?> argumen
String nestedName = getNestedName(name);
if (combinatorSuffix.equalsIgnoreCase(STATE)) {
AggregateFunction nestedFunction = buildState(nestedName, arguments);
// distinct will be passed as 1st boolean true arg. remove it
if (!arguments.isEmpty() && arguments.get(0) instanceof Boolean && (Boolean) arguments.get(0)) {
arguments = arguments.subList(1, arguments.size());
}
return Pair.of(new StateCombinator((List<Expression>) arguments, nestedFunction), nestedFunction);
} else if (combinatorSuffix.equalsIgnoreCase(MERGE)) {
AggregateFunction nestedFunction = buildMergeOrUnion(nestedName, arguments);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
package org.apache.doris.nereids.trees.expressions.functions;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctGroupConcat;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.util.Utils;

Expand Down Expand Up @@ -98,4 +102,17 @@ public String toString() {
.collect(Collectors.joining(", "));
return getName() + "(" + args + ")";
}

/**
* checkOrderExprIsValid.
*/
public void checkOrderExprIsValid() {
for (Expression child : children) {
if (child instanceof OrderExpression
&& !(this instanceof GroupConcat || this instanceof MultiDistinctGroupConcat)) {
throw new AnalysisException(
String.format("%s doesn't support order by expression", getName()));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,8 @@ public String toString() {
public List<Expression> getDistinctArguments() {
return distinct ? getArguments() : ImmutableList.of();
}

public boolean mustUseMultiDistinctAgg() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,57 +52,30 @@ public class GroupConcat extends NullableAggregateFunction
/**
* constructor with 1 argument.
*/
public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, OrderExpression... orders) {
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, orders));
this.nonOrderArguments = 1;
checkArguments();
public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, Expression... others) {
this(distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, others));
}

/**
* constructor with 1 argument.
*/
public GroupConcat(boolean distinct, Expression arg, OrderExpression... orders) {
this(distinct, false, arg, orders);
public GroupConcat(boolean distinct, Expression arg, Expression... others) {
this(distinct, false, arg, others);
}

/**
* constructor with 1 argument, use for function search.
*/
public GroupConcat(Expression arg, OrderExpression... orders) {
this(false, arg, orders);
}

/**
* constructor with 2 arguments.
*/
public GroupConcat(boolean distinct, boolean alwaysNullable,
Expression arg0, Expression arg1, OrderExpression... orders) {
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg0, arg1, orders));
this.nonOrderArguments = 2;
checkArguments();
}

/**
* constructor with 2 arguments.
*/
public GroupConcat(boolean distinct, Expression arg0, Expression arg1, OrderExpression... orders) {
this(distinct, false, arg0, arg1, orders);
}

/**
* constructor with 2 arguments, use for function search.
*/
public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
this(false, arg0, arg1, orders);
public GroupConcat(Expression arg, Expression... others) {
this(false, arg, others);
}

/**
* constructor for always nullable.
*/
public GroupConcat(boolean distinct, boolean alwaysNullable, int nonOrderArguments, List<Expression> args) {
public GroupConcat(boolean distinct, boolean alwaysNullable, List<Expression> args) {
super("group_concat", distinct, alwaysNullable, args);
this.nonOrderArguments = nonOrderArguments;
checkArguments();
this.nonOrderArguments = findOrderExprIndex(children);
}

@Override
Expand Down Expand Up @@ -139,38 +112,15 @@ public void checkLegalityBeforeTypeCoercion() {

@Override
public GroupConcat withAlwaysNullable(boolean alwaysNullable) {
return new GroupConcat(distinct, alwaysNullable, nonOrderArguments, children);
return new GroupConcat(distinct, alwaysNullable, children);
}

/**
* withDistinctAndChildren.
*/
@Override
public GroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1);
boolean foundOrderExpr = false;
int firstOrderExrIndex = 0;
for (int i = 0; i < children.size(); i++) {
Expression child = children.get(i);
if (child instanceof OrderExpression) {
foundOrderExpr = true;
} else if (!foundOrderExpr) {
firstOrderExrIndex++;
} else {
throw new AnalysisException("invalid group_concat parameters: " + children);
}
}

List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
if (firstOrderExrIndex == 1) {
return new GroupConcat(distinct, alwaysNullable,
children.get(0), orders.toArray(new OrderExpression[0]));
} else if (firstOrderExrIndex == 2) {
return new GroupConcat(distinct, alwaysNullable,
children.get(0), children.get(1), orders.toArray(new OrderExpression[0]));
} else {
throw new AnalysisException("group_concat requires one or two parameters: " + children);
}
return new GroupConcat(distinct, alwaysNullable, children);
}

@Override
Expand All @@ -186,15 +136,34 @@ public List<FunctionSignature> getSignatures() {
public MultiDistinctGroupConcat convertToMultiDistinct() {
Preconditions.checkArgument(distinct,
"can't convert to multi_distinct_group_concat because there is no distinct args");
return new MultiDistinctGroupConcat(alwaysNullable, nonOrderArguments, children);
return new MultiDistinctGroupConcat(alwaysNullable, children);
}

// TODO: because of current be's limitation, we have to thow exception for now
// remove this after be support new method of multi distinct functions
private void checkArguments() {
if (isDistinct() && children().stream().anyMatch(expression -> expression instanceof OrderExpression)) {
@Override
public boolean mustUseMultiDistinctAgg() {
return distinct && children.stream().anyMatch(OrderExpression.class::isInstance);
}

private int findOrderExprIndex(List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1, "children's size should >= 1");
boolean foundOrderExpr = false;
int firstOrderExrIndex = 0;
for (int i = 0; i < children.size(); i++) {
Expression child = children.get(i);
if (child instanceof OrderExpression) {
foundOrderExpr = true;
} else if (!foundOrderExpr) {
firstOrderExrIndex++;
} else {
throw new AnalysisException(
"invalid multi_distinct_group_concat parameters: " + children);
}
}

if (firstOrderExrIndex > 2) {
throw new AnalysisException(
"group_concat don't support using distinct with order by together");
"multi_distinct_group_concat requires one or two parameters: " + children);
}
return firstOrderExrIndex;
}
}
Loading

0 comments on commit 40567b5

Please sign in to comment.