Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
commit

commit

commit

commit

commit

commit

commit

fix tests

commit

commit

commit

commit

commit

commit
  • Loading branch information
dtenedor committed Nov 18, 2024
1 parent 33378a6 commit 09b04ed
Show file tree
Hide file tree
Showing 13 changed files with 1,334 additions and 108 deletions.
10 changes: 8 additions & 2 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3989,9 +3989,15 @@
],
"sqlState" : "42K03"
},
"PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION" : {
"PIPE_OPERATOR_AGGREGATE_EXPRESSION_CONTAINS_NO_AGGREGATE_FUNCTION" : {
"message" : [
"Aggregate function <expr> is not allowed when using the pipe operator |> SELECT clause; please use the pipe operator |> AGGREGATE clause instead"
"Non-grouping expression <expr> is provided as an argument to the |> AGGREGATE pipe operator but does not contain any aggregate function; please update it to include an aggregate function and then retry the query again."
],
"sqlState" : "0A000"
},
"PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION" : {
"message" : [
"Aggregate function <expr> is not allowed when using the pipe operator |> <clause> clause; please use the pipe operator |> AGGREGATE clause instead."
],
"sqlState" : "0A000"
},
Expand Down
1 change: 1 addition & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ Below is a list of all the keywords in Spark SQL.
|EXISTS|non-reserved|non-reserved|reserved|
|EXPLAIN|non-reserved|non-reserved|non-reserved|
|EXPORT|non-reserved|non-reserved|non-reserved|
|EXTEND|non-reserved|non-reserved|non-reserved|
|EXTENDED|non-reserved|non-reserved|non-reserved|
|EXTERNAL|non-reserved|non-reserved|reserved|
|EXTRACT|non-reserved|non-reserved|reserved|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ EXCLUDE: 'EXCLUDE';
EXISTS: 'EXISTS';
EXPLAIN: 'EXPLAIN';
EXPORT: 'EXPORT';
EXTEND: 'EXTEND';
EXTENDED: 'EXTENDED';
EXTERNAL: 'EXTERNAL';
EXTRACT: 'EXTRACT';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,10 @@ version

operatorPipeRightSide
: selectClause windowClause?
| EXTEND extendList=namedExpressionSeq
| SET operatorPipeSetAssignmentSeq
| DROP identifierSeq
| AS errorCapturingIdentifier
// Note that the WINDOW clause is not allowed in the WHERE pipe operator, but we add it here in
// the grammar simply for purposes of catching this invalid syntax and throwing a specific
// dedicated error message.
Expand All @@ -1519,6 +1523,11 @@ operatorPipeRightSide
| AGGREGATE namedExpressionSeq? aggregationClause?
;

operatorPipeSetAssignmentSeq
: ident+=errorCapturingIdentifier EQ expression
(COMMA ident+=errorCapturingIdentifier EQ expression)*
;

// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL.
// - Reserved keywords:
// Keywords that are reserved and can't be used as identifiers for table, view, column,
Expand Down Expand Up @@ -1617,6 +1626,7 @@ ansiNonReserved
| EXISTS
| EXPLAIN
| EXPORT
| EXTEND
| EXTENDED
| EXTERNAL
| EXTRACT
Expand Down Expand Up @@ -1963,6 +1973,7 @@ nonReserved
| EXISTS
| EXPLAIN
| EXPORT
| EXTEND
| EXTENDED
| EXTERNAL
| EXTRACT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.trees.TreePattern.{PIPE_OPERATOR_SELECT, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.errors.QueryCompilationErrors

/**
* Represents a SELECT clause when used with the |> SQL pipe operator.
* We use this to make sure that no aggregate functions exist in the SELECT expressions.
* Represents an expression when used with the SQL pipe operators |> SELECT, |> EXTEND, or |> SET.
* We use this to make sure that no aggregate functions exist in these expressions.
*/
case class PipeSelect(child: Expression)
case class PipeSelect(child: Expression, clause: String = "SELECT")
extends UnaryExpression with RuntimeReplaceable {
final override val nodePatterns: Seq[TreePattern] = Seq(PIPE_OPERATOR_SELECT, RUNTIME_REPLACEABLE)
override def withNewChildInternal(newChild: Expression): Expression = PipeSelect(newChild)
Expand All @@ -35,7 +35,7 @@ case class PipeSelect(child: Expression)
// If we used the pipe operator |> SELECT clause to specify an aggregate function, this is
// invalid; return an error message instructing the user to use the pipe operator
// |> AGGREGATE clause for this purpose instead.
throw QueryCompilationErrors.pipeOperatorSelectContainsAggregateFunction(a)
throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction(a, clause)
case _: WindowExpression =>
// Window functions are allowed in pipe SELECT operators, so do not traverse into children.
case _ =>
Expand All @@ -46,13 +46,42 @@ case class PipeSelect(child: Expression)
}
}

/**
* Represents an expression when used with the SQL pipe operator |> AGGREGATE.
* We use this to make sure that at least one aggregate function exists in each of these
* expressions.
*/
case class PipeAggregate(child: Expression) extends UnaryExpression with RuntimeReplaceable {
final override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
override def withNewChildInternal(newChild: Expression): Expression = PipeAggregate(newChild)
override lazy val replacement: Expression = {
var foundAggregate = false
def visit(e: Expression): Unit = {
e match {
case _: AggregateFunction =>
foundAggregate = true
case _ =>
e.children.foreach(visit)
}
}
visit(child)
if (!foundAggregate) {
throw QueryCompilationErrors.pipeOperatorAggregateExpressionContainsNoAggregateFunction(child)
}
child
}
}

object PipeOperators {
// These are definitions of query result clauses that can be used with the pipe operator.
val clusterByClause = "CLUSTER BY"
val distributeByClause = "DISTRIBUTE BY"
val extendClause = "EXTEND"
val limitClause = "LIMIT"
val offsetClause = "OFFSET"
val orderByClause = "ORDER BY"
val selectClause = "SELECT"
val setClause = "SET"
val sortByClause = "SORT BY"
val sortByDistributeByClause = "SORT BY ... DISTRIBUTE BY ..."
val windowClause = "WINDOW"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser
import java.util.Locale
import java.util.concurrent.TimeUnit

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set}
import scala.jdk.CollectionConverters._
import scala.util.{Left, Right}
Expand Down Expand Up @@ -5901,6 +5902,61 @@ class AstBuilder extends DataTypeAstBuilder
windowClause = ctx.windowClause,
relation = left,
isPipeOperatorSelect = true)
}.getOrElse(Option(ctx.EXTEND).map { _ =>
val extendExpressions: Seq[NamedExpression] =
Option(ctx.extendList).map { n: NamedExpressionSeqContext =>
val visited = visitNamedExpressionSeq(n)
visited.map {
case (a: Alias, _) =>
a.copy(child = PipeSelect(a.child, PipeOperators.extendClause))(
a.exprId, a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys)
case (e: Expression, aliasFunc) =>
UnresolvedAlias(PipeSelect(e, PipeOperators.extendClause), aliasFunc)
}
}.get
val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) ++ extendExpressions
Project(projectList, left)
}.getOrElse(Option(ctx.SET).map { _ =>
val (setIdentifiers: Seq[String], setTargets: Seq[Expression]) =
visitOperatorPipeSetAssignmentSeq(ctx.operatorPipeSetAssignmentSeq())
var plan = left
val visitedSetIdentifiers = mutable.Set.empty[String]
setIdentifiers.zip(setTargets).foreach {
case (_, _: Alias) =>
operationNotAllowed(
"SQL pipe syntax |> SET operator with an alias assigned with [AS] aliasName", ctx)
case (ident, target) =>
// Check uniqueness of the assignment keys.
val checkKey = if (SQLConf.get.caseSensitiveAnalysis) {
ident.toLowerCase(Locale.ROOT)
} else {
ident
}
if (visitedSetIdentifiers(checkKey)) {
operationNotAllowed(
s"SQL pipe syntax |> SET operator with duplicate assignment key $ident", ctx)
}
visitedSetIdentifiers += checkKey
// Add an UnresolvedStarExcept to exclude the SET expression name from the relation and
// add the new SET expression to the projection list.
// Use a PipeSelect expression to make sure it does not contain any aggregate functions.
plan = Project(
Seq(UnresolvedStarExcept(None, Seq(Seq(ident))),
Alias(PipeSelect(target, PipeOperators.setClause), ident)()), plan)
}
plan
}.getOrElse(Option(ctx.DROP).map { _ =>
var plan = left
visitIdentifierSeq(ctx.identifierSeq()).foreach { ident: String =>
plan = Project(Seq(UnresolvedStarExcept(None, Seq(Seq(ident)))), plan)
}
plan
}.getOrElse(Option(ctx.AS).map { _ =>
val child = left match {
case s: SubqueryAlias => s.child
case _ => left
}
SubqueryAlias(ctx.errorCapturingIdentifier().getText, child)
}.getOrElse(Option(ctx.whereClause).map { c =>
if (ctx.windowClause() != null) {
throw QueryParsingErrors.windowClauseInPipeOperatorWhereClauseNotAllowedError(ctx)
Expand All @@ -5927,9 +5983,17 @@ class AstBuilder extends DataTypeAstBuilder
withQueryResultClauses(c, withSubqueryAlias(), forPipeOperators = true)
}.getOrElse(
visitOperatorPipeAggregate(ctx, left)
))))))))
))))))))))))
}

