Skip to content

Commit

Permalink
Translate CaseKeyWhen to CaseWhen at parsing time.
Browse files Browse the repository at this point in the history
  • Loading branch information
concretevitamin committed Jun 12, 2014
1 parent 47d406a commit 7d2b7e2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
* Refer to this link for the corresponding semantics:
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
*
* The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
* translated to this form at parsing time i.e. CASE WHEN a=b THEN c ...).
*
* Note that branches are considered in consecutive pairs (cond, val), and the optional last element
* is the val for the default catch-all case (if provided). Hence, `branches` consist of at least
* two elements, and can have an odd or even length.
Expand Down Expand Up @@ -274,68 +277,3 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
firstBranch ++ otherBranches
}
}

/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". This type
* of case statements is separated out from the other type mainly due to performance reason: this
* approach avoids branching (based on whether or not the key is provided) in eval().
*/
case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends Expression {
type EvaluatedType = Any
def children = key +: branches
def references = children.flatMap(_.references).toSet
def dataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
}
branches(1).dataType
}

override def nullable = branches.sliding(2, 2).map {
case Seq(cond, value) => value.nullable
case Seq(elseValue) => elseValue.nullable
}.reduce(_ || _)


override lazy val resolved = {
lazy val dataTypes = branches.sliding(2, 2).map {
case Seq(cond, value) => value.dataType
case Seq(elseValue) => elseValue.dataType
}.toSeq
lazy val dataTypesEqual =
if (dataTypes.size <= 1) true else dataTypes.drop(1).map(_ == dataTypes(0)).reduce(_ && _)
if (!childrenResolved) false else dataTypesEqual
}

private lazy val branchesArr = branches.toArray

override def eval(input: Row): Any = {
val evaledKey = key.eval(input)
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
var res: Any = null
while (i < len - 1) {
if (branchesArr(i).eval(input) == evaledKey) {
res = branchesArr(i + 1).eval(input)
return res
}
i += 2
}
if (i == len - 1) {
res = branchesArr(i).eval(input)
}
res
}

override def toString = {
val keyString = key.toString
val firstBranch = s"if ($keyString == ${branches(0)}) { ${branches(1)} }"
val otherBranches = branches.sliding(2, 2).drop(1).map {
case Seq(cond, value) => s" else if ($keyString == $cond) { $value }"
case Seq(elseValue) => s" else { $elseValue }"
}.mkString
firstBranch ++ otherBranches
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,14 @@ private[hive] object HiveQl {
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
CaseWhen(branches.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
CaseKeyWhen(nodeToExpr(branches(0)), branches.drop(1).map(nodeToExpr))
val transformed = branches.drop(1).sliding(2, 2).map {
case Seq(condVal, value) =>
// FIXME?: the key will get evaluated for multiple times in CaseWhen's eval(). Optimize?
Seq(Equals(nodeToExpr(branches(0)), nodeToExpr(condVal)),
nodeToExpr(value))
case Seq(elseVal) => Seq(nodeToExpr(elseVal))
}.toSeq.reduce(_ ++ _)
CaseWhen(transformed)

/* Complex datatype manipulation */
case Token("[", child :: ordinal :: Nil) =>
Expand Down

0 comments on commit 7d2b7e2

Please sign in to comment.