Skip to content

Commit

Permalink
Enables struct fields as sub expressions of grouping fields
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Nov 13, 2014
1 parent b9e1c2e commit 7f46532
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
ResolveFunctions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimAliases ::
TrimGroupingAliases ::
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
Expand All @@ -70,6 +70,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
EliminateAnalysisOperators)
)

private def trimAliases(e: Expression) = e.transform { case Alias(c, _) => c }

/**
* Makes sure all attributes and logical plans have been resolved.
*/
Expand All @@ -93,17 +95,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
/**
* Removes no-op Alias expressions from the plan.
*/
object TrimAliases extends Rule[LogicalPlan] {
object TrimGroupingAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Aggregate(groups, aggs, child) =>
Aggregate(
groups.map {
_ transform {
case Alias(c, _) => c
}
},
aggs,
child)
Aggregate(groups.map(trimAliases), aggs, child)
}
}

Expand All @@ -122,10 +117,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case e => e.children.forall(isValidAggregateExpression)
}

aggregateExprs.foreach { e =>
if (!isValidAggregateExpression(e)) {
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}
aggregateExprs.find { e =>
!isValidAggregateExpression(trimAliases(e))
}.foreach { e =>
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}

aggregatePlan
Expand Down Expand Up @@ -328,4 +323,3 @@ object EliminateAnalysisOperators extends Rule[LogicalPlan] {
case Subquery(_, child) => child
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,17 @@ object PartialAggregation {
case other => (other, Alias(other, "PartialGroup")())
}.toMap

def trimAliases(e: Expression) = e.transform { case Alias(c, _) => c }

// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
case e: Expression if namedGroupingExpressions.contains(e) =>
namedGroupingExpressions(e).toAttribute
case e: Expression if namedGroupingExpressions.contains(trimAliases(e)) =>
namedGroupingExpressions(trimAliases(e)).toAttribute
}).asInstanceOf[Seq[NamedExpression]]

val partialComputation =
Expand Down Expand Up @@ -188,7 +192,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
logDebug(s"Considering join on: $condition")
// Find equi-join predicates that can be evaluated before the join, and thus can be used
// as join keys.
val (joinPredicates, otherPredicates) =
val (joinPredicates, otherPredicates) =
condition.map(splitConjunctivePredicates).getOrElse(Nil).partition {
case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
(canEvaluate(l, right) && canEvaluate(r, left)) => true
Expand All @@ -203,7 +207,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
val rightKeys = joinKeys.map(_._2)

if (joinKeys.nonEmpty) {
logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}")
logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right))
} else {
None
Expand Down
12 changes: 11 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil)
}

test("INTERSECT") {
test("INTERSECT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"),
(1, "a") ::
Expand Down Expand Up @@ -942,4 +942,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"),
(1 to 99).map(i => Seq(i)))
}

test("SPARK-4322 Grouping field with struct field as sub expression") {
jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data")
checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1)
dropTempTable("data")

jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2)
dropTempTable("data")
}
}

0 comments on commit 7f46532

Please sign in to comment.