override def visitOperatorPipeSetAssignmentSeq(
ctx: OperatorPipeSetAssignmentSeqContext): (Seq[String], Seq[Expression]) =
withOrigin(ctx) {
val setIdentifiers: Seq[String] = ctx.errorCapturingIdentifier().asScala.map(_.getText).toSeq
val setTargets: Seq[Expression] = ctx.expression().asScala.map(typedVisit[Expression]).toSeq
(setIdentifiers, setTargets)
}

private def visitOperatorPipeAggregate(
ctx: OperatorPipeRightSideContext, left: LogicalPlan): LogicalPlan = {
assert(ctx.AGGREGATE != null)
Expand All @@ -5941,8 +6005,11 @@ class AstBuilder extends DataTypeAstBuilder
val aggregateExpressions: Seq[NamedExpression] =
Option(ctx.namedExpressionSeq()).map { n: NamedExpressionSeqContext =>
visitNamedExpressionSeq(n).map {
case (e: NamedExpression, _) => e
case (e: Expression, aliasFunc) => UnresolvedAlias(e, aliasFunc)
case (a: Alias, _) =>
a.copy(child = PipeAggregate(a.child))(
a.exprId, a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys)
case (e: Expression, aliasFunc) =>
UnresolvedAlias(PipeAggregate(e), aliasFunc)
}
}.getOrElse(Seq.empty)
Option(ctx.aggregationClause()).map { c: AggregationClauseContext =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4129,14 +4129,23 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
)
}

def pipeOperatorSelectContainsAggregateFunction(expr: Expression): Throwable = {
def pipeOperatorAggregateExpressionContainsNoAggregateFunction(expr: Expression): Throwable = {
new AnalysisException(
errorClass = "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION",
errorClass = "PIPE_OPERATOR_AGGREGATE_EXPRESSION_CONTAINS_NO_AGGREGATE_FUNCTION",
messageParameters = Map(
"expr" -> expr.toString),
origin = expr.origin)
}

def pipeOperatorContainsAggregateFunction(expr: Expression, clause: String): Throwable = {
new AnalysisException(
errorClass = "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION",
messageParameters = Map(
"expr" -> expr.toString,
"clause" -> clause),
origin = expr.origin)
}

def inlineTableContainsScalarSubquery(inlineTable: LogicalPlan): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.SCALAR_SUBQUERY_IN_VALUES",
Expand Down
Loading

0 comments on commit 09b04ed

Please sign in to comment.