diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 92da13df5ff13..6903467cf1a22 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1254,6 +1254,30 @@ ], "sqlState" : "42614" }, + "DUPLICATE_CONDITION_IN_SCOPE" : { + "message" : [ + "Found duplicate condition in the scope. Please, remove one of them." + ], + "sqlState" : "42734" + }, + "DUPLICATE_EXCEPTION_HANDLER" : { + "message" : [ + "Found duplicate handlers. Please, remove one of them." + ], + "subClass" : { + "CONDITION" : { + "message" : [ + "Found duplicate handlers for the same condition ." + ] + }, + "SQLSTATE" : { + "message" : [ + "Found duplicate handlers for the same SQLSTATE ." + ] + } + }, + "sqlState" : "42734" + }, "DUPLICATE_KEY" : { "message" : [ "Found duplicate keys ." @@ -2440,6 +2464,29 @@ ], "sqlState" : "42K05" }, + "INVALID_ERROR_CONDITION_DECLARATION" : { + "message" : [ + "Invalid condition declaration." + ], + "subClass" : { + "ONLY_AT_BEGINNING" : { + "message" : [ + "Condition can only be declared at the beginning of the compound." + ] + }, + "QUALIFIED_CONDITION_NAME" : { + "message" : [ + "Condition cannot be qualified." + ] + }, + "SPECIAL_CHARACTER_FOUND" : { + "message" : [ + "Special character found in condition name . Only alphanumeric characters and underscores are allowed." + ] + } + }, + "sqlState" : "42K0R" + }, "INVALID_ESC" : { "message" : [ "Found an invalid escape string: . The escape string must contain only one character." @@ -2608,6 +2655,39 @@ }, "sqlState" : "HY000" }, + "INVALID_HANDLER_DECLARATION" : { + "message" : [ + "Invalid handler declaration." + ], + "subClass" : { + "CONDITION_NOT_FOUND" : { + "message" : [ + "Condition not found." + ] + }, + "DUPLICATE_CONDITION_IN_HANDLER_DECLARATION" : { + "message" : [ + "Found duplicate condition in the handler declaration. Please, remove one of them." + ] + }, + "DUPLICATE_SQLSTATE_IN_HANDLER_DECLARATION" : { + "message" : [ + "Found duplicate sqlState in the handler declaration. Please, remove one of them." + ] + }, + "INVALID_CONDITION_COMBINATION" : { + "message" : [ + "Invalid combination of conditions in the handler declaration. SQLEXCEPTION and NOT FOUND cannot be used together with other condition/sqlstate values." + ] + }, + "WRONG_PLACE_OF_DECLARATION" : { + "message" : [ + "Handlers must be declared after variable/condition declaration, and before other statements." + ] + } + }, + "sqlState" : "42K0Q" + }, "INVALID_IDENTIFIER" : { "message" : [ "The unquoted identifier is invalid and must be back quoted as: ``.", @@ -3264,6 +3344,12 @@ }, "sqlState" : "42616" }, + "INVALID_SQLSTATE" : { + "message" : [ + "Invalid SQLSTATE value: ''. SQLSTATE must be exactly 5 characters long and contain only A-Z and 0-9. SQLSTATE must not start with '00', '01', or 'XX'." + ], + "sqlState" : "428B3" + }, "INVALID_SQL_ARG" : { "message" : [ "The argument of `sql()` is invalid. Consider to replace it either by a SQL literal or by collection constructor functions such as `map()`, `array()`, `struct()`." @@ -5492,6 +5578,11 @@ "Attach a comment to the namespace ." ] }, + "CONTINUE_EXCEPTION_HANDLER" : { + "message" : [ + "CONTINUE exception handler is not supported. Use EXIT handler." + ] + }, "DESC_TABLE_COLUMN_JSON" : { "message" : [ "DESC TABLE COLUMN AS JSON not supported for individual columns." diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index fb899e4eb207e..2e6c83560b014 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4643,6 +4643,18 @@ "standard": "N", "usedBy": ["Spark"] }, + "42K0Q": { + "description": "Invalid handler declaration.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, + "42K0R": { + "description": "Invalid condition declaration.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 7c7e2a6909574..37ec8f4ac8f7c 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -459,8 +459,10 @@ Below is a list of all the keywords in Spark SQL. |COMPENSATION|non-reserved|non-reserved|non-reserved| |COMPUTE|non-reserved|non-reserved|non-reserved| |CONCATENATE|non-reserved|non-reserved|non-reserved| +|CONDITION|non-reserved|non-reserved|non-reserved| |CONSTRAINT|reserved|non-reserved|reserved| |CONTAINS|non-reserved|non-reserved|non-reserved| +|CONTINUE|non-reserved|non-reserved|non-reserved| |COST|non-reserved|non-reserved|non-reserved| |CREATE|reserved|non-reserved|reserved| |CROSS|reserved|strict-non-reserved|reserved| @@ -513,6 +515,7 @@ Below is a list of all the keywords in Spark SQL. |EXCLUDE|non-reserved|non-reserved|non-reserved| |EXECUTE|reserved|non-reserved|reserved| |EXISTS|non-reserved|non-reserved|reserved| +|EXIT|non-reserved|non-reserved|non-reserved| |EXPLAIN|non-reserved|non-reserved|non-reserved| |EXPORT|non-reserved|non-reserved|non-reserved| |EXTEND|non-reserved|non-reserved|non-reserved| @@ -531,6 +534,7 @@ Below is a list of all the keywords in Spark SQL. |FOREIGN|reserved|non-reserved|reserved| |FORMAT|non-reserved|non-reserved|non-reserved| |FORMATTED|non-reserved|non-reserved|non-reserved| +|FOUND|non-reserved|non-reserved|non-reserved| |FROM|reserved|non-reserved|reserved| |FULL|reserved|strict-non-reserved|reserved| |FUNCTION|non-reserved|non-reserved|reserved| @@ -540,6 +544,7 @@ Below is a list of all the keywords in Spark SQL. |GRANT|reserved|non-reserved|reserved| |GROUP|reserved|non-reserved|reserved| |GROUPING|non-reserved|non-reserved|reserved| +|HANDLER|non-reserved|non-reserved|non-reserved| |HAVING|reserved|non-reserved|reserved| |HOUR|non-reserved|non-reserved|non-reserved| |HOURS|non-reserved|non-reserved|non-reserved| @@ -701,6 +706,8 @@ Below is a list of all the keywords in Spark SQL. |SOURCE|non-reserved|non-reserved|non-reserved| |SPECIFIC|non-reserved|non-reserved|reserved| |SQL|reserved|non-reserved|reserved| +|SQLEXCEPTION|non-reserved|non-reserved|non-reserved| +|SQLSTATE|non-reserved|non-reserved|non-reserved| |START|non-reserved|non-reserved|reserved| |STATISTICS|non-reserved|non-reserved|non-reserved| |STORED|non-reserved|non-reserved|non-reserved| @@ -754,6 +761,7 @@ Below is a list of all the keywords in Spark SQL. |USE|non-reserved|non-reserved|non-reserved| |USER|reserved|non-reserved|reserved| |USING|reserved|strict-non-reserved|reserved| +|VALUE|non-reserved|non-reserved|non-reserved| |VALUES|non-reserved|non-reserved|reserved| |VARCHAR|non-reserved|non-reserved|reserved| |VAR|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 360854d81e384..b868eea41b692 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -174,8 +174,10 @@ COMPACTIONS: 'COMPACTIONS'; COMPENSATION: 'COMPENSATION'; COMPUTE: 'COMPUTE'; CONCATENATE: 'CONCATENATE'; +CONDITION: 'CONDITION'; CONSTRAINT: 'CONSTRAINT'; CONTAINS: 'CONTAINS'; +CONTINUE: 'CONTINUE'; COST: 'COST'; CREATE: 'CREATE'; CROSS: 'CROSS'; @@ -227,6 +229,7 @@ EXCEPT: 'EXCEPT'; EXCHANGE: 'EXCHANGE'; EXCLUDE: 'EXCLUDE'; EXISTS: 'EXISTS'; +EXIT: 'EXIT'; EXPLAIN: 'EXPLAIN'; EXPORT: 'EXPORT'; EXTEND: 'EXTEND'; @@ -245,6 +248,7 @@ FOR: 'FOR'; FOREIGN: 'FOREIGN'; FORMAT: 'FORMAT'; FORMATTED: 'FORMATTED'; +FOUND: 'FOUND'; FROM: 'FROM'; FULL: 'FULL'; FUNCTION: 'FUNCTION'; @@ -254,6 +258,7 @@ GLOBAL: 'GLOBAL'; GRANT: 'GRANT'; GROUP: 'GROUP'; GROUPING: 'GROUPING'; +HANDLER: 'HANDLER'; HAVING: 'HAVING'; BINARY_HEX: 'X'; HOUR: 'HOUR'; @@ -415,6 +420,8 @@ SORTED: 'SORTED'; SOURCE: 'SOURCE'; SPECIFIC: 'SPECIFIC'; SQL: 'SQL'; +SQLEXCEPTION: 'SQLEXCEPTION'; +SQLSTATE: 'SQLSTATE'; START: 'START'; STATISTICS: 'STATISTICS'; STORED: 'STORED'; @@ -468,6 +475,7 @@ UPDATE: 'UPDATE'; USE: 'USE'; USER: 'USER'; USING: 'USING'; +VALUE: 'VALUE'; VALUES: 'VALUES'; VARCHAR: 'VARCHAR'; VAR: 'VAR'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 9b438a667a3e5..6cc9ece451c0b 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -63,6 +63,8 @@ compoundStatement : statement | setStatementWithOptionalVarKeyword | beginEndCompoundBlock + | declareConditionStatement + | declareHandlerStatement | ifElseStatement | caseStatement | whileStatement @@ -79,6 +81,29 @@ setStatementWithOptionalVarKeyword LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword ; +sqlStateValue + : stringLit + ; + +declareConditionStatement + : DECLARE multipartIdentifier CONDITION (FOR SQLSTATE VALUE? sqlStateValue)? + ; + +conditionValue + : SQLSTATE VALUE? sqlStateValue + | SQLEXCEPTION + | NOT FOUND + | multipartIdentifier + ; + +conditionValues + : cvList+=conditionValue (COMMA cvList+=conditionValue)* + ; + +declareHandlerStatement + : DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValues (beginEndCompoundBlock | statement | setStatementWithOptionalVarKeyword) + ; + whileStatement : beginLabel? WHILE booleanExpression DO compoundBody END WHILE endLabel? ; @@ -1607,7 +1632,9 @@ ansiNonReserved | COMPENSATION | COMPUTE | CONCATENATE + | CONDITION | CONTAINS + | CONTINUE | COST | CUBE | CURRENT @@ -1648,6 +1675,7 @@ ansiNonReserved | EXCHANGE | EXCLUDE | EXISTS + | EXIT | EXPLAIN | EXPORT | EXTEND @@ -1661,11 +1689,13 @@ ansiNonReserved | FOLLOWING | FORMAT | FORMATTED + | FOUND | FUNCTION | FUNCTIONS | GENERATED | GLOBAL | GROUPING + | HANDLER | HOUR | HOURS | IDENTIFIER_KW @@ -1798,6 +1828,8 @@ ansiNonReserved | SORTED | SOURCE | SPECIFIC + | SQLEXCEPTION + | SQLSTATE | START | STATISTICS | STORED @@ -1840,6 +1872,7 @@ ansiNonReserved | UNTIL | UPDATE | USE + | VALUE | VALUES | VARCHAR | VAR @@ -1945,8 +1978,10 @@ nonReserved | COMPENSATION | COMPUTE | CONCATENATE + | CONDITION | CONSTRAINT | CONTAINS + | CONTINUE | COST | CREATE | CUBE @@ -1997,6 +2032,7 @@ nonReserved | EXCLUDE | EXECUTE | EXISTS + | EXIT | EXPLAIN | EXPORT | EXTEND @@ -2016,6 +2052,7 @@ nonReserved | FORMAT | FORMATTED | FROM + | FOUND | FUNCTION | FUNCTIONS | GENERATED @@ -2023,6 +2060,7 @@ nonReserved | GRANT | GROUP | GROUPING + | HANDLER | HAVING | HOUR | HOURS @@ -2174,6 +2212,8 @@ nonReserved | SOURCE | SPECIFIC | SQL + | SQLEXCEPTION + | SQLSTATE | START | STATISTICS | STORED @@ -2224,6 +2264,7 @@ nonReserved | UPDATE | USE | USER + | VALUE | VALUES | VARCHAR | VAR diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9eb2b4dcb5d9f..05b44f8643698 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import java.util.concurrent.TimeUnit -import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} +import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} @@ -28,7 +28,7 @@ import org.antlr.v4.runtime.{ParserRuleContext, RuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} -import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable} +import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable, SparkThrowableHelper} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PARTITION_SPECIFICATION import org.apache.spark.sql.catalyst.{EvaluateUnresolvedInlineTable, FunctionIdentifier, SQLConfHelper, TableIdentifier} @@ -159,6 +159,135 @@ class AstBuilder extends DataTypeAstBuilder script } + private def assertSqlState(sqlState: String): Unit = { + val sqlStateRegex = "^[A-Za-z0-9]{5}$".r + if (sqlStateRegex.findFirstIn(sqlState.toUpperCase(Locale.ROOT)).isEmpty + || sqlState.startsWith("00") + || sqlState.startsWith("01") + || sqlState.startsWith("XX")) { + throw SqlScriptingErrors.invalidSqlStateValue(CurrentOrigin.get, sqlState) + } + } + + private def visitConditionValueImpl( + ctx: ConditionValueContext, + handlerTriggers: ExceptionHandlerTriggers): Unit = { + // Current element is SQLSTATE. + Option(ctx.sqlStateValue()) + .foreach { sqlStateValueContext => + val sqlState = string(visitStringLit(sqlStateValueContext.stringLit())) + assertSqlState(sqlState) + handlerTriggers.addUniqueSqlState(sqlState) + } + + // Current element is condition. + Option(ctx.multipartIdentifier()) + .foreach { conditionContext => + val conditionNameParts = visitMultipartIdentifier(conditionContext) + val conditionNameString = conditionNameParts.mkString(".").toUpperCase(Locale.ROOT) + if (conditionNameParts.size > 1) { + if (!SparkThrowableHelper.isValidErrorClass(conditionNameString)) { + throw SqlScriptingErrors + .conditionNotFound(CurrentOrigin.get, conditionNameString) + } + } + handlerTriggers.addUniqueCondition(conditionNameString) + } + + Option(ctx.SQLEXCEPTION()) + .foreach { _ => + handlerTriggers.addUniqueSqlException() + } + + // It is sufficient to check only NOT for NOT FOUND handler. + Option(ctx.NOT()).foreach { _ => + handlerTriggers.addUniqueNotFound() + } + } + + /** + * Visit list of condition/sqlstate values in handler declaration. + */ + private def visitConditionValuesImpl( + ctx: ConditionValuesContext): ExceptionHandlerTriggers = { + + val handlerTriggers: ExceptionHandlerTriggers = new ExceptionHandlerTriggers() + + ctx.cvList.forEach { cvContext => + visitConditionValueImpl(cvContext, handlerTriggers) + } + + if (handlerTriggers.sqlException || handlerTriggers.notFound) { + if (handlerTriggers.conditions.nonEmpty || handlerTriggers.sqlStates.nonEmpty) { + throw SqlScriptingErrors + .sqlExceptionOrNotFoundCannotBeCombinedWithOtherConditions(CurrentOrigin.get) + } + + } + + handlerTriggers + } + + private def assertConditionName(condition: String): Unit = { + val conditionRegex = "^[A-Za-z0-9_]+$".r + if (conditionRegex.findFirstIn(condition).isEmpty) { + throw SqlScriptingErrors + .conditionDeclarationContainsSpecialCharacter(CurrentOrigin.get, condition) + } + } + + private def visitDeclareConditionStatementImpl( + ctx: DeclareConditionStatementContext): ErrorCondition = { + + // Qualified user defined condition name is not allowed. + if (ctx.multipartIdentifier().parts.size() > 1) { + throw SqlScriptingErrors + .conditionCannotBeQualified(CurrentOrigin.get, ctx.multipartIdentifier().getText) + } + + // If SQLSTATE is not provided, default to 45000. + val sqlState = Option(ctx.sqlStateValue()) + .map(sqlStateValueContext => string(visitStringLit(sqlStateValueContext.stringLit()))) + .getOrElse("45000") + + assertSqlState(sqlState) + + // Get condition name. + val conditionName = visitMultipartIdentifier(ctx.multipartIdentifier()).head + + assertConditionName(conditionName) + + // Convert everything to upper case. + ErrorCondition(conditionName.toUpperCase(Locale.ROOT), sqlState.toUpperCase(Locale.ROOT)) + } + + private def visitDeclareHandlerStatementImpl( + ctx: DeclareHandlerStatementContext, + labelCtx: SqlScriptingLabelContext): ExceptionHandler = { + val exceptionHandlerTriggers = visitConditionValuesImpl(ctx.conditionValues()) + + if (Option(ctx.CONTINUE()).isDefined) { + throw SqlScriptingErrors.continueHandlerNotSupported(CurrentOrigin.get) + } + + val handlerType = ExceptionHandlerType.EXIT + val body = if (Option(ctx.beginEndCompoundBlock()).isDefined) { + visitBeginEndCompoundBlockImpl( + ctx.beginEndCompoundBlock(), + labelCtx) + } else { + // If there is no compound body, then there must be a statement or set statement. + val statement = Option(ctx.statement().asInstanceOf[ParserRuleContext]) + .orElse(Option(ctx.setStatementWithOptionalVarKeyword().asInstanceOf[ParserRuleContext])) + .map { s => + SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan]) + } + CompoundBody(Seq(statement.get), None, isScope = false) + } + + ExceptionHandler(exceptionHandlerTriggers, body, handlerType) + } + private def visitCompoundBodyImpl( ctx: CompoundBodyContext, label: Option[String], @@ -166,37 +295,51 @@ class AstBuilder extends DataTypeAstBuilder labelCtx: SqlScriptingLabelContext, isScope: Boolean): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() - ctx.compoundStatements.forEach( - compoundStatement => buff += visitCompoundStatementImpl(compoundStatement, labelCtx)) - val compoundStatements = buff.toList + val handlers = ListBuffer[ExceptionHandler]() + val conditions = HashMap[String, String]() + + val scriptingParserContext = new SqlScriptingParsingContext() + + ctx.compoundStatements.forEach(compoundStatement => { + val stmt = visitCompoundStatementImpl(compoundStatement, labelCtx) + stmt match { + case handler: ExceptionHandler => + scriptingParserContext.handler() + // All conditions are already visited when we encounter a handler. + handler.exceptionHandlerTriggers.conditions.foreach(conditionName => { + // Everything is stored in upper case so we can make case-insensitive comparisons. + // If condition is not spark-defined error condition, check if user defined it. + if (!SparkThrowableHelper.isValidErrorClass(conditionName)) { + if (!conditions.contains(conditionName)) { + throw SqlScriptingErrors + .conditionNotFound(CurrentOrigin.get, conditionName) + } + } + }) - val candidates = if (allowVarDeclare) { - compoundStatements.dropWhile { - case SingleStatement(_: CreateVariable) => true - case _ => false + handlers += handler + case condition: ErrorCondition => + scriptingParserContext.condition(condition) + // Check for duplicate condition names in each scope. + // When conditions are visited, everything is converted to upper-case + // for case-insensitive comparisons. + if (conditions.contains(condition.conditionName)) { + throw SqlScriptingErrors + .duplicateConditionInScope(CurrentOrigin.get, condition.conditionName) + } + conditions += condition.conditionName -> condition.sqlState + case statement => + statement match { + case SingleStatement(createVariable: CreateVariable) => + scriptingParserContext.variable(createVariable, allowVarDeclare) + case _ => scriptingParserContext.statement() + } + buff += statement } - } else { - compoundStatements - } - - val declareVarStatement = candidates.collectFirst { - case SingleStatement(c: CreateVariable) => c - } - - declareVarStatement match { - case Some(c: CreateVariable) => - if (allowVarDeclare) { - throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( - c.origin, c.name.asInstanceOf[UnresolvedIdentifier].nameParts) - } else { - throw SqlScriptingErrors.variableDeclarationNotAllowedInScope( - c.origin, c.name.asInstanceOf[UnresolvedIdentifier].nameParts) - } - case _ => - } + }) - CompoundBody(buff.toSeq, label, isScope) + CompoundBody(buff.toSeq, label, isScope, handlers.toSeq, conditions) } private def visitBeginEndCompoundBlockImpl( @@ -243,6 +386,10 @@ class AstBuilder extends DataTypeAstBuilder visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx) case forStatementContext: ForStatementContext => visitForStatementImpl(forStatementContext, labelCtx) + case declareHandlerContext: DeclareHandlerStatementContext => + visitDeclareHandlerStatementImpl(declareHandlerContext, labelCtx) + case declareConditionContext: DeclareConditionStatementContext => + visitDeclareConditionStatementImpl(declareConditionContext) case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement] } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 1bc4f95f95daf..235f2ae70c0d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -25,7 +25,10 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl} +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, ErrorCondition} import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.catalyst.util.SparkParserUtils import org.apache.spark.sql.catalyst.util.SparkParserUtils.withOrigin @@ -140,6 +143,122 @@ object ParserUtils extends SparkParserUtils { } } +class SqlScriptingParsingContext { + + object State extends Enumeration { + type State = Value + val INIT, VARIABLE, CONDITION, HANDLER, STATEMENT = Value + } + + private var currentState: State.State = State.INIT + + /** Transition to VARIABLE state. */ + def variable(createVariable: CreateVariable, allowVarDeclare: Boolean): Unit = { + if (!allowVarDeclare) { + throw SqlScriptingErrors.variableDeclarationNotAllowedInScope( + createVariable.origin, createVariable.name.asInstanceOf[UnresolvedIdentifier].nameParts) + } + transitionTo(State.VARIABLE, createVariable = Some(createVariable), None) + } + + /** Transition to CONDITION state. */ + def condition(errorCondition: ErrorCondition): Unit = { + transitionTo(State.CONDITION, None, errorCondition = Some(errorCondition)) + } + + /** Transition to HANDLER state. */ + def handler(): Unit = { + transitionTo(State.HANDLER) + } + + /** Transition to STATEMENT state. */ + def statement(): Unit = { + transitionTo(State.STATEMENT) + } + + /** + * Helper method to transition to a new state. + * Possible states are: + * 1a. VARIABLE (1) + * 1b. CONDITION (1) + * 2. HANDLERS (2) + * 3. STATEMENTS (3) + * Transition is allowed from state with number n to state with number m, + * where m >= n. + * + * @param newState The new state to transition to. + */ + private def transitionTo( + newState: State.State, + createVariable: Option[CreateVariable] = None, + errorCondition: Option[ErrorCondition] = None): Unit = { + (currentState, newState) match { + // VALID TRANSITIONS + + case (State.INIT, _) => currentState = newState + + // Transitions from VARIABLE to other states. + case (State.VARIABLE, State.VARIABLE) => // do nothing + + case (State.VARIABLE, State.CONDITION) => currentState = State.CONDITION + + case (State.VARIABLE, State.HANDLER) => currentState = State.HANDLER + + case (State.VARIABLE, State.STATEMENT) => currentState = State.STATEMENT + + // Transition from CONDITION to other states. + case (State.CONDITION, State.CONDITION) => // do nothing + + case (State.CONDITION, State.VARIABLE) => currentState = State.VARIABLE + + case (State.CONDITION, State.HANDLER) => currentState = State.HANDLER + + case (State.CONDITION, State.STATEMENT) => currentState = State.STATEMENT + + // Transition from HANDLER to other states. + case (State.HANDLER, State.HANDLER) => // do nothing + + case (State.HANDLER, State.STATEMENT) => currentState = State.STATEMENT + + // Transition from STATEMENT to other states. + case (State.STATEMENT, State.STATEMENT) => // do nothing + + // INVALID TRANSITIONS + + // Invalid transitions to VARIABLE state. + case (State.STATEMENT, State.VARIABLE) => + throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( + createVariable.get.origin, + createVariable.get.name.asInstanceOf[UnresolvedIdentifier].nameParts) + + case (State.HANDLER, State.VARIABLE) => + throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( + createVariable.get.origin, + createVariable.get.name.asInstanceOf[UnresolvedIdentifier].nameParts) + + // Invalid transitions to CONDITION state. + case (State.STATEMENT, State.CONDITION) => + throw SqlScriptingErrors.conditionDeclarationOnlyAtBeginning( + CurrentOrigin.get, + errorCondition.get.conditionName) + + case (State.HANDLER, State.CONDITION) => + throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( + createVariable.get.origin, + createVariable.get.name.asInstanceOf[UnresolvedIdentifier].nameParts) + + // Invalid transitions to HANDLER state. + case (State.STATEMENT, State.HANDLER) => + throw SqlScriptingErrors.handlerDeclarationInWrongPlace(CurrentOrigin.get) + + // This should never happen. + case _ => + throw SparkException.internalError( + s"Invalid state transition from $currentState to $newState") + } + } +} + class SqlScriptingLabelContext { /** Set to keep track of labels seen so far */ private val seenLabels = Set[String]() @@ -148,12 +267,17 @@ class SqlScriptingLabelContext { * Check if the beginLabelCtx and endLabelCtx match. * If the labels are defined, they must follow rules: * - If both labels exist, they must match. + * - If label is qualified, it is invalid. * - Begin label must exist if end label exists. + * + * @param beginLabelCtx Begin label context. + * @param endLabelCtx The end label context. */ private def checkLabels( beginLabelCtx: Option[BeginLabelContext], endLabelCtx: Option[EndLabelContext]) : Unit = { (beginLabelCtx, endLabelCtx) match { + // Throw an error if labels do not match. case (Some(bl: BeginLabelContext), Some(el: EndLabelContext)) if bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) != el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) => @@ -163,6 +287,7 @@ class SqlScriptingLabelContext { bl.multipartIdentifier().getText, el.multipartIdentifier().getText) } + // Throw an error if label is qualified. case (Some(bl: BeginLabelContext), _) if bl.multipartIdentifier().parts.size() > 1 => withOrigin(bl) { @@ -171,6 +296,7 @@ class SqlScriptingLabelContext { bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) ) } + // Throw an error if end label exists without begin label. case (None, Some(el: EndLabelContext)) => withOrigin(el) { throw SqlScriptingErrors.endLabelWithoutBeginLabel( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index b3bd86149ba91..bbbdd3b09a3c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical +import java.util.Locale + +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} - +import org.apache.spark.sql.errors.SqlScriptingErrors /** * Trait for all SQL Scripting logical operators that are product of parsing phase. @@ -64,11 +69,16 @@ case class SingleStatement(parsedPlan: LogicalPlan) * for example when CompoundBody is inside loop or conditional block. * @param isScope Flag indicating if the CompoundBody is a labeled scope. * Scopes are used for grouping local variables and exception handlers. + * @param handlers Collection of error handlers that are defined within the compound body. + * @param conditions Collection of error conditions that are defined within the compound body. */ case class CompoundBody( collection: Seq[CompoundPlanStatement], label: Option[String], - isScope: Boolean) extends Command with CompoundPlanStatement { + isScope: Boolean, + handlers: Seq[ExceptionHandler] = Seq.empty, + conditions: mutable.Map[String, String] = mutable.HashMap()) + extends Command with CompoundPlanStatement { override def children: Seq[LogicalPlan] = collection @@ -298,3 +308,100 @@ case class ForStatement( ForStatement(query, variableName, body, label) } } + +/** + * Logical operator for an error condition. + * @param conditionName Name of the error condition. + * @param sqlState SQLSTATE or Error Code. + */ +case class ErrorCondition( + conditionName: String, + sqlState: String) extends CompoundPlanStatement { + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this.copy() +} + +object ExceptionHandlerType extends Enumeration { + type ExceptionHandlerType = Value + val EXIT, CONTINUE = Value +} + +/** + * Class holding information about what triggers the handler. + * @param sqlStates Set of sqlStates that will trigger handler. + * @param conditions Set of error condition names that will trigger handler. + * @param sqlException Flag indicating if the handler is triggered by SQLEXCEPTION. + * @param notFound Flag indicating if the handler is triggered by NOT FOUND. + */ +class ExceptionHandlerTriggers( + val sqlStates: mutable.Set[String] = mutable.Set.empty, + val conditions: mutable.Set[String] = mutable.Set.empty, + var sqlException: Boolean = false, + var notFound: Boolean = false) { + + def addUniqueSqlException(): Unit = { + if (sqlException) { + throw SqlScriptingErrors + .duplicateConditionInHandlerDeclaration(CurrentOrigin.get, "SQLEXCEPTION") + } + sqlException = true + } + + def addUniqueNotFound(): Unit = { + if (notFound) { + throw SqlScriptingErrors + .duplicateConditionInHandlerDeclaration(CurrentOrigin.get, "NOT FOUND") + } + notFound = true + } + + def addUniqueCondition(value: String): Unit = { + val uppercaseValue = value.toUpperCase(Locale.ROOT) + if (conditions.contains(uppercaseValue)) { + throw SqlScriptingErrors + .duplicateConditionInHandlerDeclaration(CurrentOrigin.get, uppercaseValue) + } + conditions += uppercaseValue + } + + def addUniqueSqlState(value: String): Unit = { + val uppercaseValue = value.toUpperCase(Locale.ROOT) + if (sqlStates.contains(uppercaseValue)) { + throw SqlScriptingErrors + .duplicateSqlStateInHandlerDeclaration(CurrentOrigin.get, uppercaseValue) + } + sqlStates += uppercaseValue + } +} + +/** + * Logical operator for an error handler. + * @param exceptionHandlerTriggers Collection of different handler triggers: + * sqlStates -> set of sqlStates that will trigger handler + * conditions -> set of conditions that will trigger handler + * sqlException -> if handler is triggered by SQLEXCEPTION + * notFound -> if handler is triggered by NotFound + * @param body CompoundBody of the handler. + * @param handlerType Type of the handler (CONTINUE or EXIT). + */ +case class ExceptionHandler( + exceptionHandlerTriggers: ExceptionHandlerTriggers, + body: CompoundBody, + handlerType: ExceptionHandlerType) extends CompoundPlanStatement { + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq(body) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = { + assert(newChildren.length == 1) + ExceptionHandler( + exceptionHandlerTriggers, + newChildren(0).asInstanceOf[CompoundBody], + handlerType) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index da492cce22f2c..993efa1c8f6bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -143,4 +143,115 @@ private[sql] object SqlScriptingErrors { cause = null, messageParameters = Map("labelName" -> toSQLStmt(labelName))) } + + def conditionCannotBeQualified( + origin: Origin, + conditionName: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_ERROR_CONDITION_DECLARATION.QUALIFIED_CONDITION_NAME", + cause = null, + messageParameters = Map("conditionName" -> toSQLStmt(conditionName))) + } + + def conditionDeclarationOnlyAtBeginning( + origin: Origin, + conditionName: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_ERROR_CONDITION_DECLARATION.ONLY_AT_BEGINNING", + cause = null, + messageParameters = Map("conditionName" -> toSQLId(conditionName))) + } + + def conditionDeclarationContainsSpecialCharacter( + origin: Origin, + conditionName: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_ERROR_CONDITION_DECLARATION.SPECIAL_CHARACTER_FOUND", + cause = null, + messageParameters = Map("conditionName" -> toSQLId(conditionName))) + } + + def duplicateConditionInScope(origin: Origin, condition: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_CONDITION_IN_SCOPE", + cause = null, + messageParameters = Map("condition" -> condition)) + } + + def handlerDeclarationInWrongPlace(origin: Origin): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_HANDLER_DECLARATION.WRONG_PLACE_OF_DECLARATION", + cause = null, + messageParameters = Map.empty) + } + + def duplicateConditionInHandlerDeclaration(origin: Origin, condition: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_HANDLER_DECLARATION.DUPLICATE_CONDITION_IN_HANDLER_DECLARATION", + cause = null, + messageParameters = Map("condition" -> condition)) + } + + def duplicateSqlStateInHandlerDeclaration(origin: Origin, sqlState: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_HANDLER_DECLARATION.DUPLICATE_SQLSTATE_IN_HANDLER_DECLARATION", + cause = null, + messageParameters = Map("sqlState" -> sqlState)) + } + + def duplicateHandlerForSameCondition(origin: Origin, condition: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_EXCEPTION_HANDLER.CONDITION", + cause = null, + messageParameters = Map("condition" -> condition)) + } + + def duplicateHandlerForSameSqlState(origin: Origin, sqlState: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_EXCEPTION_HANDLER.SQLSTATE", + cause = null, + messageParameters = Map("sqlState" -> sqlState)) + } + + + def continueHandlerNotSupported(origin: Origin): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "UNSUPPORTED_FEATURE.CONTINUE_EXCEPTION_HANDLER", + cause = null, + messageParameters = Map.empty) + } + + def invalidSqlStateValue(origin: Origin, sqlState: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_SQLSTATE", + cause = null, + messageParameters = Map("sqlState" -> sqlState)) + } + + def sqlExceptionOrNotFoundCannotBeCombinedWithOtherConditions(origin: Origin): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_HANDLER_DECLARATION.INVALID_CONDITION_COMBINATION", + cause = null, + messageParameters = Map.empty) + } + + def conditionNotFound(origin: Origin, condition: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_HANDLER_DECLARATION.CONDITION_NOT_FOUND", + cause = null, + messageParameters = Map("condition" -> condition)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 40ba7809e5cee..c3d114836b67c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SetVariable, SingleStatement, WhileStatement} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf @@ -2096,6 +2096,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { parameters = Map("labelName" -> "PART1.PART2")) } + test("qualified label name: label cannot be qualified + end label") { + val sqlScriptText = + """ + |BEGIN + | part1.part2: BEGIN + | END part1.part2; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + }, + condition = "INVALID_LABEL_USAGE.QUALIFIED_LABEL_NAME", + parameters = Map("labelName" -> "PART1.PART2")) + } + test("unique label names: nested labeled scope statements") { val sqlScriptText = """BEGIN @@ -2398,6 +2413,363 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { head.asInstanceOf[SingleStatement].getText == "SELECT 3") } + test("declare condition: custom sqlstate") { + val sqlScriptText = + """ + |BEGIN + | DECLARE test CONDITION FOR SQLSTATE '12000'; + | SELECT 1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.conditions.size == 1) + assert(tree.conditions("TEST").equals("12000")) + } + + ignore("declare condition: default sqlstate") { + val sqlScriptText = + """ + |BEGIN + | DECLARE test CONDITION; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.conditions.size == 1) + assert(tree.conditions("TEST").equals("45000")) // Default SQLSTATE + } + + test("declare condition in wrong place") { + val sqlScriptText = + """ + |BEGIN + | SELECT 1; + | DECLARE test_condition CONDITION FOR SQLSTATE '12345'; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "INVALID_ERROR_CONDITION_DECLARATION.ONLY_AT_BEGINNING", + parameters = Map("conditionName" -> "`TEST_CONDITION`")) + assert(exception.origin.line.contains(2)) + } + + test("declare qualified condition") { + val sqlScriptText = + """ + |BEGIN + | DECLARE TEST.CONDITION CONDITION FOR SQLSTATE '12345'; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "INVALID_ERROR_CONDITION_DECLARATION.QUALIFIED_CONDITION_NAME", + parameters = Map("conditionName" -> "TEST.CONDITION")) + assert(exception.origin.line.contains(3)) + } + + test("declare condition with special characters") { + val sqlScriptText = + """ + |BEGIN + | DECLARE `test-condition` CONDITION FOR SQLSTATE '12345'; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "INVALID_ERROR_CONDITION_DECLARATION.SPECIAL_CHARACTER_FOUND", + parameters = Map("conditionName" -> toSQLId("test-condition"))) + assert(exception.origin.line.contains(3)) + } + + test("continue handler not supported") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012' + | BEGIN + | SET VAR flag = 1; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScript) + }, + condition = "UNSUPPORTED_FEATURE.CONTINUE_EXCEPTION_HANDLER", + parameters = Map.empty) + } + + test("declare handler for qualified condition name that is not supported") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE EXIT HANDLER FOR qualified.condition.name + | BEGIN + | SET VAR flag = 1; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScript) + }, + condition = "INVALID_HANDLER_DECLARATION.CONDITION_NOT_FOUND", + parameters = Map("condition" -> "QUALIFIED.CONDITION.NAME")) + } + + test("declare handler for undefined condition") { + val sqlScriptText = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR undefined_condition BEGIN SELECT 1; END; + | SELECT 1; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "INVALID_HANDLER_DECLARATION.CONDITION_NOT_FOUND", + parameters = Map("condition" -> "UNDEFINED_CONDITION")) + assert(exception.origin.line.contains(2)) + } + + test("declare handler in wrong place") { + val sqlScriptText = + """ + |BEGIN + | SELECT 1; + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO BEGIN SELECT 1; END; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "INVALID_HANDLER_DECLARATION.WRONG_PLACE_OF_DECLARATION", + parameters = Map.empty) + assert(exception.origin.line.contains(2)) + } + + test("duplicate condition in handler declaration") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE DUPLICATE_CONDITION CONDITION FOR SQLSTATE '12345'; + | DECLARE EXIT HANDLER FOR duplicate_condition, duplicate_condition + | BEGIN + | SET VAR flag = 1; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScript) + }, + condition = "INVALID_HANDLER_DECLARATION.DUPLICATE_CONDITION_IN_HANDLER_DECLARATION", + parameters = Map("condition" -> "DUPLICATE_CONDITION")) + } + + test("duplicate sqlState in handler declaration") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE EXIT HANDLER FOR SQLSTATE '12345', SQLSTATE '12345' + | BEGIN + | SET VAR flag = 1; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScript) + }, + condition = "INVALID_HANDLER_DECLARATION.DUPLICATE_SQLSTATE_IN_HANDLER_DECLARATION", + parameters = Map("sqlState" -> "12345")) + } + + test("invalid condition combination in handler declaration") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE EXIT HANDLER FOR SQLEXCEPTION, SQLSTATE '12345' + | BEGIN + | SET VAR flag = 1; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScript) + }, + condition = "INVALID_HANDLER_DECLARATION.INVALID_CONDITION_COMBINATION", + parameters = Map.empty) + } + + test("declare handler with compound body") { + val sqlScriptText = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO BEGIN SELECT 1; END; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ExceptionHandler]) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.size == 1) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.contains("DIVIDE_BY_ZERO")) + assert(tree.handlers.head.body.collection.size == 1) + assert(tree.handlers.head.body.collection.head.isInstanceOf[SingleStatement]) + assert(tree.handlers.head.body.collection.head.asInstanceOf[SingleStatement] + .parsedPlan.isInstanceOf[Project]) + } + + // This test works because END is not keyword here but a part of the statement. + // It represents the name of the column in returned dataframe. + test("declare handler single statement with END") { + val sqlScriptText = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO SELECT 1 END; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ExceptionHandler]) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.size == 1) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.contains("DIVIDE_BY_ZERO")) + assert(tree.handlers.head.body.collection.size == 1) + assert(tree.handlers.head.body.collection.head.isInstanceOf[SingleStatement]) + assert(tree.handlers.head.body.collection.head.asInstanceOf[SingleStatement] + .parsedPlan.isInstanceOf[Project]) + } + + test("declare handler single statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO SELECT 1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ExceptionHandler]) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.size == 1) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.contains("DIVIDE_BY_ZERO")) + assert(tree.handlers.head.body.collection.size == 1) + assert(tree.handlers.head.body.collection.head.isInstanceOf[SingleStatement]) + assert(tree.handlers.head.body.collection.head.asInstanceOf[SingleStatement] + .parsedPlan.isInstanceOf[Project]) + } + + test("declare handler set statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO SET test_var = 1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ExceptionHandler]) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.size == 1) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.contains("DIVIDE_BY_ZERO")) + assert(tree.handlers.head.body.collection.size == 1) + assert(tree.handlers.head.body.collection.head.isInstanceOf[SingleStatement]) + assert(tree.handlers.head.body.collection.head.asInstanceOf[SingleStatement] + .parsedPlan.isInstanceOf[SetVariable]) + } + + test("declare handler with multiple conditions/sqlstates") { + val sqlScriptText = + """ + |BEGIN + | DECLARE TEST_CONDITION_1 CONDITION FOR SQLSTATE '12345'; + | DECLARE TEST_CONDITION_2 CONDITION FOR SQLSTATE '54321'; + | DECLARE EXIT HANDLER FOR SQLSTATE '22012', TEST_CONDITION_1, test_condition_2 + | SET test_var = 1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ExceptionHandler]) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.size == 2) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.contains("TEST_CONDITION_1")) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.contains("TEST_CONDITION_2")) + assert(tree.handlers.head.exceptionHandlerTriggers.sqlStates.size == 1) + assert(tree.handlers.head.exceptionHandlerTriggers.sqlStates.contains("22012")) + assert(tree.handlers.head.body.collection.size == 1) + assert(tree.handlers.head.body.collection.head.isInstanceOf[SingleStatement]) + assert(tree.handlers.head.body.collection.head.asInstanceOf[SingleStatement] + .parsedPlan.isInstanceOf[SetVariable]) + } + + test("declare handler for SQLEXCEPTION") { + val sqlScriptText = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION SET test_var = 1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ExceptionHandler]) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.isEmpty) + assert(tree.handlers.head.exceptionHandlerTriggers.sqlStates.isEmpty) + assert(tree.handlers.head.exceptionHandlerTriggers.sqlException) // true + assert(!tree.handlers.head.exceptionHandlerTriggers.notFound) // false + assert(tree.handlers.head.body.collection.size == 1) + } + + test("declare handler for NOT FOUND") { + val sqlScriptText = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR NOT FOUND SET test_var = 1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ExceptionHandler]) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.isEmpty) + assert(tree.handlers.head.exceptionHandlerTriggers.sqlStates.isEmpty) + assert(!tree.handlers.head.exceptionHandlerTriggers.sqlException) // true + assert(tree.handlers.head.exceptionHandlerTriggers.notFound) // false + assert(tree.handlers.head.body.collection.size == 1) + } + + test("declare handler with condition and sqlstate with same value") { + val sqlScriptText = + """ + |BEGIN + | DECLARE K2000 CONDITION FOR SQLSTATE '12345'; + | DECLARE EXIT HANDLER FOR K2000, SQLSTATE VALUE 'K2000' SET test_var = 1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ExceptionHandler]) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.size == 1) + assert(tree.handlers.head.exceptionHandlerTriggers.conditions.contains("K2000")) + assert(tree.handlers.head.exceptionHandlerTriggers.sqlStates.size == 1) + assert(tree.handlers.head.exceptionHandlerTriggers.sqlStates.contains("K2000")) + assert(!tree.handlers.head.exceptionHandlerTriggers.sqlException) // true + assert(!tree.handlers.head.exceptionHandlerTriggers.notFound) // false + assert(tree.handlers.head.body.collection.size == 1) + } + + // Helper methods def cleanupStatementString(statementStr: String): String = { statementStr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index fabd47422daf4..5eccf1bcee2f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.scripting +import org.apache.spark.SparkThrowable import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody} import org.apache.spark.sql.classic.{DataFrame, SparkSession} @@ -44,11 +45,13 @@ class SqlScriptingExecution( val ctx = new SqlScriptingExecutionContext() val executionPlan = interpreter.buildExecutionPlan(sqlScript, args, ctx) // Add frame which represents SQL Script to the context. - ctx.frames.append(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator)) + ctx.frames.append( + new SqlScriptingExecutionFrame(executionPlan, SqlScriptingFrameType.SQL_SCRIPT)) // Enter the scope of the top level compound. - // We don't need to exit this scope explicitly as it will be done automatically - // when the frame is removed during iteration. + // We exit this scope explicitly in the getNextStatement method when there are no more + // statements to execute. executionPlan.enterScope() + // Return the context. ctx } @@ -57,7 +60,32 @@ class SqlScriptingExecution( private def getNextStatement: Option[CompoundStatementExec] = { // Remove frames that are already executed. while (context.frames.nonEmpty && !context.frames.last.hasNext) { + val lastFrame = context.frames.last + + // First frame on stack is always script frame. If there are no more statements to execute, + // exit the scope of the script frame. + // This scope was entered when the script frame was created and added to the context. + if (context.frames.size == 1 && context.frames.last.scopes.size == 1) { + context.frames.last.executionPlan.exitScope() + } + context.frames.remove(context.frames.size - 1) + + // If the last frame is a handler, set leave statement to be the next one in the + // innermost scope that should be exited. + if (lastFrame.frameType == SqlScriptingFrameType.HANDLER && context.frames.nonEmpty) { + // Remove the scope if handler is executed. + if (context.firstHandlerScopeLabel.isDefined + && lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) { + context.firstHandlerScopeLabel = None + } + + var execPlan: CompoundBodyExec = context.frames.last.executionPlan + while (execPlan.curr.exists(_.isInstanceOf[CompoundBodyExec])) { + execPlan = execPlan.curr.get.asInstanceOf[CompoundBodyExec] + } + execPlan.curr = Some(new LeaveStatementExec(lastFrame.scopeLabel.get)) + } } // If there are still frames available, get the next statement. if (context.frames.nonEmpty) { @@ -66,18 +94,8 @@ class SqlScriptingExecution( None } - /** - * Advances through the script and executes statements until a result statement or - * end of script is encountered. - * - * To know if there is result statement available, the method has to advance through script and - * execute statements until the result statement or end of script is encountered. For that reason - * the returned result must be executed before subsequent calls. Multiple calls without executing - * the intermediate results will lead to incorrect behavior. - * - * @return Result DataFrame if it is available, otherwise None. - */ - def getNextResult: Option[DataFrame] = { + /** Helper method to get the next result statement from the script. */ + private def getNextResultInternal: Option[DataFrame] = { var currentStatement = getNextStatement // While we don't have a result statement, execute the statements. while (currentStatement.isDefined) { @@ -97,18 +115,54 @@ class SqlScriptingExecution( None } - private def handleException(e: Throwable): Unit = { - // Rethrow the exception. - // TODO: SPARK-48353 Add error handling for SQL scripts - throw e + /** + * Advances through the script and executes statements until a result statement or + * end of script is encountered. + * + * To know if there is result statement available, the method has to advance through script and + * execute statements until the result statement or end of script is encountered. For that reason + * the returned result must be executed before subsequent calls. Multiple calls without executing + * the intermediate results will lead to incorrect behavior. + * + * @return Result DataFrame if it is available, otherwise None. + */ + def getNextResult: Option[DataFrame] = { + try { + getNextResultInternal + } catch { + case e: SparkThrowable => + handleException(e) + getNextResult // After setup for exception handling, try to get the next result again. + case throwable: Throwable => + throw throwable // Uncaught exception will be thrown. + } + } + + private def handleException(e: SparkThrowable): Unit = { + context.findHandler(e.getCondition, e.getSqlState) match { + case Some(handler) => + val handlerFrame = new SqlScriptingExecutionFrame( + handler.body, + SqlScriptingFrameType.HANDLER, + handler.scopeLabel + ) + context.frames.append( + handlerFrame + ) + handlerFrame.executionPlan.enterScope() + case None => + throw e.asInstanceOf[Throwable] + } } def withErrorHandling(f: => Unit): Unit = { try { f } catch { - case e: Throwable => - handleException(e) + case sparkThrowable: SparkThrowable => + handleException(sparkThrowable) + case throwable: Throwable => + throw throwable } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala index 94462ab828f75..2c001a196a8f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala @@ -17,22 +17,28 @@ package org.apache.spark.sql.scripting +import java.util.Locale + import scala.collection.mutable.ListBuffer import org.apache.spark.SparkException +import org.apache.spark.sql.scripting.SqlScriptingFrameType.SqlScriptingFrameType /** * SQL scripting execution context - keeps track of the current execution state. */ class SqlScriptingExecutionContext { // List of frames that are currently active. - val frames: ListBuffer[SqlScriptingExecutionFrame] = ListBuffer.empty + private[scripting] val frames: ListBuffer[SqlScriptingExecutionFrame] = ListBuffer.empty + private[scripting] var firstHandlerScopeLabel: Option[String] = None - def enterScope(label: String): Unit = { + def enterScope( + label: String, + triggerHandlerMap: TriggerToExceptionHandlerMap): Unit = { if (frames.isEmpty) { throw SparkException.internalError("Cannot enter scope: no frames.") } - frames.last.enterScope(label) + frames.last.enterScope(label, triggerHandlerMap) } def exitScope(label: String): Unit = { @@ -41,6 +47,38 @@ class SqlScriptingExecutionContext { } frames.last.exitScope(label) } + + def findHandler(condition: String, sqlState: String): Option[ExceptionHandlerExec] = { + if (frames.isEmpty) { + throw SparkException.internalError(s"Cannot find handler: no frames.") + } + + // If the last frame is a handler, try to find a handler in its body first. + if (frames.last.frameType == SqlScriptingFrameType.HANDLER) { + val handler = frames.last.findHandler(condition, sqlState, firstHandlerScopeLabel) + if (handler.isDefined) { + return handler + } + } + + // First frame is always script frame. Skip all handler frames and try to find handler in it. + // TODO: After introducing stored procedures, we need to handle the case with multiple + // script/stored procedure frames on call stack. We will have to iterate over all + // frames and skip frames representing error handlers. + val scriptFrame = frames.head + val handler = scriptFrame.findHandler(condition, sqlState, firstHandlerScopeLabel) + if (handler.isDefined) { + firstHandlerScopeLabel = handler.get.scopeLabel + return handler + } + + None + } +} + +object SqlScriptingFrameType extends Enumeration { + type SqlScriptingFrameType = Value + val SQL_SCRIPT, HANDLER = Value } /** @@ -48,22 +86,29 @@ class SqlScriptingExecutionContext { * This supports returning multiple result statements from a single script. * * @param executionPlan CompoundBody which need to be executed. + * @param frameType Type of the frame. + * @param scopeLabel Label of the scope where handler is defined. + * Available only for frameType = HANDLER. */ class SqlScriptingExecutionFrame( - executionPlan: Iterator[CompoundStatementExec]) extends Iterator[CompoundStatementExec] { + val executionPlan: CompoundBodyExec, + val frameType: SqlScriptingFrameType, + val scopeLabel: Option[String] = None) extends Iterator[CompoundStatementExec] { // List of scopes that are currently active. - private val scopes: ListBuffer[SqlScriptingExecutionScope] = ListBuffer.empty + private[scripting] val scopes: ListBuffer[SqlScriptingExecutionScope] = ListBuffer.empty - override def hasNext: Boolean = executionPlan.hasNext + override def hasNext: Boolean = executionPlan.getTreeIterator.hasNext override def next(): CompoundStatementExec = { if (!hasNext) throw SparkException.internalError("No more elements to iterate through.") - executionPlan.next() + executionPlan.getTreeIterator.next() } - def enterScope(label: String): Unit = { - scopes.append(new SqlScriptingExecutionScope(label)) + def enterScope( + label: String, + triggerToExceptionHandlerMap: TriggerToExceptionHandlerMap): Unit = { + scopes.append(new SqlScriptingExecutionScope(label, triggerToExceptionHandlerMap)) } def exitScope(label: String): Unit = { @@ -81,6 +126,37 @@ class SqlScriptingExecutionFrame( scopes.remove(scopes.length - 1) } } + + // TODO: Introduce a separate class for different frame types (Script, Stored Procedure, + // Error Handler) implementing SqlScriptingExecutionFrame interface. + def findHandler( + condition: String, + sqlState: String, + firstHandlerScopeLabel: Option[String]): Option[ExceptionHandlerExec] = { + + val searchScopes = if (frameType == SqlScriptingFrameType.HANDLER) { + // If the frame is a handler, search for the handler in its body. Don't skip any scopes. + scopes.reverseIterator + } else if (firstHandlerScopeLabel.isEmpty) { + // If no handler is active, search for the handler from the current scope. + // Don't skip any scopes. + scopes.reverseIterator + } else { + // Drop all scopes until the first most outer scope where an active handler is defined. + // Drop one more scope to start searching from the surrounding scope. + scopes.reverseIterator.dropWhile(_.label != firstHandlerScopeLabel.get).drop(1) + } + + // In the remaining scopes, try to find the most appropriate handler. + searchScopes.foreach { scope => + val handler = scope.findHandler(condition, sqlState) + if (handler.isDefined) { + return handler + } + } + + None + } } /** @@ -88,5 +164,62 @@ class SqlScriptingExecutionFrame( * * @param label * Label of the scope. + * @param triggerToExceptionHandlerMap + * Object holding condition/sqlState/sqlexception/not found to handler mapping. */ -class SqlScriptingExecutionScope(val label: String) +class SqlScriptingExecutionScope( + val label: String, + val triggerToExceptionHandlerMap: TriggerToExceptionHandlerMap) { + + /** + * Finds the most appropriate error handler for exception based on its condition and SQL state. + * + * The method follows these rules to determine the most appropriate handler: + * 1. Specific named condition handlers (e.g., DIVIDE_BY_ZERO) are checked first. + * 2. If no specific condition handler is found, SQLSTATE handlers are checked. + * 3. For SQLSTATEs starting with '02', a generic NOT FOUND handler is used if available. + * 4. For other SQLSTATEs (except those starting with 'XX' or '02'), a generic SQLEXCEPTION + * handler is used if available. + * + * Note: Handlers defined in the innermost compound statement where the exception was raised + * are considered. + * + * @param condition Error condition of the exception to find handler for. + * @param sqlState SQLSTATE of the exception to find handler for. + * + * @return Handler for the given condition if exists. + */ + def findHandler(condition: String, sqlState: String): Option[ExceptionHandlerExec] = { + // Check if there is a specific handler for the given condition. + var errorHandler: Option[ExceptionHandlerExec] = None + val uppercaseCondition = condition.toUpperCase(Locale.ROOT) + val uppercaseSqlState = sqlState.toUpperCase(Locale.ROOT) + + errorHandler = triggerToExceptionHandlerMap.getHandlerForCondition(uppercaseCondition) + + if (errorHandler.isEmpty) { + // Check if there is a specific handler for the given SQLSTATE. + errorHandler = triggerToExceptionHandlerMap.getHandlerForSqlState(uppercaseSqlState) + } + + if (errorHandler.isEmpty) { + errorHandler = triggerToExceptionHandlerMap.getNotFoundHandler match { + case Some(handler) if uppercaseSqlState.startsWith("02") => Some(handler) + case _ => None + } + } + + if (errorHandler.isEmpty) { + // If SQLEXCEPTION handler is defined, use it only for errors with class + // different from 'XX' and '02'. + errorHandler = triggerToExceptionHandlerMap.getSqlExceptionHandler match { + case Some(handler) + if !uppercaseSqlState.startsWith("XX") && !uppercaseSqlState.startsWith("02") => + Some(handler) + case _ => None + } + } + + errorHandler + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index e153d38a00cfb..97f68b9ca52de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} +import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.errors.SqlScriptingErrors @@ -175,6 +176,38 @@ class NoOpStatementExec extends LeafStatementExec { override def reset(): Unit = () } +/** + * Class to hold mapping of condition names/sqlStates to exception handlers + * defined in a compound body. + * + * @param conditionToExceptionHandlerMap + * Map of condition names to exception handlers. + * @param sqlStateToExceptionHandlerMap + * Map of sqlStates to exception handlers. + * @param sqlExceptionHandler + * "Catch-all" exception handler. + * @param notFoundHandler + * NOT FOUND exception handler. + */ +class TriggerToExceptionHandlerMap( + conditionToExceptionHandlerMap: Map[String, ExceptionHandlerExec], + sqlStateToExceptionHandlerMap: Map[String, ExceptionHandlerExec], + sqlExceptionHandler: Option[ExceptionHandlerExec], + notFoundHandler: Option[ExceptionHandlerExec]) { + + def getHandlerForCondition(condition: String): Option[ExceptionHandlerExec] = { + conditionToExceptionHandlerMap.get(condition) + } + + def getHandlerForSqlState(sqlState: String): Option[ExceptionHandlerExec] = { + sqlStateToExceptionHandlerMap.get(sqlState) + } + + def getSqlExceptionHandler: Option[ExceptionHandlerExec] = sqlExceptionHandler + + def getNotFoundHandler: Option[ExceptionHandlerExec] = notFoundHandler +} + /** * Executable node for CompoundBody. * @param statements @@ -186,12 +219,15 @@ class NoOpStatementExec extends LeafStatementExec { * Scopes are used for grouping local variables and exception handlers. * @param context * SqlScriptingExecutionContext keeps the execution state of current script. + * @param triggerToExceptionHandlerMap + * Map of condition names/sqlstates to error handlers defined in this compound body. */ class CompoundBodyExec( statements: Seq[CompoundStatementExec], label: Option[String] = None, isScope: Boolean, - context: SqlScriptingExecutionContext) + context: SqlScriptingExecutionContext, + triggerToExceptionHandlerMap: TriggerToExceptionHandlerMap) extends NonLeafStatementExec { private object ScopeStatus extends Enumeration { @@ -200,7 +236,8 @@ class CompoundBodyExec( } private var localIterator = statements.iterator - private var curr = if (localIterator.hasNext) Some(localIterator.next()) else None + private[scripting] var curr: Option[CompoundStatementExec] = + if (localIterator.hasNext) Some(localIterator.next()) else None private var scopeStatus = ScopeStatus.NOT_ENTERED /** @@ -210,11 +247,11 @@ class CompoundBodyExec( * iteration, but it should be executed only once when compound body that represent * scope is encountered for the first time. */ - def enterScope(): Unit = { + private[scripting] def enterScope(): Unit = { // This check makes this operation idempotent. if (isScope && scopeStatus == ScopeStatus.NOT_ENTERED) { scopeStatus = ScopeStatus.INSIDE - context.enterScope(label.get) + context.enterScope(label.get, triggerToExceptionHandlerMap) } } @@ -223,7 +260,7 @@ class CompoundBodyExec( * * Even though this operation is called exactly once, we are making it idempotent. */ - protected def exitScope(): Unit = { + private[scripting] def exitScope(): Unit = { // This check makes this operation idempotent. if (isScope && scopeStatus == ScopeStatus.INSIDE) { scopeStatus = ScopeStatus.EXITED @@ -918,7 +955,8 @@ class ForStatementExec( variablesMap.keys.toSeq.map(colName => createDropVarExec(colName)), None, isScope = false, - context + context, + new TriggerToExceptionHandlerMap(Map.empty, Map.empty, None, None) ) ForState.VariableCleanup } @@ -966,3 +1004,19 @@ class ForStatementExec( body.reset() } } + +/** + * Executable node for ExceptionHandler. + * @param body Executable CompoundBody of the exception handler. + * @param handlerType Handler type: EXIT, CONTINUE. + * @param scopeLabel Label of the scope where handler is defined. + */ +class ExceptionHandlerExec( + val body: CompoundBodyExec, + val handlerType: ExceptionHandlerType, + val scopeLabel: Option[String]) extends NonLeafStatementExec { + + override def getTreeIterator: Iterator[CompoundStatementExec] = body.getTreeIterator + + override def reset(): Unit = body.reset() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index d98588278956f..2a446e4bbb25e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.scripting +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} -import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ExceptionHandlerType, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.errors.SqlScriptingErrors /** * SQL scripting interpreter - builds SQL script execution plan. @@ -39,6 +43,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { * CompoundBody for which to build the plan. * @param args * A map of parameter names to SQL literal expressions. + * @param context + * SqlScriptingExecutionContext keeps the execution state of current script. * @return * Top level CompoundBodyExec representing SQL Script to be executed. */ @@ -63,6 +69,113 @@ case class SqlScriptingInterpreter(session: SparkSession) { case _ => None } + /** + * Transform [[CompoundBody]] into [[CompoundBodyExec]]. + * + * @param compoundBody + * CompoundBody to be transformed into CompoundBodyExec. + * @param args + * A map of parameter names to SQL literal expressions. + * @param context + * SqlScriptingExecutionContext keeps the execution state of current script. + * @return + * Executable version of the CompoundBody . + */ + private def transformBodyIntoExec( + compoundBody: CompoundBody, + args: Map[String, Expression], + context: SqlScriptingExecutionContext): CompoundBodyExec = { + // Add drop variables to the end of the body. + val variables = compoundBody.collection.flatMap { + case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) + case _ => None + } + val dropVariables = variables + .map(varName => DropVariable(varName, ifExists = true)) + .map(new SingleStatementExec(_, Origin(), args, isInternal = true, context)) + .reverse + + // Map of conditions to their respective handlers. + val conditionToExceptionHandlerMap: HashMap[String, ExceptionHandlerExec] = HashMap.empty + // Map of SqlStates to their respective handlers. + val sqlStateToExceptionHandlerMap: HashMap[String, ExceptionHandlerExec] = HashMap.empty + // NOT FOUND handler. + var notFoundHandler: Option[ExceptionHandlerExec] = None + // Get SQLEXCEPTION handler. + var sqlExceptionHandler: Option[ExceptionHandlerExec] = None + + compoundBody.handlers.foreach(handler => { + val handlerBodyExec = + transformBodyIntoExec( + handler.body, + args, + context) + + // Execution node of handler. + val handlerScopeLabel = if (handler.handlerType == ExceptionHandlerType.EXIT) { + Some(compoundBody.label.get) + } else { + None + } + + val handlerExec = new ExceptionHandlerExec( + handlerBodyExec, + handler.handlerType, + handlerScopeLabel) + + // For each condition handler is defined for, add corresponding key value pair + // to the conditionHandlerMap. + handler.exceptionHandlerTriggers.conditions.foreach(condition => { + // Condition can either be the key in conditions map or SqlState. + if (conditionToExceptionHandlerMap.contains(condition)) { + throw SqlScriptingErrors.duplicateHandlerForSameCondition(CurrentOrigin.get, condition) + } else { + conditionToExceptionHandlerMap.put(condition, handlerExec) + } + }) + + // For each sqlState handler is defined for, add corresponding key value pair + // to the sqlStateHandlerMap. + handler.exceptionHandlerTriggers.sqlStates.foreach(sqlState => { + if (sqlStateToExceptionHandlerMap.contains(sqlState)) { + throw SqlScriptingErrors.duplicateHandlerForSameSqlState(CurrentOrigin.get, sqlState) + } else { + sqlStateToExceptionHandlerMap.put(sqlState, handlerExec) + } + }) + + // Get NOT FOUND handler. + notFoundHandler = if (handler.exceptionHandlerTriggers.notFound) { + Some(handlerExec) + } else None + + // Get SQLEXCEPTION handler. + sqlExceptionHandler = if (handler.exceptionHandlerTriggers.sqlException) { + Some(handlerExec) + } else None + }) + + // Create a trigger to exception handler map for the current CompoundBody. + val triggerToExceptionHandlerMap = new TriggerToExceptionHandlerMap( + conditionToExceptionHandlerMap = conditionToExceptionHandlerMap.toMap, + sqlStateToExceptionHandlerMap = sqlStateToExceptionHandlerMap.toMap, + sqlExceptionHandler = sqlExceptionHandler, + notFoundHandler = notFoundHandler) + + val statements = compoundBody.collection + .map(st => transformTreeIntoExecutable(st, args, context)) ++ dropVariables match { + case Nil => Seq(new NoOpStatementExec) + case s => s + } + + new CompoundBodyExec( + statements, + compoundBody.label, + compoundBody.isScope, + context, + triggerToExceptionHandlerMap) + } + /** * Transform the parsed tree to the executable node. * @@ -70,6 +183,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { * Root node of the parsed tree. * @param args * A map of parameter names to SQL literal expressions. + * @param context + * SqlScriptingExecutionContext keeps the execution state of current script. * @return * Executable statement. */ @@ -78,28 +193,9 @@ case class SqlScriptingInterpreter(session: SparkSession) { args: Map[String, Expression], context: SqlScriptingExecutionContext): CompoundStatementExec = node match { - case CompoundBody(collection, label, isScope) => + case body: CompoundBody => // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. - val variables = collection.flatMap { - case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) - case _ => None - } - val dropVariables = variables - .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), args, isInternal = true, context)) - .reverse - - val statements = collection - .map(st => transformTreeIntoExecutable(st, args, context)) ++ dropVariables match { - case Nil => Seq(new NoOpStatementExec) - case s => s - } - - new CompoundBodyExec( - statements, - label, - isScope, - context) + transformBodyIntoExec(body, args, context) case IfElseStatement(conditions, conditionalBodies, elseBody) => val conditionsExec = conditions.map(condition => @@ -185,5 +281,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { args, isInternal = false, context) + + case _ => throw SparkException.internalError(s"Unsupported statement: $node") } } diff --git a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out index 84a35c270b698..6ee2b12be068d 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out @@ -60,8 +60,10 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false +CONDITION false CONSTRAINT true CONTAINS false +CONTINUE false COST false CREATE true CROSS true @@ -114,6 +116,7 @@ EXCHANGE false EXCLUDE false EXECUTE true EXISTS false +EXIT false EXPLAIN false EXPORT false EXTEND false @@ -132,6 +135,7 @@ FOR true FOREIGN true FORMAT false FORMATTED false +FOUND false FROM true FULL true FUNCTION false @@ -141,6 +145,7 @@ GLOBAL false GRANT true GROUP true GROUPING false +HANDLER false HAVING true HOUR false HOURS false @@ -300,6 +305,8 @@ SORTED false SOURCE false SPECIFIC false SQL true +SQLEXCEPTION false +SQLSTATE false START false STATISTICS false STORED false @@ -351,6 +358,7 @@ UPDATE false USE false USER true USING true +VALUE false VALUES false VAR false VARCHAR false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 49ea8bba3e174..16c2835bd6956 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -60,8 +60,10 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false +CONDITION false CONSTRAINT false CONTAINS false +CONTINUE false COST false CREATE false CROSS false @@ -114,6 +116,7 @@ EXCHANGE false EXCLUDE false EXECUTE false EXISTS false +EXIT false EXPLAIN false EXPORT false EXTEND false @@ -132,6 +135,7 @@ FOR false FOREIGN false FORMAT false FORMATTED false +FOUND false FROM false FULL false FUNCTION false @@ -141,6 +145,7 @@ GLOBAL false GRANT false GROUP false GROUPING false +HANDLER false HAVING false HOUR false HOURS false @@ -300,6 +305,8 @@ SORTED false SOURCE false SPECIFIC false SQL false +SQLEXCEPTION false +SQLSTATE false START false STATISTICS false STORED false @@ -351,6 +358,7 @@ UPDATE false USE false USER false USING false +VALUE false VALUES false VAR false VARCHAR false diff --git a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out index 49ea8bba3e174..16c2835bd6956 100644 --- a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out @@ -60,8 +60,10 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false +CONDITION false CONSTRAINT false CONTAINS false +CONTINUE false COST false CREATE false CROSS false @@ -114,6 +116,7 @@ EXCHANGE false EXCLUDE false EXECUTE false EXISTS false +EXIT false EXPLAIN false EXPORT false EXTEND false @@ -132,6 +135,7 @@ FOR false FOREIGN false FORMAT false FORMATTED false +FOUND false FROM false FULL false FUNCTION false @@ -141,6 +145,7 @@ GLOBAL false GRANT false GROUP false GROUPING false +HANDLER false HAVING false HOUR false HOURS false @@ -300,6 +305,8 @@ SORTED false SOURCE false SPECIFIC false SQL false +SQLEXCEPTION false +SQLSTATE false START false STATISTICS false STORED false @@ -351,6 +358,7 @@ UPDATE false USE false USER false USING false +VALUE false VALUES false VAR false VARCHAR false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index afcdfd343e33b..ef16f0485543b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -70,6 +70,33 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { } } + test("Scripting with exception handlers") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | BEGIN + | DECLARE EXIT HANDLER FOR SQLSTATE '22012' + | BEGIN + | SELECT flag; + | SET VAR flag = 2; + | END; + | SELECT 5; + | SELECT 1/0; + | SELECT 6; + | END; + | SELECT 7; + | SELECT flag; + |END + |""".stripMargin + verifySqlScriptResult(sqlScript, Seq(Row(2))) + } + test("single select") { val sqlText = "SELECT 1;" verifySqlScriptResult(sqlText, Seq(Row(1))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 2189a0a280ca3..1393aff69c43a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -37,8 +37,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi statements: Seq[CompoundStatementExec], label: Option[String] = None, isScope: Boolean = false, - context: SqlScriptingExecutionContext = null) - extends CompoundBodyExec(statements, label, isScope, context) { + context: SqlScriptingExecutionContext = null, + triggerToExceptionHandlerMap: TriggerToExceptionHandlerMap = null) + extends CompoundBodyExec(statements, label, isScope, context, triggerToExceptionHandlerMap) { // No-op to remove unnecessary logic for these tests. override def enterScope(): Unit = () diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 503d38d61c7ab..964e1c2a87048 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.scripting import scala.collection.mutable.ListBuffer -import org.apache.spark.SparkConf -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.{SparkArithmeticException, SparkConf} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.CompoundBody +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId +import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -50,7 +52,9 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { var df = sse.getNextResult while (df.isDefined) { // Collect results from the current DataFrame. - result.append(df.get.collect()) + sse.withErrorHandling { + result.append(df.get.collect()) + } df = sse.getNextResult } result.toSeq @@ -65,6 +69,557 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { } } + // Handler tests + test("duplicate handler for the same condition") { + val sqlScript = + """ + |BEGIN + | DECLARE duplicate_condition CONDITION FOR SQLSTATE '12345'; + | DECLARE OR REPLACE flag INT = -1; + | DECLARE EXIT HANDLER FOR duplicate_condition + | BEGIN + | SET VAR flag = 1; + | END; + | DECLARE EXIT HANDLER FOR duplicate_condition + | BEGIN + | SET VAR flag = 2; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + verifySqlScriptResult(sqlScript, Seq.empty) + }, + condition = "DUPLICATE_EXCEPTION_HANDLER.CONDITION", + parameters = Map("condition" -> "DUPLICATE_CONDITION")) + } + + test("duplicate handler for the same sqlState") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE EXIT HANDLER FOR SQLSTATE '12345' + | BEGIN + | SET VAR flag = 1; + | END; + | DECLARE EXIT HANDLER FOR SQLSTATE '12345' + | BEGIN + | SET VAR flag = 2; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + verifySqlScriptResult(sqlScript, Seq.empty) + }, + condition = "DUPLICATE_EXCEPTION_HANDLER.SQLSTATE", + parameters = Map("sqlState" -> "12345")) + } + + test("Specific condition takes precedence over sqlState") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | BEGIN + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | DECLARE EXIT HANDLER FOR SQLSTATE '22012' + | BEGIN + | SELECT flag; + | SET VAR flag = 2; + | END; + | SELECT 1/0; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(-1)), // select flag + Seq(Row(1)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("Innermost handler takes precedence over other handlers") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | BEGIN + | DECLARE EXIT HANDLER FOR SQLSTATE '22012' + | BEGIN + | SELECT flag; + | SET VAR flag = 2; + | END; + | SELECT 1/0; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(-1)), // select flag + Seq(Row(2)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler - exit resolve in the same block") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE flag INT = -1; + | scope_to_exit: BEGIN + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 2; + | SELECT 3; + | SELECT 1/0; + | SELECT 4; + | SELECT 5; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(2)), // select + Seq(Row(3)), // select + Seq(Row(-1)), // select flag + Seq(Row(1)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler - exit resolve in the same block when if condition fails") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE flag INT = -1; + | scope_to_exit: BEGIN + | DECLARE EXIT HANDLER FOR SQLSTATE '22012' + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 2; + | SELECT 3; + | IF 1 > 1/0 THEN + | SELECT 10; + | END IF; + | SELECT 4; + | SELECT 5; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(2)), // select + Seq(Row(3)), // select + Seq(Row(-1)), // select flag + Seq(Row(1)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler - exit resolve in outer block") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE flag INT = -1; + | l1: BEGIN + | DECLARE EXIT HANDLER FOR SQLSTATE '22012' + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 2; + | SELECT 3; + | l2: BEGIN + | SELECT 4; + | SELECT 1/0; + | SELECT 5; + | END; + | SELECT 6; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(2)), // select + Seq(Row(3)), // select + Seq(Row(4)), // select + Seq(Row(-1)), // select flag + Seq(Row(1)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler - chained handlers for different exceptions") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE flag INT = -1; + | l1: BEGIN + | DECLARE EXIT HANDLER FOR UNRESOLVED_COLUMN.WITHOUT_SUGGESTION + | BEGIN + | SELECT flag; + | SET VAR flag = 2; + | END; + | l2: BEGIN + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | select X; -- select non existing variable + | SELECT 2; + | END; + | SELECT 5; + | SELECT 1/0; -- divide by zero + | SELECT 6; + | END; + | END; + | + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)), // select + Seq(Row(-1)), // select flag from handler in l2 + Seq(Row(1)), // select flag from handler in l1 + Seq(Row(2)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler - double chained handlers") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE flag INT = -1; + | l1: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT flag; + | SET VAR flag = 2; + | END; + | l2: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | SELECT 1/0; + | SELECT 2; + | END; + | SELECT 5; + | SELECT 1/0; + | SELECT 6; + | END; + | END; + | + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)), // select + Seq(Row(-1)), // select flag from handler in l2 + Seq(Row(1)), // select flag from handler in l1 + Seq(Row(2)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler - triple chained handlers") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE flag INT = -1; + | l1: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT flag; + | SET VAR flag = 3; + | END; + | l2: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT flag; + | SET VAR flag = 2; + | SELECT 1/0; + | SELECT 2; + | END; + | + | l3: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | SELECT 1/0; + | SELECT 2; + | END; + | + | SELECT 5; + | SELECT 1/0; + | SELECT 6; + | END; + | END; + | END; + | + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)), // select + Seq(Row(-1)), // select flag from handler in l3 + Seq(Row(1)), // select flag from handler in l2 + Seq(Row(2)), // select flag from handler in l1 + Seq(Row(3)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler in handler") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE flag INT = -1; + | lbl_0: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | lbl_1: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | lbl_2: BEGIN + | SELECT flag; + | SET VAR flag = 2; + | END; + | + | SELECT flag; + | SET VAR flag = 1; + | SELECT 1/0; + | SELECT 2; + | END; + | + | SELECT 5; + | SELECT 1/0; + | SELECT 6; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)), // select + Seq(Row(-1)), // select flag from outer handler + Seq(Row(1)), // select flag from inner handler + Seq(Row(2)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("triple nested handler") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE flag INT = -1; + | lbl_0: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | lbl_1: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | lbl_2: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | lbl_3: BEGIN + | SELECT flag; -- third select flag (2) + | SET VAR flag = 3; + | END; + | + | SELECT flag; -- second select flag (1) + | SET VAR flag = 2; + | SELECT 1/0; -- third error will be thrown here + | SELECT 2; + | END; + | + | SELECT flag; -- first select flag (-1) + | SET VAR flag = 1; + | SELECT 1/0; -- second error will be thrown here + | SELECT 2; + | END; + | + | SELECT 5; + | SELECT 1/0; -- first error will be thrown here + | SELECT 6; + | END; + | SELECT flag; -- fourth select flag (3) + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)), // select + Seq(Row(-1)), // select flag in handler + Seq(Row(1)), // select flag in handler + Seq(Row(2)), // select flag in handler + Seq(Row(3)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler - exit catch-all in the same block") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | l1: BEGIN + | DECLARE EXIT HANDLER FOR SQLSTATE '22012' + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 2; + | SELECT 3; + | l2: BEGIN + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT flag; + | SET VAR flag = 2; + | END; + | SELECT 4; + | SELECT 1/0; + | SELECT 5; + | END; + | SELECT 6; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(2)), // select + Seq(Row(3)), // select + Seq(Row(4)), // select + Seq(Row(-1)), // select flag + Seq(Row(6)), // select + Seq(Row(2)) // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected = expected) + } + + test("handler with condition and sqlState with equal string value") { + // This test is intended to verify that the condition and sqlState are not + // treated equally if they have the same string value. Conditions are prioritized + // over sqlStates when choosing most appropriate Error Handler. + val sqlScript1 = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | DECLARE `22012` CONDITION FOR SQLSTATE '12345'; + | DECLARE EXIT HANDLER FOR `22012` + | BEGIN + | SET VAR flag = 1; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SparkArithmeticException] { + verifySqlScriptResult(sqlScript1, Seq.empty) + }, + condition = "DIVIDE_BY_ZERO", + parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), + queryContext = Array(ExpectedContext("", "", 174, 176, "1/0"))) + + val sqlScript2 = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | BEGIN + | DECLARE `22012` CONDITION FOR SQLSTATE '12345'; + | DECLARE EXIT HANDLER FOR `22012` + | BEGIN + | SET VAR flag = 1; + | END; + | + | DECLARE EXIT HANDLER FOR SQLSTATE '22012' + | BEGIN + | SET VAR flag = 2; + | END; + | + | SELECT 1/0; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected2 = Seq(Seq(Row(2))) // select flag from the outer body + verifySqlScriptResult(sqlScript2, expected = expected2) + + val sqlScript3 = + """ + |BEGIN + | DECLARE OR REPLACE flag INT = -1; + | BEGIN + | DECLARE `22012` CONDITION FOR SQLSTATE '12345'; + | DECLARE EXIT HANDLER FOR `22012`, SQLSTATE '22012' + | BEGIN + | SET VAR flag = 1; + | END; + | + | SELECT 1/0; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected3 = Seq(Seq(Row(1))) // select flag from the outer body + verifySqlScriptResult(sqlScript3, expected = expected3) + } + + test("handler - no appropriate handler is defined") { + val sqlScript = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | SELECT 1; -- this will not execute + | END; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty) + }, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + parameters = Map("objectName" -> toSQLId("flag")), + queryContext = Array(ExpectedContext("", "", 112, 115, "flag"))) + } + + test("invalid sqlState in handler declaration") { + val sqlScript = + """ + |BEGIN + | DECLARE EXIT HANDLER FOR SQLSTATE 'X22012' + | BEGIN + | SELECT 1; + | END; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + verifySqlScriptResult(sqlScript, Seq.empty) + }, + condition = "INVALID_SQLSTATE", + parameters = Map("sqlState" -> "X22012")) + } + // Tests test("multi statement - simple") { withTable("t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 5f149d15c6e8c..a52f93d4fc805 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -50,7 +50,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // Initialize context so scopes can be entered correctly. val context = new SqlScriptingExecutionContext() val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context) - context.frames.append(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator)) + context.frames.append(new SqlScriptingExecutionFrame( + executionPlan, SqlScriptingFrameType.SQL_SCRIPT)) executionPlan.enterScope() executionPlan.getTreeIterator.flatMap { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index ec65886fb2c98..5cf7cfc58ef2f 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,ELSEIF,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTEND,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,JSON,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,RECURSIVE,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONDITION,CONSTRAINT,CONTAINS,CONTINUE,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,ELSEIF,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXIT,EXPLAIN,EXPORT,EXTEND,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FOUND,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HANDLER,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,JSON,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,RECURSIVE,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,SQLEXCEPTION,SQLSTATE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUE,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } }