Skip to content

Commit

Permalink
[SPARK-12791][SQL] Simplify CaseWhen by breaking "branches" into "con…
Browse files Browse the repository at this point in the history
…ditions" and "values"

This pull request rewrites CaseWhen expression to break the single, monolithic "branches" field into a sequence of tuples (Seq[(condition, value)]) and an explicit optional elseValue field.

Prior to this pull request, each even position in "branches" represents the condition for each branch, and each odd position represents the value for each branch. The use of them have been pretty confusing with a lot sliding windows or grouped(2) calls.

Author: Reynold Xin <rxin@databricks.com>

Closes #10734 from rxin/simplify-case.
  • Loading branch information
rxin committed Jan 13, 2016
1 parent c2ea79f commit cbbcd8e
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 138 deletions.
24 changes: 12 additions & 12 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,12 @@ def when(self, condition, value):
>>> from pyspark.sql import functions as F
>>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
+-----+--------------------------------------------------------+
| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
+-----+--------------------------------------------------------+
|Alice| -1|
| Bob| 1|
+-----+--------------------------------------------------------+
+-----+------------------------------------------------------------+
| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
+-----+------------------------------------------------------------+
|Alice| -1|
| Bob| 1|
+-----+------------------------------------------------------------+
"""
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
Expand All @@ -393,12 +393,12 @@ def otherwise(self, value):
>>> from pyspark.sql import functions as F
>>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
+-----+---------------------------------+
| name|CASE WHEN (age > 3) THEN 1 ELSE 0|
+-----+---------------------------------+
|Alice| 0|
| Bob| 1|
+-----+---------------------------------+
+-----+-------------------------------------+
| name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|
+-----+-------------------------------------+
|Alice| 0|
| Bob| 1|
+-----+-------------------------------------+
"""
v = value._jc if isinstance(value, Column) else value
jc = self._jc.otherwise(v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C

/* Case statements */
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
CaseWhen(branches.map(nodeToExpr))
CaseWhen.createFromParser(branches.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
val keyExpr = nodeToExpr(branches.head)
CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
throw new AnalysisException(s"invalid function approximate($s) $udfName")
}
}
| CASE ~> whenThenElse ^^ CaseWhen
| CASE ~> whenThenElse ^^
{ case branches => CaseWhen.createFromParser(branches) }
| CASE ~> expression ~ whenThenElse ^^
{ case keyPart ~ branches => CaseKeyWhen(keyPart, branches) }
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,14 +621,24 @@ object HiveTypeCoercion {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
val maybeCommonType = findWiderCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case other => other
}.reduce(_ ++ _)
CaseWhen(castedBranches)
var changed = false
val newBranches = c.branches.map { case (condition, value) =>
if (value.dataType.sameType(commonType)) {
(condition, value)
} else {
changed = true
(condition, Cast(value, commonType))
}
}
val newElseValue = c.elseValue.map { value =>
if (value.dataType.sameType(commonType)) {
value
} else {
changed = true
Cast(value, commonType)
}
}
if (changed) CaseWhen(newBranches, newElseValue) else c
}.getOrElse(c)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,44 +81,39 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* When a = true, returns b; when c = true, returns d; else returns e.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
case class CaseWhen(branches: Seq[Expression]) extends Expression {

// Use private[this] Array to speed up evaluation.
@transient private[this] lazy val branchesArr = branches.toArray

override def children: Seq[Expression] = branches

@transient lazy val whenList =
branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq

@transient lazy val thenList =
branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
extends Expression {

val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue

// both then and else expressions should be considered.
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType)

def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall {
case Seq(dt1, dt2) => dt1.sameType(dt2)
}

override def dataType: DataType = thenList.head.dataType
override def dataType: DataType = branches.head._2.dataType

override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
// Result is nullable if any of the branch is nullable, or if the else value is nullable
branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
}

