Skip to content

Commit

Permalink
[SPARK-47637][SQL] Use errorCapturingIdentifier in more places
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

errorCapturingIdentifier parses identifier with included '-' to raise INVALID_IDENTIFIER errors
instead of SYNTAX_ERROR for non-delimited identifiers containing a hyphen.

It is meant to be used wherever the context is not that of an expression
This PR replaces a few missed identifiers with that rule.

### Why are the changes needed?

Improve error messages for undelimited identifiers with a hyphen.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added unit tests in ErrorParserSuite.scala

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#45764 from srielau/SPARK-47637-errorCapturingIdentifier.

Authored-by: Serge Rielau <serge@rielau.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
  • Loading branch information
srielau authored and gengliangwang committed Mar 29, 2024
1 parent d182810 commit db14be8
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ statement
| ctes? dmlStatementNoWith #dmlStatement
| USE identifierReference #use
| USE namespace identifierReference #useNamespace
| SET CATALOG (identifier | stringLit) #setCatalog
| SET CATALOG (errorCapturingIdentifier | stringLit) #setCatalog
| CREATE namespace (IF NOT EXISTS)? identifierReference
(commentSpec |
locationSpec |
Expand Down Expand Up @@ -392,7 +392,7 @@ describeFuncName
;

describeColName
: nameParts+=identifier (DOT nameParts+=identifier)*
: nameParts+=errorCapturingIdentifier (DOT nameParts+=errorCapturingIdentifier)*
;

ctes
Expand Down Expand Up @@ -429,7 +429,7 @@ property
;

propertyKey
: identifier (DOT identifier)*
: errorCapturingIdentifier (DOT errorCapturingIdentifier)*
| stringLit
;

Expand Down Expand Up @@ -683,18 +683,18 @@ pivotClause
;

pivotColumn
: identifiers+=identifier
| LEFT_PAREN identifiers+=identifier (COMMA identifiers+=identifier)* RIGHT_PAREN
: identifiers+=errorCapturingIdentifier
| LEFT_PAREN identifiers+=errorCapturingIdentifier (COMMA identifiers+=errorCapturingIdentifier)* RIGHT_PAREN
;

pivotValue
: expression (AS? identifier)?
: expression (AS? errorCapturingIdentifier)?
;

unpivotClause
: UNPIVOT nullOperator=unpivotNullClause? LEFT_PAREN
operator=unpivotOperator
RIGHT_PAREN (AS? identifier)?
RIGHT_PAREN (AS? errorCapturingIdentifier)?
;

unpivotNullClause
Expand Down Expand Up @@ -736,7 +736,7 @@ unpivotColumn
;

unpivotAlias
: AS? identifier
: AS? errorCapturingIdentifier
;

lateralView
Expand Down Expand Up @@ -1188,7 +1188,7 @@ complexColTypeList
;

complexColType
: identifier COLON? dataType (NOT NULL)? commentSpec?
: errorCapturingIdentifier COLON? dataType (NOT NULL)? commentSpec?
;

whenClause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) {
import ctx._
val structField = StructField(
name = identifier.getText,
name = errorCapturingIdentifier.getText,
dataType = typedVisit(dataType()),
nullable = NULL == null)
Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
.flatMap(_.namedExpression.asScala)
.map(typedVisit[Expression])
val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) {
UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText)
UnresolvedAttribute.quoted(ctx.pivotColumn.errorCapturingIdentifier.getText)
} else {
CreateStruct(
ctx.pivotColumn.identifiers.asScala.map(
Expand All @@ -1270,8 +1270,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
*/
override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) {
val e = expression(ctx.expression)
if (ctx.identifier != null) {
Alias(e, ctx.identifier.getText)()
if (ctx.errorCapturingIdentifier != null) {
Alias(e, ctx.errorCapturingIdentifier.getText)()
} else {
e
}
Expand Down Expand Up @@ -1334,8 +1334,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
}

// alias unpivot result
if (ctx.identifier() != null) {
val alias = ctx.identifier().getText
if (ctx.errorCapturingIdentifier() != null) {
val alias = ctx.errorCapturingIdentifier().getText
SubqueryAlias(alias, filtered)
} else {
filtered
Expand All @@ -1355,7 +1355,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
override def visitUnpivotColumnAndAlias(ctx: UnpivotColumnAndAliasContext):
(NamedExpression, Option[String]) = withOrigin(ctx) {
val attr = visitUnpivotColumn(ctx.unpivotColumn())
val alias = Option(ctx.unpivotAlias()).map(_.identifier().getText)
val alias = Option(ctx.unpivotAlias()).map(_.errorCapturingIdentifier().getText)
(attr, alias)
}

Expand All @@ -1367,7 +1367,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
(Seq[NamedExpression], Option[String]) =
withOrigin(ctx) {
val exprs = ctx.unpivotColumns.asScala.map(visitUnpivotColumn).toSeq
val alias = Option(ctx.unpivotAlias()).map(_.identifier().getText)
val alias = Option(ctx.unpivotAlias()).map(_.errorCapturingIdentifier().getText)
(exprs, alias)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class ErrorParserSuite extends AnalysisTest {
exception = parseException("USE test-test"),
errorClass = "INVALID_IDENTIFIER",
parameters = Map("ident" -> "test-test"))
checkError(
exception = parseException("SET CATALOG test-test"),
errorClass = "INVALID_IDENTIFIER",
parameters = Map("ident" -> "test-test"))
checkError(
exception = parseException("CREATE DATABASE IF NOT EXISTS my-database"),
errorClass = "INVALID_IDENTIFIER",
Expand Down Expand Up @@ -167,6 +171,10 @@ class ErrorParserSuite extends AnalysisTest {
exception = parseException("ANALYZE TABLE test-table PARTITION (part1)"),
errorClass = "INVALID_IDENTIFIER",
parameters = Map("ident" -> "test-table"))
checkError(
exception = parseException("CREATE TABLE t(c1 struct<test-test INT, c2 INT>)"),
errorClass = "INVALID_IDENTIFIER",
parameters = Map("ident" -> "test-test"))
checkError(
exception = parseException("LOAD DATA INPATH \"path\" INTO TABLE my-tab"),
errorClass = "INVALID_IDENTIFIER",
Expand Down Expand Up @@ -276,6 +284,19 @@ class ErrorParserSuite extends AnalysisTest {
""".stripMargin),
errorClass = "INVALID_IDENTIFIER",
parameters = Map("ident" -> "test-table"))
checkError(
exception = parseException(
"""
|SELECT * FROM (
| SELECT year, course, earnings FROM courseSales
|)
|PIVOT (
| sum(earnings)
| FOR test-test IN ('dotNET', 'Java')
|);
""".stripMargin),
errorClass = "INVALID_IDENTIFIER",
parameters = Map("ident" -> "test-test"))
}

test("datatype not supported") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ class SparkSqlAstBuilder extends AstBuilder {
* Create a [[SetCatalogCommand]] logical command.
*/
override def visitSetCatalog(ctx: SetCatalogContext): LogicalPlan = withOrigin(ctx) {
if (ctx.identifier() != null) {
SetCatalogCommand(ctx.identifier().getText)
if (ctx.errorCapturingIdentifier() != null) {
SetCatalogCommand(ctx.errorCapturingIdentifier().getText)
} else if (ctx.stringLit() != null) {
SetCatalogCommand(string(visitStringLit(ctx.stringLit())))
} else {
Expand Down

0 comments on commit db14be8

Please sign in to comment.