Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Feb 10, 2025
1 parent 849fbcb commit 85083f0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi
}

def replaceAttributes(e: Expression, replaceMap: Map[ExprId, Expression]): Expression = {
e.transform {
e match {
case attr: Attribute =>
replaceMap.get(attr.exprId) match {
case Some(replaceAttr) => replaceAttr
case None =>
throw new GlutenNotSupportException(s"Not found attribute: ${attr.qualifiedName}")
throw new GlutenNotSupportException(s"Not found attribute: $attr ${attr.qualifiedName}")
}
case _ =>
e.withNewChildren(e.children.map(replaceAttributes(_, replaceMap)))
}
}

Expand All @@ -137,15 +139,25 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi
}
}

def isValidSource(plan: LogicalPlan): Boolean = {
plan match {
case relation if isRelation(relation) => true
case _: Project | _: Filter | _: SubqueryAlias =>
plan.children.forall(isValidSource)
case _ => false
}
}

// Try to make the plan simple, contain only three steps, source, filter, aggregate.
lazy val extractedSourcePlan = {
val filter = extractFilter()
if (!filter.isDefined) {
None
} else {
filter.get.child match {
case project: Project => Some(project.child)
case other => Some(other)
case project: Project if isValidSource(project.child) => Some(project.child)
case other if isValidSource(other) => Some(other)
case _ => None
}
}
}
Expand Down Expand Up @@ -255,19 +267,24 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi
*/
case class AnalyzedPlan(plan: LogicalPlan, analyzedInfo: Option[AggregateAnalzyInfo])

def isResolvedPlan(plan: LogicalPlan): Boolean = {
plan match {
case isnert: InsertIntoStatement => isnert.query.resolved
case _ => plan.resolved
}
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (
plan.resolved && spark.conf
spark.conf
.get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, "true")
.toBoolean
.toBoolean && isResolvedPlan(plan)
) {
Try {
visitPlan(plan)
} match {
case Success(res) => res
case Failure(e) =>
logError(s"Failed to rewrite plan. ", e)
plan
case Failure(e) => plan
}
} else {
plan
Expand Down Expand Up @@ -375,6 +392,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi
}

def isSupportedAggregate(info: AggregateAnalzyInfo): Boolean = {

!info.hasAggregateWithFilter &&
info.constructedAggregatePlan.isDefined &&
info.positionInGroupingKeys.forall(_ >= 0) &&
Expand All @@ -388,7 +406,8 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi
} else {
true
}
}
} &&
info.extractedSourcePlan.isDefined
}

/**
Expand Down Expand Up @@ -455,9 +474,7 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi
planWithIndex._1.head.analyzedInfo.get,
analyzedInfo)) match {
case Some((_, i)) => i
case None =>
logError(s"xxx not found match plan: ${analyzedInfo.originalAggregate}")
-1
case None => -1
}

}
Expand Down Expand Up @@ -496,7 +513,6 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi
}
}
} else {
logError(s"xxx not supported plan: $agg")
val rewrittenPlan = visitPlan(agg)
groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None))
}
Expand Down Expand Up @@ -528,6 +544,9 @@ class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] wi
lAttr.qualifiedName.equals(rAttr.qualifiedName)
case (lLiteral: Literal, rLiteral: Literal) =>
lLiteral.value == rLiteral.value
case (lagg: AggregateExpression, ragg: AggregateExpression) =>
lagg.isDistinct == ragg.isDistinct &&
areStructureMatchedExpressions(lagg.aggregateFunction, ragg.aggregateFunction)
case _ =>
l.children.length == r.children.length &&
l.getClass == r.getClass &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran
compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
}

test("coalesce aggregation union. case 11") {
val sql =
"""
|select a, x, y from (
| select a, count(distinct x) as x, sum(y) as y from coalesce_union_t1 group by a
| union all
| select a, count(distinct x) as x, sum(y) as y from coalesce_union_t1 group by a
|) order by a, x, y
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
}

test("no coalesce aggregation union. case 1") {
val sql =
"""
Expand Down Expand Up @@ -368,4 +380,16 @@ class GlutenCoalesceAggregationUnionSuite extends GlutenClickHouseWholeStageTran
compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
}

test("coalesce aggregation union. case 8") {
val sql =
"""
|select a, x, y from (
| select a, count(distinct x) as x, sum(y) as y from coalesce_union_t1 group by a
| union all
| select a, count(x) as x, sum(y) as y from coalesce_union_t1 group by a
|) order by a, x, y
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
}

}

0 comments on commit 85083f0

Please sign in to comment.