override def checkInputDataTypes(): TypeCheckResult = {
// Make sure all branch conditions are boolean types.
if (valueTypesEqual) {
if (whenList.forall(_.dataType == BooleanType)) {
if (branches.forall(_._1.dataType == BooleanType)) {
TypeCheckResult.TypeCheckSuccess
} else {
val index = whenList.indexWhere(_.dataType != BooleanType)
val index = branches.indexWhere(_._1.dataType != BooleanType)
TypeCheckResult.TypeCheckFailure(
s"WHEN expressions in CaseWhen should all be boolean type, " +
s"but the ${index + 1}th when expression's type is ${whenList(index)}")
s"but the ${index + 1}th when expression's type is ${branches(index)._1}")
}
} else {
TypeCheckResult.TypeCheckFailure(
Expand All @@ -127,31 +122,26 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
}

override def eval(input: InternalRow): Any = {
// Written in imperative fashion for performance considerations
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
while (i < len - 1) {
if (branchesArr(i).eval(input) == true) {
return branchesArr(i + 1).eval(input)
while (i < branches.size) {
if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) {
return branches(i)._2.eval(input)
}
i += 2
i += 1
}
var res: Any = null
if (i == len - 1) {
res = branchesArr(i).eval(input)
if (elseValue.isDefined) {
return elseValue.get.eval(input)
} else {
return null
}
return res
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val len = branchesArr.length
val got = ctx.freshName("got")

val cases = (0 until len/2).map { i =>
val cond = branchesArr(i * 2).gen(ctx)
val res = branchesArr(i * 2 + 1).gen(ctx)
val cases = branches.map { case (condition, value) =>
val cond = condition.gen(ctx)
val res = value.gen(ctx)
s"""
if (!$got) {
${cond.code}
Expand All @@ -165,50 +155,62 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
"""
}.mkString("\n")

val other = if (len % 2 == 1) {
val res = branchesArr(len - 1).gen(ctx)
s"""
val elseCase = {
if (elseValue.isDefined) {
val res = elseValue.get.gen(ctx)
s"""
if (!$got) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
}
"""
} else {
""
"""
} else {
""
}
}

s"""
boolean $got = false;
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$cases
$other
$elseCase
"""
}

override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
"CASE" + cases + elseCase + " END"
}

override def sql: String = {
val branchesSQL = branches.map(_.sql)
val (cases, maybeElse) = if (branches.length % 2 == 0) {
(branchesSQL, None)
} else {
(branchesSQL.init, Some(branchesSQL.last))
}
val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
"CASE" + cases + elseCase + " END"
}
}

val head = s"CASE "
val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
val body = cases.grouped(2).map {
case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
}.mkString(" ")
/** Factory methods for CaseWhen. */
object CaseWhen {

head + body + tail
def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
CaseWhen(branches, Option(elseValue))
}

/**
* A factory method to faciliate the creation of this expression when used in parsers.
* @param branches Expressions at even position are the branch conditions, and expressions at odd
* position are branch values.
*/
def createFromParser(branches: Seq[Expression]): CaseWhen = {
val cases = branches.grouped(2).flatMap {
case cond :: value :: Nil => Some((cond, value))
case value :: Nil => None
}.toArray.toSeq // force materialization to make the seq serializable
val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
CaseWhen(cases, elseValue)
}
}

Expand All @@ -218,17 +220,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
*/
object CaseKeyWhen {
def apply(key: Expression, branches: Seq[Expression]): CaseWhen = {
val newBranches = branches.zipWithIndex.map { case (expr, i) =>
if (i % 2 == 0 && i != branches.size - 1) {
// If this expression is at even position, then it is either a branch condition, or
// the very last value that is the "else value". The "i != branches.size - 1" makes
// sure we are not adding an EqualTo to the "else value".
EqualTo(key, expr)
} else {
expr
}
}
CaseWhen(newBranches)
val cases = branches.grouped(2).flatMap {
case cond :: value :: Nil => Some((EqualTo(key, cond), value))
case value :: Nil => None
}.toArray.toSeq // force materialization to make the seq serializable
val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
CaseWhen(cases, elseValue)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
} else {
arg
}
case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}
case nonChild: AnyRef => nonChild
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class AnalysisSuite extends AnalysisTest {

test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val"))
val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val"))
assertAnalysisSuccess(plan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))

assertError(
CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)),
CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('booleanField.attr, 'mapField.attr))),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('intField.attr, 'intField.attr))),
"WHEN expressions in CaseWhen should all be boolean type")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,14 @@ class HiveTypeCoercionSuite extends PlanTest {
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
ruleTest(HiveTypeCoercion.CaseWhenCoercion,
CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))),
CaseWhen(Seq(
Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)))
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
)
ruleTest(HiveTypeCoercion.CaseWhenCoercion,
CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))),
CaseWhen(Seq(
Literal(true), Cast(Literal(100L), DecimalType(22, 2)),
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))))
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
)
}

Expand Down Expand Up @@ -452,7 +451,7 @@ class HiveTypeCoercionSuite extends PlanTest {
val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
DecimalType(25, 5), DoubleType, DoubleType)

rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) =>
val plan2 = LocalRelation(
AttributeReference("r", rType)())

Expand Down
Loading

0 comments on commit cbbcd8e

Please sign in to comment.