From 06a01e6bebaab8bcd33ef33664cf3385f01fd79b Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 6 Feb 2025 17:05:58 +0800 Subject: [PATCH] wip --- .../backendsapi/clickhouse/CHBackend.scala | 3 + .../backendsapi/clickhouse/CHRuleApi.scala | 1 + .../extension/CoalesceAggregationUnion.scala | 365 ++++++++---------- .../GlutenCoalesceAggregationUnionSuite.scala | 44 ++- 4 files changed, 204 insertions(+), 209 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index e883f8c454ce..f47aa0416cc3 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -156,6 +156,9 @@ object CHBackendSettings extends BackendSettingsApi with Logging { CHConf.prefixOf("convert.left.anti_semi.to.right") val GLUTEN_CLICKHOUSE_CONVERT_LEFT_ANTI_SEMI_TO_RIGHT_DEFAULT_VALUE: String = "false" + val GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION: String = + CHConf.prefixOf("enable.coalesce.aggregation.union") + def affinityMode: String = { SparkEnv.get.conf .get( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 40344e96e768..ecd7e5a24107 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -60,6 +60,7 @@ object CHRuleApi { (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) injector.injectParser( (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) + injector.injectResolutionRule(spark => new CoalesceAggregationUnion(spark)) injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index 4a592d03a34c..0e59004ba4ee 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.extension +import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings import org.apache.gluten.exception.GlutenNotSupportException import org.apache.spark.internal.Logging @@ -93,7 +94,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case class AggregateAnalzyInfo(originalAggregate: Aggregate) { - protected def buildAttributesToExpressionsMap( + protected def createAttributeToExpressionMap( attributes: Seq[Attribute], expressions: Seq[Expression]): Map[ExprId, Expression] = { val map = new mutable.HashMap[ExprId, Expression]() @@ -104,7 +105,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi map.toMap } - protected def replaceAttributes( + protected def replaceAttributesInExpression( expression: Expression, replaceMap: Map[ExprId, Expression]): Expression = { expression.transform { @@ -113,20 +114,16 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - protected def getFilter(): Option[Filter] = { + protected def extractFilter(): Option[Filter] = { originalAggregate.child match { case filter: Filter => Some(filter) case project @ Project(_, filter: Filter) => Some(filter) case subquery: SubqueryAlias => - logError(s"xxx subquery child. ${subquery.child.getClass}") subquery.child match { case filter: Filter => Some(filter) case project @ Project(_, filter: Filter) => Some(filter) case relation: LogicalRelation => Some(Filter(Literal(true, BooleanType), subquery)) case nestedRelation: SubqueryAlias => - logError( - s"xxx nestedRelation child. ${nestedRelation.child.getClass}" + - s"\n$nestedRelation") if (nestedRelation.child.isInstanceOf[LogicalRelation]) { Some(Filter(Literal(true, BooleanType), nestedRelation)) } else { @@ -139,8 +136,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } // Try to make the plan simple, contain only three steps, source, filter, aggregate. - lazy val sourcePlan = { - val filter = getFilter() + lazy val extractedSourcePlan = { + val filter = extractFilter() if (!filter.isDefined) { None } else { @@ -151,9 +148,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - lazy val filterPlan = { - val filter = getFilter() - if (!filter.isDefined || !sourcePlan.isDefined) { + lazy val constructedFilterPlan = { + val filter = extractFilter() + if (!filter.isDefined || !extractedSourcePlan.isDefined) { None } else { val project = filter.get.child match { @@ -161,19 +158,19 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case other => None } - val replacedFilter = project match { + val newFilter = project match { case Some(project) => - val replaceMap = buildAttributesToExpressionsMap(project.output, project.child.output) - val replacedCondition = replaceAttributes(filter.get.condition, replaceMap) - Filter(replacedCondition, sourcePlan.get) - case None => filter.get.withNewChildren(Seq(sourcePlan.get)) + val replaceMap = createAttributeToExpressionMap(project.output, project.child.output) + val newCondition = replaceAttributesInExpression(filter.get.condition, replaceMap) + Filter(newCondition, extractedSourcePlan.get) + case None => filter.get.withNewChildren(Seq(extractedSourcePlan.get)) } - Some(replacedFilter) + Some(newFilter) } } - lazy val aggregatePlan = { - if (!filterPlan.isDefined) { + lazy val constructedAggregatePlan = { + if (!constructedFilterPlan.isDefined) { None } else { val project = originalAggregate.child match { @@ -186,20 +183,20 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case _ => None } - val replacedAggregate = project match { + val newAggregate = project match { case Some(innerProject) => val replaceMap = - buildAttributesToExpressionsMap(innerProject.output, innerProject.projectList) - val groupExpressions = originalAggregate.groupingExpressions.map { - e => replaceAttributes(e, replaceMap) + createAttributeToExpressionMap(innerProject.output, innerProject.projectList) + val newGroupExpressions = originalAggregate.groupingExpressions.map { + e => replaceAttributesInExpression(e, replaceMap) } - val aggregateExpressions = originalAggregate.aggregateExpressions.map { - e => replaceAttributes(e, replaceMap).asInstanceOf[NamedExpression] + val newAggregateExpressions = originalAggregate.aggregateExpressions.map { + e => replaceAttributesInExpression(e, replaceMap).asInstanceOf[NamedExpression] } - Aggregate(groupExpressions, aggregateExpressions, filterPlan.get) - case None => originalAggregate.withNewChildren(Seq(filterPlan.get)) + Aggregate(newGroupExpressions, newAggregateExpressions, constructedFilterPlan.get) + case None => originalAggregate.withNewChildren(Seq(constructedFilterPlan.get)) } - Some(replacedAggregate) + Some(newAggregate) } } @@ -208,7 +205,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } // The output results which are not aggregate expressions. - lazy val resultGroupingExpressions = aggregatePlan match { + lazy val resultGroupingExpressions = constructedAggregatePlan match { case Some(agg) => agg.asInstanceOf[Aggregate].aggregateExpressions.filter(e => !hasAggregateExpression(e)) case None => Seq.empty @@ -223,7 +220,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi // `select k1 + k2, count(1) from t group by k1, k2`. resultGroupingExpressions.map { e => - val aggregate = aggregatePlan.get.asInstanceOf[Aggregate] + val aggregate = constructedAggregatePlan.get.asInstanceOf[Aggregate] e match { case literal @ Alias(_: Literal, _) => var idx = aggregate.groupingExpressions.indexOf(e) @@ -254,20 +251,19 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case class AnalyzedPlan(plan: LogicalPlan, analyzedInfo: Option[AggregateAnalzyInfo]) override def apply(plan: LogicalPlan): LogicalPlan = { - if (plan.resolved) { - logError(s"xxx visit plan:\n$plan") - val newPlan = visitPlan(plan) - logError(s"xxx output attributes:\n${newPlan.output}\n${plan.output}") - logError(s"xxx rewrite plan:\n$newPlan") - newPlan + if ( + plan.resolved && spark.conf + .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, "true") + .toBoolean + ) { + visitPlan(plan) } else { - logError(s"xxx plan not resolved:\n$plan") plan } } def visitPlan(plan: LogicalPlan): LogicalPlan = { - val newPlan = plan match { + plan match { case union: Union => val planGroups = groupStructureMatchedAggregate(union) val newUnionClauses = planGroups.map { @@ -276,42 +272,40 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi groupedPlans.head.plan } else { val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get - val aggregates = groupedPlans.map(_.analyzedInfo.get.aggregatePlan.get) - val replaceAttributes = collectReplaceAttributes(aggregates) + val aggregates = groupedPlans.map(_.analyzedInfo.get.constructedAggregatePlan.get) + val replaceAttributes = collectreplaceAttributesInExpression(aggregates) val filterConditions = buildAggregateCasesConditions(aggregates, replaceAttributes) - logError(s"xxx filterConditions. ${filterConditions.length},\n$filterConditions") val firstAggregateFilter = - firstAggregateAnalzyInfo.filterPlan.get.asInstanceOf[Filter] + firstAggregateAnalzyInfo.constructedFilterPlan.get.asInstanceOf[Filter] // Add a filter step with condition `cond1 or cond2 or ...`, `cond_i` comes from each // union clause. Apply this filter on the source plan. val unionFilter = Filter( buildUnionConditionForAggregateSource(filterConditions), - firstAggregateAnalzyInfo.sourcePlan.get) + firstAggregateAnalzyInfo.extractedSourcePlan.get) // Wrap all the attributes into a single structure attribute. - val wrappedAttributesProject = buildStructWrapperProject( + val wrappedAttributesProject = buildProjectFoldIntoStruct( unionFilter, groupedPlans, filterConditions, replaceAttributes) // Build an array which element are response to each union clause. - val arrayProject = buildArrayProject(wrappedAttributesProject, filterConditions) + val arrayProject = buildProjectBranchArray(wrappedAttributesProject, filterConditions) // Explode the array - val explode = buildArrayExplode(arrayProject) + val explode = buildExplodeBranchArray(arrayProject) // Null value means that the union clause does not have the corresponding data. val notNullFilter = Filter(IsNotNull(explode.output.head), explode) // Destruct the struct attribute. - val destructStructProject = buildDestructStructProject(notNullFilter) + val destructStructProject = buildProjectUnfoldStruct(notNullFilter) buildAggregateWithGroupId(destructStructProject, groupedPlans) } } - logError(s"xxx newUnionClauses. ${newUnionClauses.length},\n$newUnionClauses") val coalesePlan = if (newUnionClauses.length == 1) { newUnionClauses.head } else { @@ -321,7 +315,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } Union(firstUnionChild, newUnionClauses.last) } - logError(s"xxx coalesePlan:$coalesePlan") // We need to keep the output atrributes same as the original plan. val outputAttrPairs = coalesePlan.output.zip(union.output) @@ -340,42 +333,23 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } case _ => plan.withNewChildren(plan.children.map(visitPlan)) } - // newPlan.copyTagsFrom(plan) - newPlan } def isSupportedAggregate(info: AggregateAnalzyInfo): Boolean = { - if (info.hasAggregateWithFilter) { - return false - } - if (!info.aggregatePlan.isDefined) { - return false - } - - if (info.positionInGroupingKeys.exists(_ < 0)) { - return false - } - - // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y` is not supported. - if ( - info.originalAggregate.aggregateExpressions.exists { - e => - val innerExpr = removeAlias(e) - if (hasAggregateExpression(innerExpr)) { - !innerExpr.isInstanceOf[AggregateExpression] && - !innerExpr.children.forall(e => isAggregateExpression(e)) - } else { - false - } - } - ) { - return false - } - - if (!info.aggregatePlan.isDefined) { - return false + !info.hasAggregateWithFilter && + info.constructedAggregatePlan.isDefined && + info.positionInGroupingKeys.forall(_ >= 0) && + info.originalAggregate.aggregateExpressions.forall { + e => + val innerExpr = removeAlias(e) + // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y` is not supported. + if (hasAggregateExpression(innerExpr)) { + innerExpr.isInstanceOf[AggregateExpression] || + innerExpr.children.forall(e => isAggregateExpression(e)) + } else { + true + } } - true } /** @@ -392,61 +366,28 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi * True if the two instances have the same structure, false otherwise. */ def areStructureMatchedAggregate(l: AggregateAnalzyInfo, r: AggregateAnalzyInfo): Boolean = { - val lAggregate = l.aggregatePlan.get.asInstanceOf[Aggregate] - val rAggregate = r.aggregatePlan.get.asInstanceOf[Aggregate] - - // Check aggregate result expressions. need same schema. - if (lAggregate.aggregateExpressions.length != rAggregate.aggregateExpressions.length) { - logError(s"xxx not equal 1") - - return false - } - val allAggregateExpressionAreSame = - lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).forall { - case (lExpr, rExpr) => - if (!lExpr.dataType.equals(rExpr.dataType)) { - false - } else { - (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match { - case (true, true) => - areStructureMatchedExpressions(lExpr, rExpr) - case (false, true) => false - case (true, false) => false - case (false, false) => true - } + val lAggregate = l.constructedAggregatePlan.get.asInstanceOf[Aggregate] + val rAggregate = r.constructedAggregatePlan.get.asInstanceOf[Aggregate] + lAggregate.aggregateExpressions.length == rAggregate.aggregateExpressions.length && + lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).forall { + case (lExpr, rExpr) => + if (!lExpr.dataType.equals(rExpr.dataType)) { + false + } else { + (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match { + case (true, true) => areStructureMatchedExpressions(lExpr, rExpr) + case (false, true) => false + case (true, false) => false + case (false, false) => true } - } - if (!allAggregateExpressionAreSame) { - return false - } - - // Check grouping expressions, need same schema. - if (lAggregate.groupingExpressions.length != rAggregate.groupingExpressions.length) { - logError(s"xxx not equal 3") - return false - } - if ( - l.positionInGroupingKeys.length != - r.positionInGroupingKeys.length - ) { - logError(s"xxx not equal 4") - return false - } - val allSameGroupingKeysRef = l.positionInGroupingKeys - .zip(r.positionInGroupingKeys) - .forall { case (lPos, rPos) => lPos == rPos } - if (!allSameGroupingKeysRef) { - logError(s"xxx not equal 5") - return false - } - - // Must come from same source. - if (!areSameAggregateSource(l.sourcePlan.get, r.sourcePlan.get)) { - logError(s"xxx not same source. ${l.sourcePlan.get}\n${r.sourcePlan.get}") - return false - } - - true + } + } && + lAggregate.groupingExpressions.length == rAggregate.groupingExpressions.length && + l.positionInGroupingKeys.length == r.positionInGroupingKeys.length && + l.positionInGroupingKeys.zip(r.positionInGroupingKeys).forall { + case (lPos, rPos) => lPos == rPos + } && + areSameAggregateSource(l.extractedSourcePlan.get, r.extractedSourcePlan.get) } /* @@ -493,27 +434,46 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def groupStructureMatchedAggregate(union: Union): ArrayBuffer[ArrayBuffer[AnalyzedPlan]] = { - val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]() - collectAllUnionClauses(union).foreach { - case agg: Aggregate => - val analyzedInfo = AggregateAnalzyInfo(agg) - if (isSupportedAggregate(analyzedInfo)) { - if (groupResults.isEmpty) { - groupResults += ArrayBuffer(AnalyzedPlan(agg, Some(analyzedInfo))) + + def tryPutToGroup( + groupResults: ArrayBuffer[ArrayBuffer[AnalyzedPlan]], + agg: Aggregate): Unit = { + val analyzedInfo = AggregateAnalzyInfo(agg) + if (isSupportedAggregate(analyzedInfo)) { + if (groupResults.isEmpty) { + groupResults += ArrayBuffer( + AnalyzedPlan(analyzedInfo.originalAggregate, Some(analyzedInfo))) + } else { + val idx = findStructureMatchedAggregate(groupResults, analyzedInfo) + if (idx != -1) { + groupResults(idx) += AnalyzedPlan( + analyzedInfo.constructedAggregatePlan.get, + Some(analyzedInfo)) } else { - val idx = findStructureMatchedAggregate(groupResults, analyzedInfo) - if (idx != -1) { - groupResults(idx) += AnalyzedPlan(agg, Some(analyzedInfo)) - } else { - groupResults += ArrayBuffer(AnalyzedPlan(agg, Some(analyzedInfo))) - } + groupResults += ArrayBuffer( + AnalyzedPlan(analyzedInfo.constructedAggregatePlan.get, Some(analyzedInfo))) } + } + } else { + val rewrittenPlan = visitPlan(agg) + groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None)) + } + } + + val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]() + collectAllUnionClauses(union).foreach { + case project @ Project(projectList, agg: Aggregate) => + if (projectList.forall(e => e.isInstanceOf[Alias])) { + tryPutToGroup(groupResults, agg) } else { - logError(s"xxx not supported. $agg") - groupResults += ArrayBuffer(AnalyzedPlan(agg, None)) + val rewrittenPlan = visitPlan(project) + groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None)) } + case agg: Aggregate => + tryPutToGroup(groupResults, agg) case other => - groupResults += ArrayBuffer(AnalyzedPlan(other, None)) + val rewrittenPlan = visitPlan(other) + groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None)) } groupResults } @@ -521,17 +481,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def areStructureMatchedExpressions(l: Expression, r: Expression): Boolean = { (l, r) match { case (lAttr: Attribute, rAttr: Attribute) => - logError(s"xxx attr equal: ${lAttr.qualifiedName}, ${rAttr.qualifiedName}") - lAttr.qualifiedName == rAttr.qualifiedName + // The the qualifier may be overwritten by a subquery alias, and make this check fail. + lAttr.qualifiedName.equals(rAttr.qualifiedName) case (lLiteral: Literal, rLiteral: Literal) => lLiteral.value.equals(rLiteral.value) case _ => - if (l.children.length != r.children.length || l.getClass != r.getClass) { - false - } else { - l.children.zip(r.children).forall { - case (lChild, rChild) => areStructureMatchedExpressions(lChild, rChild) - } + l.children.length == r.children.length && + l.getClass == r.getClass && + l.children.zip(r.children).forall { + case (lChild, rChild) => areStructureMatchedExpressions(lChild, rChild) } } } @@ -544,19 +502,22 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case (lRel: LogicalRelation, rRel: LogicalRelation) => val lTable = lRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") val rTable = rRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") - logError(s"xxx table equal: $lTable, $rTable") lTable.equals(rTable) && lTable.nonEmpty + case (lRef: CTERelationRef, rRelf: CTERelationRef) => + lRef.cteId == rRelf.cteId case (lSubQuery: SubqueryAlias, rSubQuery: SubqueryAlias) => areSameAggregateSource(lSubQuery.child, rSubQuery.child) - case (lChild, rChild) => false + case (lChild, rChild) => + false } } } - def collectReplaceAttributes(groupedPlans: ArrayBuffer[LogicalPlan]): Map[String, Attribute] = { - def findFirstRelation(plan: LogicalPlan): LogicalRelation = { - if (plan.isInstanceOf[LogicalRelation]) { - return plan.asInstanceOf[LogicalRelation] + def collectreplaceAttributesInExpression( + groupedPlans: ArrayBuffer[LogicalPlan]): Map[String, Attribute] = { + def findFirstRelation(plan: LogicalPlan): LogicalPlan = { + if (plan.isInstanceOf[LogicalRelation] || plan.isInstanceOf[CTERelationRef]) { + return plan } else if (plan.children.isEmpty) { return null } else { @@ -586,7 +547,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi replaceMap.toMap } - def replaceAttributes(expression: Expression, replaceMap: Map[String, Attribute]): Expression = { + def replaceAttributesInExpression( + expression: Expression, + replaceMap: Map[String, Attribute]): Expression = { expression.transform { case attr: Attribute => replaceMap.get(attr.qualifiedName) match { @@ -602,7 +565,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi groupedPlans.map { plan => val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] - replaceAttributes(filter.condition, replaceMap) + replaceAttributesInExpression(filter.condition, replaceMap) } } @@ -619,13 +582,13 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi groupedPlans.zipWithIndex.foreach { case (aggregateCase, case_index) => val analyzedInfo = aggregateCase.analyzedInfo.get - val aggregate = analyzedInfo.aggregatePlan.get.asInstanceOf[Aggregate] + val aggregate = analyzedInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate] val structFields = ArrayBuffer[Expression]() var fieldIndex: Int = 0 aggregate.groupingExpressions.foreach { e => structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributes(e, replaceMap) + structFields += replaceAttributesInExpression(e, replaceMap) fieldIndex += 1 } for (i <- 0 until analyzedInfo.positionInGroupingKeys.length) { @@ -633,7 +596,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi if (position >= fieldIndex) { val expr = analyzedInfo.resultGroupingExpressions(i) structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributes(analyzedInfo.resultGroupingExpressions(i), replaceMap) + structFields += replaceAttributesInExpression( + analyzedInfo.resultGroupingExpressions(i), + replaceMap) fieldIndex += 1 } } @@ -652,7 +617,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structFields += Literal( UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributes(child, replaceMap) + structFields += replaceAttributesInExpression(child, replaceMap) fieldIndex += 1 } case combineAgg if hasAggregateExpression(combineAgg) => @@ -663,7 +628,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structFields += Literal( UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributes(other, replaceMap) + structFields += replaceAttributesInExpression(other, replaceMap) fieldIndex += 1 } } @@ -676,7 +641,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structAttributes } - def buildStructWrapperProject( + def buildProjectFoldIntoStruct( child: LogicalPlan, groupedPlans: ArrayBuffer[AnalyzedPlan], conditions: ArrayBuffer[Expression], @@ -690,7 +655,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Project(ifAttributes, child) } - def buildArrayProject(child: LogicalPlan, conditions: ArrayBuffer[Expression]): LogicalPlan = { + def buildProjectBranchArray( + child: LogicalPlan, + conditions: ArrayBuffer[Expression]): LogicalPlan = { assert( child.output.length == conditions.length, s"Expected same length of output and conditions") @@ -698,7 +665,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Project(Seq(array), child) } - def buildArrayExplode(child: LogicalPlan): LogicalPlan = { + def buildExplodeBranchArray(child: LogicalPlan): LogicalPlan = { assert(child.output.length == 1, s"Expected single output from $child") val array = child.output.head.asInstanceOf[Expression] assert(array.dataType.isInstanceOf[ArrayType], s"Expected ArrayType from $array") @@ -714,18 +681,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi child) } - def buildGroupConditions( - groupedPlans: ArrayBuffer[LogicalPlan], - replaceMap: Map[String, Attribute]): (ArrayBuffer[Expression], Expression) = { - val conditions = groupedPlans.map { - plan => - val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] - replaceAttributes(filter.condition, replaceMap) - } - val unionCond = conditions.reduce(Or) - (conditions, unionCond) - } - def makeAlias(e: Expression, name: String): NamedExpression = { Alias(e, name)( NamedExpression.newExprId, @@ -737,7 +692,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Seq.empty) } - def buildDestructStructProject(child: LogicalPlan): LogicalPlan = { + def buildProjectUnfoldStruct(child: LogicalPlan): LogicalPlan = { assert(child.output.length == 1, s"Expected single output from $child") val structedData = child.output.head assert( @@ -759,7 +714,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi groupedPlans: ArrayBuffer[AnalyzedPlan]): LogicalPlan = { val attributes = child.output val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get - val aggregateTemplate = firstAggregateAnalzyInfo.aggregatePlan.get.asInstanceOf[Aggregate] + val aggregateTemplate = + firstAggregateAnalzyInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate] val analyzedInfo = groupedPlans.head.analyzedInfo.get val totalGroupingExpressionsCount = @@ -779,11 +735,10 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi e => removeAlias(e) match { case aggExpr if hasAggregateExpression(aggExpr) => - aggregateExpressions += makeAlias( - constructAggregateExpression(aggExpr, attributes, aggregateExpressionIndex), - e.name) - .asInstanceOf[NamedExpression] - aggregateExpressionIndex += aggExpr.children.length + val (newAggExpr, count) = + constructAggregateExpression(aggExpr, attributes, aggregateExpressionIndex) + aggregateExpressions += makeAlias(newAggExpr, e.name).asInstanceOf[NamedExpression] + aggregateExpressionIndex += count case other => val position = normalExpressionPosition(normalExpressionCount) val attr = attributes(position) @@ -798,7 +753,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def constructAggregateExpression( aggExpr: Expression, attributes: Seq[Attribute], - index: Int): Expression = { + index: Int): (Expression, Int) = { aggExpr match { case singleAggExpr: AggregateExpression => val aggFunc = singleAggExpr.aggregateFunction @@ -808,18 +763,24 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } val newAggFunc = aggFunc.withNewChildren(newAggFuncArgs).asInstanceOf[AggregateFunction] - AggregateExpression( + val res = AggregateExpression( newAggFunc, singleAggExpr.mode, singleAggExpr.isDistinct, singleAggExpr.filter, singleAggExpr.resultId) + (res, 1) case combineAggExpr if hasAggregateExpression(combineAggExpr) => - combineAggExpr.withNewChildren( - combineAggExpr.children.map(constructAggregateExpression(_, attributes, index))) - case _ => - val normalExpr = attributes(index) - normalExpr + val childrenExpressions = ArrayBuffer[Expression]() + var totalCount = 0 + combineAggExpr.children.foreach { + child => + val (expr, count) = constructAggregateExpression(child, attributes, totalCount + index) + childrenExpressions += expr + totalCount += count + } + (combineAggExpr.withNewChildren(childrenExpressions.toSeq), totalCount) + case _ => (attributes(index), 1) } } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala index ae5b4b7d5a92..456d95d4b515 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala @@ -16,9 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.config.GlutenConfig -import org.apache.gluten.utils.UTSystemParameters - import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig @@ -42,7 +39,6 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran .set("spark.databricks.delta.properties.defaults.checkpointInterval", "5") .set("spark.databricks.delta.stalenessLimit", "3600000") .set(ClickHouseConfig.CLICKHOUSE_WORKER_ID, "1") - .set(GlutenConfig.GLUTEN_LIB_PATH, UTSystemParameters.clickHouseLibPath) .set("spark.gluten.sql.columnar.iterator", "true") .set("spark.gluten.sql.columnar.hashagg.enablefinal", "true") .set("spark.gluten.sql.enable.native.validation", "false") @@ -154,12 +150,12 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran val sql = """ |select * from ( - | select a, 1 as t, count(x) + sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | select a, 1 as t, count(x) + sum(y) as n from coalesce_union_t1 where b % 3 = 0 | group by a | union all - | select a, 2 as t, count(x) + sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | select a, 2 as t, count(x) + sum(y) as n from coalesce_union_t1 where b % 3 = 1 | group by a - |) order by a, t, y + |) order by a, t, n |""".stripMargin compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) } @@ -240,6 +236,40 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) } + test("coalesce aggregation union. case 10") { + val sql = + """ + |select * from ( + | select a as a, sum(y) as y from ( + | select concat(a, "x") as a, y from coalesce_union_t1 where b % 3 = 0 + | ) group by a + | union all + | select x as a , sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by x + |) order by a, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 11") { + val sql = + """ + |select t1.a, t1.y, t2.x from ( + | select a as a, sum(y) as y from ( + | select concat(a, "x") as a, y from coalesce_union_t1 where b % 3 = 0 + | ) group by a + | union all + | select x as a , sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by x + |) as t1 + |left join ( + | select a, x from coalesce_union_t2 + |) as t2 + |on t1.a = t2.a + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + test("no coalesce aggregation union. case 1") { val sql = """