Skip to content

Commit

Permalink
Revert unwanted code
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Feb 2, 2020
1 parent 7316cca commit 4f22a7a
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 521 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1814,9 +1814,15 @@ class Analyzer(
}
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction =>
if (filter.isDefined && !filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
// TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT
if (filter.isDefined) {
if (isDistinct) {
failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " +
"at the same time")
} else if (!filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
}
AggregateExpression(agg, Complete, isDistinct, filter)
// This function is not an aggregate function, just return the resolved one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ import org.apache.spark.sql.types.IntegerType
* First example: query without filter clauses (in scala):
* {{{
* val data = Seq(
* (1, "a", "ca1", "cb1", 10),
* (2, "a", "ca1", "cb2", 5),
* (3, "b", "ca1", "cb1", 13))
* .toDF("id", "key", "cat1", "cat2", "value")
* ("a", "ca1", "cb1", 10),
* ("a", "ca1", "cb2", 5),
* ("b", "ca1", "cb1", 13))
* .toDF("key", "cat1", "cat2", "value")
* data.createOrReplaceTempView("data")
*
* val agg = data.groupBy($"key")
Expand Down Expand Up @@ -118,66 +118,7 @@ import org.apache.spark.sql.types.IntegerType
* LocalTableScan [...]
* }}}
*
* Third example: single distinct aggregate function with filter clauses (in sql):
* {{{
* SELECT
* COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt1,
* COUNT(DISTINCT cat1) as cat1_cnt2,
* SUM(value) AS total
* FROM
* data
* GROUP BY
* key
* }}}
*
* This translates to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1),
* COUNT(DISTINCT 'cat1),
* sum('value)]
* output = ['key, 'cat1_cnt1, 'cat1_cnt2, 'total])
* LocalTableScan [...]
* }}}
*
* This rule rewrites this logical plan to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) '_gen_distinct_1 else null),
* count(if (('gid = 2)) '_gen_distinct_2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt1, 'cat1_cnt2, 'total])
* Aggregate(
* key = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid]
* functions = [sum('value)]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, 'value),
* ('key, '_gen_distinct_1, null, 1, null),
* ('key, null, '_gen_distinct_2, 2, null)]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value])
* Expand(
* projections = [('key, if ('id > 1) 'cat1 else null, 'cat1, cast('value as bigint))]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'value])
* LocalTableScan [...]
* }}}
*
* The rule serves two purposes:
* 1. Expand distinct aggregates which exists filter clause.
* 2. Rewrite when aggregate exists at least two distinct aggregates.
*
* The first child rule does the following things here:
* 1. Guaranteed to compute filter clause locally.
* 2. The attributes referenced by different distinct aggregate expressions are likely to overlap,
* and if no additional processing is performed, data loss will occur. To prevent this, we
* generate new attributes and replace the original ones.
* 3. If we apply the first rule to distinct aggregate expressions which exists filter
* clause, the aggregate after expand may have at least two distinct aggregates, so we need to
* apply the second rule too.
*
* The second child rule does the following things here:
* The rule does the following things here:
* 1. Expand the data. There are three aggregation groups in this query:
* i. the non-distinct group;
* ii. the distinct 'cat1 group;
Expand Down Expand Up @@ -207,106 +148,24 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val distinctAggs = exprs.flatMap { _.collect {
case ae: AggregateExpression if ae.isDistinct => ae
}}
// This rule serves two purposes:
// One is to rewrite when there exists at least two distinct aggregates. We need at least
// two distinct aggregates for this rule because aggregation strategy can handle a single
// distinct group.
// Another is to expand distinct aggregates which exists filter clause so that we can
// evaluate the filter locally.
// We need at least two distinct aggregates for this rule because aggregation
// strategy can handle a single distinct group.
// This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a).
distinctAggs.size >= 1 || distinctAggs.exists(_.filter.isDefined)
distinctAggs.size > 1
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) =>
val expandAggregate = extractFiltersInDistinctAggregate(a)
rewriteDistinctAggregate(expandAggregate)
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a)
}

