diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2245ba086ac30..b956a0fd856a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -208,6 +208,9 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi * Refer to this link for the corresponding semantics: * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions * + * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets + * translated to this form at parsing time i.e. CASE WHEN a=b THEN c ...). + * * Note that branches are considered in consecutive pairs (cond, val), and the optional last element * is the val for the default catch-all case (if provided). Hence, `branches` consist of at least * two elements, and can have an odd or even length. @@ -274,68 +277,3 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { firstBranch ++ otherBranches } } - -/** - * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". This type - * of case statements is separated out from the other type mainly due to performance reason: this - * approach avoids branching (based on whether or not the key is provided) in eval(). - */ -case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends Expression { - type EvaluatedType = Any - def children = key +: branches - def references = children.flatMap(_.references).toSet - def dataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") - } - branches(1).dataType - } - - override def nullable = branches.sliding(2, 2).map { - case Seq(cond, value) => value.nullable - case Seq(elseValue) => elseValue.nullable - }.reduce(_ || _) - - - override lazy val resolved = { - lazy val dataTypes = branches.sliding(2, 2).map { - case Seq(cond, value) => value.dataType - case Seq(elseValue) => elseValue.dataType - }.toSeq - lazy val dataTypesEqual = - if (dataTypes.size <= 1) true else dataTypes.drop(1).map(_ == dataTypes(0)).reduce(_ && _) - if (!childrenResolved) false else dataTypesEqual - } - - private lazy val branchesArr = branches.toArray - - override def eval(input: Row): Any = { - val evaledKey = key.eval(input) - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - var res: Any = null - while (i < len - 1) { - if (branchesArr(i).eval(input) == evaledKey) { - res = branchesArr(i + 1).eval(input) - return res - } - i += 2 - } - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - res - } - - override def toString = { - val keyString = key.toString - val firstBranch = s"if ($keyString == ${branches(0)}) { ${branches(1)} }" - val otherBranches = branches.sliding(2, 2).drop(1).map { - case Seq(cond, value) => s" else if ($keyString == $cond) { $value }" - case Seq(elseValue) => s" else { $elseValue }" - }.mkString - firstBranch ++ otherBranches - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b876832364645..b58e3352c740d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -923,7 +923,14 @@ private[hive] object HiveQl { case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => CaseWhen(branches.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - CaseKeyWhen(nodeToExpr(branches(0)), branches.drop(1).map(nodeToExpr)) + val transformed = branches.drop(1).sliding(2, 2).map { + case Seq(condVal, value) => + // FIXME?: the key will get evaluated for multiple times in CaseWhen's eval(). Optimize? + Seq(Equals(nodeToExpr(branches(0)), nodeToExpr(condVal)), + nodeToExpr(value)) + case Seq(elseVal) => Seq(nodeToExpr(elseVal)) + }.toSeq.reduce(_ ++ _) + CaseWhen(transformed) /* Complex datatype manipulation */ case Token("[", child :: ordinal :: Nil) =>