From 5826c5e6a9feb913ef3704f816f4b000a819640a Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Wed, 22 Jan 2025 14:23:14 +0800 Subject: [PATCH 1/9] stage --- .../extension/CoalesceAggregationUnion.scala | 401 ++++++++++++++++++ 1 file changed, 401 insertions(+) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala new file mode 100644 index 000000000000..eb8af1020ea3 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.exception.GlutenNotSupportException + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + override def apply(plan: LogicalPlan): LogicalPlan = { + logError(s"xxx plan is resolved: ${plan.resolved}") + if (plan.resolved) { + logError(s"xxx visit plan:\n$plan") + visitPlan(plan) + } else { + plan + } + } + + def visitPlan(plan: LogicalPlan): LogicalPlan = { + val newPlan = plan match { + case union: Union => + logError(s"xxx is union node, children: ${union.children}") + val groups = groupUnionAggregations(union) + val rewrittenGroups = groups.map { + group => + if (group.length == 1) { + group.head + } else { + val replaceMap = collectReplaceAttributes(group) + val filter = group.head.asInstanceOf[Aggregate].child.asInstanceOf[Filter] + logError(s"xxx replace map: $replaceMap") + val (groupConds, unionCond) = buildGroupConditions(group, replaceMap) + logError(s"xxx replace condition: $groupConds, $unionCond") + val explodeResult = explodeBranches(group, groupConds, replaceMap) + val unionFilter = Filter(unionCond, filter.child) + val project = Project(Seq(explodeResult), unionFilter) + logError(s"xxx new project\n$project") + val nullFilter = Filter(IsNotNull(explodeResult), project) + logError(s"xxx new null filter\n$nullFilter") + group.head + } + } + if (rewrittenGroups.length == 1) { + rewrittenGroups.head + } else { + union.withNewChildren(rewrittenGroups) + } + case project: Project => + project.projectList.foreach( + e => logError(s"xxx project expression: ${removeAlias(e).getClass}, $e")) + project.withNewChildren(project.children.map(visitPlan)) + case _ => plan.withNewChildren(plan.children.map(visitPlan)) + } + newPlan.copyTagsFrom(plan) + newPlan + } + + def removeAlias(e: Expression): Expression = { + e match { + case alias: Alias => alias.child + case _ => e + } + } + + def groupUnionAggregations(union: Union): ArrayBuffer[ArrayBuffer[LogicalPlan]] = { + val groupResults = ArrayBuffer[ArrayBuffer[LogicalPlan]]() + union.children.foreach { + case agg @ Aggregate(_, _, filter: Filter) => + agg.groupingExpressions.foreach( + e => logError(s"xxx grouping expression: ${e.getClass}, $e")) + agg.aggregateExpressions.foreach( + e => logError(s"xxx aggregate expression: ${e.getClass}, $e")) + if ( + groupResults.isEmpty && agg.aggregateExpressions.exists( + e => + removeAlias(e) match { + case aggExpr: AggregateExpression => !aggExpr.filter.isDefined + case _ => true + }) + ) { + groupResults += ArrayBuffer(agg) + } else { + groupResults.find { + group => + group.head match { + case toMatchAgg @ Aggregate(_, _, toMatchFilter: Filter) => + areSameSchemaAggregations(toMatchAgg, agg) && + areSameAggregationSource(toMatchFilter.child, filter.child) + case _ => false + } + } match { + case Some(group) => + group += agg.asInstanceOf[LogicalPlan] + case None => + val newGroup = ArrayBuffer(agg.asInstanceOf[LogicalPlan]) + groupResults += newGroup + } + } + case other => + // Other plans will be remained as union clauses. + val singlePlan = ArrayBuffer(other) + groupResults += singlePlan + } + groupResults + } + + def areEqualExpressions(l: Expression, r: Expression): Boolean = { + (l, r) match { + case (lAttr: Attribute, rAttr: Attribute) => + lAttr.qualifiedName == rAttr.qualifiedName + case (lLiteral: Literal, rLiteral: Literal) => + lLiteral.value.equals(rLiteral.value) + case _ => + if (l.children.length != r.children.length || l.getClass != r.getClass) { + false + } else { + l.children.zip(r.children).forall { + case (lChild, rChild) => areEqualExpressions(lChild, rChild) + } + } + } + } + + def areSameSchemaAggregations(agg1: Aggregate, agg2: Aggregate): Boolean = { + if ( + agg1.aggregateExpressions.length != agg2.aggregateExpressions.length || + agg1.groupingExpressions.length != agg2.groupingExpressions.length + ) { + false + } else { + agg1.aggregateExpressions.zip(agg2.aggregateExpressions).forall { + case (_ @Alias(_: Literal, lName), _ @Alias(_: Literal, rName)) => + lName.equals(rName) + case (l, r) => areEqualExpressions(l, r) + } && + agg1.groupingExpressions.zip(agg2.groupingExpressions).forall { + case (l, r) => areEqualExpressions(l, r) + } + } + } + + def areSameAggregationSource(lPlan: LogicalPlan, rPlan: LogicalPlan): Boolean = { + if (lPlan.children.length != rPlan.children.length || lPlan.getClass != rPlan.getClass) { + false + } else { + lPlan.children.zip(rPlan.children).forall { + case (lRel: LogicalRelation, rRel: LogicalRelation) => + val lTable = lRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") + val rTable = rRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") + lRel.output.foreach(attr => logError(s"xxx l attr: $attr, ${attr.qualifiedName}")) + rRel.output.foreach(attr => logError(s"xxx r attr: $attr, ${attr.qualifiedName}")) + logError(s"xxx table: $lTable, $rTable") + lTable.equals(rTable) && lTable.nonEmpty + case (lSubQuery: SubqueryAlias, rSubQuery: SubqueryAlias) => + areSameAggregationSource(lSubQuery.child, rSubQuery.child) + case (lChild, rChild) => false + } + } + } + + def collectReplaceAttributes(group: ArrayBuffer[LogicalPlan]): Map[String, Attribute] = { + def findFirstRelation(plan: LogicalPlan): LogicalRelation = { + if (plan.isInstanceOf[LogicalRelation]) { + return plan.asInstanceOf[LogicalRelation] + } else if (plan.children.isEmpty) { + return null + } else { + plan.children.foreach { + child => + val rel = findFirstRelation(child) + if (rel != null) { + return rel + } + } + return null + } + } + val replaceMap = new mutable.HashMap[String, Attribute]() + val firstFilter = group.head.asInstanceOf[Aggregate].child.asInstanceOf[Filter] + val qualifierPrefix = + firstFilter.output.find(e => e.qualifier.nonEmpty).head.qualifier.mkString(".") + val firstRelation = findFirstRelation(firstFilter.child) + if (firstRelation == null) { + throw new GlutenNotSupportException(s"Not found relation in plan: $firstFilter") + } + firstRelation.output.foreach { + attr => + val qualifiedName = s"$qualifierPrefix.${attr.name}" + replaceMap.put(qualifiedName, attr) + } + replaceMap.toMap + } + + def replaceAttributes(expression: Expression, replaceMap: Map[String, Attribute]): Expression = { + expression.transform { + case attr: Attribute => + replaceMap.get(attr.qualifiedName) match { + case Some(replaceAttr) => replaceAttr + case None => attr + } + } + } + + def buildGroupConditions( + group: ArrayBuffer[LogicalPlan], + replaceMap: Map[String, Attribute]): (ArrayBuffer[Expression], Expression) = { + val conditions = group.map { + plan => + val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] + replaceAttributes(filter.condition, replaceMap) + } + val unionCond = conditions.reduce(Or) + (conditions, unionCond) + } + + def buildStructType(attributes: Seq[Attribute]): StructType = { + val fields = attributes.map { + attr => StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + } + StructType(fields :+ StructField(CoalesceAggregationUnion.UNION_TAG_FIELD, IntegerType, false)) + } + + def buildStructType1( + group: ArrayBuffer[LogicalPlan], + replaceMap: Map[String, Attribute]): StructType = { + var attributesSet: AttributeSet = null + val attributes = group.foreach { + plan => + val agg = plan.asInstanceOf[Aggregate] + val filter = agg.child.asInstanceOf[Filter] + agg.groupingExpressions + .map(e => replaceAttributes(e, replaceMap)) + .foreach( + e => + if (attributesSet == null) { + attributesSet = e.references + } else { + attributesSet ++= e.references + }) + agg.aggregateExpressions + .map(e => replaceAttributes(e, replaceMap)) + .foreach(e => attributesSet ++= e.references) + filter.output + .map(e => replaceAttributes(e, replaceMap)) + .foreach(e => attributesSet ++= e.references) + } + logError(s"xxx all needed attributes: $attributesSet") + val fields = ArrayBuffer[StructField]() + attributesSet.foreach { + attr => fields += StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + } + StructType(fields :+ StructField(CoalesceAggregationUnion.UNION_TAG_FIELD, IntegerType, false)) + } + + def collectGroupRequiredAttributes( + group: ArrayBuffer[LogicalPlan], + replaceMap: Map[String, Attribute]): AttributeSet = { + var attributesSet: AttributeSet = null + group.foreach { + plan => + val agg = plan.asInstanceOf[Aggregate] + val filter = agg.child.asInstanceOf[Filter] + agg.groupingExpressions + .map(e => replaceAttributes(e, replaceMap)) + .foreach( + e => + if (attributesSet == null) { + attributesSet = e.references + } else { + attributesSet ++= e.references + }) + agg.aggregateExpressions + .map(e => replaceAttributes(e, replaceMap)) + .foreach(e => attributesSet ++= e.references) + filter.output + .map(e => replaceAttributes(e, replaceMap)) + .foreach(e => attributesSet ++= e.references) + } + attributesSet + } + + def concatAllNotAggregateExpressionsInAggregate( + group: ArrayBuffer[LogicalPlan], + replaceMap: Map[String, Attribute]): Seq[NamedExpression] = { + + def makeAlias(e: Expression, name: String): NamedExpression = { + Alias(e, name)( + NamedExpression.newExprId, + e match { + case ne: NamedExpression => ne.qualifier + case _ => Seq.empty + }, + None, + Seq.empty) + } + + val projectionExpressions = ArrayBuffer[NamedExpression]() + val columnNamePrefix = "agg_expr_" + val structFieldNamePrefix = "field_" + var groupIndex: Int = 0 + group.foreach { + plan => + var fieldIndex: Int = 0 + val agg = plan.asInstanceOf[Aggregate] + val structFields = ArrayBuffer[Expression]() + agg.groupingExpressions.map(e => replaceAttributes(e, replaceMap)).foreach { + e => + structFields += Literal( + UTF8String.fromString(s"$structFieldNamePrefix$fieldIndex"), + StringType) + structFields += e + fieldIndex += 1 + } + agg.aggregateExpressions.map(e => replaceAttributes(e, replaceMap)).foreach { + e => + removeAlias(e) match { + case aggExpr: AggregateExpression => + val aggFunc = aggExpr.aggregateFunction + aggFunc.children.foreach { + child => + structFields += Literal( + UTF8String.fromString(s"$structFieldNamePrefix$fieldIndex"), + StringType) + structFields += child + fieldIndex += 1 + } + case notAggExpr => + structFields += Literal( + UTF8String.fromString(s"$structFieldNamePrefix$fieldIndex"), + StringType) + structFields += e + fieldIndex += 1 + } + } + projectionExpressions += makeAlias( + CreateNamedStruct(structFields), + s"$columnNamePrefix$groupIndex") + groupIndex += 1 + } + projectionExpressions + } + + def buildBranchesArray( + group: ArrayBuffer[LogicalPlan], + groupConditions: ArrayBuffer[Expression], + replaceMap: Map[String, Attribute]): Expression = { + val groupStructs = concatAllNotAggregateExpressionsInAggregate(group, replaceMap) + val arrrayFields = ArrayBuffer[Expression]() + for (i <- 0 until groupStructs.length) { + val structData = CreateNamedStruct( + Seq[Expression]( + Literal(UTF8String.fromString("f1"), StringType), + groupStructs(i), + Literal(UTF8String.fromString("f2"), StringType), + Literal(i, IntegerType) + )) + val field = If(groupConditions(i), structData, Literal(null, groupStructs(i).dataType)) + arrrayFields += field + } + CreateArray(arrrayFields.toSeq) + } + + def explodeBranches( + group: ArrayBuffer[LogicalPlan], + groupConditions: ArrayBuffer[Expression], + replaceMap: Map[String, Attribute]): NamedExpression = { + Alias( + Explode(buildBranchesArray(group, groupConditions, replaceMap)), + "explode_branches_result")() + } +} + +object CoalesceAggregationUnion { + val UNION_TAG_FIELD: String = "union_tag_" +} From 627bc945f1faed70b2fdcf2abae0946b525b23aa Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 23 Jan 2025 10:37:21 +0800 Subject: [PATCH 2/9] wip --- .../extension/CoalesceAggregationUnion.scala | 205 +++++++++++------- 1 file changed, 121 insertions(+), 84 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index eb8af1020ea3..5893b2c09888 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -36,7 +36,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi logError(s"xxx plan is resolved: ${plan.resolved}") if (plan.resolved) { logError(s"xxx visit plan:\n$plan") - visitPlan(plan) + val newPlan = visitPlan(plan) + logError(s"xxx rewritten plan:\n$newPlan") + newPlan } else { plan } @@ -57,13 +59,34 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi logError(s"xxx replace map: $replaceMap") val (groupConds, unionCond) = buildGroupConditions(group, replaceMap) logError(s"xxx replace condition: $groupConds, $unionCond") - val explodeResult = explodeBranches(group, groupConds, replaceMap) val unionFilter = Filter(unionCond, filter.child) - val project = Project(Seq(explodeResult), unionFilter) - logError(s"xxx new project\n$project") - val nullFilter = Filter(IsNotNull(explodeResult), project) - logError(s"xxx new null filter\n$nullFilter") - group.head + val branchesArray = buildBranchesArray(group, groupConds, replaceMap) + logError(s"xxx branches array: $branchesArray") + val arrayProject = Project(Seq(Alias(branchesArray, "branch_arrays_")()), unionFilter) + logError(s"xxx array project\n$arrayProject") + val explodeExpr = Explode(arrayProject.output.head) + logError(s"xxx explode expression: $explodeExpr") + val generateOutputAttribute = AttributeReference( + "generate_output", + arrayProject.output.head.dataType.asInstanceOf[ArrayType].elementType)() + val generate = Generate( + explodeExpr, + Seq(0), + false, + None, + Seq(generateOutputAttribute), + arrayProject) + logError(s"xxx generate ${generate.resolved}\n${generate.output}\n$generate") + val nullFilter = Filter(IsNotNull(generate.output.head), generate) + logError(s"xxx null filter ${nullFilter.resolved}\n$nullFilter") + val newAgg = buildAggregateWithGroupId( + nullFilter.output.head, + nullFilter, + group.head.asInstanceOf[Aggregate]) + logError(s"xxx new agg\n$newAgg") + + // group.head + newAgg } } if (rewrittenGroups.length == 1) { @@ -71,6 +94,11 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } else { union.withNewChildren(rewrittenGroups) } + case generate: Generate => + logError( + s"xxx generate:\nunrequiredChildIndex: ${generate.unrequiredChildIndex}" + + s"\nuter:${generate.outer} \neneratorOutput: ${generate.generatorOutput}") + generate case project: Project => project.projectList.foreach( e => logError(s"xxx project expression: ${removeAlias(e).getClass}, $e")) @@ -239,87 +267,27 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi (conditions, unionCond) } - def buildStructType(attributes: Seq[Attribute]): StructType = { - val fields = attributes.map { - attr => StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) - } - StructType(fields :+ StructField(CoalesceAggregationUnion.UNION_TAG_FIELD, IntegerType, false)) + def makeAlias(e: Expression, name: String): NamedExpression = { + Alias(e, name)( + NamedExpression.newExprId, + e match { + case ne: NamedExpression => ne.qualifier + case _ => Seq.empty + }, + None, + Seq.empty) } - def buildStructType1( - group: ArrayBuffer[LogicalPlan], - replaceMap: Map[String, Attribute]): StructType = { - var attributesSet: AttributeSet = null - val attributes = group.foreach { - plan => - val agg = plan.asInstanceOf[Aggregate] - val filter = agg.child.asInstanceOf[Filter] - agg.groupingExpressions - .map(e => replaceAttributes(e, replaceMap)) - .foreach( - e => - if (attributesSet == null) { - attributesSet = e.references - } else { - attributesSet ++= e.references - }) - agg.aggregateExpressions - .map(e => replaceAttributes(e, replaceMap)) - .foreach(e => attributesSet ++= e.references) - filter.output - .map(e => replaceAttributes(e, replaceMap)) - .foreach(e => attributesSet ++= e.references) - } - logError(s"xxx all needed attributes: $attributesSet") - val fields = ArrayBuffer[StructField]() - attributesSet.foreach { - attr => fields += StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) - } - StructType(fields :+ StructField(CoalesceAggregationUnion.UNION_TAG_FIELD, IntegerType, false)) - } + def wrapAllColumnsIntoStruct( + aggGroup: ArrayBuffer[LogicalPlan], + replaceMap: Map[String, Attribute]): (Seq[Int], Seq[Int], Seq[Int], Seq[NamedExpression]) = { - def collectGroupRequiredAttributes( - group: ArrayBuffer[LogicalPlan], - replaceMap: Map[String, Attribute]): AttributeSet = { - var attributesSet: AttributeSet = null - group.foreach { - plan => - val agg = plan.asInstanceOf[Aggregate] - val filter = agg.child.asInstanceOf[Filter] - agg.groupingExpressions - .map(e => replaceAttributes(e, replaceMap)) - .foreach( - e => - if (attributesSet == null) { - attributesSet = e.references - } else { - attributesSet ++= e.references - }) - agg.aggregateExpressions - .map(e => replaceAttributes(e, replaceMap)) - .foreach(e => attributesSet ++= e.references) - filter.output - .map(e => replaceAttributes(e, replaceMap)) - .foreach(e => attributesSet ++= e.references) - } - attributesSet } def concatAllNotAggregateExpressionsInAggregate( group: ArrayBuffer[LogicalPlan], replaceMap: Map[String, Attribute]): Seq[NamedExpression] = { - def makeAlias(e: Expression, name: String): NamedExpression = { - Alias(e, name)( - NamedExpression.newExprId, - e match { - case ne: NamedExpression => ne.qualifier - case _ => Seq.empty - }, - None, - Seq.empty) - } - val projectionExpressions = ArrayBuffer[NamedExpression]() val columnNamePrefix = "agg_expr_" val structFieldNamePrefix = "field_" @@ -363,6 +331,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi s"$columnNamePrefix$groupIndex") groupIndex += 1 } + projectionExpressions.foreach(e => logError(s"xxx 11cprojection expression: ${e.resolved}, $e")) projectionExpressions } @@ -380,19 +349,87 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Literal(UTF8String.fromString("f2"), StringType), Literal(i, IntegerType) )) - val field = If(groupConditions(i), structData, Literal(null, groupStructs(i).dataType)) + val field = If(groupConditions(i), structData, Literal(null, structData.dataType)) + logError(s"xxx if field: ${field.resolved} $field") + logError(s"xxx if field type: ${field.dataType}") arrrayFields += field } - CreateArray(arrrayFields.toSeq) + val res = CreateArray(arrrayFields.toSeq) + logError(s"xxx array fields: ${res.resolved} $res") + res } def explodeBranches( group: ArrayBuffer[LogicalPlan], groupConditions: ArrayBuffer[Expression], replaceMap: Map[String, Attribute]): NamedExpression = { - Alias( - Explode(buildBranchesArray(group, groupConditions, replaceMap)), - "explode_branches_result")() + val explode = Explode(buildBranchesArray(group, groupConditions, replaceMap)) + logError(s"xxx explode. resolved: ${explode.resolved}, $explode\n${explode.dataType}") + val alias = Alias(explode, "explode_branches_result")() + logError( + s"xxx explode alias. resolved: ${alias.resolved},${alias.childrenResolved}," + + s"${alias.checkInputDataTypes().isSuccess} \n$alias") + alias + } + + def buildAggregateWithGroupId( + structedData: Expression, + child: LogicalPlan, + templateAgg: Aggregate): LogicalPlan = { + logError(s"xxx struct data: $structedData") + logError(s"xxx struct type: ${structedData.dataType}") + val structType = structedData.dataType.asInstanceOf[StructType] + val flattenAtrributes = ArrayBuffer[NamedExpression]() + var fieldIndex: Int = 0 + structType.fields(0).dataType.asInstanceOf[StructType].fields.foreach { + field => + flattenAtrributes += Alias( + GetStructField(GetStructField(structedData, 0), fieldIndex), + field.name)() + fieldIndex += 1 + } + flattenAtrributes += Alias( + GetStructField(structedData, 1), + CoalesceAggregationUnion.UNION_TAG_FIELD)() + logError(s"xxx flatten attributes: ${flattenAtrributes.length}\n$flattenAtrributes") + + val flattenProject = Project(flattenAtrributes, child) + val groupingExpessions = ArrayBuffer[Expression]() + for (i <- 0 until templateAgg.groupingExpressions.length) { + groupingExpessions += flattenProject.output(i) + } + groupingExpessions += flattenProject.output.last + + var aggregateExpressionIndex = templateAgg.groupingExpressions.length + val aggregateExpressions = ArrayBuffer[NamedExpression]() + for (i <- 0 until templateAgg.aggregateExpressions.length) { + logError( + s"xxx field index: $aggregateExpressionIndex, i: $i, " + + s"len:${templateAgg.aggregateExpressions.length}") + removeAlias(templateAgg.aggregateExpressions(i)) match { + case aggExpr: AggregateExpression => + val aggregateFunctionArgs = aggExpr.aggregateFunction.children.zipWithIndex.map { + case (e, index) => flattenProject.output(aggregateExpressionIndex + index) + } + aggregateExpressionIndex += aggExpr.aggregateFunction.children.length + val aggregateFunction = aggExpr.aggregateFunction + .withNewChildren(aggregateFunctionArgs) + .asInstanceOf[AggregateFunction] + val newAggregateExpression = AggregateExpression( + aggregateFunction, + aggExpr.mode, + aggExpr.isDistinct, + aggExpr.filter, + aggExpr.resultId) + aggregateExpressions += Alias( + newAggregateExpression, + templateAgg.aggregateExpressions(i).asInstanceOf[Alias].name)() + case notAggExpr => + aggregateExpressions += flattenProject.output(aggregateExpressionIndex) + aggregateExpressionIndex += 1 + } + } + Aggregate(groupingExpessions, aggregateExpressions, flattenProject) } } From ffe29359e428f3565742b3fdbed8fda9ce72bac8 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 23 Jan 2025 17:56:04 +0800 Subject: [PATCH 3/9] wip --- .../extension/CoalesceAggregationUnion.scala | 440 ++++++++++++++++-- 1 file changed, 389 insertions(+), 51 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index 5893b2c09888..906372454e0c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -32,6 +32,66 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + def removeAlias(e: Expression): Expression = { + e match { + case alias: Alias => alias.child + case _ => e + } + } + + case class AnalyzedAggregteInfo(aggregate: Aggregate) { + lazy val resultGroupingExpressions = aggregate.aggregateExpressions.filter { + e => + removeAlias(e) match { + case aggExpr: AggregateExpression => false + case _ => true + } + } + + lazy val constResultGroupingExpressions = resultGroupingExpressions.filter { + e => + removeAlias(e) match { + case literal: Literal => true + case _ => false + } + } + + lazy val nonConstResultGroupingExpressions = resultGroupingExpressions.filter { + e => + removeAlias(e) match { + case literal: Literal => false + case _ => true + } + } + + lazy val hasAggregateWithFilter = aggregate.aggregateExpressions.exists { + e => + removeAlias(e) match { + case aggExpr: AggregateExpression => aggExpr.filter.isDefined + case _ => false + } + } + + lazy val resultPositionInGroupingKeys = { + var i = 0 + resultGroupingExpressions.map { + e => + e match { + case literal @ Alias(_: Literal, _) => + var idx = aggregate.groupingExpressions.indexOf(e) + if (idx == -1) { + idx = aggregate.groupingExpressions.length + i + i += 1 + } + idx + case _ => aggregate.groupingExpressions.indexOf(removeAlias(e)) + } + } + } + } + + case class GroupPlanResult(plan: LogicalPlan, analyzedAggregateInfo: Option[AnalyzedAggregteInfo]) + override def apply(plan: LogicalPlan): LogicalPlan = { logError(s"xxx plan is resolved: ${plan.resolved}") if (plan.resolved) { @@ -47,46 +107,54 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def visitPlan(plan: LogicalPlan): LogicalPlan = { val newPlan = plan match { case union: Union => - logError(s"xxx is union node, children: ${union.children}") - val groups = groupUnionAggregations(union) + val groups = groupSameStructureAggregate(union) + groups.zipWithIndex.foreach { + case (g, i) => + g.foreach( + e => + logError( + s"xxx group $i plan: ${e.plan}\n" + + s"positions:" + + s"${e.analyzedAggregateInfo.get.resultPositionInGroupingKeys}")) + } val rewrittenGroups = groups.map { group => if (group.length == 1) { - group.head + group.head.plan } else { - val replaceMap = collectReplaceAttributes(group) - val filter = group.head.asInstanceOf[Aggregate].child.asInstanceOf[Filter] - logError(s"xxx replace map: $replaceMap") - val (groupConds, unionCond) = buildGroupConditions(group, replaceMap) - logError(s"xxx replace condition: $groupConds, $unionCond") - val unionFilter = Filter(unionCond, filter.child) - val branchesArray = buildBranchesArray(group, groupConds, replaceMap) - logError(s"xxx branches array: $branchesArray") - val arrayProject = Project(Seq(Alias(branchesArray, "branch_arrays_")()), unionFilter) - logError(s"xxx array project\n$arrayProject") - val explodeExpr = Explode(arrayProject.output.head) - logError(s"xxx explode expression: $explodeExpr") - val generateOutputAttribute = AttributeReference( - "generate_output", - arrayProject.output.head.dataType.asInstanceOf[ArrayType].elementType)() - val generate = Generate( - explodeExpr, - Seq(0), - false, - None, - Seq(generateOutputAttribute), - arrayProject) - logError(s"xxx generate ${generate.resolved}\n${generate.output}\n$generate") - val nullFilter = Filter(IsNotNull(generate.output.head), generate) - logError(s"xxx null filter ${nullFilter.resolved}\n$nullFilter") - val newAgg = buildAggregateWithGroupId( - nullFilter.output.head, - nullFilter, - group.head.asInstanceOf[Aggregate]) - logError(s"xxx new agg\n$newAgg") - - // group.head - newAgg + val aggregates = group.map(_.plan) + val replaceAttributes = collectReplaceAttributes(aggregates) + val filterConditions = buildAggregateCasesConditions(aggregates, replaceAttributes) + val firstAggregateFilter = + group.head.plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] + + // Concat all filter conditions with `or` and apply it on the source node + val unionFilter = Filter( + buildUnionConditionForAggregateSource(filterConditions), + firstAggregateFilter.child) + logError(s"xxx union filter:\n$unionFilter") + + val wrappedAttributes = wrapAggregatesAttributesInStructs(group, replaceAttributes) + logError(s"xxx wrapped attributes:\n$wrappedAttributes") + val wrappedAttributesProject = Project(wrappedAttributes, unionFilter) + logError(s"xxx wrapped attributes project:\n$wrappedAttributesProject") + + val arrayProject = buildArrayProject(wrappedAttributesProject) + logError(s"xxx array project:\n$arrayProject") + + val explode = buildArrayExplode(arrayProject) + logError(s"xxx explode:\n$explode") + + val notNullFilter = Filter(IsNotNull(explode.output.head), explode) + logError(s"xxx not null filter:\n$notNullFilter") + + val destructStructProject = buildDestructStructProject(notNullFilter) + logError(s"xxx destruct struct project:\n$destructStructProject") + + val resultAgg = buildAggregateWithGroupId(destructStructProject, group) + logError(s"xxx result agg:\n$resultAgg") + + group.head.plan } } if (rewrittenGroups.length == 1) { @@ -109,21 +177,134 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi newPlan } - def removeAlias(e: Expression): Expression = { - e match { - case alias: Alias => alias.child - case _ => e + def isSupportedAggregate(info: AnalyzedAggregteInfo): Boolean = { + if (info.hasAggregateWithFilter) { + false + } else { + true + } + } + + def areSameStructureAggregate(l: AnalyzedAggregteInfo, r: AnalyzedAggregteInfo): Boolean = { + val lAggregate = l.aggregate + val rAggregate = r.aggregate + + // Check aggregate result expressions. need same schema. + if (lAggregate.aggregateExpressions.length != rAggregate.aggregateExpressions.length) { + return false + } + + lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).foreach { + case (lExpr, rExpr) => + if (!lExpr.dataType.equals(rExpr.dataType)) { + return false + } + (removeAlias(lExpr), removeAlias(rExpr)) match { + case (lAggExpr: AggregateExpression, rAggExpr: AggregateExpression) => + if (lAggExpr.aggregateFunction.getClass != rAggExpr.aggregateFunction.getClass) { + return false + } + if ( + lAggExpr.aggregateFunction.children.length != + rAggExpr.aggregateFunction.children.length + ) { + return false + } + case _ => + } + } + + // Check grouping expressions, need same schema. + if (lAggregate.groupingExpressions.length != rAggregate.groupingExpressions.length) { + return false + } + lAggregate.groupingExpressions.zip(rAggregate.groupingExpressions).foreach { + case (lExpr, rExpr) => + if (!lExpr.dataType.equals(rExpr.dataType)) { + return false + } + } + + // All result expressions which come from grouping keys must refer to the same position in + // grouping keys + if ( + l.resultPositionInGroupingKeys.length != + r.resultPositionInGroupingKeys.length + ) { + return false + } + l.resultPositionInGroupingKeys + .zip(r.resultPositionInGroupingKeys) + .foreach { + case (lPos, rPos) => + if (lPos != rPos) { + return false + } + } + + // Must come from same source. + if ( + areSameAggregationSource( + lAggregate.child.asInstanceOf[Filter].child, + rAggregate.child.asInstanceOf[Filter].child) + ) { + return true + } + + true + } + + // If returns -1, not found same structure aggregate. + def findSameStructureAggregate( + groups: ArrayBuffer[ArrayBuffer[GroupPlanResult]], + analyzedAggregateInfo: AnalyzedAggregteInfo): Int = { + groups.zipWithIndex.foreach { + case (group, i) => + if ( + group.head.analyzedAggregateInfo.isDefined && + areSameStructureAggregate(group.head.analyzedAggregateInfo.get, analyzedAggregateInfo) + ) { + return i + } + } + -1 + } + + def groupSameStructureAggregate(union: Union): ArrayBuffer[ArrayBuffer[GroupPlanResult]] = { + val groupResults = ArrayBuffer[ArrayBuffer[GroupPlanResult]]() + union.children.foreach { + case agg @ Aggregate(_, _, filter: Filter) => + val analyzedInfo = AnalyzedAggregteInfo(agg) + if (isSupportedAggregate(analyzedInfo)) { + if (groupResults.isEmpty) { + groupResults += ArrayBuffer(GroupPlanResult(agg, Some(analyzedInfo))) + } else { + val idx = findSameStructureAggregate(groupResults, analyzedInfo) + if (idx != -1) { + groupResults(idx) += GroupPlanResult(agg, Some(analyzedInfo)) + } else { + groupResults += ArrayBuffer(GroupPlanResult(agg, Some(analyzedInfo))) + } + } + } else { + groupResults += ArrayBuffer(GroupPlanResult(agg, None)) + } + case other => + groupResults += ArrayBuffer(GroupPlanResult(other, None)) } + groupResults } - def groupUnionAggregations(union: Union): ArrayBuffer[ArrayBuffer[LogicalPlan]] = { + def groupSameStructureAggregations(union: Union): ArrayBuffer[ArrayBuffer[LogicalPlan]] = { val groupResults = ArrayBuffer[ArrayBuffer[LogicalPlan]]() union.children.foreach { case agg @ Aggregate(_, _, filter: Filter) => - agg.groupingExpressions.foreach( - e => logError(s"xxx grouping expression: ${e.getClass}, $e")) - agg.aggregateExpressions.foreach( - e => logError(s"xxx aggregate expression: ${e.getClass}, $e")) + val analyzedInfo = AnalyzedAggregteInfo(agg) + if (isSupportedAggregate(analyzedInfo)) { + if (groupResults.isEmpty) { + groupResults += ArrayBuffer(agg) + } else {} + } if ( groupResults.isEmpty && agg.aggregateExpressions.exists( e => @@ -255,6 +436,100 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } + def buildAggregateCasesConditions( + group: ArrayBuffer[LogicalPlan], + replaceMap: Map[String, Attribute]): ArrayBuffer[Expression] = { + group.map { + plan => + val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] + replaceAttributes(filter.condition, replaceMap) + } + } + + def buildUnionConditionForAggregateSource(conditions: ArrayBuffer[Expression]): Expression = { + conditions.reduce(Or); + } + + def wrapAggregatesAttributesInStructs( + aggregateGroup: ArrayBuffer[GroupPlanResult], + replaceMap: Map[String, Attribute]): Seq[NamedExpression] = { + val structAttributes = ArrayBuffer[NamedExpression]() + val casePrefix = "case_" + val structPrefix = "field_" + aggregateGroup.zipWithIndex.foreach { + case (aggregateCase, case_index) => + val aggregate = aggregateCase.plan.asInstanceOf[Aggregate] + val analyzedInfo = aggregateCase.analyzedAggregateInfo.get + val structFields = ArrayBuffer[Expression]() + var fieldIndex: Int = 0 + aggregate.groupingExpressions.foreach { + e => + structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) + structFields += e + fieldIndex += 1 + } + logError(s"xxx 1) struct fields: $structFields") + for (i <- 0 until analyzedInfo.resultPositionInGroupingKeys.length) { + val position = analyzedInfo.resultPositionInGroupingKeys(i) + logError(s"xxx position: $position, structFields: ${structFields.length}, f:$fieldIndex") + if (position >= fieldIndex) { + val expr = analyzedInfo.resultGroupingExpressions(i) + structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) + structFields += analyzedInfo.resultGroupingExpressions(i) + fieldIndex += 1 + } + } + logError(s"xxx 2) struct fields: $structFields") + + aggregate.aggregateExpressions + .filter( + e => + removeAlias(e) match { + case aggExpr: AggregateExpression => true + case _ => false + }) + .foreach { + e => + val aggFunction = removeAlias(e).asInstanceOf[AggregateExpression].aggregateFunction + aggFunction.children.foreach { + child => + structFields += Literal( + UTF8String.fromString(s"$structPrefix$fieldIndex"), + StringType) + structFields += child + fieldIndex += 1 + } + } + logError(s"xxx 3) struct fields: $structFields") + structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) + structFields += Literal(case_index, IntegerType) + structAttributes += makeAlias(CreateNamedStruct(structFields), s"$casePrefix$case_index") + } + structAttributes + } + + def buildArrayProject(child: LogicalPlan): Project = { + val outputs = child.output.map(_.asInstanceOf[Expression]) + val array = makeAlias(CreateArray(outputs), "array") + Project(Seq(array), child) + } + + def buildArrayExplode(child: LogicalPlan): LogicalPlan = { + assert(child.output.length == 1, s"Expected single output from $child") + val array = child.output.head.asInstanceOf[Expression] + assert(array.dataType.isInstanceOf[ArrayType], s"Expected ArrayType from $array") + val explodeExpr = Explode(array) + val exploadOutput = + AttributeReference("generate_output", array.dataType.asInstanceOf[ArrayType].elementType)() + Generate( + explodeExpr, + unrequiredChildIndex = Seq(0), + outer = false, + qualifier = None, + generatorOutput = Seq(exploadOutput), + child) + } + def buildGroupConditions( group: ArrayBuffer[LogicalPlan], replaceMap: Map[String, Attribute]): (ArrayBuffer[Expression], Expression) = { @@ -278,12 +553,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Seq.empty) } - def wrapAllColumnsIntoStruct( - aggGroup: ArrayBuffer[LogicalPlan], - replaceMap: Map[String, Attribute]): (Seq[Int], Seq[Int], Seq[Int], Seq[NamedExpression]) = { - - } - def concatAllNotAggregateExpressionsInAggregate( group: ArrayBuffer[LogicalPlan], replaceMap: Map[String, Attribute]): Seq[NamedExpression] = { @@ -372,7 +641,76 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi alias } + def buildDestructStructProject(child: LogicalPlan): LogicalPlan = { + assert(child.output.length == 1, s"Expected single output from $child") + val structedData = child.output.head + assert( + structedData.dataType.isInstanceOf[StructType], + s"Expected StructType from $structedData") + val structType = structedData.dataType.asInstanceOf[StructType] + val attributes = ArrayBuffer[NamedExpression]() + var index = 0 + structType.fields.foreach { + field => + attributes += Alias(GetStructField(structedData, index), field.name)() + index += 1 + } + Project(attributes, child) + } + def buildAggregateWithGroupId( + child: LogicalPlan, + aggregateGroup: ArrayBuffer[GroupPlanResult]): LogicalPlan = { + val attributes = child.output + val aggregateTemplate = aggregateGroup.head.plan.asInstanceOf[Aggregate] + val analyzedAggregateInfo = aggregateGroup.head.analyzedAggregateInfo.get + + val totalGroupingExpressionsCount = + math.max( + aggregateTemplate.groupingExpressions.length, + analyzedAggregateInfo.resultPositionInGroupingKeys.max + 1) + + val groupingExpressions = attributes + .slice(0, totalGroupingExpressionsCount) + .map(_.asInstanceOf[Expression]) :+ attributes.last + + val normalExpressionPosition = analyzedAggregateInfo.resultPositionInGroupingKeys + var normalExpressionCount = 0 + var aggregateExpressionIndex = totalGroupingExpressionsCount + val aggregateExpressions = ArrayBuffer[NamedExpression]() + aggregateTemplate.aggregateExpressions.foreach { + e => + removeAlias(e) match { + case aggExpr: AggregateExpression => + val aggFunc = aggExpr.aggregateFunction + val newAggFuncArgs = aggFunc.children.zipWithIndex.map { + case (arg, i) => + logError(s"xxx agg expr: $arg, $i, $aggregateExpressionIndex") + attributes(aggregateExpressionIndex + i) + } + aggregateExpressionIndex += aggFunc.children.length + val newAggFunc = + aggFunc.withNewChildren(newAggFuncArgs).asInstanceOf[AggregateFunction] + val newAggExpr = AggregateExpression( + newAggFunc, + aggExpr.mode, + aggExpr.isDistinct, + aggExpr.filter, + aggExpr.resultId) + aggregateExpressions += makeAlias(newAggExpr, e.name) + + case other => + val position = normalExpressionPosition(normalExpressionCount) + val attr = attributes(position) + normalExpressionCount += 1 + aggregateExpressions += makeAlias(attr, e.name) + .asInstanceOf[NamedExpression] + } + } + Aggregate(groupingExpressions, aggregateExpressions, child) + } + + def buildAggregateWithGroupId1( structedData: Expression, child: LogicalPlan, templateAgg: Aggregate): LogicalPlan = { From 4c89412a87e538f34b58bf0de33ba3f974394e70 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 24 Jan 2025 12:03:20 +0800 Subject: [PATCH 4/9] wip --- .../extension/CoalesceAggregationUnion.scala | 422 +++++------------- 1 file changed, 113 insertions(+), 309 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index 906372454e0c..97266767d044 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -31,6 +31,28 @@ import org.apache.spark.unsafe.types.UTF8String import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +/* + * Example: + * Rewrite query + * SELECT a, b, sum(c) FROM t WHERE d = 1 GROUP BY a,b + * UNION ALL + * SELECT a, b, sum(c) FROM t WHERE d = 2 GROUP BY a,b + * into + * SELECT a, b, sum(c) FROM ( + * SELECT s.a as a, s.b as b, s.c as c, s.id as group_id FROM ( + * SELECT explode(s) as s FROM ( + * SELECT array( + * if(d = 1, named_struct('a', a, 'b', b, 'c', c, 'id', 0), null), + * if(d = 2, named_struct('a', a, 'b', b, 'c', c, 'id', 1), null)) as s + * FROM t WHERE d = 1 OR d = 2 + * ) + * ) WHERE s is not null + * ) GROUP BY a,b, group_id + * + * The first query need to scan `t` multiple times, when the output of scan is large, the query is + * really slow. The rewritten query only scan `t` once, and the performance is much better. + */ + class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] with Logging { def removeAlias(e: Expression): Expression = { e match { @@ -48,22 +70,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - lazy val constResultGroupingExpressions = resultGroupingExpressions.filter { - e => - removeAlias(e) match { - case literal: Literal => true - case _ => false - } - } - - lazy val nonConstResultGroupingExpressions = resultGroupingExpressions.filter { - e => - removeAlias(e) match { - case literal: Literal => false - case _ => true - } - } - lazy val hasAggregateWithFilter = aggregate.aggregateExpressions.exists { e => removeAlias(e) match { @@ -90,10 +96,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - case class GroupPlanResult(plan: LogicalPlan, analyzedAggregateInfo: Option[AnalyzedAggregteInfo]) + case class AnalyzedPlan(plan: LogicalPlan, analyzedAggregateInfo: Option[AnalyzedAggregteInfo]) override def apply(plan: LogicalPlan): LogicalPlan = { - logError(s"xxx plan is resolved: ${plan.resolved}") if (plan.resolved) { logError(s"xxx visit plan:\n$plan") val newPlan = visitPlan(plan) @@ -107,26 +112,17 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def visitPlan(plan: LogicalPlan): LogicalPlan = { val newPlan = plan match { case union: Union => - val groups = groupSameStructureAggregate(union) - groups.zipWithIndex.foreach { - case (g, i) => - g.foreach( - e => - logError( - s"xxx group $i plan: ${e.plan}\n" + - s"positions:" + - s"${e.analyzedAggregateInfo.get.resultPositionInGroupingKeys}")) - } - val rewrittenGroups = groups.map { - group => - if (group.length == 1) { - group.head.plan + val planGroups = groupSameStructureAggregate(union) + val newUnionClauses = planGroups.map { + groupedPlans => + if (groupedPlans.length == 1) { + groupedPlans.head.plan } else { - val aggregates = group.map(_.plan) + val aggregates = groupedPlans.map(_.plan) val replaceAttributes = collectReplaceAttributes(aggregates) val filterConditions = buildAggregateCasesConditions(aggregates, replaceAttributes) val firstAggregateFilter = - group.head.plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] + groupedPlans.head.plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] // Concat all filter conditions with `or` and apply it on the source node val unionFilter = Filter( @@ -134,12 +130,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi firstAggregateFilter.child) logError(s"xxx union filter:\n$unionFilter") - val wrappedAttributes = wrapAggregatesAttributesInStructs(group, replaceAttributes) - logError(s"xxx wrapped attributes:\n$wrappedAttributes") - val wrappedAttributesProject = Project(wrappedAttributes, unionFilter) + val wrappedAttributesProject = + buildStructWrapperProject( + unionFilter, + groupedPlans, + filterConditions, + replaceAttributes) logError(s"xxx wrapped attributes project:\n$wrappedAttributesProject") - val arrayProject = buildArrayProject(wrappedAttributesProject) + val arrayProject = buildArrayProject(wrappedAttributesProject, filterConditions) logError(s"xxx array project:\n$arrayProject") val explode = buildArrayExplode(arrayProject) @@ -151,26 +150,17 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi val destructStructProject = buildDestructStructProject(notNullFilter) logError(s"xxx destruct struct project:\n$destructStructProject") - val resultAgg = buildAggregateWithGroupId(destructStructProject, group) - logError(s"xxx result agg:\n$resultAgg") + val singleAggregate = buildAggregateWithGroupId(destructStructProject, groupedPlans) + logError(s"xxx single agg:\n$singleAggregate") - group.head.plan + singleAggregate } } - if (rewrittenGroups.length == 1) { - rewrittenGroups.head + if (newUnionClauses.length == 1) { + newUnionClauses.head } else { - union.withNewChildren(rewrittenGroups) + union.withNewChildren(newUnionClauses) } - case generate: Generate => - logError( - s"xxx generate:\nunrequiredChildIndex: ${generate.unrequiredChildIndex}" + - s"\nuter:${generate.outer} \neneratorOutput: ${generate.generatorOutput}") - generate - case project: Project => - project.projectList.foreach( - e => logError(s"xxx project expression: ${removeAlias(e).getClass}, $e")) - project.withNewChildren(project.children.map(visitPlan)) case _ => plan.withNewChildren(plan.children.map(visitPlan)) } newPlan.copyTagsFrom(plan) @@ -218,15 +208,33 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi if (lAggregate.groupingExpressions.length != rAggregate.groupingExpressions.length) { return false } + + def hasAggregateExpression(e: Expression): Boolean = { + if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) { + return false + } + e match { + case _: AggregateExpression => true + case _ => e.children.exists(hasAggregateExpression(_)) + } + } + + // 1. data types must be same. + // 2. The aggregate expressions must have same structure and at the same position. lAggregate.groupingExpressions.zip(rAggregate.groupingExpressions).foreach { case (lExpr, rExpr) => if (!lExpr.dataType.equals(rExpr.dataType)) { return false } + val lHasAgg = hasAggregateExpression(lExpr) + val rHasAgg = hasAggregateExpression(rExpr) + if (lHasAgg != rHasAgg) { + return false + } else if (lHasAgg && !areEqualExpressions(lExpr, rExpr)) { + return false + } } - // All result expressions which come from grouping keys must refer to the same position in - // grouping keys if ( l.resultPositionInGroupingKeys.length != r.resultPositionInGroupingKeys.length @@ -244,7 +252,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi // Must come from same source. if ( - areSameAggregationSource( + areSameAggregateSource( lAggregate.child.asInstanceOf[Filter].child, rAggregate.child.asInstanceOf[Filter].child) ) { @@ -256,13 +264,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi // If returns -1, not found same structure aggregate. def findSameStructureAggregate( - groups: ArrayBuffer[ArrayBuffer[GroupPlanResult]], + planGroups: ArrayBuffer[ArrayBuffer[AnalyzedPlan]], analyzedAggregateInfo: AnalyzedAggregteInfo): Int = { - groups.zipWithIndex.foreach { - case (group, i) => + planGroups.zipWithIndex.foreach { + case (groupedPlans, i) => if ( - group.head.analyzedAggregateInfo.isDefined && - areSameStructureAggregate(group.head.analyzedAggregateInfo.get, analyzedAggregateInfo) + groupedPlans.head.analyzedAggregateInfo.isDefined && + areSameStructureAggregate( + groupedPlans.head.analyzedAggregateInfo.get, + analyzedAggregateInfo) ) { return i } @@ -270,71 +280,27 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi -1 } - def groupSameStructureAggregate(union: Union): ArrayBuffer[ArrayBuffer[GroupPlanResult]] = { - val groupResults = ArrayBuffer[ArrayBuffer[GroupPlanResult]]() + def groupSameStructureAggregate(union: Union): ArrayBuffer[ArrayBuffer[AnalyzedPlan]] = { + val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]() union.children.foreach { case agg @ Aggregate(_, _, filter: Filter) => val analyzedInfo = AnalyzedAggregteInfo(agg) if (isSupportedAggregate(analyzedInfo)) { if (groupResults.isEmpty) { - groupResults += ArrayBuffer(GroupPlanResult(agg, Some(analyzedInfo))) + groupResults += ArrayBuffer(AnalyzedPlan(agg, Some(analyzedInfo))) } else { val idx = findSameStructureAggregate(groupResults, analyzedInfo) if (idx != -1) { - groupResults(idx) += GroupPlanResult(agg, Some(analyzedInfo)) + groupResults(idx) += AnalyzedPlan(agg, Some(analyzedInfo)) } else { - groupResults += ArrayBuffer(GroupPlanResult(agg, Some(analyzedInfo))) + groupResults += ArrayBuffer(AnalyzedPlan(agg, Some(analyzedInfo))) } } } else { - groupResults += ArrayBuffer(GroupPlanResult(agg, None)) - } - case other => - groupResults += ArrayBuffer(GroupPlanResult(other, None)) - } - groupResults - } - - def groupSameStructureAggregations(union: Union): ArrayBuffer[ArrayBuffer[LogicalPlan]] = { - val groupResults = ArrayBuffer[ArrayBuffer[LogicalPlan]]() - union.children.foreach { - case agg @ Aggregate(_, _, filter: Filter) => - val analyzedInfo = AnalyzedAggregteInfo(agg) - if (isSupportedAggregate(analyzedInfo)) { - if (groupResults.isEmpty) { - groupResults += ArrayBuffer(agg) - } else {} - } - if ( - groupResults.isEmpty && agg.aggregateExpressions.exists( - e => - removeAlias(e) match { - case aggExpr: AggregateExpression => !aggExpr.filter.isDefined - case _ => true - }) - ) { - groupResults += ArrayBuffer(agg) - } else { - groupResults.find { - group => - group.head match { - case toMatchAgg @ Aggregate(_, _, toMatchFilter: Filter) => - areSameSchemaAggregations(toMatchAgg, agg) && - areSameAggregationSource(toMatchFilter.child, filter.child) - case _ => false - } - } match { - case Some(group) => - group += agg.asInstanceOf[LogicalPlan] - case None => - val newGroup = ArrayBuffer(agg.asInstanceOf[LogicalPlan]) - groupResults += newGroup - } + groupResults += ArrayBuffer(AnalyzedPlan(agg, None)) } case other => - // Other plans will be remained as union clauses. - val singlePlan = ArrayBuffer(other) - groupResults += singlePlan + groupResults += ArrayBuffer(AnalyzedPlan(other, None)) } groupResults } @@ -356,25 +322,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - def areSameSchemaAggregations(agg1: Aggregate, agg2: Aggregate): Boolean = { - if ( - agg1.aggregateExpressions.length != agg2.aggregateExpressions.length || - agg1.groupingExpressions.length != agg2.groupingExpressions.length - ) { - false - } else { - agg1.aggregateExpressions.zip(agg2.aggregateExpressions).forall { - case (_ @Alias(_: Literal, lName), _ @Alias(_: Literal, rName)) => - lName.equals(rName) - case (l, r) => areEqualExpressions(l, r) - } && - agg1.groupingExpressions.zip(agg2.groupingExpressions).forall { - case (l, r) => areEqualExpressions(l, r) - } - } - } - - def areSameAggregationSource(lPlan: LogicalPlan, rPlan: LogicalPlan): Boolean = { + def areSameAggregateSource(lPlan: LogicalPlan, rPlan: LogicalPlan): Boolean = { if (lPlan.children.length != rPlan.children.length || lPlan.getClass != rPlan.getClass) { false } else { @@ -382,18 +330,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case (lRel: LogicalRelation, rRel: LogicalRelation) => val lTable = lRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") val rTable = rRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") - lRel.output.foreach(attr => logError(s"xxx l attr: $attr, ${attr.qualifiedName}")) - rRel.output.foreach(attr => logError(s"xxx r attr: $attr, ${attr.qualifiedName}")) - logError(s"xxx table: $lTable, $rTable") lTable.equals(rTable) && lTable.nonEmpty case (lSubQuery: SubqueryAlias, rSubQuery: SubqueryAlias) => - areSameAggregationSource(lSubQuery.child, rSubQuery.child) + areSameAggregateSource(lSubQuery.child, rSubQuery.child) case (lChild, rChild) => false } } } - def collectReplaceAttributes(group: ArrayBuffer[LogicalPlan]): Map[String, Attribute] = { + def collectReplaceAttributes(groupedPlans: ArrayBuffer[LogicalPlan]): Map[String, Attribute] = { def findFirstRelation(plan: LogicalPlan): LogicalRelation = { if (plan.isInstanceOf[LogicalRelation]) { return plan.asInstanceOf[LogicalRelation] @@ -411,7 +356,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } val replaceMap = new mutable.HashMap[String, Attribute]() - val firstFilter = group.head.asInstanceOf[Aggregate].child.asInstanceOf[Filter] + val firstFilter = groupedPlans.head.asInstanceOf[Aggregate].child.asInstanceOf[Filter] val qualifierPrefix = firstFilter.output.find(e => e.qualifier.nonEmpty).head.qualifier.mkString(".") val firstRelation = findFirstRelation(firstFilter.child) @@ -437,9 +382,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def buildAggregateCasesConditions( - group: ArrayBuffer[LogicalPlan], + groupedPlans: ArrayBuffer[LogicalPlan], replaceMap: Map[String, Attribute]): ArrayBuffer[Expression] = { - group.map { + groupedPlans.map { plan => val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] replaceAttributes(filter.condition, replaceMap) @@ -451,12 +396,12 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def wrapAggregatesAttributesInStructs( - aggregateGroup: ArrayBuffer[GroupPlanResult], + groupedPlans: ArrayBuffer[AnalyzedPlan], replaceMap: Map[String, Attribute]): Seq[NamedExpression] = { val structAttributes = ArrayBuffer[NamedExpression]() val casePrefix = "case_" val structPrefix = "field_" - aggregateGroup.zipWithIndex.foreach { + groupedPlans.zipWithIndex.foreach { case (aggregateCase, case_index) => val aggregate = aggregateCase.plan.asInstanceOf[Aggregate] val analyzedInfo = aggregateCase.analyzedAggregateInfo.get @@ -465,21 +410,18 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi aggregate.groupingExpressions.foreach { e => structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += e + structFields += replaceAttributes(e, replaceMap) fieldIndex += 1 } - logError(s"xxx 1) struct fields: $structFields") for (i <- 0 until analyzedInfo.resultPositionInGroupingKeys.length) { val position = analyzedInfo.resultPositionInGroupingKeys(i) - logError(s"xxx position: $position, structFields: ${structFields.length}, f:$fieldIndex") if (position >= fieldIndex) { val expr = analyzedInfo.resultGroupingExpressions(i) structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += analyzedInfo.resultGroupingExpressions(i) + structFields += replaceAttributes(analyzedInfo.resultGroupingExpressions(i), replaceMap) fieldIndex += 1 } } - logError(s"xxx 2) struct fields: $structFields") aggregate.aggregateExpressions .filter( @@ -496,11 +438,10 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structFields += Literal( UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += child + structFields += replaceAttributes(child, replaceMap) fieldIndex += 1 } } - logError(s"xxx 3) struct fields: $structFields") structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) structFields += Literal(case_index, IntegerType) structAttributes += makeAlias(CreateNamedStruct(structFields), s"$casePrefix$case_index") @@ -508,9 +449,25 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structAttributes } - def buildArrayProject(child: LogicalPlan): Project = { - val outputs = child.output.map(_.asInstanceOf[Expression]) - val array = makeAlias(CreateArray(outputs), "array") + def buildStructWrapperProject( + child: LogicalPlan, + groupedPlans: ArrayBuffer[AnalyzedPlan], + conditions: ArrayBuffer[Expression], + replaceMap: Map[String, Attribute]): LogicalPlan = { + val wrappedAttributes = wrapAggregatesAttributesInStructs(groupedPlans, replaceMap) + val ifAttributes = wrappedAttributes.zip(conditions).map { + case (attr, condition) => + makeAlias(If(condition, attr, Literal(null, attr.dataType)), attr.name) + .asInstanceOf[NamedExpression] + } + Project(ifAttributes, child) + } + + def buildArrayProject(child: LogicalPlan, conditions: ArrayBuffer[Expression]): LogicalPlan = { + assert( + child.output.length == conditions.length, + s"Expected same length of output and conditions") + val array = makeAlias(CreateArray(child.output), "array") Project(Seq(array), child) } @@ -531,9 +488,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def buildGroupConditions( - group: ArrayBuffer[LogicalPlan], + groupedPlans: ArrayBuffer[LogicalPlan], replaceMap: Map[String, Attribute]): (ArrayBuffer[Expression], Expression) = { - val conditions = group.map { + val conditions = groupedPlans.map { plan => val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] replaceAttributes(filter.condition, replaceMap) @@ -553,94 +510,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Seq.empty) } - def concatAllNotAggregateExpressionsInAggregate( - group: ArrayBuffer[LogicalPlan], - replaceMap: Map[String, Attribute]): Seq[NamedExpression] = { - - val projectionExpressions = ArrayBuffer[NamedExpression]() - val columnNamePrefix = "agg_expr_" - val structFieldNamePrefix = "field_" - var groupIndex: Int = 0 - group.foreach { - plan => - var fieldIndex: Int = 0 - val agg = plan.asInstanceOf[Aggregate] - val structFields = ArrayBuffer[Expression]() - agg.groupingExpressions.map(e => replaceAttributes(e, replaceMap)).foreach { - e => - structFields += Literal( - UTF8String.fromString(s"$structFieldNamePrefix$fieldIndex"), - StringType) - structFields += e - fieldIndex += 1 - } - agg.aggregateExpressions.map(e => replaceAttributes(e, replaceMap)).foreach { - e => - removeAlias(e) match { - case aggExpr: AggregateExpression => - val aggFunc = aggExpr.aggregateFunction - aggFunc.children.foreach { - child => - structFields += Literal( - UTF8String.fromString(s"$structFieldNamePrefix$fieldIndex"), - StringType) - structFields += child - fieldIndex += 1 - } - case notAggExpr => - structFields += Literal( - UTF8String.fromString(s"$structFieldNamePrefix$fieldIndex"), - StringType) - structFields += e - fieldIndex += 1 - } - } - projectionExpressions += makeAlias( - CreateNamedStruct(structFields), - s"$columnNamePrefix$groupIndex") - groupIndex += 1 - } - projectionExpressions.foreach(e => logError(s"xxx 11cprojection expression: ${e.resolved}, $e")) - projectionExpressions - } - - def buildBranchesArray( - group: ArrayBuffer[LogicalPlan], - groupConditions: ArrayBuffer[Expression], - replaceMap: Map[String, Attribute]): Expression = { - val groupStructs = concatAllNotAggregateExpressionsInAggregate(group, replaceMap) - val arrrayFields = ArrayBuffer[Expression]() - for (i <- 0 until groupStructs.length) { - val structData = CreateNamedStruct( - Seq[Expression]( - Literal(UTF8String.fromString("f1"), StringType), - groupStructs(i), - Literal(UTF8String.fromString("f2"), StringType), - Literal(i, IntegerType) - )) - val field = If(groupConditions(i), structData, Literal(null, structData.dataType)) - logError(s"xxx if field: ${field.resolved} $field") - logError(s"xxx if field type: ${field.dataType}") - arrrayFields += field - } - val res = CreateArray(arrrayFields.toSeq) - logError(s"xxx array fields: ${res.resolved} $res") - res - } - - def explodeBranches( - group: ArrayBuffer[LogicalPlan], - groupConditions: ArrayBuffer[Expression], - replaceMap: Map[String, Attribute]): NamedExpression = { - val explode = Explode(buildBranchesArray(group, groupConditions, replaceMap)) - logError(s"xxx explode. resolved: ${explode.resolved}, $explode\n${explode.dataType}") - val alias = Alias(explode, "explode_branches_result")() - logError( - s"xxx explode alias. resolved: ${alias.resolved},${alias.childrenResolved}," + - s"${alias.checkInputDataTypes().isSuccess} \n$alias") - alias - } - def buildDestructStructProject(child: LogicalPlan): LogicalPlan = { assert(child.output.length == 1, s"Expected single output from $child") val structedData = child.output.head @@ -660,10 +529,10 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def buildAggregateWithGroupId( child: LogicalPlan, - aggregateGroup: ArrayBuffer[GroupPlanResult]): LogicalPlan = { + groupedPlans: ArrayBuffer[AnalyzedPlan]): LogicalPlan = { val attributes = child.output - val aggregateTemplate = aggregateGroup.head.plan.asInstanceOf[Aggregate] - val analyzedAggregateInfo = aggregateGroup.head.analyzedAggregateInfo.get + val aggregateTemplate = groupedPlans.head.plan.asInstanceOf[Aggregate] + val analyzedAggregateInfo = groupedPlans.head.analyzedAggregateInfo.get val totalGroupingExpressionsCount = math.max( @@ -685,7 +554,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi val aggFunc = aggExpr.aggregateFunction val newAggFuncArgs = aggFunc.children.zipWithIndex.map { case (arg, i) => - logError(s"xxx agg expr: $arg, $i, $aggregateExpressionIndex") attributes(aggregateExpressionIndex + i) } aggregateExpressionIndex += aggFunc.children.length @@ -709,68 +577,4 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } Aggregate(groupingExpressions, aggregateExpressions, child) } - - def buildAggregateWithGroupId1( - structedData: Expression, - child: LogicalPlan, - templateAgg: Aggregate): LogicalPlan = { - logError(s"xxx struct data: $structedData") - logError(s"xxx struct type: ${structedData.dataType}") - val structType = structedData.dataType.asInstanceOf[StructType] - val flattenAtrributes = ArrayBuffer[NamedExpression]() - var fieldIndex: Int = 0 - structType.fields(0).dataType.asInstanceOf[StructType].fields.foreach { - field => - flattenAtrributes += Alias( - GetStructField(GetStructField(structedData, 0), fieldIndex), - field.name)() - fieldIndex += 1 - } - flattenAtrributes += Alias( - GetStructField(structedData, 1), - CoalesceAggregationUnion.UNION_TAG_FIELD)() - logError(s"xxx flatten attributes: ${flattenAtrributes.length}\n$flattenAtrributes") - - val flattenProject = Project(flattenAtrributes, child) - val groupingExpessions = ArrayBuffer[Expression]() - for (i <- 0 until templateAgg.groupingExpressions.length) { - groupingExpessions += flattenProject.output(i) - } - groupingExpessions += flattenProject.output.last - - var aggregateExpressionIndex = templateAgg.groupingExpressions.length - val aggregateExpressions = ArrayBuffer[NamedExpression]() - for (i <- 0 until templateAgg.aggregateExpressions.length) { - logError( - s"xxx field index: $aggregateExpressionIndex, i: $i, " + - s"len:${templateAgg.aggregateExpressions.length}") - removeAlias(templateAgg.aggregateExpressions(i)) match { - case aggExpr: AggregateExpression => - val aggregateFunctionArgs = aggExpr.aggregateFunction.children.zipWithIndex.map { - case (e, index) => flattenProject.output(aggregateExpressionIndex + index) - } - aggregateExpressionIndex += aggExpr.aggregateFunction.children.length - val aggregateFunction = aggExpr.aggregateFunction - .withNewChildren(aggregateFunctionArgs) - .asInstanceOf[AggregateFunction] - val newAggregateExpression = AggregateExpression( - aggregateFunction, - aggExpr.mode, - aggExpr.isDistinct, - aggExpr.filter, - aggExpr.resultId) - aggregateExpressions += Alias( - newAggregateExpression, - templateAgg.aggregateExpressions(i).asInstanceOf[Alias].name)() - case notAggExpr => - aggregateExpressions += flattenProject.output(aggregateExpressionIndex) - aggregateExpressionIndex += 1 - } - } - Aggregate(groupingExpessions, aggregateExpressions, flattenProject) - } -} - -object CoalesceAggregationUnion { - val UNION_TAG_FIELD: String = "union_tag_" } From 52084f93414fd82c12bf6074770d1d51bbc8f251 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 24 Jan 2025 17:44:07 +0800 Subject: [PATCH 5/9] wip --- .../extension/CoalesceAggregationUnion.scala | 271 ++++++++++++------ 1 file changed, 187 insertions(+), 84 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index 97266767d044..56756eebe544 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -61,27 +61,146 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - case class AnalyzedAggregteInfo(aggregate: Aggregate) { - lazy val resultGroupingExpressions = aggregate.aggregateExpressions.filter { - e => - removeAlias(e) match { - case aggExpr: AggregateExpression => false - case _ => true + def hasAggregateExpression(e: Expression): Boolean = { + if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) { + return false + } + e match { + case _: AggregateExpression => true + case _ => e.children.exists(hasAggregateExpression(_)) + } + } + + def hasAggregateExpressionsWithFilter(e: Expression): Boolean = { + if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) { + return false + } + e match { + case aggExpr: AggregateExpression => + aggExpr.filter.isDefined + case _ => e.children.exists(hasAggregateExpressionsWithFilter(_)) + } + } + + case class AggregateAnalzyInfo(originalAggregate: Aggregate) { + + protected def buildAttributesToExpressionsMap( + attributes: Seq[Attribute], + expressions: Seq[Expression]): Map[ExprId, Expression] = { + val map = new mutable.HashMap[ExprId, Expression]() + attributes.zip(expressions).foreach { + case (attr, expr) => + map.put(attr.exprId, expr) + } + map.toMap + } + + protected def replaceAttributes( + expression: Expression, + replaceMap: Map[ExprId, Expression]): Expression = { + expression.transform { + case attr: Attribute => + logError(s"xxx replace attr:$attr") + replaceMap.getOrElse(attr.exprId, attr.asInstanceOf[Expression]) + } + } + + protected def getFilter(): Option[Filter] = { + originalAggregate.child match { + case filter: Filter => Some(filter) + case project @ Project(_, filter: Filter) => Some(filter) + case subquery: SubqueryAlias => + subquery.child match { + case filter: Filter => Some(filter) + case project @ Project(_, filter: Filter) => Some(filter) + case _ => None + } + } + } + + lazy val sourcePlan = { + val filter = getFilter() + if (!filter.isDefined) { + None + } else { + filter.get.child match { + case project: Project => Some(project.child) + case other => Some(other) } + } } - lazy val hasAggregateWithFilter = aggregate.aggregateExpressions.exists { - e => - removeAlias(e) match { - case aggExpr: AggregateExpression => aggExpr.filter.isDefined - case _ => false + lazy val filterPlan = { + val filter = getFilter() + if (!filter.isDefined || !sourcePlan.isDefined) { + None + } else { + val project = filter.get.child match { + case project: Project => Some(project) + case other => None + } + val replacedFilter = project match { + case Some(project) => + val replaceMap = buildAttributesToExpressionsMap(project.output, project.child.output) + val replacedCondition = replaceAttributes(filter.get.condition, replaceMap) + Filter(replacedCondition, sourcePlan.get) + case None => filter.get.withNewChildren(Seq(sourcePlan.get)) + } + Some(replacedFilter) + } + } + + lazy val aggregatePlan = { + if (!filterPlan.isDefined) { + None + } else { + + val project = originalAggregate.child match { + case p: Project => Some(p) + case subquery: SubqueryAlias => + subquery.child match { + case p: Project => Some(p) + case _ => None + } + case _ => None } + + val replacedAggregate = project match { + case Some(innerProject) => + val replaceMap = + buildAttributesToExpressionsMap(innerProject.output, innerProject.projectList) + logError(s"xxx replace map:\n$replaceMap") + val groupExpressions = originalAggregate.groupingExpressions.map { + e => replaceAttributes(e, replaceMap) + } + val aggregateExpressions = originalAggregate.aggregateExpressions.map { + e => replaceAttributes(e, replaceMap).asInstanceOf[NamedExpression] + } + logError( + s"xxx group expressions:$groupExpressions\n" + + s"aggregateExpressions:$aggregateExpressions") + Aggregate(groupExpressions, aggregateExpressions, filterPlan.get) + case None => originalAggregate.withNewChildren(Seq(filterPlan.get)) + } + Some(replacedAggregate) + } } - lazy val resultPositionInGroupingKeys = { + lazy val hasAggregateWithFilter = originalAggregate.aggregateExpressions.exists { + e => hasAggregateExpressionsWithFilter(e) + } + + lazy val resultGroupingExpressions = aggregatePlan match { + case Some(agg) => + agg.asInstanceOf[Aggregate].aggregateExpressions.filter(e => !hasAggregateExpression(e)) + case None => Seq.empty + } + + lazy val positionInGroupingKeys = { var i = 0 resultGroupingExpressions.map { e => + val aggregate = aggregatePlan.get.asInstanceOf[Aggregate] e match { case literal @ Alias(_: Literal, _) => var idx = aggregate.groupingExpressions.indexOf(e) @@ -90,13 +209,21 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi i += 1 } idx - case _ => aggregate.groupingExpressions.indexOf(removeAlias(e)) + case _ => + var idx = aggregate.groupingExpressions.indexOf(removeAlias(e)) + idx = if (idx == -1) { + aggregate.groupingExpressions.indexOf(e) + } else { + idx + } + assert(idx != -1, s"Expected $e in ${aggregate.groupingExpressions}") + idx } } } } - case class AnalyzedPlan(plan: LogicalPlan, analyzedAggregateInfo: Option[AnalyzedAggregteInfo]) + case class AnalyzedPlan(plan: LogicalPlan, analyzedInfo: Option[AggregateAnalzyInfo]) override def apply(plan: LogicalPlan): LogicalPlan = { if (plan.resolved) { @@ -118,16 +245,17 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi if (groupedPlans.length == 1) { groupedPlans.head.plan } else { - val aggregates = groupedPlans.map(_.plan) + val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get + val aggregates = groupedPlans.map(_.analyzedInfo.get.aggregatePlan.get) val replaceAttributes = collectReplaceAttributes(aggregates) val filterConditions = buildAggregateCasesConditions(aggregates, replaceAttributes) val firstAggregateFilter = - groupedPlans.head.plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] + firstAggregateAnalzyInfo.filterPlan.get.asInstanceOf[Filter] // Concat all filter conditions with `or` and apply it on the source node val unionFilter = Filter( buildUnionConditionForAggregateSource(filterConditions), - firstAggregateFilter.child) + firstAggregateAnalzyInfo.sourcePlan.get) logError(s"xxx union filter:\n$unionFilter") val wrappedAttributesProject = @@ -167,17 +295,20 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi newPlan } - def isSupportedAggregate(info: AnalyzedAggregteInfo): Boolean = { + def isSupportedAggregate(info: AggregateAnalzyInfo): Boolean = { if (info.hasAggregateWithFilter) { - false - } else { - true + return false } + + if (!info.aggregatePlan.isDefined) { + return false + } + true } - def areSameStructureAggregate(l: AnalyzedAggregteInfo, r: AnalyzedAggregteInfo): Boolean = { - val lAggregate = l.aggregate - val rAggregate = r.aggregate + def areSameStructureAggregate(l: AggregateAnalzyInfo, r: AggregateAnalzyInfo): Boolean = { + val lAggregate = l.aggregatePlan.get.asInstanceOf[Aggregate] + val rAggregate = r.aggregatePlan.get.asInstanceOf[Aggregate] // Check aggregate result expressions. need same schema. if (lAggregate.aggregateExpressions.length != rAggregate.aggregateExpressions.length) { @@ -189,18 +320,16 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi if (!lExpr.dataType.equals(rExpr.dataType)) { return false } - (removeAlias(lExpr), removeAlias(rExpr)) match { - case (lAggExpr: AggregateExpression, rAggExpr: AggregateExpression) => - if (lAggExpr.aggregateFunction.getClass != rAggExpr.aggregateFunction.getClass) { - return false - } - if ( - lAggExpr.aggregateFunction.children.length != - rAggExpr.aggregateFunction.children.length - ) { + (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match { + case (true, true) => + if (!areEqualExpressions(lExpr, rExpr)) { return false } - case _ => + case (false, true) => + return false + case (true, false) => + return false + case (false, false) => } } @@ -209,40 +338,14 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi return false } - def hasAggregateExpression(e: Expression): Boolean = { - if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) { - return false - } - e match { - case _: AggregateExpression => true - case _ => e.children.exists(hasAggregateExpression(_)) - } - } - - // 1. data types must be same. - // 2. The aggregate expressions must have same structure and at the same position. - lAggregate.groupingExpressions.zip(rAggregate.groupingExpressions).foreach { - case (lExpr, rExpr) => - if (!lExpr.dataType.equals(rExpr.dataType)) { - return false - } - val lHasAgg = hasAggregateExpression(lExpr) - val rHasAgg = hasAggregateExpression(rExpr) - if (lHasAgg != rHasAgg) { - return false - } else if (lHasAgg && !areEqualExpressions(lExpr, rExpr)) { - return false - } - } - if ( - l.resultPositionInGroupingKeys.length != - r.resultPositionInGroupingKeys.length + l.positionInGroupingKeys.length != + r.positionInGroupingKeys.length ) { return false } - l.resultPositionInGroupingKeys - .zip(r.resultPositionInGroupingKeys) + l.positionInGroupingKeys + .zip(r.positionInGroupingKeys) .foreach { case (lPos, rPos) => if (lPos != rPos) { @@ -251,11 +354,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } // Must come from same source. - if ( - areSameAggregateSource( - lAggregate.child.asInstanceOf[Filter].child, - rAggregate.child.asInstanceOf[Filter].child) - ) { + if (areSameAggregateSource(l.sourcePlan.get, r.sourcePlan.get)) { return true } @@ -265,14 +364,12 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi // If returns -1, not found same structure aggregate. def findSameStructureAggregate( planGroups: ArrayBuffer[ArrayBuffer[AnalyzedPlan]], - analyzedAggregateInfo: AnalyzedAggregteInfo): Int = { + analyzedInfo: AggregateAnalzyInfo): Int = { planGroups.zipWithIndex.foreach { case (groupedPlans, i) => if ( - groupedPlans.head.analyzedAggregateInfo.isDefined && - areSameStructureAggregate( - groupedPlans.head.analyzedAggregateInfo.get, - analyzedAggregateInfo) + groupedPlans.head.analyzedInfo.isDefined && + areSameStructureAggregate(groupedPlans.head.analyzedInfo.get, analyzedInfo) ) { return i } @@ -283,8 +380,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def groupSameStructureAggregate(union: Union): ArrayBuffer[ArrayBuffer[AnalyzedPlan]] = { val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]() union.children.foreach { - case agg @ Aggregate(_, _, filter: Filter) => - val analyzedInfo = AnalyzedAggregteInfo(agg) + case agg: Aggregate => + val analyzedInfo = AggregateAnalzyInfo(agg) if (isSupportedAggregate(analyzedInfo)) { if (groupResults.isEmpty) { groupResults += ArrayBuffer(AnalyzedPlan(agg, Some(analyzedInfo))) @@ -403,8 +500,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi val structPrefix = "field_" groupedPlans.zipWithIndex.foreach { case (aggregateCase, case_index) => - val aggregate = aggregateCase.plan.asInstanceOf[Aggregate] - val analyzedInfo = aggregateCase.analyzedAggregateInfo.get + val analyzedInfo = aggregateCase.analyzedInfo.get + val aggregate = analyzedInfo.aggregatePlan.get.asInstanceOf[Aggregate] val structFields = ArrayBuffer[Expression]() var fieldIndex: Int = 0 aggregate.groupingExpressions.foreach { @@ -413,8 +510,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structFields += replaceAttributes(e, replaceMap) fieldIndex += 1 } - for (i <- 0 until analyzedInfo.resultPositionInGroupingKeys.length) { - val position = analyzedInfo.resultPositionInGroupingKeys(i) + for (i <- 0 until analyzedInfo.positionInGroupingKeys.length) { + val position = analyzedInfo.positionInGroupingKeys(i) if (position >= fieldIndex) { val expr = analyzedInfo.resultGroupingExpressions(i) structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) @@ -531,19 +628,22 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi child: LogicalPlan, groupedPlans: ArrayBuffer[AnalyzedPlan]): LogicalPlan = { val attributes = child.output - val aggregateTemplate = groupedPlans.head.plan.asInstanceOf[Aggregate] - val analyzedAggregateInfo = groupedPlans.head.analyzedAggregateInfo.get + val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get + val aggregateTemplate = firstAggregateAnalzyInfo.aggregatePlan.get.asInstanceOf[Aggregate] + val analyzedInfo = groupedPlans.head.analyzedInfo.get val totalGroupingExpressionsCount = math.max( aggregateTemplate.groupingExpressions.length, - analyzedAggregateInfo.resultPositionInGroupingKeys.max + 1) + analyzedInfo.positionInGroupingKeys.max + 1) val groupingExpressions = attributes .slice(0, totalGroupingExpressionsCount) .map(_.asInstanceOf[Expression]) :+ attributes.last - val normalExpressionPosition = analyzedAggregateInfo.resultPositionInGroupingKeys + val normalExpressionPosition = analyzedInfo.positionInGroupingKeys + logError(s"xxx normalExpressionPosition:$normalExpressionPosition") + logError(s"xxx aggregateExpressions:${aggregateTemplate.aggregateExpressions}") var normalExpressionCount = 0 var aggregateExpressionIndex = totalGroupingExpressionsCount val aggregateExpressions = ArrayBuffer[NamedExpression]() @@ -568,6 +668,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi aggregateExpressions += makeAlias(newAggExpr, e.name) case other => + logError( + s"xxx normalExpressionPosition.len:${normalExpressionPosition.length}" + + s", normalExpressionCount:$normalExpressionCount") val position = normalExpressionPosition(normalExpressionCount) val attr = attributes(position) normalExpressionCount += 1 From 7d02490b6a3000167dc30441f2b37b9e179bb8fa Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 6 Feb 2025 09:48:25 +0800 Subject: [PATCH 6/9] wip --- .../extension/CoalesceAggregationUnion.scala | 362 ++++++++++++------ .../GlutenCoalesceAggregationUnionSuite.scala | 341 +++++++++++++++++ 2 files changed, 593 insertions(+), 110 deletions(-) create mode 100644 backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index 56756eebe544..4a592d03a34c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -49,7 +49,7 @@ import scala.collection.mutable.ArrayBuffer * ) WHERE s is not null * ) GROUP BY a,b, group_id * - * The first query need to scan `t` multiple times, when the output of scan is large, the query is + * The first query need to scan `t` multiply, when the output of scan is large, the query is * really slow. The rewritten query only scan `t` once, and the performance is much better. */ @@ -71,6 +71,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } + def isAggregateExpression(e: Expression): Boolean = { + e match { + case cast: Cast => isAggregateExpression(cast.child) + case alias: Alias => isAggregateExpression(alias.child) + case agg: AggregateExpression => true + case _ => false + } + } + def hasAggregateExpressionsWithFilter(e: Expression): Boolean = { if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) { return false @@ -100,7 +109,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi replaceMap: Map[ExprId, Expression]): Expression = { expression.transform { case attr: Attribute => - logError(s"xxx replace attr:$attr") replaceMap.getOrElse(attr.exprId, attr.asInstanceOf[Expression]) } } @@ -110,14 +118,27 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case filter: Filter => Some(filter) case project @ Project(_, filter: Filter) => Some(filter) case subquery: SubqueryAlias => + logError(s"xxx subquery child. ${subquery.child.getClass}") subquery.child match { case filter: Filter => Some(filter) case project @ Project(_, filter: Filter) => Some(filter) + case relation: LogicalRelation => Some(Filter(Literal(true, BooleanType), subquery)) + case nestedRelation: SubqueryAlias => + logError( + s"xxx nestedRelation child. ${nestedRelation.child.getClass}" + + s"\n$nestedRelation") + if (nestedRelation.child.isInstanceOf[LogicalRelation]) { + Some(Filter(Literal(true, BooleanType), nestedRelation)) + } else { + None + } case _ => None } + case _ => None } } + // Try to make the plan simple, contain only three steps, source, filter, aggregate. lazy val sourcePlan = { val filter = getFilter() if (!filter.isDefined) { @@ -137,7 +158,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } else { val project = filter.get.child match { case project: Project => Some(project) - case other => None + case other => + None } val replacedFilter = project match { case Some(project) => @@ -154,7 +176,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi if (!filterPlan.isDefined) { None } else { - val project = originalAggregate.child match { case p: Project => Some(p) case subquery: SubqueryAlias => @@ -169,16 +190,12 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case Some(innerProject) => val replaceMap = buildAttributesToExpressionsMap(innerProject.output, innerProject.projectList) - logError(s"xxx replace map:\n$replaceMap") val groupExpressions = originalAggregate.groupingExpressions.map { e => replaceAttributes(e, replaceMap) } val aggregateExpressions = originalAggregate.aggregateExpressions.map { e => replaceAttributes(e, replaceMap).asInstanceOf[NamedExpression] } - logError( - s"xxx group expressions:$groupExpressions\n" + - s"aggregateExpressions:$aggregateExpressions") Aggregate(groupExpressions, aggregateExpressions, filterPlan.get) case None => originalAggregate.withNewChildren(Seq(filterPlan.get)) } @@ -190,6 +207,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi e => hasAggregateExpressionsWithFilter(e) } + // The output results which are not aggregate expressions. lazy val resultGroupingExpressions = aggregatePlan match { case Some(agg) => agg.asInstanceOf[Aggregate].aggregateExpressions.filter(e => !hasAggregateExpression(e)) @@ -198,6 +216,11 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi lazy val positionInGroupingKeys = { var i = 0 + // In most cases, the expressions which are not aggregate result could be matched with one of + // groupingk keys. There are some exceptions + // 1. The expression is a literal. The grouping keys do not contain the literal. + // 2. The expression is an expression withs gruping keys. For example, + // `select k1 + k2, count(1) from t group by k1, k2`. resultGroupingExpressions.map { e => val aggregate = aggregatePlan.get.asInstanceOf[Aggregate] @@ -216,22 +239,29 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } else { idx } - assert(idx != -1, s"Expected $e in ${aggregate.groupingExpressions}") idx } } } } + /* + * Case class representing an analyzed plan. + * + * @param plan The logical plan that to be analyzed. + * @param analyzedInfo Optional information about the aggregate analysis. + */ case class AnalyzedPlan(plan: LogicalPlan, analyzedInfo: Option[AggregateAnalzyInfo]) override def apply(plan: LogicalPlan): LogicalPlan = { if (plan.resolved) { logError(s"xxx visit plan:\n$plan") val newPlan = visitPlan(plan) - logError(s"xxx rewritten plan:\n$newPlan") + logError(s"xxx output attributes:\n${newPlan.output}\n${plan.output}") + logError(s"xxx rewrite plan:\n$newPlan") newPlan } else { + logError(s"xxx plan not resolved:\n$plan") plan } } @@ -239,7 +269,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def visitPlan(plan: LogicalPlan): LogicalPlan = { val newPlan = plan match { case union: Union => - val planGroups = groupSameStructureAggregate(union) + val planGroups = groupStructureMatchedAggregate(union) val newUnionClauses = planGroups.map { groupedPlans => if (groupedPlans.length == 1) { @@ -249,49 +279,68 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi val aggregates = groupedPlans.map(_.analyzedInfo.get.aggregatePlan.get) val replaceAttributes = collectReplaceAttributes(aggregates) val filterConditions = buildAggregateCasesConditions(aggregates, replaceAttributes) + logError(s"xxx filterConditions. ${filterConditions.length},\n$filterConditions") val firstAggregateFilter = firstAggregateAnalzyInfo.filterPlan.get.asInstanceOf[Filter] - // Concat all filter conditions with `or` and apply it on the source node + // Add a filter step with condition `cond1 or cond2 or ...`, `cond_i` comes from each + // union clause. Apply this filter on the source plan. val unionFilter = Filter( buildUnionConditionForAggregateSource(filterConditions), firstAggregateAnalzyInfo.sourcePlan.get) - logError(s"xxx union filter:\n$unionFilter") - val wrappedAttributesProject = - buildStructWrapperProject( - unionFilter, - groupedPlans, - filterConditions, - replaceAttributes) - logError(s"xxx wrapped attributes project:\n$wrappedAttributesProject") + // Wrap all the attributes into a single structure attribute. + val wrappedAttributesProject = buildStructWrapperProject( + unionFilter, + groupedPlans, + filterConditions, + replaceAttributes) + // Build an array which element are response to each union clause. val arrayProject = buildArrayProject(wrappedAttributesProject, filterConditions) - logError(s"xxx array project:\n$arrayProject") + // Explode the array val explode = buildArrayExplode(arrayProject) - logError(s"xxx explode:\n$explode") + // Null value means that the union clause does not have the corresponding data. val notNullFilter = Filter(IsNotNull(explode.output.head), explode) - logError(s"xxx not null filter:\n$notNullFilter") + // Destruct the struct attribute. val destructStructProject = buildDestructStructProject(notNullFilter) - logError(s"xxx destruct struct project:\n$destructStructProject") - val singleAggregate = buildAggregateWithGroupId(destructStructProject, groupedPlans) - logError(s"xxx single agg:\n$singleAggregate") - - singleAggregate + buildAggregateWithGroupId(destructStructProject, groupedPlans) } } - if (newUnionClauses.length == 1) { + logError(s"xxx newUnionClauses. ${newUnionClauses.length},\n$newUnionClauses") + val coalesePlan = if (newUnionClauses.length == 1) { newUnionClauses.head } else { - union.withNewChildren(newUnionClauses) + var firstUnionChild = newUnionClauses.head + for (i <- 1 until newUnionClauses.length - 1) { + firstUnionChild = Union(firstUnionChild, newUnionClauses(i)) + } + Union(firstUnionChild, newUnionClauses.last) + } + logError(s"xxx coalesePlan:$coalesePlan") + + // We need to keep the output atrributes same as the original plan. + val outputAttrPairs = coalesePlan.output.zip(union.output) + if (outputAttrPairs.forall(pair => pair._1.semanticEquals(pair._2))) { + coalesePlan + } else { + val reprejectOutputs = outputAttrPairs.map { + case (newAttr, oldAttr) => + if (newAttr.exprId == oldAttr.exprId) { + newAttr + } else { + Alias(newAttr, oldAttr.name)(oldAttr.exprId, oldAttr.qualifier, None, Seq.empty) + } + } + Project(reprejectOutputs, coalesePlan) } case _ => plan.withNewChildren(plan.children.map(visitPlan)) } - newPlan.copyTagsFrom(plan) + // newPlan.copyTagsFrom(plan) newPlan } @@ -299,6 +348,29 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi if (info.hasAggregateWithFilter) { return false } + if (!info.aggregatePlan.isDefined) { + return false + } + + if (info.positionInGroupingKeys.exists(_ < 0)) { + return false + } + + // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y` is not supported. + if ( + info.originalAggregate.aggregateExpressions.exists { + e => + val innerExpr = removeAlias(e) + if (hasAggregateExpression(innerExpr)) { + !innerExpr.isInstanceOf[AggregateExpression] && + !innerExpr.children.forall(e => isAggregateExpression(e)) + } else { + false + } + } + ) { + return false + } if (!info.aggregatePlan.isDefined) { return false @@ -306,87 +378,130 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi true } - def areSameStructureAggregate(l: AggregateAnalzyInfo, r: AggregateAnalzyInfo): Boolean = { + /** + * Checks if two AggregateAnalzyInfo instances have the same structure. + * + * This method compares the aggregate expressions, grouping expressions, and the source plans of + * the two AggregateAnalzyInfo instances to determine if they have the same structure. + * + * @param l + * The first AggregateAnalzyInfo instance. + * @param r + * The second AggregateAnalzyInfo instance. + * @return + * True if the two instances have the same structure, false otherwise. + */ + def areStructureMatchedAggregate(l: AggregateAnalzyInfo, r: AggregateAnalzyInfo): Boolean = { val lAggregate = l.aggregatePlan.get.asInstanceOf[Aggregate] val rAggregate = r.aggregatePlan.get.asInstanceOf[Aggregate] // Check aggregate result expressions. need same schema. if (lAggregate.aggregateExpressions.length != rAggregate.aggregateExpressions.length) { + logError(s"xxx not equal 1") + return false } - - lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).foreach { - case (lExpr, rExpr) => - if (!lExpr.dataType.equals(rExpr.dataType)) { - return false - } - (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match { - case (true, true) => - if (!areEqualExpressions(lExpr, rExpr)) { - return false + val allAggregateExpressionAreSame = + lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).forall { + case (lExpr, rExpr) => + if (!lExpr.dataType.equals(rExpr.dataType)) { + false + } else { + (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match { + case (true, true) => + areStructureMatchedExpressions(lExpr, rExpr) + case (false, true) => false + case (true, false) => false + case (false, false) => true } - case (false, true) => - return false - case (true, false) => - return false - case (false, false) => - } + } + } + if (!allAggregateExpressionAreSame) { + return false } // Check grouping expressions, need same schema. if (lAggregate.groupingExpressions.length != rAggregate.groupingExpressions.length) { + logError(s"xxx not equal 3") return false } - if ( l.positionInGroupingKeys.length != r.positionInGroupingKeys.length ) { + logError(s"xxx not equal 4") return false } - l.positionInGroupingKeys + val allSameGroupingKeysRef = l.positionInGroupingKeys .zip(r.positionInGroupingKeys) - .foreach { - case (lPos, rPos) => - if (lPos != rPos) { - return false - } - } + .forall { case (lPos, rPos) => lPos == rPos } + if (!allSameGroupingKeysRef) { + logError(s"xxx not equal 5") + return false + } // Must come from same source. - if (areSameAggregateSource(l.sourcePlan.get, r.sourcePlan.get)) { - return true + if (!areSameAggregateSource(l.sourcePlan.get, r.sourcePlan.get)) { + logError(s"xxx not same source. ${l.sourcePlan.get}\n${r.sourcePlan.get}") + return false } true } - // If returns -1, not found same structure aggregate. - def findSameStructureAggregate( + /* + * Finds the index of the first group in `planGroups` that has the same structure as the given + * `analyzedInfo`. + * + * This method iterates over the `planGroups` and checks if the first `AnalyzedPlan` in each group + * has an `analyzedInfo` that matches the structure of the provided `analyzedInfo`. If a match is + * found, the index of the group is returned. If no match is found, -1 is returned. + * + * @param planGroups + * An ArrayBuffer of ArrayBuffers, where each inner ArrayBuffer contains `AnalyzedPlan` + * instances. + * @param analyzedInfo + * The `AggregateAnalzyInfo` to match against the groups in `planGroups`. + * @return + * The index of the first group with a matching structure, or -1 if no match is found. + */ + def findStructureMatchedAggregate( planGroups: ArrayBuffer[ArrayBuffer[AnalyzedPlan]], analyzedInfo: AggregateAnalzyInfo): Int = { - planGroups.zipWithIndex.foreach { - case (groupedPlans, i) => - if ( - groupedPlans.head.analyzedInfo.isDefined && - areSameStructureAggregate(groupedPlans.head.analyzedInfo.get, analyzedInfo) - ) { - return i - } + planGroups.zipWithIndex.find( + planWithIndex => + planWithIndex._1.head.analyzedInfo.isDefined && + areStructureMatchedAggregate( + planWithIndex._1.head.analyzedInfo.get, + analyzedInfo)) match { + case Some((_, i)) => i + case None => -1 } - -1 + } - def groupSameStructureAggregate(union: Union): ArrayBuffer[ArrayBuffer[AnalyzedPlan]] = { - val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]() + // Union only has two children. It's children may also be Union. + def collectAllUnionClauses(union: Union): ArrayBuffer[LogicalPlan] = { + val unionClauses = ArrayBuffer[LogicalPlan]() union.children.foreach { + case u: Union => + unionClauses ++= collectAllUnionClauses(u) + case other => + unionClauses += other + } + unionClauses + } + + def groupStructureMatchedAggregate(union: Union): ArrayBuffer[ArrayBuffer[AnalyzedPlan]] = { + val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]() + collectAllUnionClauses(union).foreach { case agg: Aggregate => val analyzedInfo = AggregateAnalzyInfo(agg) if (isSupportedAggregate(analyzedInfo)) { if (groupResults.isEmpty) { groupResults += ArrayBuffer(AnalyzedPlan(agg, Some(analyzedInfo))) } else { - val idx = findSameStructureAggregate(groupResults, analyzedInfo) + val idx = findStructureMatchedAggregate(groupResults, analyzedInfo) if (idx != -1) { groupResults(idx) += AnalyzedPlan(agg, Some(analyzedInfo)) } else { @@ -394,6 +509,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } } else { + logError(s"xxx not supported. $agg") groupResults += ArrayBuffer(AnalyzedPlan(agg, None)) } case other => @@ -402,9 +518,10 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi groupResults } - def areEqualExpressions(l: Expression, r: Expression): Boolean = { + def areStructureMatchedExpressions(l: Expression, r: Expression): Boolean = { (l, r) match { case (lAttr: Attribute, rAttr: Attribute) => + logError(s"xxx attr equal: ${lAttr.qualifiedName}, ${rAttr.qualifiedName}") lAttr.qualifiedName == rAttr.qualifiedName case (lLiteral: Literal, rLiteral: Literal) => lLiteral.value.equals(rLiteral.value) @@ -413,7 +530,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi false } else { l.children.zip(r.children).forall { - case (lChild, rChild) => areEqualExpressions(lChild, rChild) + case (lChild, rChild) => areStructureMatchedExpressions(lChild, rChild) } } } @@ -427,6 +544,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case (lRel: LogicalRelation, rRel: LogicalRelation) => val lTable = lRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") val rTable = rRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") + logError(s"xxx table equal: $lTable, $rTable") lTable.equals(rTable) && lTable.nonEmpty case (lSubQuery: SubqueryAlias, rSubQuery: SubqueryAlias) => areSameAggregateSource(lSubQuery.child, rSubQuery.child) @@ -521,23 +639,35 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } aggregate.aggregateExpressions - .filter( - e => - removeAlias(e) match { - case aggExpr: AggregateExpression => true - case _ => false - }) + .filter(e => hasAggregateExpression(e)) .foreach { e => - val aggFunction = removeAlias(e).asInstanceOf[AggregateExpression].aggregateFunction - aggFunction.children.foreach { - child => - structFields += Literal( - UTF8String.fromString(s"$structPrefix$fieldIndex"), - StringType) - structFields += replaceAttributes(child, replaceMap) - fieldIndex += 1 + def collectExpressionsInAggregateExpression(aggExpr: Expression): Unit = { + aggExpr match { + case aggExpr: AggregateExpression => + val aggFunction = + removeAlias(aggExpr).asInstanceOf[AggregateExpression].aggregateFunction + aggFunction.children.foreach { + child => + structFields += Literal( + UTF8String.fromString(s"$structPrefix$fieldIndex"), + StringType) + structFields += replaceAttributes(child, replaceMap) + fieldIndex += 1 + } + case combineAgg if hasAggregateExpression(combineAgg) => + combineAgg.children.foreach { + combindAggchild => collectExpressionsInAggregateExpression(combindAggchild) + } + case other => + structFields += Literal( + UTF8String.fromString(s"$structPrefix$fieldIndex"), + StringType) + structFields += replaceAttributes(other, replaceMap) + fieldIndex += 1 + } } + collectExpressionsInAggregateExpression(e) } structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) structFields += Literal(case_index, IntegerType) @@ -642,35 +772,19 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi .map(_.asInstanceOf[Expression]) :+ attributes.last val normalExpressionPosition = analyzedInfo.positionInGroupingKeys - logError(s"xxx normalExpressionPosition:$normalExpressionPosition") - logError(s"xxx aggregateExpressions:${aggregateTemplate.aggregateExpressions}") var normalExpressionCount = 0 var aggregateExpressionIndex = totalGroupingExpressionsCount val aggregateExpressions = ArrayBuffer[NamedExpression]() aggregateTemplate.aggregateExpressions.foreach { e => removeAlias(e) match { - case aggExpr: AggregateExpression => - val aggFunc = aggExpr.aggregateFunction - val newAggFuncArgs = aggFunc.children.zipWithIndex.map { - case (arg, i) => - attributes(aggregateExpressionIndex + i) - } - aggregateExpressionIndex += aggFunc.children.length - val newAggFunc = - aggFunc.withNewChildren(newAggFuncArgs).asInstanceOf[AggregateFunction] - val newAggExpr = AggregateExpression( - newAggFunc, - aggExpr.mode, - aggExpr.isDistinct, - aggExpr.filter, - aggExpr.resultId) - aggregateExpressions += makeAlias(newAggExpr, e.name) - + case aggExpr if hasAggregateExpression(aggExpr) => + aggregateExpressions += makeAlias( + constructAggregateExpression(aggExpr, attributes, aggregateExpressionIndex), + e.name) + .asInstanceOf[NamedExpression] + aggregateExpressionIndex += aggExpr.children.length case other => - logError( - s"xxx normalExpressionPosition.len:${normalExpressionPosition.length}" + - s", normalExpressionCount:$normalExpressionCount") val position = normalExpressionPosition(normalExpressionCount) val attr = attributes(position) normalExpressionCount += 1 @@ -680,4 +794,32 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } Aggregate(groupingExpressions, aggregateExpressions, child) } + + def constructAggregateExpression( + aggExpr: Expression, + attributes: Seq[Attribute], + index: Int): Expression = { + aggExpr match { + case singleAggExpr: AggregateExpression => + val aggFunc = singleAggExpr.aggregateFunction + val newAggFuncArgs = aggFunc.children.zipWithIndex.map { + case (arg, i) => + attributes(index + i) + } + val newAggFunc = + aggFunc.withNewChildren(newAggFuncArgs).asInstanceOf[AggregateFunction] + AggregateExpression( + newAggFunc, + singleAggExpr.mode, + singleAggExpr.isDistinct, + singleAggExpr.filter, + singleAggExpr.resultId) + case combineAggExpr if hasAggregateExpression(combineAggExpr) => + combineAggExpr.withNewChildren( + combineAggExpr.children.map(constructAggregateExpression(_, attributes, index))) + case _ => + val normalExpr = attributes(index) + normalExpr + } + } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala new file mode 100644 index 000000000000..ae5b4b7d5a92 --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.config.GlutenConfig +import org.apache.gluten.utils.UTSystemParameters + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig +import org.apache.spark.sql.types._ + +import java.nio.file.Files + +class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTransformerSuite { + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.sql.files.maxPartitionBytes", "1g") + .set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + .set("spark.sql.shuffle.partitions", "5") + .set("spark.sql.adaptive.enabled", "false") + .set("spark.sql.files.minPartitionNum", "1") + .set( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseSparkCatalog") + .set("spark.databricks.delta.maxSnapshotLineageLength", "20") + .set("spark.databricks.delta.snapshotPartitions", "1") + .set("spark.databricks.delta.properties.defaults.checkpointInterval", "5") + .set("spark.databricks.delta.stalenessLimit", "3600000") + .set(ClickHouseConfig.CLICKHOUSE_WORKER_ID, "1") + .set(GlutenConfig.GLUTEN_LIB_PATH, UTSystemParameters.clickHouseLibPath) + .set("spark.gluten.sql.columnar.iterator", "true") + .set("spark.gluten.sql.columnar.hashagg.enablefinal", "true") + .set("spark.gluten.sql.enable.native.validation", "false") + .set("spark.sql.warehouse.dir", warehouse) + .set("spark.shuffle.manager", "sort") + .set("spark.io.compression.codec", "snappy") + .set("spark.sql.shuffle.partitions", "5") + .set("spark.sql.autoBroadcastJoinThreshold", "10MB") + } + + def createTestTable(tableName: String, data: DataFrame): Unit = { + val tempFile = Files.createTempFile("", ".parquet").toFile + tempFile.deleteOnExit() + val tempFilePath = tempFile.getAbsolutePath + data.coalesce(1).write.format("parquet").mode("overwrite").parquet(tempFilePath) + spark.catalog.createTable(tableName, tempFilePath, "parquet") + } + + override def beforeAll(): Unit = { + super.beforeAll() + + val schema = StructType( + Array( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("x", StringType, nullable = true), + StructField("y", IntegerType, nullable = true) + )) + val data = sparkContext.parallelize( + Seq( + Row("a", 1, null, 1), + Row("a", 2, "a", 2), + Row("a", 3, "b", 3), + Row("a", 4, "c", 4), + Row("b", 1, "d", 5), + Row("b", 2, "e", 6), + Row("b", 3, "f", 7), + Row("b", 4, "g", null) + )) + + val dataFrame = spark.createDataFrame(data, schema) + createTestTable("coalesce_union_t1", dataFrame) + createTestTable("coalesce_union_t2", dataFrame) + } + + def checkNoUnion(df: DataFrame): Unit = { + val unions = collectWithSubqueries(df.queryExecution.executedPlan) { + case e: ColumnarUnionExec => e + } + assert(unions.isEmpty) + } + + def checkHasUnion(df: DataFrame): Unit = { + val unions = collectWithSubqueries(df.queryExecution.executedPlan) { + case e: ColumnarUnionExec => e + } + assert(unions.size == 1) + } + + test("coalesce aggregation union. case 1") { + val sql = + """ + |select a, x + 1 as x, y from ( + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by a + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 2 + | group by a + |) order by a, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 2") { + val sql = + """ + |select a, x + 1 as x, y from ( + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by a + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b > 1 + | group by a + |) order by a, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 3") { + val sql = + """ + |select a, x + 1 as x, y from ( + | select a, 1 as t, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select a, 2 as t, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by a + |) order by a, t, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 4") { + val sql = + """ + |select * from ( + | select a, 1 as t, count(x) + sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select a, 2 as t, count(x) + sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by a + |) order by a, t, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 5") { + val sql = + """ + |select * from ( + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a, b + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by a, b + |) order by a, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 6") { + val sql = + """ + |select * from ( + | select y + 1 as y , count(x) as x from coalesce_union_t1 where b % 3 = 0 + | group by y + 1 + | union all + | select y + 1 as y, count(x) as x from coalesce_union_t1 where b % 3 = 1 + | group by y + 1 + |) order by y, x + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 7") { + val sql = + """ + |select * from ( + | select a, count(x) as x, sum(y + 1) as y from coalesce_union_t1 where b % 3 = 0 + | group by a, b + | union all + | select a, count(x) as x, sum(y + 1) as y from coalesce_union_t1 where b % 3 = 1 + | group by a, b + |) order by a, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 8") { + val sql = + """ + |select * from ( + | select a as a, sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select x as a , sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by x + |) order by a, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 9") { + val sql = + """ + |select a, x + 1 as x, y from ( + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by a + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 2 + | group by a + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 where b % 3 = 3 + | group by a + |) order by a, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("no coalesce aggregation union. case 1") { + val sql = + """ + |select * from ( + | select a, count(x) + 1 as x, sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select a, count(x) + 1 as x, sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by a + |) order by a, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) + } + + test("no coalesce aggregation union. case 2") { + val sql = + """ + |select * from ( + | select y + 1 as y, count(x) as x from coalesce_union_t1 where b % 3 = 0 + | group by y + | union all + | select y + 1 as y, count(x) as x from coalesce_union_t1 where b % 3 = 1 + | group by y + |) order by y, x + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) + } + + test("no coalesce aggregation union. case 3") { + val sql = + """ + |select * from ( + | select a, count(x) as x from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select a, count(y) as x from coalesce_union_t1 where b % 3 = 1 + | group by a + |) order by a, x + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) + } + + test("no coalesce aggregation union. case 4") { + val sql = + """ + |select * from ( + | select a, 1 as b, count(x) as x from coalesce_union_t1 where b % 3 = 0 + | group by a, 1 + | union all + | select a, b, count(x) as x from coalesce_union_t1 where b % 3 = 1 + | group by a, b + |) order by a, b, x + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) + } + + test("no coalesce aggregation union. case 5") { + val sql = + """ + |select * from ( + | select a, b, count(x) as x from coalesce_union_t1 where b % 3 = 0 + | group by a, b + | union all + | select a, b, count(x) as x from coalesce_union_t2 where b % 3 = 1 + | group by a, b + |) order by a, b, x + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) + } + + test("no coalesce aggregation union. case 6") { + val sql = + """ + |select * from ( + | select a as k1, x as k2, count(y) as c from coalesce_union_t1 where b % 3 = 0 + | group by a, x + | union all + | select x as k1, a as k2, count(y) as c from coalesce_union_t1 where b % 3 = 1 + | group by a, x + |) order by k1, k2, c + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) + } + + test("no coalesce aggregation union. case 7") { + val sql = + """ + |select * from ( + | select a, count(y) as y from coalesce_union_t1 where b % 3 = 0 + | group by a + | union all + | select a, count(y) as y from coalesce_union_t2 + | group by a + |) order by a, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) + } + +} From 2b41a7e9756ef76eecb6d8d2f523a719e6b2d8db Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 6 Feb 2025 17:05:58 +0800 Subject: [PATCH 7/9] wip --- .../backendsapi/clickhouse/CHBackend.scala | 3 + .../backendsapi/clickhouse/CHRuleApi.scala | 1 + .../extension/CoalesceAggregationUnion.scala | 484 ++++++++---------- .../GlutenCoalesceAggregationUnionSuite.scala | 44 +- 4 files changed, 267 insertions(+), 265 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index e883f8c454ce..f47aa0416cc3 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -156,6 +156,9 @@ object CHBackendSettings extends BackendSettingsApi with Logging { CHConf.prefixOf("convert.left.anti_semi.to.right") val GLUTEN_CLICKHOUSE_CONVERT_LEFT_ANTI_SEMI_TO_RIGHT_DEFAULT_VALUE: String = "false" + val GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION: String = + CHConf.prefixOf("enable.coalesce.aggregation.union") + def affinityMode: String = { SparkEnv.get.conf .get( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 40344e96e768..ecd7e5a24107 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -60,6 +60,7 @@ object CHRuleApi { (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) injector.injectParser( (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) + injector.injectResolutionRule(spark => new CoalesceAggregationUnion(spark)) injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index 4a592d03a34c..073ccf5d9dda 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.extension +import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings import org.apache.gluten.exception.GlutenNotSupportException import org.apache.spark.internal.Logging @@ -93,7 +94,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case class AggregateAnalzyInfo(originalAggregate: Aggregate) { - protected def buildAttributesToExpressionsMap( + protected def createAttributeToExpressionMap( attributes: Seq[Attribute], expressions: Seq[Expression]): Map[ExprId, Expression] = { val map = new mutable.HashMap[ExprId, Expression]() @@ -104,7 +105,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi map.toMap } - protected def replaceAttributes( + protected def replaceAttributesInExpression( expression: Expression, replaceMap: Map[ExprId, Expression]): Expression = { expression.transform { @@ -113,20 +114,16 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - protected def getFilter(): Option[Filter] = { + protected def extractFilter(): Option[Filter] = { originalAggregate.child match { case filter: Filter => Some(filter) case project @ Project(_, filter: Filter) => Some(filter) case subquery: SubqueryAlias => - logError(s"xxx subquery child. ${subquery.child.getClass}") subquery.child match { case filter: Filter => Some(filter) case project @ Project(_, filter: Filter) => Some(filter) case relation: LogicalRelation => Some(Filter(Literal(true, BooleanType), subquery)) case nestedRelation: SubqueryAlias => - logError( - s"xxx nestedRelation child. ${nestedRelation.child.getClass}" + - s"\n$nestedRelation") if (nestedRelation.child.isInstanceOf[LogicalRelation]) { Some(Filter(Literal(true, BooleanType), nestedRelation)) } else { @@ -139,8 +136,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } // Try to make the plan simple, contain only three steps, source, filter, aggregate. - lazy val sourcePlan = { - val filter = getFilter() + lazy val extractedSourcePlan = { + val filter = extractFilter() if (!filter.isDefined) { None } else { @@ -151,9 +148,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - lazy val filterPlan = { - val filter = getFilter() - if (!filter.isDefined || !sourcePlan.isDefined) { + lazy val constructedFilterPlan = { + val filter = extractFilter() + if (!filter.isDefined || !extractedSourcePlan.isDefined) { None } else { val project = filter.get.child match { @@ -161,19 +158,19 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case other => None } - val replacedFilter = project match { + val newFilter = project match { case Some(project) => - val replaceMap = buildAttributesToExpressionsMap(project.output, project.child.output) - val replacedCondition = replaceAttributes(filter.get.condition, replaceMap) - Filter(replacedCondition, sourcePlan.get) - case None => filter.get.withNewChildren(Seq(sourcePlan.get)) + val replaceMap = createAttributeToExpressionMap(project.output, project.child.output) + val newCondition = replaceAttributesInExpression(filter.get.condition, replaceMap) + Filter(newCondition, extractedSourcePlan.get) + case None => filter.get.withNewChildren(Seq(extractedSourcePlan.get)) } - Some(replacedFilter) + Some(newFilter) } } - lazy val aggregatePlan = { - if (!filterPlan.isDefined) { + lazy val constructedAggregatePlan = { + if (!constructedFilterPlan.isDefined) { None } else { val project = originalAggregate.child match { @@ -186,20 +183,20 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case _ => None } - val replacedAggregate = project match { + val newAggregate = project match { case Some(innerProject) => val replaceMap = - buildAttributesToExpressionsMap(innerProject.output, innerProject.projectList) - val groupExpressions = originalAggregate.groupingExpressions.map { - e => replaceAttributes(e, replaceMap) + createAttributeToExpressionMap(innerProject.output, innerProject.projectList) + val newGroupExpressions = originalAggregate.groupingExpressions.map { + e => replaceAttributesInExpression(e, replaceMap) } - val aggregateExpressions = originalAggregate.aggregateExpressions.map { - e => replaceAttributes(e, replaceMap).asInstanceOf[NamedExpression] + val newAggregateExpressions = originalAggregate.aggregateExpressions.map { + e => replaceAttributesInExpression(e, replaceMap).asInstanceOf[NamedExpression] } - Aggregate(groupExpressions, aggregateExpressions, filterPlan.get) - case None => originalAggregate.withNewChildren(Seq(filterPlan.get)) + Aggregate(newGroupExpressions, newAggregateExpressions, constructedFilterPlan.get) + case None => originalAggregate.withNewChildren(Seq(constructedFilterPlan.get)) } - Some(replacedAggregate) + Some(newAggregate) } } @@ -208,7 +205,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } // The output results which are not aggregate expressions. - lazy val resultGroupingExpressions = aggregatePlan match { + lazy val resultGroupingExpressions = constructedAggregatePlan match { case Some(agg) => agg.asInstanceOf[Aggregate].aggregateExpressions.filter(e => !hasAggregateExpression(e)) case None => Seq.empty @@ -223,7 +220,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi // `select k1 + k2, count(1) from t group by k1, k2`. resultGroupingExpressions.map { e => - val aggregate = aggregatePlan.get.asInstanceOf[Aggregate] + val aggregate = constructedAggregatePlan.get.asInstanceOf[Aggregate] e match { case literal @ Alias(_: Literal, _) => var idx = aggregate.groupingExpressions.indexOf(e) @@ -254,128 +251,110 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case class AnalyzedPlan(plan: LogicalPlan, analyzedInfo: Option[AggregateAnalzyInfo]) override def apply(plan: LogicalPlan): LogicalPlan = { - if (plan.resolved) { - logError(s"xxx visit plan:\n$plan") - val newPlan = visitPlan(plan) - logError(s"xxx output attributes:\n${newPlan.output}\n${plan.output}") - logError(s"xxx rewrite plan:\n$newPlan") - newPlan + if ( + plan.resolved && spark.conf + .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, "true") + .toBoolean + ) { + visitPlan(plan) } else { - logError(s"xxx plan not resolved:\n$plan") plan } } def visitPlan(plan: LogicalPlan): LogicalPlan = { - val newPlan = plan match { + plan match { case union: Union => val planGroups = groupStructureMatchedAggregate(union) - val newUnionClauses = planGroups.map { - groupedPlans => - if (groupedPlans.length == 1) { - groupedPlans.head.plan - } else { - val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get - val aggregates = groupedPlans.map(_.analyzedInfo.get.aggregatePlan.get) - val replaceAttributes = collectReplaceAttributes(aggregates) - val filterConditions = buildAggregateCasesConditions(aggregates, replaceAttributes) - logError(s"xxx filterConditions. ${filterConditions.length},\n$filterConditions") - val firstAggregateFilter = - firstAggregateAnalzyInfo.filterPlan.get.asInstanceOf[Filter] - - // Add a filter step with condition `cond1 or cond2 or ...`, `cond_i` comes from each - // union clause. Apply this filter on the source plan. - val unionFilter = Filter( - buildUnionConditionForAggregateSource(filterConditions), - firstAggregateAnalzyInfo.sourcePlan.get) - - // Wrap all the attributes into a single structure attribute. - val wrappedAttributesProject = buildStructWrapperProject( - unionFilter, - groupedPlans, - filterConditions, - replaceAttributes) - - // Build an array which element are response to each union clause. - val arrayProject = buildArrayProject(wrappedAttributesProject, filterConditions) - - // Explode the array - val explode = buildArrayExplode(arrayProject) - - // Null value means that the union clause does not have the corresponding data. - val notNullFilter = Filter(IsNotNull(explode.output.head), explode) - - // Destruct the struct attribute. - val destructStructProject = buildDestructStructProject(notNullFilter) - - buildAggregateWithGroupId(destructStructProject, groupedPlans) - } - } - logError(s"xxx newUnionClauses. ${newUnionClauses.length},\n$newUnionClauses") - val coalesePlan = if (newUnionClauses.length == 1) { - newUnionClauses.head - } else { - var firstUnionChild = newUnionClauses.head - for (i <- 1 until newUnionClauses.length - 1) { - firstUnionChild = Union(firstUnionChild, newUnionClauses(i)) - } - Union(firstUnionChild, newUnionClauses.last) - } - logError(s"xxx coalesePlan:$coalesePlan") - - // We need to keep the output atrributes same as the original plan. - val outputAttrPairs = coalesePlan.output.zip(union.output) - if (outputAttrPairs.forall(pair => pair._1.semanticEquals(pair._2))) { - coalesePlan + if (planGroups.forall(group => group.length == 1)) { + plan.withNewChildren(plan.children.map(visitPlan)) } else { - val reprejectOutputs = outputAttrPairs.map { - case (newAttr, oldAttr) => - if (newAttr.exprId == oldAttr.exprId) { - newAttr + val newUnionClauses = planGroups.map { + groupedPlans => + if (groupedPlans.length == 1) { + groupedPlans.head.plan } else { - Alias(newAttr, oldAttr.name)(oldAttr.exprId, oldAttr.qualifier, None, Seq.empty) + val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get + val aggregates = groupedPlans.map(_.analyzedInfo.get.constructedAggregatePlan.get) + val replaceAttributes = collectreplaceAttributesInExpression(aggregates) + val filterConditions = buildAggregateCasesConditions(aggregates, replaceAttributes) + val firstAggregateFilter = + firstAggregateAnalzyInfo.constructedFilterPlan.get.asInstanceOf[Filter] + + // Add a filter step with condition `cond1 or cond2 or ...`, `cond_i` comes from + // each union clause. Apply this filter on the source plan. + val unionFilter = Filter( + buildUnionConditionForAggregateSource(filterConditions), + firstAggregateAnalzyInfo.extractedSourcePlan.get) + + // Wrap all the attributes into a single structure attribute. + val wrappedAttributesProject = buildProjectFoldIntoStruct( + unionFilter, + groupedPlans, + filterConditions, + replaceAttributes) + + // Build an array which element are response to each union clause. + val arrayProject = + buildProjectBranchArray(wrappedAttributesProject, filterConditions) + + // Explode the array + val explode = buildExplodeBranchArray(arrayProject) + + // Null value means that the union clause does not have the corresponding data. + val notNullFilter = Filter(IsNotNull(explode.output.head), explode) + + // Destruct the struct attribute. + val destructStructProject = buildProjectUnfoldStruct(notNullFilter) + + buildAggregateWithGroupId(destructStructProject, groupedPlans) } } - Project(reprejectOutputs, coalesePlan) + val coalesePlan = if (newUnionClauses.length == 1) { + newUnionClauses.head + } else { + var firstUnionChild = newUnionClauses.head + for (i <- 1 until newUnionClauses.length - 1) { + firstUnionChild = Union(firstUnionChild, newUnionClauses(i)) + } + Union(firstUnionChild, newUnionClauses.last) + } + + // We need to keep the output atrributes same as the original plan. + val outputAttrPairs = coalesePlan.output.zip(union.output) + if (outputAttrPairs.forall(pair => pair._1.semanticEquals(pair._2))) { + coalesePlan + } else { + val reprejectOutputs = outputAttrPairs.map { + case (newAttr, oldAttr) => + if (newAttr.exprId == oldAttr.exprId) { + newAttr + } else { + Alias(newAttr, oldAttr.name)(oldAttr.exprId, oldAttr.qualifier, None, Seq.empty) + } + } + Project(reprejectOutputs, coalesePlan) + } } case _ => plan.withNewChildren(plan.children.map(visitPlan)) } - // newPlan.copyTagsFrom(plan) - newPlan } def isSupportedAggregate(info: AggregateAnalzyInfo): Boolean = { - if (info.hasAggregateWithFilter) { - return false - } - if (!info.aggregatePlan.isDefined) { - return false - } - - if (info.positionInGroupingKeys.exists(_ < 0)) { - return false - } - - // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y` is not supported. - if ( - info.originalAggregate.aggregateExpressions.exists { - e => - val innerExpr = removeAlias(e) - if (hasAggregateExpression(innerExpr)) { - !innerExpr.isInstanceOf[AggregateExpression] && - !innerExpr.children.forall(e => isAggregateExpression(e)) - } else { - false - } - } - ) { - return false - } - - if (!info.aggregatePlan.isDefined) { - return false + !info.hasAggregateWithFilter && + info.constructedAggregatePlan.isDefined && + info.positionInGroupingKeys.forall(_ >= 0) && + info.originalAggregate.aggregateExpressions.forall { + e => + val innerExpr = removeAlias(e) + // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y` is not supported. + if (hasAggregateExpression(innerExpr)) { + innerExpr.isInstanceOf[AggregateExpression] || + innerExpr.children.forall(e => isAggregateExpression(e)) + } else { + true + } } - true } /** @@ -392,61 +371,28 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi * True if the two instances have the same structure, false otherwise. */ def areStructureMatchedAggregate(l: AggregateAnalzyInfo, r: AggregateAnalzyInfo): Boolean = { - val lAggregate = l.aggregatePlan.get.asInstanceOf[Aggregate] - val rAggregate = r.aggregatePlan.get.asInstanceOf[Aggregate] - - // Check aggregate result expressions. need same schema. - if (lAggregate.aggregateExpressions.length != rAggregate.aggregateExpressions.length) { - logError(s"xxx not equal 1") - - return false - } - val allAggregateExpressionAreSame = - lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).forall { - case (lExpr, rExpr) => - if (!lExpr.dataType.equals(rExpr.dataType)) { - false - } else { - (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match { - case (true, true) => - areStructureMatchedExpressions(lExpr, rExpr) - case (false, true) => false - case (true, false) => false - case (false, false) => true - } + val lAggregate = l.constructedAggregatePlan.get.asInstanceOf[Aggregate] + val rAggregate = r.constructedAggregatePlan.get.asInstanceOf[Aggregate] + lAggregate.aggregateExpressions.length == rAggregate.aggregateExpressions.length && + lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).forall { + case (lExpr, rExpr) => + if (!lExpr.dataType.equals(rExpr.dataType)) { + false + } else { + (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match { + case (true, true) => areStructureMatchedExpressions(lExpr, rExpr) + case (false, true) => false + case (true, false) => false + case (false, false) => true } - } - if (!allAggregateExpressionAreSame) { - return false - } - - // Check grouping expressions, need same schema. - if (lAggregate.groupingExpressions.length != rAggregate.groupingExpressions.length) { - logError(s"xxx not equal 3") - return false - } - if ( - l.positionInGroupingKeys.length != - r.positionInGroupingKeys.length - ) { - logError(s"xxx not equal 4") - return false - } - val allSameGroupingKeysRef = l.positionInGroupingKeys - .zip(r.positionInGroupingKeys) - .forall { case (lPos, rPos) => lPos == rPos } - if (!allSameGroupingKeysRef) { - logError(s"xxx not equal 5") - return false - } - - // Must come from same source. - if (!areSameAggregateSource(l.sourcePlan.get, r.sourcePlan.get)) { - logError(s"xxx not same source. ${l.sourcePlan.get}\n${r.sourcePlan.get}") - return false - } - - true + } + } && + lAggregate.groupingExpressions.length == rAggregate.groupingExpressions.length && + l.positionInGroupingKeys.length == r.positionInGroupingKeys.length && + l.positionInGroupingKeys.zip(r.positionInGroupingKeys).forall { + case (lPos, rPos) => lPos == rPos + } && + areSameAggregateSource(l.extractedSourcePlan.get, r.extractedSourcePlan.get) } /* @@ -493,27 +439,46 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def groupStructureMatchedAggregate(union: Union): ArrayBuffer[ArrayBuffer[AnalyzedPlan]] = { - val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]() - collectAllUnionClauses(union).foreach { - case agg: Aggregate => - val analyzedInfo = AggregateAnalzyInfo(agg) - if (isSupportedAggregate(analyzedInfo)) { - if (groupResults.isEmpty) { - groupResults += ArrayBuffer(AnalyzedPlan(agg, Some(analyzedInfo))) + + def tryPutToGroup( + groupResults: ArrayBuffer[ArrayBuffer[AnalyzedPlan]], + agg: Aggregate): Unit = { + val analyzedInfo = AggregateAnalzyInfo(agg) + if (isSupportedAggregate(analyzedInfo)) { + if (groupResults.isEmpty) { + groupResults += ArrayBuffer( + AnalyzedPlan(analyzedInfo.originalAggregate, Some(analyzedInfo))) + } else { + val idx = findStructureMatchedAggregate(groupResults, analyzedInfo) + if (idx != -1) { + groupResults(idx) += AnalyzedPlan( + analyzedInfo.constructedAggregatePlan.get, + Some(analyzedInfo)) } else { - val idx = findStructureMatchedAggregate(groupResults, analyzedInfo) - if (idx != -1) { - groupResults(idx) += AnalyzedPlan(agg, Some(analyzedInfo)) - } else { - groupResults += ArrayBuffer(AnalyzedPlan(agg, Some(analyzedInfo))) - } + groupResults += ArrayBuffer( + AnalyzedPlan(analyzedInfo.constructedAggregatePlan.get, Some(analyzedInfo))) } + } + } else { + val rewrittenPlan = visitPlan(agg) + groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None)) + } + } + + val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]() + collectAllUnionClauses(union).foreach { + case project @ Project(projectList, agg: Aggregate) => + if (projectList.forall(e => e.isInstanceOf[Alias])) { + tryPutToGroup(groupResults, agg) } else { - logError(s"xxx not supported. $agg") - groupResults += ArrayBuffer(AnalyzedPlan(agg, None)) + val rewrittenPlan = visitPlan(project) + groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None)) } + case agg: Aggregate => + tryPutToGroup(groupResults, agg) case other => - groupResults += ArrayBuffer(AnalyzedPlan(other, None)) + val rewrittenPlan = visitPlan(other) + groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None)) } groupResults } @@ -521,17 +486,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def areStructureMatchedExpressions(l: Expression, r: Expression): Boolean = { (l, r) match { case (lAttr: Attribute, rAttr: Attribute) => - logError(s"xxx attr equal: ${lAttr.qualifiedName}, ${rAttr.qualifiedName}") - lAttr.qualifiedName == rAttr.qualifiedName + // The the qualifier may be overwritten by a subquery alias, and make this check fail. + lAttr.qualifiedName.equals(rAttr.qualifiedName) case (lLiteral: Literal, rLiteral: Literal) => lLiteral.value.equals(rLiteral.value) case _ => - if (l.children.length != r.children.length || l.getClass != r.getClass) { - false - } else { - l.children.zip(r.children).forall { - case (lChild, rChild) => areStructureMatchedExpressions(lChild, rChild) - } + l.children.length == r.children.length && + l.getClass == r.getClass && + l.children.zip(r.children).forall { + case (lChild, rChild) => areStructureMatchedExpressions(lChild, rChild) } } } @@ -544,19 +507,22 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi case (lRel: LogicalRelation, rRel: LogicalRelation) => val lTable = lRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") val rTable = rRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") - logError(s"xxx table equal: $lTable, $rTable") lTable.equals(rTable) && lTable.nonEmpty + case (lRef: CTERelationRef, rRelf: CTERelationRef) => + lRef.cteId == rRelf.cteId case (lSubQuery: SubqueryAlias, rSubQuery: SubqueryAlias) => areSameAggregateSource(lSubQuery.child, rSubQuery.child) - case (lChild, rChild) => false + case (lChild, rChild) => + false } } } - def collectReplaceAttributes(groupedPlans: ArrayBuffer[LogicalPlan]): Map[String, Attribute] = { - def findFirstRelation(plan: LogicalPlan): LogicalRelation = { - if (plan.isInstanceOf[LogicalRelation]) { - return plan.asInstanceOf[LogicalRelation] + def collectreplaceAttributesInExpression( + groupedPlans: ArrayBuffer[LogicalPlan]): Map[String, Attribute] = { + def findFirstRelation(plan: LogicalPlan): LogicalPlan = { + if (plan.isInstanceOf[LogicalRelation] || plan.isInstanceOf[CTERelationRef]) { + return plan } else if (plan.children.isEmpty) { return null } else { @@ -586,7 +552,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi replaceMap.toMap } - def replaceAttributes(expression: Expression, replaceMap: Map[String, Attribute]): Expression = { + def replaceAttributesInExpression( + expression: Expression, + replaceMap: Map[String, Attribute]): Expression = { expression.transform { case attr: Attribute => replaceMap.get(attr.qualifiedName) match { @@ -602,7 +570,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi groupedPlans.map { plan => val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] - replaceAttributes(filter.condition, replaceMap) + replaceAttributesInExpression(filter.condition, replaceMap) } } @@ -619,13 +587,13 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi groupedPlans.zipWithIndex.foreach { case (aggregateCase, case_index) => val analyzedInfo = aggregateCase.analyzedInfo.get - val aggregate = analyzedInfo.aggregatePlan.get.asInstanceOf[Aggregate] + val aggregate = analyzedInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate] val structFields = ArrayBuffer[Expression]() var fieldIndex: Int = 0 aggregate.groupingExpressions.foreach { e => structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributes(e, replaceMap) + structFields += replaceAttributesInExpression(e, replaceMap) fieldIndex += 1 } for (i <- 0 until analyzedInfo.positionInGroupingKeys.length) { @@ -633,7 +601,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi if (position >= fieldIndex) { val expr = analyzedInfo.resultGroupingExpressions(i) structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributes(analyzedInfo.resultGroupingExpressions(i), replaceMap) + structFields += replaceAttributesInExpression( + analyzedInfo.resultGroupingExpressions(i), + replaceMap) fieldIndex += 1 } } @@ -652,7 +622,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structFields += Literal( UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributes(child, replaceMap) + structFields += replaceAttributesInExpression(child, replaceMap) fieldIndex += 1 } case combineAgg if hasAggregateExpression(combineAgg) => @@ -663,7 +633,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structFields += Literal( UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributes(other, replaceMap) + structFields += replaceAttributesInExpression(other, replaceMap) fieldIndex += 1 } } @@ -671,12 +641,14 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) structFields += Literal(case_index, IntegerType) - structAttributes += makeAlias(CreateNamedStruct(structFields), s"$casePrefix$case_index") + structAttributes += makeAlias( + CreateNamedStruct(structFields.toSeq), + s"$casePrefix$case_index") } - structAttributes + structAttributes.toSeq } - def buildStructWrapperProject( + def buildProjectFoldIntoStruct( child: LogicalPlan, groupedPlans: ArrayBuffer[AnalyzedPlan], conditions: ArrayBuffer[Expression], @@ -690,7 +662,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Project(ifAttributes, child) } - def buildArrayProject(child: LogicalPlan, conditions: ArrayBuffer[Expression]): LogicalPlan = { + def buildProjectBranchArray( + child: LogicalPlan, + conditions: ArrayBuffer[Expression]): LogicalPlan = { assert( child.output.length == conditions.length, s"Expected same length of output and conditions") @@ -698,7 +672,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Project(Seq(array), child) } - def buildArrayExplode(child: LogicalPlan): LogicalPlan = { + def buildExplodeBranchArray(child: LogicalPlan): LogicalPlan = { assert(child.output.length == 1, s"Expected single output from $child") val array = child.output.head.asInstanceOf[Expression] assert(array.dataType.isInstanceOf[ArrayType], s"Expected ArrayType from $array") @@ -714,18 +688,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi child) } - def buildGroupConditions( - groupedPlans: ArrayBuffer[LogicalPlan], - replaceMap: Map[String, Attribute]): (ArrayBuffer[Expression], Expression) = { - val conditions = groupedPlans.map { - plan => - val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] - replaceAttributes(filter.condition, replaceMap) - } - val unionCond = conditions.reduce(Or) - (conditions, unionCond) - } - def makeAlias(e: Expression, name: String): NamedExpression = { Alias(e, name)( NamedExpression.newExprId, @@ -737,7 +699,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi Seq.empty) } - def buildDestructStructProject(child: LogicalPlan): LogicalPlan = { + def buildProjectUnfoldStruct(child: LogicalPlan): LogicalPlan = { assert(child.output.length == 1, s"Expected single output from $child") val structedData = child.output.head assert( @@ -751,7 +713,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi attributes += Alias(GetStructField(structedData, index), field.name)() index += 1 } - Project(attributes, child) + Project(attributes.toSeq, child) } def buildAggregateWithGroupId( @@ -759,7 +721,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi groupedPlans: ArrayBuffer[AnalyzedPlan]): LogicalPlan = { val attributes = child.output val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get - val aggregateTemplate = firstAggregateAnalzyInfo.aggregatePlan.get.asInstanceOf[Aggregate] + val aggregateTemplate = + firstAggregateAnalzyInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate] val analyzedInfo = groupedPlans.head.analyzedInfo.get val totalGroupingExpressionsCount = @@ -779,11 +742,10 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi e => removeAlias(e) match { case aggExpr if hasAggregateExpression(aggExpr) => - aggregateExpressions += makeAlias( - constructAggregateExpression(aggExpr, attributes, aggregateExpressionIndex), - e.name) - .asInstanceOf[NamedExpression] - aggregateExpressionIndex += aggExpr.children.length + val (newAggExpr, count) = + constructAggregateExpression(aggExpr, attributes, aggregateExpressionIndex) + aggregateExpressions += makeAlias(newAggExpr, e.name).asInstanceOf[NamedExpression] + aggregateExpressionIndex += count case other => val position = normalExpressionPosition(normalExpressionCount) val attr = attributes(position) @@ -792,13 +754,13 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi .asInstanceOf[NamedExpression] } } - Aggregate(groupingExpressions, aggregateExpressions, child) + Aggregate(groupingExpressions.toSeq, aggregateExpressions.toSeq, child) } def constructAggregateExpression( aggExpr: Expression, attributes: Seq[Attribute], - index: Int): Expression = { + index: Int): (Expression, Int) = { aggExpr match { case singleAggExpr: AggregateExpression => val aggFunc = singleAggExpr.aggregateFunction @@ -808,18 +770,24 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } val newAggFunc = aggFunc.withNewChildren(newAggFuncArgs).asInstanceOf[AggregateFunction] - AggregateExpression( + val res = AggregateExpression( newAggFunc, singleAggExpr.mode, singleAggExpr.isDistinct, singleAggExpr.filter, singleAggExpr.resultId) + (res, 1) case combineAggExpr if hasAggregateExpression(combineAggExpr) => - combineAggExpr.withNewChildren( - combineAggExpr.children.map(constructAggregateExpression(_, attributes, index))) - case _ => - val normalExpr = attributes(index) - normalExpr + val childrenExpressions = ArrayBuffer[Expression]() + var totalCount = 0 + combineAggExpr.children.foreach { + child => + val (expr, count) = constructAggregateExpression(child, attributes, totalCount + index) + childrenExpressions += expr + totalCount += count + } + (combineAggExpr.withNewChildren(childrenExpressions.toSeq), totalCount) + case _ => (attributes(index), 1) } } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala index ae5b4b7d5a92..456d95d4b515 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala @@ -16,9 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.config.GlutenConfig -import org.apache.gluten.utils.UTSystemParameters - import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig @@ -42,7 +39,6 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran .set("spark.databricks.delta.properties.defaults.checkpointInterval", "5") .set("spark.databricks.delta.stalenessLimit", "3600000") .set(ClickHouseConfig.CLICKHOUSE_WORKER_ID, "1") - .set(GlutenConfig.GLUTEN_LIB_PATH, UTSystemParameters.clickHouseLibPath) .set("spark.gluten.sql.columnar.iterator", "true") .set("spark.gluten.sql.columnar.hashagg.enablefinal", "true") .set("spark.gluten.sql.enable.native.validation", "false") @@ -154,12 +150,12 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran val sql = """ |select * from ( - | select a, 1 as t, count(x) + sum(y) as y from coalesce_union_t1 where b % 3 = 0 + | select a, 1 as t, count(x) + sum(y) as n from coalesce_union_t1 where b % 3 = 0 | group by a | union all - | select a, 2 as t, count(x) + sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | select a, 2 as t, count(x) + sum(y) as n from coalesce_union_t1 where b % 3 = 1 | group by a - |) order by a, t, y + |) order by a, t, n |""".stripMargin compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) } @@ -240,6 +236,40 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) } + test("coalesce aggregation union. case 10") { + val sql = + """ + |select * from ( + | select a as a, sum(y) as y from ( + | select concat(a, "x") as a, y from coalesce_union_t1 where b % 3 = 0 + | ) group by a + | union all + | select x as a , sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by x + |) order by a, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + + test("coalesce aggregation union. case 11") { + val sql = + """ + |select t1.a, t1.y, t2.x from ( + | select a as a, sum(y) as y from ( + | select concat(a, "x") as a, y from coalesce_union_t1 where b % 3 = 0 + | ) group by a + | union all + | select x as a , sum(y) as y from coalesce_union_t1 where b % 3 = 1 + | group by x + |) as t1 + |left join ( + | select a, x from coalesce_union_t2 + |) as t2 + |on t1.a = t2.a + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + test("no coalesce aggregation union. case 1") { val sql = """ From a18b54873685d6d08df8db76b459635f2adfa9cf Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 7 Feb 2025 14:08:35 +0800 Subject: [PATCH 8/9] wip --- .../extension/CoalesceAggregationUnion.scala | 248 +++++++++--------- 1 file changed, 127 insertions(+), 121 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index 073ccf5d9dda..cada97955b59 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -21,6 +21,8 @@ import org.apache.gluten.exception.GlutenNotSupportException import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -31,6 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success, Try} /* * Example: @@ -92,28 +95,30 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } - case class AggregateAnalzyInfo(originalAggregate: Aggregate) { - - protected def createAttributeToExpressionMap( - attributes: Seq[Attribute], - expressions: Seq[Expression]): Map[ExprId, Expression] = { - val map = new mutable.HashMap[ExprId, Expression]() - attributes.zip(expressions).foreach { - case (attr, expr) => - map.put(attr.exprId, expr) - } - map.toMap - } + def buildAttributesMap( + attributes: Seq[Attribute], + expressions: Seq[Expression]): Map[ExprId, Expression] = { + assert(attributes.length == expressions.length) + val map = new mutable.HashMap[ExprId, Expression]() + attributes.zip(expressions).foreach { + case (attr, expr) => + map.put(attr.exprId, expr) + } + map.toMap + } - protected def replaceAttributesInExpression( - expression: Expression, - replaceMap: Map[ExprId, Expression]): Expression = { - expression.transform { - case attr: Attribute => - replaceMap.getOrElse(attr.exprId, attr.asInstanceOf[Expression]) - } + def replaceAttributes(e: Expression, replaceMap: Map[ExprId, Expression]): Expression = { + e.transform { + case attr: Attribute => + replaceMap.get(attr.exprId) match { + case Some(replaceAttr) => replaceAttr + case None => + throw new GlutenNotSupportException(s"Not found attribute: ${attr.qualifiedName}") + } } + } + case class AggregateAnalzyInfo(originalAggregate: Aggregate) { protected def extractFilter(): Option[Filter] = { originalAggregate.child match { case filter: Filter => Some(filter) @@ -122,13 +127,10 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi subquery.child match { case filter: Filter => Some(filter) case project @ Project(_, filter: Filter) => Some(filter) - case relation: LogicalRelation => Some(Filter(Literal(true, BooleanType), subquery)) - case nestedRelation: SubqueryAlias => - if (nestedRelation.child.isInstanceOf[LogicalRelation]) { - Some(Filter(Literal(true, BooleanType), nestedRelation)) - } else { - None - } + case relation if isRelation(relation) => + Some(Filter(Literal(true, BooleanType), subquery)) + case nestedRelation: SubqueryAlias if (isRelation(nestedRelation.child)) => + Some(Filter(Literal(true, BooleanType), nestedRelation)) case _ => None } case _ => None @@ -160,8 +162,10 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } val newFilter = project match { case Some(project) => - val replaceMap = createAttributeToExpressionMap(project.output, project.child.output) - val newCondition = replaceAttributesInExpression(filter.get.condition, replaceMap) + val replaceMap = buildAttributesMap( + project.output, + project.child.output.map(_.asInstanceOf[Expression])) + val newCondition = replaceAttributes(filter.get.condition, replaceMap) Filter(newCondition, extractedSourcePlan.get) case None => filter.get.withNewChildren(Seq(extractedSourcePlan.get)) } @@ -185,13 +189,14 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi val newAggregate = project match { case Some(innerProject) => - val replaceMap = - createAttributeToExpressionMap(innerProject.output, innerProject.projectList) + val replaceMap = buildAttributesMap( + innerProject.output, + innerProject.projectList.map(_.asInstanceOf[Expression])) val newGroupExpressions = originalAggregate.groupingExpressions.map { - e => replaceAttributesInExpression(e, replaceMap) + e => replaceAttributes(e, replaceMap) } val newAggregateExpressions = originalAggregate.aggregateExpressions.map { - e => replaceAttributesInExpression(e, replaceMap).asInstanceOf[NamedExpression] + e => replaceAttributes(e, replaceMap).asInstanceOf[NamedExpression] } Aggregate(newGroupExpressions, newAggregateExpressions, constructedFilterPlan.get) case None => originalAggregate.withNewChildren(Seq(constructedFilterPlan.get)) @@ -256,7 +261,14 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, "true") .toBoolean ) { - visitPlan(plan) + Try { + visitPlan(plan) + } match { + case Success(res) => res + case Failure(e) => + logError(s"Failed to rewrite plan. ", e) + plan + } } else { plan } @@ -276,8 +288,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } else { val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get val aggregates = groupedPlans.map(_.analyzedInfo.get.constructedAggregatePlan.get) - val replaceAttributes = collectreplaceAttributesInExpression(aggregates) - val filterConditions = buildAggregateCasesConditions(aggregates, replaceAttributes) + val filterConditions = buildAggregateCasesConditions(groupedPlans) val firstAggregateFilter = firstAggregateAnalzyInfo.constructedFilterPlan.get.asInstanceOf[Filter] @@ -288,11 +299,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi firstAggregateAnalzyInfo.extractedSourcePlan.get) // Wrap all the attributes into a single structure attribute. - val wrappedAttributesProject = buildProjectFoldIntoStruct( - unionFilter, - groupedPlans, - filterConditions, - replaceAttributes) + val wrappedAttributesProject = + buildProjectFoldIntoStruct(unionFilter, groupedPlans, filterConditions) // Build an array which element are response to each union clause. val arrayProject = @@ -340,6 +348,32 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } + def isRelation(plan: LogicalPlan): Boolean = { + plan.isInstanceOf[MultiInstanceRelation] + } + + def areSameRelation(l: LogicalPlan, r: LogicalPlan): Boolean = { + (l, r) match { + case (lRelation: LogicalRelation, rRelation: LogicalRelation) => + val lTable = lRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("") + val rTable = rRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("") + lRelation.output.length == rRelation.output.length && + lRelation.output.zip(rRelation.output).forall { + case (lAttr, rAttr) => + lAttr.dataType.equals(rAttr.dataType) && lAttr.name.equals(rAttr.name) + } && + lTable.equals(rTable) && lTable.nonEmpty + case (lCTE: CTERelationRef, rCTE: CTERelationRef) => + lCTE.cteId == rCTE.cteId + case (lHiveTable: HiveTableRelation, rHiveTable: HiveTableRelation) => + lHiveTable.tableMeta.identifier.unquotedString + .equals(rHiveTable.tableMeta.identifier.unquotedString) + case (_, _) => + logInfo(s"xxx unknow relation: ${l.getClass}, ${r.getClass}") + false + } + } + def isSupportedAggregate(info: AggregateAnalzyInfo): Boolean = { !info.hasAggregateWithFilter && info.constructedAggregatePlan.isDefined && @@ -421,7 +455,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi planWithIndex._1.head.analyzedInfo.get, analyzedInfo)) match { case Some((_, i)) => i - case None => -1 + case None => + logError(s"xxx not found match plan: ${analyzedInfo.originalAggregate}") + -1 } } @@ -460,6 +496,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } } else { + logError(s"xxx not supported plan: $agg") val rewrittenPlan = visitPlan(agg) groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None)) } @@ -484,18 +521,22 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def areStructureMatchedExpressions(l: Expression, r: Expression): Boolean = { - (l, r) match { - case (lAttr: Attribute, rAttr: Attribute) => - // The the qualifier may be overwritten by a subquery alias, and make this check fail. - lAttr.qualifiedName.equals(rAttr.qualifiedName) - case (lLiteral: Literal, rLiteral: Literal) => - lLiteral.value.equals(rLiteral.value) - case _ => - l.children.length == r.children.length && - l.getClass == r.getClass && - l.children.zip(r.children).forall { - case (lChild, rChild) => areStructureMatchedExpressions(lChild, rChild) - } + if (l.dataType.equals(r.dataType)) { + (l, r) match { + case (lAttr: Attribute, rAttr: Attribute) => + // The the qualifier may be overwritten by a subquery alias, and make this check fail. + lAttr.qualifiedName.equals(rAttr.qualifiedName) + case (lLiteral: Literal, rLiteral: Literal) => + lLiteral.value == rLiteral.value + case _ => + l.children.length == r.children.length && + l.getClass == r.getClass && + l.children.zip(r.children).forall { + case (lChild, rChild) => areStructureMatchedExpressions(lChild, rChild) + } + } + } else { + false } } @@ -504,73 +545,36 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi false } else { lPlan.children.zip(rPlan.children).forall { - case (lRel: LogicalRelation, rRel: LogicalRelation) => - val lTable = lRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") - val rTable = rRel.catalogTable.map(_.identifier.unquotedString).getOrElse("") - lTable.equals(rTable) && lTable.nonEmpty - case (lRef: CTERelationRef, rRelf: CTERelationRef) => - lRef.cteId == rRelf.cteId + case (lRelation, rRelation) if (isRelation(lRelation) && isRelation(rRelation)) => + areSameRelation(lRelation, rRelation) case (lSubQuery: SubqueryAlias, rSubQuery: SubqueryAlias) => areSameAggregateSource(lSubQuery.child, rSubQuery.child) - case (lChild, rChild) => - false + case (lproject: Project, rproject: Project) => + lproject.projectList.length == rproject.projectList.length && + lproject.projectList.zip(rproject.projectList).forall { + case (lExpr, rExpr) => areStructureMatchedExpressions(lExpr, rExpr) + } && + areSameAggregateSource(lproject.child, rproject.child) + case (lFilter: Filter, rFilter: Filter) => + areStructureMatchedExpressions(lFilter.condition, rFilter.condition) && + areSameAggregateSource(lFilter.child, rFilter.child) + case (lChild, rChild) => false } } } - def collectreplaceAttributesInExpression( - groupedPlans: ArrayBuffer[LogicalPlan]): Map[String, Attribute] = { - def findFirstRelation(plan: LogicalPlan): LogicalPlan = { - if (plan.isInstanceOf[LogicalRelation] || plan.isInstanceOf[CTERelationRef]) { - return plan - } else if (plan.children.isEmpty) { - return null - } else { - plan.children.foreach { - child => - val rel = findFirstRelation(child) - if (rel != null) { - return rel - } - } - return null - } - } - val replaceMap = new mutable.HashMap[String, Attribute]() - val firstFilter = groupedPlans.head.asInstanceOf[Aggregate].child.asInstanceOf[Filter] - val qualifierPrefix = - firstFilter.output.find(e => e.qualifier.nonEmpty).head.qualifier.mkString(".") - val firstRelation = findFirstRelation(firstFilter.child) - if (firstRelation == null) { - throw new GlutenNotSupportException(s"Not found relation in plan: $firstFilter") - } - firstRelation.output.foreach { - attr => - val qualifiedName = s"$qualifierPrefix.${attr.name}" - replaceMap.put(qualifiedName, attr) - } - replaceMap.toMap - } - - def replaceAttributesInExpression( - expression: Expression, - replaceMap: Map[String, Attribute]): Expression = { - expression.transform { - case attr: Attribute => - replaceMap.get(attr.qualifiedName) match { - case Some(replaceAttr) => replaceAttr - case None => attr - } - } - } - def buildAggregateCasesConditions( - groupedPlans: ArrayBuffer[LogicalPlan], - replaceMap: Map[String, Attribute]): ArrayBuffer[Expression] = { + groupedPlans: ArrayBuffer[AnalyzedPlan]): ArrayBuffer[Expression] = { + val firstPlanSourceOutputAttrs = + groupedPlans.head.analyzedInfo.get.extractedSourcePlan.get.output groupedPlans.map { plan => - val filter = plan.asInstanceOf[Aggregate].child.asInstanceOf[Filter] - replaceAttributesInExpression(filter.condition, replaceMap) + val attrsMap = + buildAttributesMap( + plan.analyzedInfo.get.extractedSourcePlan.get.output, + firstPlanSourceOutputAttrs) + val filter = plan.analyzedInfo.get.constructedFilterPlan.get.asInstanceOf[Filter] + replaceAttributes(filter.condition, attrsMap) } } @@ -579,21 +583,24 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def wrapAggregatesAttributesInStructs( - groupedPlans: ArrayBuffer[AnalyzedPlan], - replaceMap: Map[String, Attribute]): Seq[NamedExpression] = { + groupedPlans: ArrayBuffer[AnalyzedPlan]): Seq[NamedExpression] = { val structAttributes = ArrayBuffer[NamedExpression]() val casePrefix = "case_" val structPrefix = "field_" + val firstSourceAttrs = groupedPlans.head.analyzedInfo.get.extractedSourcePlan.get.output groupedPlans.zipWithIndex.foreach { case (aggregateCase, case_index) => val analyzedInfo = aggregateCase.analyzedInfo.get val aggregate = analyzedInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate] val structFields = ArrayBuffer[Expression]() var fieldIndex: Int = 0 + val attrReplaceMap = buildAttributesMap( + aggregateCase.analyzedInfo.get.extractedSourcePlan.get.output, + firstSourceAttrs) aggregate.groupingExpressions.foreach { e => structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributesInExpression(e, replaceMap) + structFields += replaceAttributes(e, attrReplaceMap) fieldIndex += 1 } for (i <- 0 until analyzedInfo.positionInGroupingKeys.length) { @@ -601,9 +608,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi if (position >= fieldIndex) { val expr = analyzedInfo.resultGroupingExpressions(i) structFields += Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributesInExpression( + structFields += replaceAttributes( analyzedInfo.resultGroupingExpressions(i), - replaceMap) + attrReplaceMap) fieldIndex += 1 } } @@ -622,7 +629,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structFields += Literal( UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributesInExpression(child, replaceMap) + structFields += replaceAttributes(child, attrReplaceMap) fieldIndex += 1 } case combineAgg if hasAggregateExpression(combineAgg) => @@ -633,7 +640,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi structFields += Literal( UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType) - structFields += replaceAttributesInExpression(other, replaceMap) + structFields += replaceAttributes(other, attrReplaceMap) fieldIndex += 1 } } @@ -651,9 +658,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi def buildProjectFoldIntoStruct( child: LogicalPlan, groupedPlans: ArrayBuffer[AnalyzedPlan], - conditions: ArrayBuffer[Expression], - replaceMap: Map[String, Attribute]): LogicalPlan = { - val wrappedAttributes = wrapAggregatesAttributesInStructs(groupedPlans, replaceMap) + conditions: ArrayBuffer[Expression]): LogicalPlan = { + val wrappedAttributes = wrapAggregatesAttributesInStructs(groupedPlans) val ifAttributes = wrappedAttributes.zip(conditions).map { case (attr, condition) => makeAlias(If(condition, attr, Literal(null, attr.dataType)), attr.name) From f358239c9aaca44fa0936541eacce956890d5864 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 10 Feb 2025 11:51:51 +0800 Subject: [PATCH 9/9] wip --- .../extension/CoalesceAggregationUnion.scala | 47 +++++++++++++------ .../GlutenCoalesceAggregationUnionSuite.scala | 24 ++++++++++ 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala index cada97955b59..4a83830e5137 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala @@ -108,13 +108,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def replaceAttributes(e: Expression, replaceMap: Map[ExprId, Expression]): Expression = { - e.transform { + e match { case attr: Attribute => replaceMap.get(attr.exprId) match { case Some(replaceAttr) => replaceAttr case None => - throw new GlutenNotSupportException(s"Not found attribute: ${attr.qualifiedName}") + throw new GlutenNotSupportException(s"Not found attribute: $attr ${attr.qualifiedName}") } + case _ => + e.withNewChildren(e.children.map(replaceAttributes(_, replaceMap))) } } @@ -137,6 +139,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } + def isValidSource(plan: LogicalPlan): Boolean = { + plan match { + case relation if isRelation(relation) => true + case _: Project | _: Filter | _: SubqueryAlias => + plan.children.forall(isValidSource) + case _ => false + } + } + // Try to make the plan simple, contain only three steps, source, filter, aggregate. lazy val extractedSourcePlan = { val filter = extractFilter() @@ -144,8 +155,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi None } else { filter.get.child match { - case project: Project => Some(project.child) - case other => Some(other) + case project: Project if isValidSource(project.child) => Some(project.child) + case other if isValidSource(other) => Some(other) + case _ => None } } } @@ -255,19 +267,24 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi */ case class AnalyzedPlan(plan: LogicalPlan, analyzedInfo: Option[AggregateAnalzyInfo]) + def isResolvedPlan(plan: LogicalPlan): Boolean = { + plan match { + case isnert: InsertIntoStatement => isnert.query.resolved + case _ => plan.resolved + } + } + override def apply(plan: LogicalPlan): LogicalPlan = { if ( - plan.resolved && spark.conf + spark.conf .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, "true") - .toBoolean + .toBoolean && isResolvedPlan(plan) ) { Try { visitPlan(plan) } match { case Success(res) => res - case Failure(e) => - logError(s"Failed to rewrite plan. ", e) - plan + case Failure(e) => plan } } else { plan @@ -375,6 +392,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } def isSupportedAggregate(info: AggregateAnalzyInfo): Boolean = { + !info.hasAggregateWithFilter && info.constructedAggregatePlan.isDefined && info.positionInGroupingKeys.forall(_ >= 0) && @@ -388,7 +406,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } else { true } - } + } && + info.extractedSourcePlan.isDefined } /** @@ -455,9 +474,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi planWithIndex._1.head.analyzedInfo.get, analyzedInfo)) match { case Some((_, i)) => i - case None => - logError(s"xxx not found match plan: ${analyzedInfo.originalAggregate}") - -1 + case None => -1 } } @@ -496,7 +513,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi } } } else { - logError(s"xxx not supported plan: $agg") val rewrittenPlan = visitPlan(agg) groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None)) } @@ -528,6 +544,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi lAttr.qualifiedName.equals(rAttr.qualifiedName) case (lLiteral: Literal, rLiteral: Literal) => lLiteral.value == rLiteral.value + case (lagg: AggregateExpression, ragg: AggregateExpression) => + lagg.isDistinct == ragg.isDistinct && + areStructureMatchedExpressions(lagg.aggregateFunction, ragg.aggregateFunction) case _ => l.children.length == r.children.length && l.getClass == r.getClass && diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala index 456d95d4b515..23c6022727eb 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala @@ -270,6 +270,18 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) } + test("coalesce aggregation union. case 12") { + val sql = + """ + |select a, x, y from ( + | select a, count(distinct x) as x, sum(y) as y from coalesce_union_t1 group by a + | union all + | select a, count(distinct x) as x, sum(y) as y from coalesce_union_t1 group by a + |) order by a, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true) + } + test("no coalesce aggregation union. case 1") { val sql = """ @@ -368,4 +380,16 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) } + test("no coalesce aggregation union. case 8") { + val sql = + """ + |select a, x, y from ( + | select a, count(distinct x) as x, sum(y) as y from coalesce_union_t1 group by a + | union all + | select a, count(x) as x, sum(y) as y from coalesce_union_t1 group by a + |) order by a, x, y + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true) + } + }