Skip to content

Commit

Permalink
[SPARK-22705][SQL] Case, Coalesce, and In use less global variables
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR accomplishes the following two items.

1. Reduce # of global variables from two to one for generated code of `Case` and `Coalesce` and remove global variables for generated code of `In`.
2. Make lifetime of global variable local within an operation

Item 1. reduces # of constant pool entries in a Java class. Item 2. ensures that an variable is not passed to arguments in a method split by `CodegenContext.splitExpressions()`, which is addressed by apache#19865.

## How was this patch tested?

Added new tests into `PredicateSuite`, `NullExpressionsSuite`, and `ConditionalExpressionSuite`.

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes apache#19901 from kiszk/SPARK-22705.
  • Loading branch information
kiszk authored and cloud-fan committed Dec 7, 2017
1 parent e103adf commit ea2fbf4
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,18 @@ case class CaseWhen(
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// This variable represents whether the first successful condition is met or not.
// It is initialized to `false` and it is set to `true` when the first condition which
// evaluates to `true` is met and therefore is not needed to go on anymore on the computation
// of the following conditions.
val conditionMet = ctx.freshName("caseWhenConditionMet")
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
ctx.addMutableState(ctx.javaType(dataType), ev.value)
// This variable holds the state of the result:
// -1 means the condition is not met yet and the result is unknown.
val NOT_MATCHED = -1
// 0 means the condition is met and result is not null.
val HAS_NONNULL = 0
// 1 means the condition is met and result is null.
val HAS_NULL = 1
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
// We won't go on anymore on the computation.
val resultState = ctx.freshName("caseWhenResultState")
val tmpResult = ctx.freshName("caseWhenTmpResult")
ctx.addMutableState(ctx.javaType(dataType), tmpResult)

// these blocks are meant to be inside a
// do {
Expand All @@ -200,9 +205,8 @@ case class CaseWhen(
|${cond.code}
|if (!${cond.isNull} && ${cond.value}) {
| ${res.code}
| ${ev.isNull} = ${res.isNull};
| ${ev.value} = ${res.value};
| $conditionMet = true;
| $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
| $tmpResult = ${res.value};
| continue;
|}
""".stripMargin
Expand All @@ -212,59 +216,63 @@ case class CaseWhen(
val res = elseExpr.genCode(ctx)
s"""
|${res.code}
|${ev.isNull} = ${res.isNull};
|${ev.value} = ${res.value};
|$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
|$tmpResult = ${res.value};
""".stripMargin
}

val allConditions = cases ++ elseCode

// This generates code like:
// conditionMet = caseWhen_1(i);
// if(conditionMet) {
// caseWhenResultState = caseWhen_1(i);
// if(caseWhenResultState != -1) {
// continue;
// }
// conditionMet = caseWhen_2(i);
// if(conditionMet) {
// caseWhenResultState = caseWhen_2(i);
// if(caseWhenResultState != -1) {
// continue;
// }
// ...
// and the declared methods are:
// private boolean caseWhen_1234() {
// boolean conditionMet = false;
// private byte caseWhen_1234() {
// byte caseWhenResultState = -1;
// do {
// // here the evaluation of the conditions
// } while (false);
// return conditionMet;
// return caseWhenResultState;
// }
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = allConditions,
funcName = "caseWhen",
returnType = ctx.JAVA_BOOLEAN,
returnType = ctx.JAVA_BYTE,
makeSplitFunction = func =>
s"""
|${ctx.JAVA_BOOLEAN} $conditionMet = false;
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $func
|} while (false);
|return $conditionMet;
|return $resultState;
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$conditionMet = $funcCall;
|if ($conditionMet) {
|$resultState = $funcCall;
|if ($resultState != $NOT_MATCHED) {
| continue;
|}
""".stripMargin
}.mkString)

ev.copy(code = s"""
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
${ctx.JAVA_BOOLEAN} $conditionMet = false;
do {
$codes
} while (false);""")
ev.copy(code =
s"""
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|$tmpResult = ${ctx.defaultValue(dataType)};
|do {
| $codes
|} while (false);
|// TRUE if any condition is met and the result is null, or no any condition is met.
|final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
|final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
""".stripMargin)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,35 +72,39 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
ctx.addMutableState(ctx.javaType(dataType), ev.value)
val tmpIsNull = ctx.freshName("coalesceTmpIsNull")
ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull)

// all the evals are meant to be in a do { ... } while (false); loop
val evals = children.map { e =>
val eval = e.genCode(ctx)
s"""
|${eval.code}
|if (!${eval.isNull}) {
| ${ev.isNull} = false;
| $tmpIsNull = false;
| ${ev.value} = ${eval.value};
| continue;
|}
""".stripMargin
}

val resultType = ctx.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "coalesce",
returnType = resultType,
makeSplitFunction = func =>
s"""
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|do {
| $func
|} while (false);
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$funcCall;
|if (!${ev.isNull}) {
|${ev.value} = $funcCall;
|if (!$tmpIsNull) {
| continue;
|}
""".stripMargin
Expand All @@ -109,11 +113,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression {

ev.copy(code =
s"""
|${ev.isNull} = true;
|${ev.value} = ${ctx.defaultValue(dataType)};
|$tmpIsNull = true;
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|do {
| $codes
|} while (false);
|final boolean ${ev.isNull} = $tmpIsNull;
""".stripMargin)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,37 +237,44 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val javaDataType = ctx.javaType(value.dataType)
val valueGen = value.genCode(ctx)
val listGen = list.map(_.genCode(ctx))
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value)
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
// inTmpResult has 3 possible values:
// -1 means no matches found and there is at least one value in the list evaluated to null
val HAS_NULL = -1
// 0 means no matches found and all values in the list are not null
val NOT_MATCHED = 0
// 1 means one value in the list is matched
val MATCHED = 1
val tmpResult = ctx.freshName("inTmpResult")
val valueArg = ctx.freshName("valueArg")
// All the blocks are meant to be inside a do { ... } while (false); loop.
// The evaluation of variables can be stopped when we find a matching value.
val listCode = listGen.map(x =>
s"""
|${x.code}
|if (${x.isNull}) {
| ${ev.isNull} = true;
| $tmpResult = $HAS_NULL; // ${ev.isNull} = true;
|} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true;
| continue;
|}
""".stripMargin)

val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = listCode,
funcName = "valueIn",
extraArguments = (javaDataType, valueArg) :: Nil,
extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil,
returnType = ctx.JAVA_BYTE,
makeSplitFunction = body =>
s"""
|do {
| $body
|} while (false);
|return $tmpResult;
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$funcCall;
|if (${ev.value}) {
|$tmpResult = $funcCall;
|if ($tmpResult == $MATCHED) {
| continue;
|}
""".stripMargin
Expand All @@ -276,14 +283,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
ev.copy(code =
s"""
|${valueGen.code}
|${ev.value} = false;
|${ev.isNull} = ${valueGen.isNull};
|if (!${ev.isNull}) {
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
| $tmpResult = 0;
| $javaDataType $valueArg = ${valueGen.value};
| do {
| $codes
| } while (false);
|}
|final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL);
|final boolean ${ev.value} = ($tmpResult == $MATCHED);
""".stripMargin)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._

class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
IndexedSeq((Literal(12) === Literal(1), Literal(42)),
(Literal(12) === Literal(42), Literal(1))))
}

test("SPARK-22705: case when should use less global variables") {
val ctx = new CodegenContext()
CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx)
assert(ctx.mutableStates.size == 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Coalesce(inputs), "x_1")
}

test("SPARK-22705: Coalesce should use less global variables") {
val ctx = new CodegenContext()
Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx)
assert(ctx.mutableStates.size == 1)
}

test("AtLeastNNonNulls should not throw 64kb exception") {
val inputs = (1 to 4000).map(x => Literal(s"x_$x"))
checkEvaluation(AtLeastNNonNulls(1, inputs), true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal(1.0D), sets), true)
}

test("SPARK-22705: In should use less global variables") {
val ctx = new CodegenContext()
In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
}

test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null
Expand Down

0 comments on commit ea2fbf4

Please sign in to comment.