private def extractFiltersInDistinctAggregate(a: Aggregate): Aggregate = {
val aggExpressions = collectAggregateExprs(a)
val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct)
if (distinctAggExpressions.exists(_.filter.isDefined)) {
// Setup expand for the 'regular' aggregate expressions. Because we will construct a new
// aggregate, the children of the distinct aggregates will be changed to the generate
// ones, so we need creates new references to avoid collisions between distinct and
// regular aggregate children.
val regularAggExprs = regularAggExpressions.filter(_.children.exists(!_.foldable))
val regularFunChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggMap = regularAggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c))
val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val filterOpt = filter.map(_.transform {
case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a)
})
val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt)
(ae, aggExpr)
}
def rewrite(a: Aggregate): Aggregate = {

// Setup expand for the 'distinct' aggregate expressions.
val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable))
val (projections, expressionAttrs, aggExprPairs) = distinctAggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
// Why do we need to construct the `exprId` ?
// First, In order to reduce costs, it is better to handle the filter clause locally.
// e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression
// If(id > 1) 'a else null first, and use the result as output.
// Second, If at least two DISTINCT aggregate expression which may references the
// same attributes. We need to construct the generated attributes so as the output not
// lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output
// attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a.
// Note: We just need to illusion the expression with filter clause.
// The illusionary mechanism may result in multiple distinct aggregations uses
// different column, so we still need to call `rewrite`.
val exprId = NamedExpression.newExprId.id
val unfoldableChildren = af.children.filter(!_.foldable)
val exprAttrs = unfoldableChildren.map { e =>
(e, AttributeReference(s"_gen_distinct_$exprId", e.dataType, nullable = true)())
}
val exprAttrLookup = exprAttrs.toMap
val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c))
val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val aggExpr = ae.copy(aggregateFunction = raf, filter = None)
// Expand projection
val projection = unfoldableChildren.map {
case e if filter.isDefined => If(filter.get, e, nullify(e))
case e => e
}
(projection, exprAttrs, (ae, aggExpr))
}.unzip3
val distinctAggChildAttrs = expressionAttrs.flatten.map(_._2)
val allAggAttrs = regularAggChildAttrMap.map(_._2) ++ distinctAggChildAttrs
// Construct the aggregate input projection.
val rewriteAggProjections =
Seq(a.groupingExpressions ++ regularAggChildren ++ projections.flatten)
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)()
}
val groupByAttrs = groupByMap.map(_._2)
// Construct the expand operator.
val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, a.child)
val rewriteAggExprLookup = (aggExprPairs ++ regularAggMap).toMap
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae)
}.asInstanceOf[NamedExpression]
// Collect all aggregate expressions.
val aggExpressions = a.aggregateExpressions.flatMap { e =>
e.collect {
case ae: AggregateExpression => ae
}
val expandAggregate = Aggregate(groupByAttrs, patchedAggExpressions, expand)
expandAggregate
} else {
a
}
}

private def rewriteDistinctAggregate(a: Aggregate): Aggregate = {
val aggExpressions = collectAggregateExprs(a)

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
Expand Down Expand Up @@ -472,14 +331,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
}

private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = {
a.aggregateExpressions.flatMap { e =>
e.collect {
case ae: AggregateExpression => ae
}
}
}

private def nullify(e: Expression) = Literal.create(null, e.dataType)

private def expressionAttributePair(e: Expression) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ class AnalysisErrorSuite extends AnalysisTest {
"FILTER (WHERE c > 1)"),
"FILTER predicate specified, but aggregate is not an aggregate function" :: Nil)

errorTest(
"DISTINCT and FILTER cannot be used in aggregate functions at the same time",
CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"),
"DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil)

errorTest(
"FILTER expression is non-deterministic, it cannot be used in aggregate functions",
CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"),
Expand Down
39 changes: 14 additions & 25 deletions sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp;
SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp;
SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp;
SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp;
SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp;
SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp;
SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp;
SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp;
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp;

-- Aggregate with filter and non-empty GroupBy expressions.
SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a;
Expand All @@ -46,10 +44,8 @@ SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;
SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id;
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;

-- Aggregate with filter and grouped by literals.
SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1;
Expand All @@ -62,21 +58,13 @@ select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary),
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
select dept_id, sum(distinct (id + dept_id))) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id;
select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id;
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;

-- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE.
SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1;
Expand All @@ -90,8 +78,9 @@ SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1;
SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1;

-- Aggregate with filter, foldable input and multiple distinct groups.
SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2)
FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2)
-- FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;

-- Check analysis exceptions
SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k;
Expand Down
Loading

0 comments on commit 4f22a7a

Please sign in to comment.