From 593d61704844989e5d2e61d39efe4d7b7a0e1d1a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Jun 2015 15:27:44 -0700 Subject: [PATCH 01/18] pushing codegen into Expression --- .../catalyst/expressions/BoundAttribute.scala | 9 + .../spark/sql/catalyst/expressions/Cast.scala | 37 + .../sql/catalyst/expressions/Expression.scala | 94 +++ .../sql/catalyst/expressions/arithmetic.scala | 127 ++++ .../expressions/codegen/CodeGenerator.scala | 686 +++--------------- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateOrdering.scala | 4 +- .../codegen/GeneratePredicate.scala | 2 +- .../codegen/GenerateProjection.scala | 18 +- .../expressions/decimalFunctions.scala | 25 + .../sql/catalyst/expressions/literals.scala | 38 + .../catalyst/expressions/nullFunctions.scala | 36 + .../sql/catalyst/expressions/predicates.scala | 111 ++- .../spark/sql/catalyst/expressions/sets.scala | 42 ++ .../GeneratedEvaluationSuite.scala | 2 +- .../GeneratedMutableEvaluationSuite.scala | 2 +- 16 files changed, 650 insertions(+), 587 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1ffc95c676f6f..1055be6e9d273 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + s""" + final boolean ${ev.nullTerm} = i.isNullAt($ordinal); + final ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? + ${ctx.defaultPrimitive(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + """ + } } object BindReferences extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 21adac144112e..a986844d18e8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -433,6 +434,42 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val evaluated = child.eval(input) if (evaluated == null) null else cast(evaluated) } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = this match { + + case Cast(child @ BinaryType(), StringType) => + castOrNull (ctx, ev, c => + s"new org.apache.spark.sql.types.UTF8String().set($c)", + StringType) + + case Cast(child @ DateType(), StringType) => + castOrNull(ctx, ev, c => + s"""new org.apache.spark.sql.types.UTF8String().set( + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", + StringType) + + case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c?1:0)", dt) + + case Cast(child @ DecimalType(), IntegerType) => + castOrNull(ctx, ev, c => s"($c).toInt()", IntegerType) + + case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + castOrNull(ctx, ev, c => s"($c).to${ctx.termForType(dt)}()", dt) + + case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c)", dt) + + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case Cast(e, StringType) if e.dataType != TimestampType => + castOrNull(ctx, ev, c => + s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))", + StringType) + + case other => + super.genSource(ctx, ev) + } } object Cast { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3cf851aec15ea..f66f8f9ff105e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -51,6 +52,51 @@ abstract class Expression extends TreeNode[Expression] { /** Returns the result of evaluating this expression on a given input Row */ def eval(input: Row = null): Any + /** + * Returns an [[EvaluatedExpression]], which contains Java source code that + * can be used to generate the result of evaluating the expression on an input row. + * @param ctx a [[CodeGenContext]] + */ + def gen(ctx: CodeGenContext): EvaluatedExpression = { + val nullTerm = ctx.freshName("nullTerm") + val primitiveTerm = ctx.freshName("primitiveTerm") + val objectTerm = ctx.freshName("objectTerm") + val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm) + ve.code = genSource(ctx, ve) + + // Only inject debugging code if debugging is turned on. + // val debugCode = + // if (debugLogging) { + // val localLogger = log + // val localLoggerTree = reify { localLogger } + // s""" + // $localLoggerTree.debug( + // ${this.toString} + ": " + (if (${ev.nullTerm}) "null" else ${ev.primitiveTerm}.toString)) + // """ + // } else { + // "" + // } + + ve + } + + /** + * Returns Java source code for this expression + */ + def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val e = this.asInstanceOf[Expression] + ctx.references += e + s""" + /* expression: ${this} */ + Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.nullTerm} = ${ev.objectTerm} == null; + ${ctx.primitiveForType(e.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(e.dataType)}; + if (!${ev.nullTerm}) ${ev.primitiveTerm} = + (${ctx.termForType(e.dataType)})${ev.objectTerm}; + """ + } + /** * Returns `true` if this expression and all its children have been resolved to a specific schema * and input data types checking passed, and `false` if it still contains any unresolved @@ -116,6 +162,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"($left $symbol $right)" + + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function from two primitive term names to a tree that evaluates them. + */ + def evaluate(ctx: CodeGenContext, + ev: EvaluatedExpression, + f: (String, String) => String): String = + evaluateAs(left.dataType)(ctx, ev, f) + + def evaluateAs(resultType: DataType)(ctx: CodeGenContext, + ev: EvaluatedExpression, + f: (String, String) => String): String = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (left.dataType != right.dataType) { + // log.warn(s"${left.dataType} != ${right.dataType}") + } + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + + eval1.code + eval2.code + + s""" + boolean ${ev.nullTerm} = ${eval1.nullTerm} || ${eval2.nullTerm}; + ${ctx.primitiveForType(resultType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(resultType)}; + if(!${ev.nullTerm}) { + ${ev.primitiveTerm} = (${ctx.primitiveForType(resultType)})($resultCode); + } + """ + } } abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { @@ -124,6 +205,19 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + def castOrNull(ctx: CodeGenContext, + ev: EvaluatedExpression, + f: String => String, dataType: DataType): String = { + val eval = child.gen(ctx) + eval.code + + s""" + boolean ${ev.nullTerm} = ${eval.nullTerm}; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; + } + """ + } } // TODO Semantically we probably not need GroupExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2ac53f8f6613f..4320fbf51bd6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -86,6 +87,8 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => + def decimalMethod: String = "" + override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -114,12 +117,21 @@ abstract class BinaryArithmetic extends BinaryExpression { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + if (left.dataType.isInstanceOf[DecimalType]) { + evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } ) + } else { + evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } ) + } + } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryArithmetics must override either eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -134,6 +146,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" + override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -148,6 +161,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" + override def decimalMethod: String = "$times" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -162,6 +176,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" + override def decimalMethod: String = "$divide" + override def nullable: Boolean = true override lazy val resolved = @@ -188,10 +204,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitiveTerm}.isZero()" + } else { + s"${eval2.primitiveTerm} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(left.dataType)}; + if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { + ${ev.nullTerm} = true; + } else { + ${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); + } + """ + } } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" + override def decimalMethod: String = "reminder" + override def nullable: Boolean = true override lazy val resolved = @@ -218,6 +262,32 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitiveTerm}.isZero()" + } else { + s"${eval2.primitiveTerm} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(left.dataType)}; + if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { + ${ev.nullTerm} = true; + } else { + ${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); + } + """ + } } /** @@ -336,6 +406,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + if (ctx.isNativeType(left.dataType)) { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + eval1.code + eval2.code + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(left.dataType)}; + + if (${eval1.nullTerm}) { + ${ev.nullTerm} = ${eval2.nullTerm}; + ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + } else if (${eval2.nullTerm}) { + ${ev.nullTerm} = ${eval1.nullTerm}; + ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + } else { + if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { + ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + } else { + ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + } + } + """ + } else { + super.genSource(ctx, ev) + } + } override def toString: String = s"MaxOf($left, $right)" } @@ -363,5 +460,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + if (ctx.isNativeType(left.dataType)) { + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + eval1.code + eval2.code + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(left.dataType)}; + + if (${eval1.nullTerm}) { + ${ev.nullTerm} = ${eval2.nullTerm}; + ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + } else if (${eval2.nullTerm}) { + ${ev.nullTerm} = ${eval1.nullTerm}; + ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + } else { + if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { + ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + } else { + ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + } + } + """ + } else { + super.genSource(ctx, ev) + } + } + override def toString: String = s"MinOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cd604121b7dd9..bec1899a3aad2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -33,586 +32,50 @@ class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] /** - * A base class for generators of byte code to perform expression evaluation. Includes a set of - * helpers for referring to Catalyst types and building trees that perform evaluation of individual - * expressions. + * Java source for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not + * valid if `nullTerm` is set to `true`. + * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ -abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { +case class EvaluatedExpression(var code: String, + nullTerm: String, + primitiveTerm: String, + objectTerm: String) + +/** + * A context for codegen + * @param references the expressions that don't support codegen + */ +case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { - protected val rowType = classOf[Row].getName protected val stringType = classOf[UTF8String].getName protected val decimalType = classOf[Decimal].getName - protected val exprType = classOf[Expression].getName - protected val mutableRowType = classOf[MutableRow].getName - protected val genericMutableRowType = classOf[GenericMutableRow].getName private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. - */ - var debugLogging = false - - /** - * Generates a class for a given input expression. Called when there is not cached code - * already available. - */ - protected def create(in: InType): OutType - - /** - * Canonicalizes an input expression. Used to avoid double caching expressions that differ only - * cosmetically. - */ - protected def canonicalize(in: InType): InType - - /** Binds an input expression to a given input schema */ - protected def bind(in: InType, inputSchema: Seq[Attribute]): InType - - /** - * Compile the Java source code into a Java class, using Janino. - * - * It will track the time used to compile - */ - protected def compile(code: String): Class[_] = { - val startTime = System.nanoTime() - val clazz = new ClassBodyEvaluator(code).getClazz() - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms") - clazz - } - - /** - * A cache of generated classes. - * - * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most - * fundamental difference is that a ConcurrentMap persists all elements that are added to it until - * they are explicitly removed. A Cache on the other hand is generally configured to evict entries - * automatically, in order to constrain its memory footprint. Note that this cache does not use - * weak keys/values and thus does not respond to memory pressure. - */ - protected val cache = CacheBuilder.newBuilder() - .maximumSize(1000) - .build( - new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = { - val startTime = System.nanoTime() - val result = create(in) - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logInfo(s"Code generated expression $in in $timeMs ms") - result - } - }) - - /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = - generate(bind(expressions, inputSchema)) - - /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) - /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - protected def freshName(prefix: String): String = { + def freshName(prefix: String): String = { s"$prefix${curId.getAndIncrement}" } - /** - * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. - * - * @param code The sequence of statements required to evaluate the expression. - * @param nullTerm A term that holds a boolean value representing whether the expression evaluated - * to null. - * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `true`. - * @param objectTerm A possibly boxed version of the result of evaluating this expression. - */ - protected case class EvaluatedExpression( - code: String, - nullTerm: String, - primitiveTerm: String, - objectTerm: String) - - /** - * A context for codegen, which is used to bookkeeping the expressions those are not supported - * by codegen, then they are evaluated directly. The unsupported expression is appended at the - * end of `references`, the position of it is kept in the code, used to access and evaluate it. - */ - protected class CodeGenContext { - /** - * Holding all the expressions those do not support codegen, will be evaluated directly. - */ - val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - } - - /** - * Create a new codegen context for expression evaluator, used to store those - * expressions that don't support codegen - */ - def newCodeGenContext(): CodeGenContext = { - new CodeGenContext() - } - - /** - * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that - * can be used to determine the result of evaluating the expression on an input row. - */ - def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = { - val primitiveTerm = freshName("primitiveTerm") - val nullTerm = freshName("nullTerm") - val objectTerm = freshName("objectTerm") - - implicit class Evaluate1(e: Expression) { - def castOrNull(f: String => String, dataType: DataType): String = { - val eval = expressionEvaluator(e, ctx) - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - if (!$nullTerm) { - $primitiveTerm = ${f(eval.primitiveTerm)}; - } - """ - } - } - - implicit class Evaluate2(expressions: (Expression, Expression)) { - - /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f a function from two primitive term names to a tree that evaluates them. - */ - def evaluate(f: (String, String) => String): String = - evaluateAs(expressions._1.dataType)(f) - - def evaluateAs(resultType: DataType)(f: (String, String) => String): String = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (expressions._1.dataType != expressions._2.dataType) { - log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") - } - - val eval1 = expressionEvaluator(expressions._1, ctx) - val eval2 = expressionEvaluator(expressions._2, ctx) - val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}; - ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)}; - if(!$nullTerm) { - $primitiveTerm = (${primitiveForType(resultType)})($resultCode); - } - """ - } - } - - val inputTuple = "i" - - // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, String] = { - case b @ BoundReference(ordinal, dataType, nullable) => - s""" - final boolean $nullTerm = $inputTuple.isNullAt($ordinal); - final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ? - ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)}); - """ - - case expressions.Literal(null, dataType) => - s""" - final boolean $nullTerm = true; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - """ - - case expressions.Literal(value: UTF8String, StringType) => - val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}" - s""" - final boolean $nullTerm = false; - ${stringType} $primitiveTerm = - new ${stringType}().set(${arr}); - """ - - case expressions.Literal(value, FloatType) => - s""" - final boolean $nullTerm = false; - float $primitiveTerm = ${value}f; - """ - - case expressions.Literal(value, dt @ DecimalType()) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value); - """ - - case expressions.Literal(value, dataType) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dataType)} $primitiveTerm = $value; - """ - - case Cast(child @ BinaryType(), StringType) => - child.castOrNull(c => - s"new ${stringType}().set($c)", - StringType) - - case Cast(child @ DateType(), StringType) => - child.castOrNull(c => - s"""new ${stringType}().set( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", - StringType) - - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt) - - case Cast(child @ DecimalType(), IntegerType) => - child.castOrNull(c => s"($c).toInt()", IntegerType) - - case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"($c).to${termForType(dt)}()", dt) - - case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt) - - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. - case Cast(e, StringType) if e.dataType != TimestampType => - e.castOrNull(c => - s"new ${stringType}().set(String.valueOf($c))", - StringType) - - case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => - (e1, e2).evaluateAs (BooleanType) { - case (eval1, eval2) => - s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" - } - - case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" } - - case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" } - case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" } - case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" } - case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" } - - case And(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { - } else { - ${eval2.code} - if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true; - } else { - $nullTerm = true; - } - } - """ - - case Or(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true; - } else { - ${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true; - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false; - } else { - $nullTerm = true; - } - } - """ - - case Not(child) => - // Uh, bad function name... - child.castOrNull(c => s"!$c", BooleanType) - - case Add(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" } - case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" } - case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" } - case Divide(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = null; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm}); - } - """ - case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm}); - } - """ - - case Add(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" } - case Subtract(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" } - case Multiply(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" } - case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}; - } - """ - case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}; - } - """ - - case IsNotNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = !${eval.nullTerm}; - """ - - case IsNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = ${eval.nullTerm}; - """ - - case e @ Coalesce(children) => - s""" - boolean $nullTerm = true; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - """ + - children.map { c => - val eval = expressionEvaluator(c, ctx) - s""" - if($nullTerm) { - ${eval.code} - if(!${eval.nullTerm}) { - $nullTerm = false; - $primitiveTerm = ${eval.primitiveTerm}; - } - } - """ - }.mkString("\n") - - case e @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition, ctx) - val trueEval = expressionEvaluator(trueValue, ctx) - val falseEval = expressionEvaluator(falseValue, ctx) - - s""" - boolean $nullTerm = false; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - ${condEval.code} - if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ${trueEval.code} - $nullTerm = ${trueEval.nullTerm}; - $primitiveTerm = ${trueEval.primitiveTerm}; - } else { - ${falseEval.code} - $nullTerm = ${falseEval.nullTerm}; - $primitiveTerm = ${falseEval.primitiveTerm}; - } - """ - - case NewSet(elementType) => - s""" - boolean $nullTerm = false; - ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}(); - """ - - case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item, ctx) - val setEval = expressionEvaluator(set, ctx) - - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - - itemEval.code + setEval.code + - s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); - } - boolean $nullTerm = false; - ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm}; - """ - - case CombineSets(left, right) => - val leftEval = expressionEvaluator(left, ctx) - val rightEval = expressionEvaluator(right, ctx) - - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - - leftEval.code + rightEval.code + - s""" - boolean $nullTerm = false; - ${htype} $primitiveTerm = - (${htype})${leftEval.primitiveTerm}; - $primitiveTerm.union((${htype})${rightEval.primitiveTerm}); - """ - - case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - $primitiveTerm = ${eval2.primitiveTerm}; - } - } - """ - - case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - $primitiveTerm = ${eval2.primitiveTerm}; - } - } - """ - - case UnscaledValue(child) => - val childEval = expressionEvaluator(child, ctx) - - childEval.code + - s""" - boolean $nullTerm = ${childEval.nullTerm}; - long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong(); - """ - - case MakeDecimal(child, precision, scale) => - val eval = expressionEvaluator(child, ctx) - - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())}; - - if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal(); - $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale); - $nullTerm = $primitiveTerm == null; - } - """ - } - - // If there was no match in the partial function above, we fall back on calling the interpreted - // expression evaluator. - val code: String = - primitiveEvaluation.lift.apply(e).getOrElse { - logError(s"No rules to generate $e") - ctx.references += e - s""" - /* expression: ${e} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); - boolean $nullTerm = $objectTerm == null; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm; - """ - } - - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) - } - - protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = { + def getColumn(dataType: DataType, ordinal: Int): String = { dataType match { - case StringType => s"(${stringType})$inputRow.apply($ordinal)" - case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)" + case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)" + case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)" + case _ => s"(${termForType(dataType)})i.apply($ordinal)" } } - protected def setColumn( - destinationRow: String, - dataType: DataType, - ordinal: Int, - value: String): String = { + def setColumn(destinationRow: String, dataType: DataType, ordinal: Int, value: String): String = { dataType match { case StringType => s"$destinationRow.update($ordinal, $value)" case dt: DataType if isNativeType(dt) => @@ -621,24 +84,24 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } } - protected def accessorForType(dt: DataType) = dt match { + def accessorForType(dt: DataType): String = dt match { case IntegerType => "getInt" case other => s"get${termForType(dt)}" } - protected def mutatorForType(dt: DataType) = dt match { + def mutatorForType(dt: DataType): String = dt match { case IntegerType => "setInt" case other => s"set${termForType(dt)}" } - protected def hashSetForType(dt: DataType): String = dt match { + def hashSetForType(dt: DataType): String = dt match { case IntegerType => classOf[IntegerHashSet].getName case LongType => classOf[LongHashSet].getName case unsupportedType => sys.error(s"Code generation not support for hashset of type $unsupportedType") } - protected def primitiveForType(dt: DataType): String = dt match { + def primitiveForType(dt: DataType): String = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -654,7 +117,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case _ => "Object" } - protected def defaultPrimitive(dt: DataType): String = dt match { + def defaultPrimitive(dt: DataType): String = dt match { case BooleanType => "false" case FloatType => "-1.0f" case ShortType => "-1" @@ -668,7 +131,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case _ => "null" } - protected def termForType(dt: DataType): String = dt match { + def termForType(dt: DataType): String = dt match { case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" @@ -687,11 +150,96 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** * List of data types that have special accessors and setters in [[Row]]. */ - protected val nativeTypes = + val nativeTypes = Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) /** * Returns true if the data type has a special accessor and setter in [[Row]]. */ - protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) + def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) +} + +/** + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { + + protected val rowType = classOf[Row].getName + protected val exprType = classOf[Expression].getName + protected val mutableRowType = classOf[MutableRow].getName + protected val genericMutableRowType = classOf[GenericMutableRow].getName + + /** + * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. + */ + var debugLogging = false + + /** + * Generates a class for a given input expression. Called when there is not cached code + * already available. + */ + protected def create(in: InType): OutType + + /** + * Canonicalizes an input expression. Used to avoid double caching expressions that differ only + * cosmetically. + */ + protected def canonicalize(in: InType): InType + + /** Binds an input expression to a given input schema */ + protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + + /** + * Compile the Java source code into a Java class, using Janino. + * + * It will track the time used to compile + */ + protected def compile(code: String): Class[_] = { + val startTime = System.nanoTime() + val clazz = new ClassBodyEvaluator(code).getClazz() + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") + clazz + } + + /** + * A cache of generated classes. + * + * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most + * fundamental difference is that a ConcurrentMap persists all elements that are added to it until + * they are explicitly removed. A Cache on the other hand is generally configured to evict entries + * automatically, in order to constrain its memory footprint. Note that this cache does not use + * weak keys/values and thus does not respond to memory pressure. + */ + protected val cache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build( + new CacheLoader[InType, OutType]() { + override def load(in: InType): OutType = { + val startTime = System.nanoTime() + val result = create(in) + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logInfo(s"Code generated expression $in in $timeMs ms") + result + } + }) + + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) + + /** + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen + */ + def newCodeGenContext(): CodeGenContext = { + new CodeGenContext(new mutable.ArrayBuffer[Expression]()) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 638b53fe0fe2f..02b7d3fae6767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -37,13 +37,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = expressionEvaluator(e, ctx) + val evaluationCode = e.gen(ctx) evaluationCode.code + s""" if(${evaluationCode.nullTerm}) mutableRow.setNullAt($i); else - ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 0ff840dab393c..d3c219fddc53c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -52,8 +52,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val ctx = newCodeGenContext() val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child, ctx) - val evalB = expressionEvaluator(order.child, ctx) + val evalA = order.child.gen(ctx) + val evalB = order.child.gen(ctx) val asc = order.direction == Ascending val compare = order.child.dataType match { case BinaryType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index fb18769f00da3..dd4474de05df9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,7 +38,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { protected def create(predicate: Expression): ((Row) => Boolean) = { val ctx = newCodeGenContext() - val eval = expressionEvaluator(predicate, ctx) + val eval = predicate.gen(ctx) val code = s""" import org.apache.spark.sql.Row; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index d5be1fc12e0f0..0e8ad76f65bad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -45,12 +45,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val ctx = newCodeGenContext() val columns = expressions.zipWithIndex.map { case (e, i) => - s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n" + s"private ${ctx.primitiveForType(e.dataType)} c$i = ${ctx.defaultPrimitive(e.dataType)};\n" }.mkString("\n ") val initColumns = expressions.zipWithIndex.map { case (e, i) => - val eval = expressionEvaluator(e, ctx) + val eval = e.gen(ctx) s""" { // column$i @@ -68,10 +68,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n ") val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${termForType(e.dataType)})value; return;}" + s"case $i: { c$i = (${ctx.termForType(e.dataType)})value; return;}" }.mkString("\n ") - val specificAccessorFunctions = nativeTypes.map { dataType => + val specificAccessorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { case (e, i) if e.dataType == dataType => s"case $i: return c$i;" @@ -80,21 +80,21 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) { + public ${ctx.primitiveForType(dataType)} ${ctx.accessorForType(dataType)}(int i) { if (isNullAt(i)) { - return ${defaultPrimitive(dataType)}; + return ${ctx.defaultPrimitive(dataType)}; } switch (i) { $cases } - return ${defaultPrimitive(dataType)}; + return ${ctx.defaultPrimitive(dataType)}; }""" } else { "" } }.mkString("\n") - val specificMutatorFunctions = nativeTypes.map { dataType => + val specificMutatorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { case (e, i) if e.dataType == dataType => s"case $i: { c$i = value; return; }" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) { + public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveForType(dataType)} value) { nullBits[i] = false; switch (i) { $cases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 65ba18924afe1..76273a5b7ee68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -35,6 +36,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { childResult.asInstanceOf[Decimal].toUnscaledLong } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval = child.gen(ctx) + eval.code +s""" + boolean ${ev.nullTerm} = ${eval.nullTerm}; + long ${ev.primitiveTerm} = ${ev.nullTerm} ? -1 : ${eval.primitiveTerm}.toUnscaledLong(); + """ + } } /** Create a Decimal from an unscaled Long value */ @@ -53,4 +62,20 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.nullTerm} = ${eval.nullTerm}; + org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(DecimalType())}; + + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); + ${ev.primitiveTerm} = + ${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale); + ${ev.nullTerm} = ${ev.primitiveTerm} == null; + } + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d3ca3d9a4b18b..d9fbda9511a5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, EvaluatedExpression} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -79,6 +80,43 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" override def eval(input: Row): Any = value + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + if (value == null) { + s""" + final boolean ${ev.nullTerm} = true; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + """ + } else { + dataType match { + case StringType => + val v = value.asInstanceOf[UTF8String] + val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}" + s""" + final boolean ${ev.nullTerm} = false; + org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} = + new org.apache.spark.sql.types.UTF8String().set(${arr}); + """ + case FloatType => + s""" + final boolean ${ev.nullTerm} = false; + float ${ev.primitiveTerm} = ${value}f; + """ + case dt: DecimalType => + s""" + final boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveForType(dt)}().set($value); + """ + case dt: NumericType => + s""" + final boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = $value; + """ + case other => + super.genSource(ctx, ev) + } + } + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5070570b4740d..2af0f96146c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType @@ -51,6 +52,25 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } result } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + s""" + boolean ${ev.nullTerm} = true; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + """ + + children.map { e => + val eval = e.gen(ctx) + s""" + if(${ev.nullTerm}) { + ${eval.code} + if(!${eval.nullTerm}) { + ${ev.nullTerm} = false; + ${ev.primitiveTerm} = ${eval.primitiveTerm}; + } + } + """ + }.mkString("\n") + } } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { @@ -61,6 +81,14 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval = child.gen(ctx) + eval.code + s""" + final boolean ${ev.nullTerm} = false; + final boolean ${ev.primitiveTerm} = ${eval.nullTerm}; + """ + } + override def toString: String = s"IS NULL $child" } @@ -72,6 +100,14 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def eval(input: Row): Any = { child.eval(input) != null } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.nullTerm} = false; + boolean ${ev.primitiveTerm} = !${eval.nullTerm}; + """ + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 807021d50e8e0..b6b2c7db28960 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -82,6 +83,11 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex case b: Boolean => !b } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + // Uh, bad function name... + castOrNull(ctx, ev, c => s"!($c)", BooleanType) + } } /** @@ -141,6 +147,26 @@ case class And(left: Expression, right: Expression) } } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.nullTerm} = false; + boolean ${ev.primitiveTerm} = false; + + if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { + } else { + ${eval2.code} + if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { + } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { + ${ev.primitiveTerm} = true; + } else { + ${ev.nullTerm} = true; + } + } + """ + } } case class Or(left: Expression, right: Expression) @@ -167,10 +193,44 @@ case class Or(left: Expression, right: Expression) } } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.nullTerm} = false; + boolean ${ev.primitiveTerm} = false; + + if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { + ${ev.primitiveTerm} = true; + } else { + ${eval2.code} + if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { + ${ev.primitiveTerm} = true; + } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { + ${ev.primitiveTerm} = false; + } else { + ${ev.nullTerm} = true; + } + } + """ + } } abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + left.dataType match { + case dt: NumericType => evaluateAs(BooleanType) (ctx, ev, { + (eval1, eval2) => s"$eval1 $symbol $eval2" + }) + case dt: TimestampType => + super.genSource(ctx, ev) + case other => evaluateAs(BooleanType) (ctx, ev, { + (eval1, eval2) => s"$eval1.compare($eval2) $symbol 0" + }) + } + } override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { @@ -216,6 +276,17 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression) = { + left.dataType match { + case BinaryType() => + evaluateAs (BooleanType) (ctx, ev, { + case (eval1, eval2) => + s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" + }) + case other => + evaluateAs (BooleanType) (ctx, ev, { case (eval1, eval2) => s"$eval1 == $eval2" }) + } + } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -236,6 +307,22 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp l == r } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val cmpCode = if (left.dataType.isInstanceOf[BinaryType]) { + s"java.util.Arrays.equals((byte[])${eval1.primitiveTerm}, (byte[])${eval2.primitiveTerm})" + } else { + s"${eval1.primitiveTerm} == ${eval2.primitiveTerm}" + } + eval1.code + eval2.code + + s""" + final boolean ${ev.nullTerm} = false; + final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) || + (!${eval1.nullTerm} && $cmpCode); + """ + } } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { @@ -309,6 +396,26 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi falseValue.eval(input) } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val condEval = predicate.gen(ctx) + val trueEval = trueValue.gen(ctx) + val falseEval = falseValue.gen(ctx) + + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + ${condEval.code} + if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + ${trueEval.code} + ${ev.nullTerm} = ${trueEval.nullTerm}; + ${ev.primitiveTerm} = ${trueEval.primitiveTerm}; + } else { + ${falseEval.code} + ${ev.nullTerm} = ${falseEval.nullTerm}; + ${ev.primitiveTerm} = ${falseEval.primitiveTerm}; + } + """ + } override def toString: String = s"if ($predicate) $trueValue else $falseValue" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index b65bf165f21db..e6ae81c2aad52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -60,6 +61,14 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + s""" + boolean ${ev.nullTerm} = false; + ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = + new ${ctx.hashSetForType(elementType)}(); + """ + } + override def toString: String = s"new Set($dataType)" } @@ -91,6 +100,23 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val itemEval = item.gen(ctx) + val setEval = set.gen(ctx) + + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = ctx.hashSetForType(elementType) + + itemEval.code + setEval.code + + s""" + if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { + (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); + } + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; + """ + } + override def toString: String = s"$set += $item" } @@ -124,6 +150,22 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres null } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val leftEval = left.gen(ctx) + val rightEval = right.gen(ctx) + + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = ctx.hashSetForType(elementType) + + leftEval.code + rightEval.code + + s""" + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = + (${htype})${leftEval.primitiveTerm}; + ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); + """ + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 8cfd853afa35f..b577de1d5aab9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -33,7 +33,7 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { } catch { case e: Throwable => val ctx = GenerateProjection.newCodeGenContext() - val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) + val evaluated = expression.gen(ctx) fail( s""" |Code generation of $expression failed: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 9ab1f7d7ad0db..9da72521ec3ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -29,7 +29,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { expected: Any, inputRow: Row = EmptyRow): Unit = { val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) + lazy val evaluated = expression.gen(ctx) val plan = try { GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) From 3ff25f81a8fc6840b5c6dc75377fc89e41454586 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Jun 2015 17:45:19 -0700 Subject: [PATCH 02/18] refactor --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 19 +++-- .../sql/catalyst/expressions/Expression.scala | 51 +++++++------ .../sql/catalyst/expressions/arithmetic.scala | 16 ++--- .../expressions/codegen/CodeGenerator.scala | 31 ++++++-- .../codegen/GenerateProjection.scala | 12 ++-- .../expressions/decimalFunctions.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 28 ++++---- .../catalyst/expressions/nullFunctions.scala | 23 +++++- .../sql/catalyst/expressions/predicates.scala | 37 ++++------ .../spark/sql/catalyst/expressions/sets.scala | 71 +++++++++++-------- 11 files changed, 163 insertions(+), 131 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1055be6e9d273..1d7f3b766a160 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -46,8 +46,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { s""" final boolean ${ev.nullTerm} = i.isNullAt($ordinal); - final ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? - ${ctx.defaultPrimitive(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? + ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a986844d18e8f..bf8642cdde535 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -439,33 +439,30 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case Cast(child @ BinaryType(), StringType) => castOrNull (ctx, ev, c => - s"new org.apache.spark.sql.types.UTF8String().set($c)", - StringType) + s"new org.apache.spark.sql.types.UTF8String().set($c)") case Cast(child @ DateType(), StringType) => castOrNull(ctx, ev, c => s"""new org.apache.spark.sql.types.UTF8String().set( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", - StringType) + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c?1:0)", dt) + case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)") case Cast(child @ DecimalType(), IntegerType) => - castOrNull(ctx, ev, c => s"($c).toInt()", IntegerType) + castOrNull(ctx, ev, c => s"($c).toInt()") case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"($c).to${ctx.termForType(dt)}()", dt) + castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c)", dt) + castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. case Cast(e, StringType) if e.dataType != TimestampType => castOrNull(ctx, ev, c => - s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))", - StringType) + s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))") case other => super.genSource(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f66f8f9ff105e..9b89a4bc744c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -90,10 +90,10 @@ abstract class Expression extends TreeNode[Expression] { /* expression: ${this} */ Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i); boolean ${ev.nullTerm} = ${ev.objectTerm} == null; - ${ctx.primitiveForType(e.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(e.dataType)}; + ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(e.dataType)}; if (!${ev.nullTerm}) ${ev.primitiveTerm} = - (${ctx.termForType(e.dataType)})${ev.objectTerm}; + (${ctx.boxedType(e.dataType)})${ev.objectTerm}; """ } @@ -173,12 +173,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express */ def evaluate(ctx: CodeGenContext, ev: EvaluatedExpression, - f: (String, String) => String): String = - evaluateAs(left.dataType)(ctx, ev, f) - - def evaluateAs(resultType: DataType)(ctx: CodeGenContext, - ev: EvaluatedExpression, - f: (String, String) => String): String = { + f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (left.dataType != right.dataType) { // log.warn(s"${left.dataType} != ${right.dataType}") @@ -188,14 +183,19 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express val eval2 = right.gen(ctx) val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - eval1.code + eval2.code + - s""" - boolean ${ev.nullTerm} = ${eval1.nullTerm} || ${eval2.nullTerm}; - ${ctx.primitiveForType(resultType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(resultType)}; - if(!${ev.nullTerm}) { - ${ev.primitiveTerm} = (${ctx.primitiveForType(resultType)})($resultCode); - } - """ + s""" + ${eval1.code} + boolean ${ev.nullTerm} = ${eval1.nullTerm}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; + if (!${ev.nullTerm}) { + ${eval2.code} + if(!${eval2.nullTerm}) { + ${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode); + } else { + ${ev.nullTerm} = true; + } + } + """ } } @@ -207,16 +207,15 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio self: Product => def castOrNull(ctx: CodeGenContext, ev: EvaluatedExpression, - f: String => String, dataType: DataType): String = { + f: String => String): String = { val eval = child.gen(ctx) - eval.code + - s""" - boolean ${ev.nullTerm} = ${eval.nullTerm}; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; - } - """ + eval.code + s""" + boolean ${ev.nullTerm} = ${eval.nullTerm}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; + } + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4320fbf51bd6d..79350dd3d65f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -221,8 +221,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { ${ev.nullTerm} = true; } else { @@ -279,8 +279,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { ${ev.nullTerm} = true; } else { @@ -412,8 +412,8 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { val eval2 = right.gen(ctx) eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm}) { ${ev.nullTerm} = ${eval2.nullTerm}; @@ -468,8 +468,8 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm}) { ${ev.nullTerm} = ${eval2.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bec1899a3aad2..4f21a1892df25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -71,7 +71,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { dataType match { case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)" case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)" - case _ => s"(${termForType(dataType)})i.apply($ordinal)" + case _ => s"(${boxedType(dataType)})i.apply($ordinal)" } } @@ -86,12 +86,12 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { def accessorForType(dt: DataType): String = dt match { case IntegerType => "getInt" - case other => s"get${termForType(dt)}" + case other => s"get${boxedType(dt)}" } def mutatorForType(dt: DataType): String = dt match { case IntegerType => "setInt" - case other => s"set${termForType(dt)}" + case other => s"set${boxedType(dt)}" } def hashSetForType(dt: DataType): String = dt match { @@ -101,7 +101,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { sys.error(s"Code generation not support for hashset of type $unsupportedType") } - def primitiveForType(dt: DataType): String = dt match { + /** + * Return the primitive type for a DataType + */ + def primitiveType(dt: DataType): String = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -117,7 +120,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { case _ => "Object" } - def defaultPrimitive(dt: DataType): String = dt match { + /** + * Return the representation of default value for given DataType + */ + def defaultValue(dt: DataType): String = dt match { case BooleanType => "false" case FloatType => "-1.0f" case ShortType => "-1" @@ -131,7 +137,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { case _ => "null" } - def termForType(dt: DataType): String = dt match { + /** + * Return the boxed type in Java + */ + def boxedType(dt: DataType): String = dt match { case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" @@ -147,6 +156,15 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { case _ => "Object" } + /** + * Returns a function to generate equal expression in Java + */ + def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { + case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" } + case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" } + case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" } + } + /** * List of data types that have special accessors and setters in [[Row]]. */ @@ -166,7 +184,6 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val rowType = classOf[Row].getName protected val exprType = classOf[Expression].getName protected val mutableRowType = classOf[MutableRow].getName protected val genericMutableRowType = classOf[GenericMutableRow].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 0e8ad76f65bad..00c856dc02ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -45,7 +45,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val ctx = newCodeGenContext() val columns = expressions.zipWithIndex.map { case (e, i) => - s"private ${ctx.primitiveForType(e.dataType)} c$i = ${ctx.defaultPrimitive(e.dataType)};\n" + s"private ${ctx.primitiveType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" }.mkString("\n ") val initColumns = expressions.zipWithIndex.map { @@ -68,7 +68,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n ") val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${ctx.termForType(e.dataType)})value; return;}" + s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" }.mkString("\n ") val specificAccessorFunctions = ctx.nativeTypes.map { dataType => @@ -80,14 +80,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public ${ctx.primitiveForType(dataType)} ${ctx.accessorForType(dataType)}(int i) { + public ${ctx.primitiveType(dataType)} ${ctx.accessorForType(dataType)}(int i) { if (isNullAt(i)) { - return ${ctx.defaultPrimitive(dataType)}; + return ${ctx.defaultValue(dataType)}; } switch (i) { $cases } - return ${ctx.defaultPrimitive(dataType)}; + return ${ctx.defaultValue(dataType)}; }""" } else { "" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveForType(dataType)} value) { + public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveType(dataType)} value) { nullBits[i] = false; switch (i) { $cases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 76273a5b7ee68..68daea725cd40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -68,7 +68,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = - ${ctx.defaultPrimitive(DecimalType())}; + ${ctx.defaultValue(DecimalType())}; if (!${ev.nullTerm}) { ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d9fbda9511a5e..366e1083eb687 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -85,7 +85,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres if (value == null) { s""" final boolean ${ev.nullTerm} = true; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; """ } else { dataType match { @@ -93,25 +93,25 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres val v = value.asInstanceOf[UTF8String] val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}" s""" - final boolean ${ev.nullTerm} = false; - org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} = - new org.apache.spark.sql.types.UTF8String().set(${arr}); - """ + final boolean ${ev.nullTerm} = false; + org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} = + new org.apache.spark.sql.types.UTF8String().set(${arr}); + """ case FloatType => s""" - final boolean ${ev.nullTerm} = false; - float ${ev.primitiveTerm} = ${value}f; - """ + final boolean ${ev.nullTerm} = false; + float ${ev.primitiveTerm} = ${value}f; + """ case dt: DecimalType => s""" - final boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveForType(dt)}().set($value); - """ + final boolean ${ev.nullTerm} = false; + ${ctx.primitiveType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dt)}().set($value); + """ case dt: NumericType => s""" - final boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = $value; - """ + final boolean ${ev.nullTerm} = false; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; + """ case other => super.genSource(ctx, ev) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 2af0f96146c1f..79c97f651f540 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -56,7 +56,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { s""" boolean ${ev.nullTerm} = true; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; """ + children.map { e => val eval = e.gen(ctx) @@ -131,4 +131,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } numNonNulls >= n } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val nonnull = ctx.freshName("nonnull") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if($nonnull < $n) { + ${eval.code} + if(!${eval.nullTerm}) { + $nonnull += 1; + } + } + """ + }.mkString("\n") + s""" + int $nonnull = 0; + $code + boolean ${ev.nullTerm} = false; + boolean ${ev.primitiveTerm} = $nonnull >= $n; + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b6b2c7db28960..3c1eeb07a91a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -85,8 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { - // Uh, bad function name... - castOrNull(ctx, ev, c => s"!($c)", BooleanType) + castOrNull(ctx, ev, c => s"!($c)") } } @@ -221,13 +220,14 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { left.dataType match { - case dt: NumericType => evaluateAs(BooleanType) (ctx, ev, { - (eval1, eval2) => s"$eval1 $symbol $eval2" + case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" }) - case dt: TimestampType => + case TimestampType => + // java.sql.Timestamp does not have compare() super.genSource(ctx, ev) - case other => evaluateAs(BooleanType) (ctx, ev, { - (eval1, eval2) => s"$eval1.compare($eval2) $symbol 0" + case other => evaluate (ctx, ev, { + (c1, c2) => s"$c1.compare($c2) $symbol 0" }) } } @@ -277,15 +277,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression) = { - left.dataType match { - case BinaryType() => - evaluateAs (BooleanType) (ctx, ev, { - case (eval1, eval2) => - s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" - }) - case other => - evaluateAs (BooleanType) (ctx, ev, { case (eval1, eval2) => s"$eval1 == $eval2" }) - } + evaluate(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -311,16 +303,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val cmpCode = if (left.dataType.isInstanceOf[BinaryType]) { - s"java.util.Arrays.equals((byte[])${eval1.primitiveTerm}, (byte[])${eval2.primitiveTerm})" - } else { - s"${eval1.primitiveTerm} == ${eval2.primitiveTerm}" - } - eval1.code + eval2.code + - s""" + val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) + eval1.code + eval2.code + s""" final boolean ${ev.nullTerm} = false; final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) || - (!${eval1.nullTerm} && $cmpCode); + (!${eval1.nullTerm} && $equalCode); """ } } @@ -403,7 +390,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; ${condEval.code} if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { ${trueEval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index e6ae81c2aad52..22755b6ecb7e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -62,11 +62,15 @@ case class NewSet(elementType: DataType) extends LeafExpression { } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { - s""" - boolean ${ev.nullTerm} = false; - ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = - new ${ctx.hashSetForType(elementType)}(); - """ + elementType match { + case IntegerType | LongType => + s""" + boolean ${ev.nullTerm} = false; + ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = + new ${ctx.hashSetForType(elementType)}(); + """ + case _ => super.genSource(ctx, ev) + } } override def toString: String = s"new Set($dataType)" @@ -101,20 +105,22 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { - val itemEval = item.gen(ctx) - val setEval = set.gen(ctx) - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = ctx.hashSetForType(elementType) - - itemEval.code + setEval.code + - s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); - } - boolean ${ev.nullTerm} = false; - ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; - """ + elementType match { + case IntegerType | LongType => + val itemEval = item.gen(ctx) + val setEval = set.gen(ctx) + val htype = ctx.hashSetForType(elementType) + + itemEval.code + setEval.code + s""" + if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { + (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); + } + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; + """ + case _ => super.genSource(ctx, ev) + } } override def toString: String = s"$set += $item" @@ -152,19 +158,20 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { - val leftEval = left.gen(ctx) - val rightEval = right.gen(ctx) - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = ctx.hashSetForType(elementType) - - leftEval.code + rightEval.code + - s""" - boolean ${ev.nullTerm} = false; - ${htype} ${ev.primitiveTerm} = - (${htype})${leftEval.primitiveTerm}; - ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); - """ + elementType match { + case IntegerType | LongType => + val leftEval = left.gen(ctx) + val rightEval = right.gen(ctx) + val htype = ctx.hashSetForType(elementType) + + leftEval.code + rightEval.code + s""" + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = ${leftEval.primitiveTerm}; + ${ev.primitiveTerm}.union(${rightEval.primitiveTerm}); + """ + case _ => super.genSource(ctx, ev) + } } } @@ -184,5 +191,9 @@ case class CountSet(child: Expression) extends UnaryExpression { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + castOrNull(ctx, ev, c => s"$c.size().toLong()") + } + override def toString: String = s"$child.count()" } From e57959d60bb841851623898790a5cb1cba314cdd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 17:00:17 -0700 Subject: [PATCH 03/18] add type alias --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 6 +-- .../sql/catalyst/expressions/Expression.scala | 20 ++------- .../sql/catalyst/expressions/arithmetic.scala | 16 +++---- .../expressions/codegen/CodeGenerator.scala | 42 +++++++++++-------- .../expressions/codegen/package.scala | 3 ++ .../expressions/decimalFunctions.scala | 6 +-- .../sql/catalyst/expressions/literals.scala | 6 +-- .../catalyst/expressions/nullFunctions.scala | 8 ++-- .../sql/catalyst/expressions/predicates.scala | 18 ++++---- .../spark/sql/catalyst/expressions/sets.scala | 14 +++---- 11 files changed, 69 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1d7f3b766a160..5978d1c931f37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { s""" final boolean ${ev.nullTerm} = i.isNullAt($ordinal); final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bf8642cdde535..bcd7781c09e00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (evaluated == null) null else cast(evaluated) } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = this match { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = this match { case Cast(child @ BinaryType(), StringType) => castOrNull (ctx, ev, c => @@ -465,7 +465,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))") case other => - super.genSource(ctx, ev) + super.genCode(ctx, ev) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9b89a4bc744c3..6efa08626795e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -62,28 +62,14 @@ abstract class Expression extends TreeNode[Expression] { val primitiveTerm = ctx.freshName("primitiveTerm") val objectTerm = ctx.freshName("objectTerm") val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm) - ve.code = genSource(ctx, ve) - - // Only inject debugging code if debugging is turned on. - // val debugCode = - // if (debugLogging) { - // val localLogger = log - // val localLoggerTree = reify { localLogger } - // s""" - // $localLoggerTree.debug( - // ${this.toString} + ": " + (if (${ev.nullTerm}) "null" else ${ev.primitiveTerm}.toString)) - // """ - // } else { - // "" - // } - + ve.code = genCode(ctx, ve) ve } /** * Returns Java source code for this expression */ - def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val e = this.asInstanceOf[Expression] ctx.references += e s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 79350dd3d65f2..6ae815e1d0096 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -117,7 +117,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { if (left.dataType.isInstanceOf[DecimalType]) { evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } ) } else { @@ -205,7 +205,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -263,7 +263,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -406,7 +406,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -430,7 +430,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } else { - super.genSource(ctx, ev) + super.genCode(ctx, ev) } } override def toString: String = s"MaxOf($left, $right)" @@ -460,7 +460,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) @@ -486,7 +486,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } else { - super.genSource(ctx, ev) + super.genCode(ctx, ev) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4f21a1892df25..c87258c622664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -41,16 +41,22 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * valid if `nullTerm` is set to `true`. * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ -case class EvaluatedExpression(var code: String, - nullTerm: String, - primitiveTerm: String, - objectTerm: String) +case class EvaluatedExpression(var code: Code, + nullTerm: Term, + primitiveTerm: Term, + objectTerm: Term) /** - * A context for codegen - * @param references the expressions that don't support codegen + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. */ -case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { +class CodeGenContext { + + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]() protected val stringType = classOf[UTF8String].getName protected val decimalType = classOf[Decimal].getName @@ -63,11 +69,11 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): String = { + def freshName(prefix: String): Term = { s"$prefix${curId.getAndIncrement}" } - def getColumn(dataType: DataType, ordinal: Int): String = { + def getColumn(dataType: DataType, ordinal: Int): Code = { dataType match { case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)" case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)" @@ -75,7 +81,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { } } - def setColumn(destinationRow: String, dataType: DataType, ordinal: Int, value: String): String = { + def setColumn(destinationRow: Term, dataType: DataType, ordinal: Int, value: Term): Code = { dataType match { case StringType => s"$destinationRow.update($ordinal, $value)" case dt: DataType if isNativeType(dt) => @@ -84,17 +90,17 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { } } - def accessorForType(dt: DataType): String = dt match { + def accessorForType(dt: DataType): Term = dt match { case IntegerType => "getInt" case other => s"get${boxedType(dt)}" } - def mutatorForType(dt: DataType): String = dt match { + def mutatorForType(dt: DataType): Term = dt match { case IntegerType => "setInt" case other => s"set${boxedType(dt)}" } - def hashSetForType(dt: DataType): String = dt match { + def hashSetForType(dt: DataType): Term = dt match { case IntegerType => classOf[IntegerHashSet].getName case LongType => classOf[LongHashSet].getName case unsupportedType => @@ -104,7 +110,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { /** * Return the primitive type for a DataType */ - def primitiveType(dt: DataType): String = dt match { + def primitiveType(dt: DataType): Term = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -123,7 +129,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { /** * Return the representation of default value for given DataType */ - def defaultValue(dt: DataType): String = dt match { + def defaultValue(dt: DataType): Term = dt match { case BooleanType => "false" case FloatType => "-1.0f" case ShortType => "-1" @@ -140,7 +146,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { /** * Return the boxed type in Java */ - def boxedType(dt: DataType): String = dt match { + def boxedType(dt: DataType): Term = dt match { case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" @@ -159,7 +165,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { /** * Returns a function to generate equal expression in Java */ - def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { + def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match { case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" } case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" } case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" } @@ -257,6 +263,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * expressions that don't support codegen */ def newCodeGenContext(): CodeGenContext = { - new CodeGenContext(new mutable.ArrayBuffer[Expression]()) + new CodeGenContext } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 7f1b12cdd5800..6f9589d20445e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,6 +27,9 @@ import org.apache.spark.util.Utils */ package object codegen { + type Term = String + type Code = String + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 68daea725cd40..250fe00b174bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) eval.code +s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; @@ -63,7 +63,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 366e1083eb687..159df36ececff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, EvaluatedExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, EvaluatedExpression} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -81,7 +81,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def eval(input: Row): Any = value - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { if (value == null) { s""" final boolean ${ev.nullTerm} = true; @@ -113,7 +113,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; """ case other => - super.genSource(ctx, ev) + super.genCode(ctx, ev) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 79c97f651f540..46582173e93b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -53,7 +53,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { s""" boolean ${ev.nullTerm} = true; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; @@ -81,7 +81,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) eval.code + s""" final boolean ${ev.nullTerm} = false; @@ -101,7 +101,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E child.eval(input) != null } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = false; @@ -132,7 +132,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3c1eeb07a91a4..4cd8bff0f4d47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -84,7 +84,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { castOrNull(ctx, ev, c => s"!($c)") } } @@ -146,7 +146,7 @@ case class And(left: Expression, right: Expression) } } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) s""" @@ -192,7 +192,7 @@ case class Or(left: Expression, right: Expression) } } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) s""" @@ -218,14 +218,14 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { left.dataType match { case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" }) case TimestampType => // java.sql.Timestamp does not have compare() - super.genSource(ctx, ev) + super.genCode(ctx, ev) case other => evaluate (ctx, ev, { (c1, c2) => s"$c1.compare($c2) $symbol 0" }) @@ -276,7 +276,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression) = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression) = { evaluate(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -300,7 +300,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) @@ -383,7 +383,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi falseValue.eval(input) } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 22755b6ecb7e9..d62212d669276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -61,7 +61,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { elementType match { case IntegerType | LongType => s""" @@ -69,7 +69,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = new ${ctx.hashSetForType(elementType)}(); """ - case _ => super.genSource(ctx, ev) + case _ => super.genCode(ctx, ev) } } @@ -104,7 +104,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => @@ -119,7 +119,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { boolean ${ev.nullTerm} = false; ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; """ - case _ => super.genSource(ctx, ev) + case _ => super.genCode(ctx, ev) } } @@ -157,7 +157,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => @@ -170,7 +170,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres ${htype} ${ev.primitiveTerm} = ${leftEval.primitiveTerm}; ${ev.primitiveTerm}.union(${rightEval.primitiveTerm}); """ - case _ => super.genSource(ctx, ev) + case _ => super.genCode(ctx, ev) } } } @@ -191,7 +191,7 @@ case class CountSet(child: Expression) extends UnaryExpression { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { castOrNull(ctx, ev, c => s"$c.size().toLong()") } From b1450476fb355699326aaedcc5f9bac2806e4cd5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 17:05:42 -0700 Subject: [PATCH 04/18] fix style --- .../spark/sql/catalyst/expressions/decimalFunctions.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/literals.scala | 3 ++- .../org/apache/spark/sql/catalyst/expressions/predicates.scala | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 250fe00b174bf..d88cdc7dd2c12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -39,7 +39,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) - eval.code +s""" + eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; long ${ev.primitiveTerm} = ${ev.nullTerm} ? -1 : ${eval.primitiveTerm}.toUnscaledLong(); """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 159df36ececff..5cb3f26e9dc50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -105,7 +105,8 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres case dt: DecimalType => s""" final boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dt)}().set($value); + ${ctx.primitiveType(dt)} ${ev.primitiveTerm} = + new ${ctx.primitiveType(dt)}().set($value); """ case dt: NumericType => s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4cd8bff0f4d47..1a89f5bdb4dea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -276,7 +276,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression) = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { evaluate(ctx, ev, ctx.equalFunc(left.dataType)) } } From 8c6d82d61fdc81755b8971fdaa73729617d6df2f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 17:09:05 -0700 Subject: [PATCH 05/18] update docs --- .../spark/sql/catalyst/expressions/Expression.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6efa08626795e..f4f866331f569 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -55,7 +55,9 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns an [[EvaluatedExpression]], which contains Java source code that * can be used to generate the result of evaluating the expression on an input row. + * * @param ctx a [[CodeGenContext]] + * @return [[EvaluatedExpression]] */ def gen(ctx: CodeGenContext): EvaluatedExpression = { val nullTerm = ctx.freshName("nullTerm") @@ -67,7 +69,11 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Returns Java source code for this expression + * Returns Java source code for this expression. + * + * @param ctx a [[CodeGenContext]] + * @param ev an [[EvaluatedExpression]] with unique terms. + * @return Java source code */ def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val e = this.asInstanceOf[Expression] From c5fb5146ee476b4f4b70ac34dbca0cafd9e249dd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 17:18:40 -0700 Subject: [PATCH 06/18] rename --- .../catalyst/expressions/BoundAttribute.scala | 4 ++-- .../spark/sql/catalyst/expressions/Cast.scala | 10 +++++----- .../sql/catalyst/expressions/Expression.scala | 20 +++++++++---------- .../sql/catalyst/expressions/arithmetic.scala | 12 +++++------ .../expressions/codegen/CodeGenerator.scala | 14 ++++++------- .../expressions/decimalFunctions.scala | 6 +++--- .../sql/catalyst/expressions/literals.scala | 7 +++---- .../catalyst/expressions/nullFunctions.scala | 10 +++++----- .../sql/catalyst/expressions/predicates.scala | 16 +++++++-------- .../spark/sql/catalyst/expressions/sets.scala | 10 +++++----- 10 files changed, 54 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5978d1c931f37..478ee997a96a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { s""" final boolean ${ev.nullTerm} = i.isNullAt($ordinal); final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bcd7781c09e00..d31e004b9c348 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -435,15 +435,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (evaluated == null) null else cast(evaluated) } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = this match { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = this match { case Cast(child @ BinaryType(), StringType) => castOrNull (ctx, ev, c => - s"new org.apache.spark.sql.types.UTF8String().set($c)") + s"new ${ctx.stringType}().set($c)") case Cast(child @ DateType(), StringType) => castOrNull(ctx, ev, c => - s"""new org.apache.spark.sql.types.UTF8String().set( + s"""new ${ctx.stringType}().set( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => @@ -462,7 +462,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // does not match the expected output. case Cast(e, StringType) if e.dataType != TimestampType => castOrNull(ctx, ev, c => - s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))") + s"new ${ctx.stringType}().set(String.valueOf($c))") case other => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f4f866331f569..1f1a2fc9694af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -53,17 +53,17 @@ abstract class Expression extends TreeNode[Expression] { def eval(input: Row = null): Any /** - * Returns an [[EvaluatedExpression]], which contains Java source code that + * Returns an [[GeneratedExpressionCode]], which contains Java source code that * can be used to generate the result of evaluating the expression on an input row. - * + * * @param ctx a [[CodeGenContext]] - * @return [[EvaluatedExpression]] + * @return [[GeneratedExpressionCode]] */ - def gen(ctx: CodeGenContext): EvaluatedExpression = { + def gen(ctx: CodeGenContext): GeneratedExpressionCode = { val nullTerm = ctx.freshName("nullTerm") val primitiveTerm = ctx.freshName("primitiveTerm") val objectTerm = ctx.freshName("objectTerm") - val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm) + val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm, objectTerm) ve.code = genCode(ctx, ve) ve } @@ -72,10 +72,10 @@ abstract class Expression extends TreeNode[Expression] { * Returns Java source code for this expression. * * @param ctx a [[CodeGenContext]] - * @param ev an [[EvaluatedExpression]] with unique terms. + * @param ev an [[GeneratedExpressionCode]] with unique terms. * @return Java source code */ - def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val e = this.asInstanceOf[Expression] ctx.references += e s""" @@ -164,7 +164,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express * @param f a function from two primitive term names to a tree that evaluates them. */ def evaluate(ctx: CodeGenContext, - ev: EvaluatedExpression, + ev: GeneratedExpressionCode, f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (left.dataType != right.dataType) { @@ -198,7 +198,7 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => def castOrNull(ctx: CodeGenContext, - ev: EvaluatedExpression, + ev: GeneratedExpressionCode, f: String => String): String = { val eval = child.gen(ctx) eval.code + s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 6ae815e1d0096..aad8479dafe41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{Code, EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -117,7 +117,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { if (left.dataType.isInstanceOf[DecimalType]) { evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } ) } else { @@ -205,7 +205,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -263,7 +263,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -406,7 +406,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -460,7 +460,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c87258c622664..0a47957bec23c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -41,10 +41,10 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * valid if `nullTerm` is set to `true`. * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ -case class EvaluatedExpression(var code: Code, - nullTerm: Term, - primitiveTerm: Term, - objectTerm: Term) +case class GeneratedExpressionCode(var code: Code, + nullTerm: Term, + primitiveTerm: Term, + objectTerm: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported @@ -58,8 +58,8 @@ class CodeGenContext { */ val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]() - protected val stringType = classOf[UTF8String].getName - protected val decimalType = classOf[Decimal].getName + val stringType = classOf[UTF8String].getName + val decimalType = classOf[Decimal].getName private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -75,7 +75,7 @@ class CodeGenContext { def getColumn(dataType: DataType, ordinal: Int): Code = { dataType match { - case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)" + case StringType => s"($stringType)i.apply($ordinal)" case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)" case _ => s"(${boxedType(dataType)})i.apply($ordinal)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index d88cdc7dd2c12..80c51cb3588ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; @@ -63,7 +63,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 5cb3f26e9dc50..21e21000c9437 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, EvaluatedExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -81,7 +81,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def eval(input: Row): Any = value - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { if (value == null) { s""" final boolean ${ev.nullTerm} = true; @@ -94,8 +94,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}" s""" final boolean ${ev.nullTerm} = false; - org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} = - new org.apache.spark.sql.types.UTF8String().set(${arr}); + ${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr}); """ case FloatType => s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 46582173e93b0..d4b35edb33b4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType @@ -53,7 +53,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { s""" boolean ${ev.nullTerm} = true; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; @@ -81,7 +81,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) eval.code + s""" final boolean ${ev.nullTerm} = false; @@ -101,7 +101,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E child.eval(input) != null } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = false; @@ -132,7 +132,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1a89f5bdb4dea..ad4535a09e04e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -84,7 +84,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { castOrNull(ctx, ev, c => s"!($c)") } } @@ -146,7 +146,7 @@ case class And(left: Expression, right: Expression) } } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) s""" @@ -192,7 +192,7 @@ case class Or(left: Expression, right: Expression) } } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) s""" @@ -218,7 +218,7 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { left.dataType match { case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" @@ -276,7 +276,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { evaluate(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -300,7 +300,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) @@ -383,7 +383,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi falseValue.eval(input) } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index d62212d669276..55fd748f96b12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -61,7 +61,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { elementType match { case IntegerType | LongType => s""" @@ -104,7 +104,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => @@ -157,7 +157,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => @@ -191,7 +191,7 @@ case class CountSet(child: Expression) extends UnaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { castOrNull(ctx, ev, c => s"$c.size().toLong()") } From 12ff88a4297871ec8a39bddedf0c377a46361c1b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 17:35:12 -0700 Subject: [PATCH 07/18] fix build --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- .../apache/spark/sql/catalyst/expressions/nullFunctions.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/expressions/sets.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0a47957bec23c..4885eec08fca9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,7 +56,7 @@ class CodeGenContext { /** * Holding all the expressions those do not support codegen, will be evaluated directly. */ - val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]() + val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() val stringType = classOf[UTF8String].getName val decimalType = classOf[Decimal].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index d4b35edb33b4c..7b26bd2697195 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 55fd748f96b12..46cad9b019584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet From 2344bc0d48fc2a3ec91de69a6233665a0ae3635e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 19:43:12 -0700 Subject: [PATCH 08/18] fix test --- .../expressions/codegen/CodeGenerator.scala | 17 +++++++++++++---- .../spark/sql/catalyst/expressions/sets.scala | 8 ++------ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4885eec08fca9..06cc6e1024b01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -166,9 +166,12 @@ class CodeGenContext { * Returns a function to generate equal expression in Java */ def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match { - case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" } - case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" } - case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" } + case BinaryType => { case (eval1, eval2) => + s"java.util.Arrays.equals($eval1, $eval2)" } + case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => + { case (eval1, eval2) => s"$eval1 == $eval2" } + case other => + { case (eval1, eval2) => s"$eval1.equals($eval2)" } } /** @@ -221,7 +224,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ protected def compile(code: String): Class[_] = { val startTime = System.nanoTime() - val clazz = new ClassBodyEvaluator(code).getClazz() + val clazz = try { + new ClassBodyEvaluator(code).getClazz() + } catch { + case e: Exception => + logError(s"failed to compile:\n $code", e) + throw e + } val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 46cad9b019584..a0dae40d964e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -167,8 +167,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres leftEval.code + rightEval.code + s""" boolean ${ev.nullTerm} = false; - ${htype} ${ev.primitiveTerm} = ${leftEval.primitiveTerm}; - ${ev.primitiveTerm}.union(${rightEval.primitiveTerm}); + ${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm}; + ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); """ case _ => super.genCode(ctx, ev) } @@ -191,9 +191,5 @@ case class CountSet(child: Expression) extends UnaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - castOrNull(ctx, ev, c => s"$c.size().toLong()") - } - override def toString: String = s"$child.count()" } From 48c454ff529f045e6c24b6f73d85d09dd4b4279a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 5 Jun 2015 01:08:56 -0700 Subject: [PATCH 09/18] Some code gen update. --- .../spark/sql/catalyst/expressions/Cast.scala | 64 ++++++++++++------- .../sql/catalyst/expressions/Expression.scala | 41 ++++++++---- .../sql/catalyst/expressions/arithmetic.scala | 19 ++++-- .../expressions/decimalFunctions.scala | 11 ++-- .../sql/catalyst/expressions/literals.scala | 5 +- .../catalyst/expressions/nullFunctions.scala | 8 +-- .../sql/catalyst/expressions/predicates.scala | 10 +-- 7 files changed, 99 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d31e004b9c348..634750dca2158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -435,37 +435,57 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (evaluated == null) null else cast(evaluated) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = this match { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + // TODO(cg): Add support for more data types. + (child.dataType, dataType) match { - case Cast(child @ BinaryType(), StringType) => - castOrNull (ctx, ev, c => - s"new ${ctx.stringType}().set($c)") + case (BinaryType, StringType) => + defineCodeGen (ctx, ev, c => + s"new ${ctx.stringType}().set($c)") - case Cast(child @ DateType(), StringType) => - castOrNull(ctx, ev, c => - s"""new ${ctx.stringType}().set( + case (DateType, StringType) => + defineCodeGen(ctx, ev, c => + s"""new ${ctx.stringType}().set( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)") + case (BooleanType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => + defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)") - case Cast(child @ DecimalType(), IntegerType) => - castOrNull(ctx, ev, c => s"($c).toInt()") + case (_: NumericType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => + defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") - case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + case (_: DecimalType, ByteType) => + defineCodeGen(ctx, ev, c => s"($c).toByte()") - case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") + case (_: DecimalType, ShortType) => + defineCodeGen(ctx, ev, c => s"($c).toShort()") - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. - case Cast(e, StringType) if e.dataType != TimestampType => - castOrNull(ctx, ev, c => - s"new ${ctx.stringType}().set(String.valueOf($c))") + case (_: DecimalType, IntegerType) => + defineCodeGen(ctx, ev, c => s"($c).toInt()") - case other => - super.genCode(ctx, ev) + case (_: DecimalType, LongType) => + defineCodeGen(ctx, ev, c => s"($c).toLong()") + + case (_: DecimalType, FloatType) => + defineCodeGen(ctx, ev, c => s"($c).toFloat()") + + case (_: DecimalType, DoubleType) => + defineCodeGen(ctx, ev, c => s"($c).toDouble()") + + case (_: DecimalType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => + defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case (TimestampType, StringType) => + super.genCode(ctx, ev) + + case (_, StringType) => + defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))") + + case other => + super.genCode(ctx, ev) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1f1a2fc9694af..db085c8c277ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -69,7 +69,9 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Returns Java source code for this expression. + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. * * @param ctx a [[CodeGenContext]] * @param ev an [[GeneratedExpressionCode]] with unique terms. @@ -82,10 +84,10 @@ abstract class Expression extends TreeNode[Expression] { /* expression: ${this} */ Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i); boolean ${ev.nullTerm} = ${ev.objectTerm} == null; - ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultValue(e.dataType)}; - if (!${ev.nullTerm}) ${ev.primitiveTerm} = - (${ctx.boxedType(e.dataType)})${ev.objectTerm}; + ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)}; + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${ev.objectTerm}; + } """ } @@ -155,17 +157,17 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" - /** * Short hand for generating binary evaluation code, which depends on two sub-evaluations of * the same type. If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f a function from two primitive term names to a tree that evaluates them. + * @param f accepts two variable names and returns Java code to compute the output. */ - def evaluate(ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (left.dataType != right.dataType) { // log.warn(s"${left.dataType} != ${right.dataType}") @@ -197,9 +199,22 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - def castOrNull(ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: String => String): String = { + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + * + * As an example, the following does a boolean inversion (i.e. NOT). + * {{{ + * defineCodeGen(ctx, ev, c => s"!($c)") + * }}} + * + * @param f function that accepts a variable name and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: String => String): String = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index aad8479dafe41..a049f8878ed32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -87,6 +87,7 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => + /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = "" override def dataType: DataType = left.dataType @@ -119,9 +120,9 @@ abstract class BinaryArithmetic extends BinaryExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { if (left.dataType.isInstanceOf[DecimalType]) { - evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } ) + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") } else { - evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } ) + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } } @@ -205,6 +206,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } + /** + * Special case handling due to division by 0 => null. + */ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -221,8 +225,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultValue(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { ${ev.nullTerm} = true; } else { @@ -263,6 +266,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } + /** + * Special case handling for x % 0 ==> null. + */ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -279,8 +285,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultValue(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { ${ev.nullTerm} = true; } else { @@ -337,7 +342,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } /** - * A function that calculates bitwise xor(^) of two numbers. + * A function that calculates bitwise xor of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 80c51cb3588ad..f1d8313b5f1dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -67,14 +67,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; - org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = - ${ctx.defaultValue(DecimalType())}; + org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = ${ctx.defaultValue(DecimalType())}; if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); - ${ev.primitiveTerm} = - ${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale); - ${ev.nullTerm} = ${ev.primitiveTerm} == null; + ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); + ${ev.primitiveTerm} = + ${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale); + ${ev.nullTerm} = ${ev.primitiveTerm} == null; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 21e21000c9437..1899c47613aae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -88,6 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; """ } else { + // TODO(cg): Add support for more data types. dataType match { case StringType => val v = value.asInstanceOf[UTF8String] @@ -96,12 +97,12 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres final boolean ${ev.nullTerm} = false; ${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr}); """ - case FloatType => + case FloatType => // This must go before NumericType s""" final boolean ${ev.nullTerm} = false; float ${ev.primitiveTerm} = ${value}f; """ - case dt: DecimalType => + case dt: DecimalType => // This must go before NumericType s""" final boolean ${ev.nullTerm} = false; ${ctx.primitiveType(dt)} ${ev.primitiveTerm} = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 7b26bd2697195..e380eafc3fc2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -61,9 +61,9 @@ case class Coalesce(children: Seq[Expression]) extends Expression { children.map { e => val eval = e.gen(ctx) s""" - if(${ev.nullTerm}) { + if (${ev.nullTerm}) { ${eval.code} - if(!${eval.nullTerm}) { + if (!${eval.nullTerm}) { ${ev.nullTerm} = false; ${ev.primitiveTerm} = ${eval.primitiveTerm}; } @@ -137,9 +137,9 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val code = children.map { e => val eval = e.gen(ctx) s""" - if($nonnull < $n) { + if ($nonnull < $n) { ${eval.code} - if(!${eval.nullTerm}) { + if (!${eval.nullTerm}) { $nonnull += 1; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ad4535a09e04e..67cac26fd0d55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - castOrNull(ctx, ev, c => s"!($c)") + defineCodeGen(ctx, ev, c => s"!($c)") } } @@ -220,13 +220,13 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { left.dataType match { - case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, { + case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" }) case TimestampType => // java.sql.Timestamp does not have compare() super.genCode(ctx, ev) - case other => evaluate (ctx, ev, { + case other => defineCodeGen (ctx, ev, { (c1, c2) => s"$c1.compare($c2) $symbol 0" }) } @@ -277,7 +277,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - evaluate(ctx, ev, ctx.equalFunc(left.dataType)) + defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -392,7 +392,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi boolean ${ev.nullTerm} = false; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; ${condEval.code} - if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) { ${trueEval.code} ${ev.nullTerm} = ${trueEval.nullTerm}; ${ev.primitiveTerm} = ${trueEval.primitiveTerm}; From 02262c91747bcd12579ebc7784d5047f6b73a268 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Jun 2015 13:04:01 -0700 Subject: [PATCH 10/18] address comments --- .../spark/sql/catalyst/expressions/Cast.scala | 48 ++++------- .../sql/catalyst/expressions/Expression.scala | 29 +++---- .../expressions/codegen/CodeGenerator.scala | 82 +++++++++---------- .../codegen/GenerateMutableProjection.scala | 2 +- .../expressions/decimalFunctions.scala | 7 +- .../sql/catalyst/expressions/literals.scala | 23 ++---- .../spark/sql/catalyst/expressions/sets.scala | 17 ++-- 7 files changed, 89 insertions(+), 119 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 634750dca2158..a6f805da242ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -442,47 +442,35 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BinaryType, StringType) => defineCodeGen (ctx, ev, c => s"new ${ctx.stringType}().set($c)") - case (DateType, StringType) => defineCodeGen(ctx, ev, c => s"""new ${ctx.stringType}().set( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") - - case (BooleanType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => - defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)") - - case (_: NumericType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => - defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") - - case (_: DecimalType, ByteType) => - defineCodeGen(ctx, ev, c => s"($c).toByte()") - - case (_: DecimalType, ShortType) => - defineCodeGen(ctx, ev, c => s"($c).toShort()") - - case (_: DecimalType, IntegerType) => - defineCodeGen(ctx, ev, c => s"($c).toInt()") - - case (_: DecimalType, LongType) => - defineCodeGen(ctx, ev, c => s"($c).toLong()") - - case (_: DecimalType, FloatType) => - defineCodeGen(ctx, ev, c => s"($c).toFloat()") - - case (_: DecimalType, DoubleType) => - defineCodeGen(ctx, ev, c => s"($c).toDouble()") - - case (_: DecimalType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => - defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") - // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. case (TimestampType, StringType) => super.genCode(ctx, ev) - case (_, StringType) => defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))") + // fallback for DecimalType, this must be before other numeric types + case (_, dt: DecimalType) => + super.genCode(ctx, ev) + + case (BooleanType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)") + case (dt: DecimalType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c.isZero()") + case (dt: NumericType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c != 0") + + case (_: DecimalType, IntegerType) => + defineCodeGen(ctx, ev, c => s"($c).toInt()") + case (_: DecimalType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + case (_: NumericType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") + case other => super.genCode(ctx, ev) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index db085c8c277ea..2df6737adb42b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -62,8 +62,7 @@ abstract class Expression extends TreeNode[Expression] { def gen(ctx: CodeGenContext): GeneratedExpressionCode = { val nullTerm = ctx.freshName("nullTerm") val primitiveTerm = ctx.freshName("primitiveTerm") - val objectTerm = ctx.freshName("objectTerm") - val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm, objectTerm) + val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm) ve.code = genCode(ctx, ve) ve } @@ -77,17 +76,18 @@ abstract class Expression extends TreeNode[Expression] { * @param ev an [[GeneratedExpressionCode]] with unique terms. * @return Java source code */ - def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val e = this.asInstanceOf[Expression] ctx.references += e + val objectTerm = ctx.freshName("obj") s""" - /* expression: ${this} */ - Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i); - boolean ${ev.nullTerm} = ${ev.objectTerm} == null; - ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)}; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${ev.objectTerm}; - } + /* expression: ${this} */ + final Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); + final boolean ${ev.nullTerm} = ${objectTerm} == null; + ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)}; + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${objectTerm}; + } """ } @@ -167,7 +167,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express protected def defineCodeGen( ctx: CodeGenContext, ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + f: (Term, Term) => Code): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (left.dataType != right.dataType) { // log.warn(s"${left.dataType} != ${right.dataType}") @@ -214,10 +214,11 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio protected def defineCodeGen( ctx: CodeGenContext, ev: GeneratedExpressionCode, - f: String => String): String = { + f: Term => Code): Code = { val eval = child.gen(ctx) + // reuse the previous nullTerm + ev.nullTerm = eval.nullTerm eval.code + s""" - boolean ${ev.nullTerm} = ${eval.nullTerm}; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; if (!${ev.nullTerm}) { ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 06cc6e1024b01..c963971d28cb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -39,12 +39,8 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * to null. * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not * valid if `nullTerm` is set to `true`. - * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ -case class GeneratedExpressionCode(var code: Code, - nullTerm: Term, - primitiveTerm: Term, - objectTerm: Term) +case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, primitiveTerm: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported @@ -73,40 +69,44 @@ class CodeGenContext { s"$prefix${curId.getAndIncrement}" } + /** + * Return the code to access a column for given DataType + */ def getColumn(dataType: DataType, ordinal: Int): Code = { - dataType match { - case StringType => s"($stringType)i.apply($ordinal)" - case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)" - case _ => s"(${boxedType(dataType)})i.apply($ordinal)" + if (isNativeType(dataType)) { + s"i.${accessorForType(dataType)}($ordinal)" + } else { + s"(${boxedType(dataType)})i.apply($ordinal)" } } - def setColumn(destinationRow: Term, dataType: DataType, ordinal: Int, value: Term): Code = { - dataType match { - case StringType => s"$destinationRow.update($ordinal, $value)" - case dt: DataType if isNativeType(dt) => - s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => s"$destinationRow.update($ordinal, $value)" + /** + * Return the code to update a column in Row for given DataType + */ + def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = { + if (isNativeType(dataType)) { + s"${mutatorForType(dataType)}($ordinal, $value)" + } else { + s"update($ordinal, $value)" } } + /** + * Return the name of accessor in Row for a DataType + */ def accessorForType(dt: DataType): Term = dt match { case IntegerType => "getInt" case other => s"get${boxedType(dt)}" } + /** + * Return the name of mutator in Row for a DataType + */ def mutatorForType(dt: DataType): Term = dt match { case IntegerType => "setInt" case other => s"set${boxedType(dt)}" } - def hashSetForType(dt: DataType): Term = dt match { - case IntegerType => classOf[IntegerHashSet].getName - case LongType => classOf[LongHashSet].getName - case unsupportedType => - sys.error(s"Code generation not support for hashset of type $unsupportedType") - } - /** * Return the primitive type for a DataType */ @@ -123,9 +123,26 @@ class CodeGenContext { case StringType => stringType case DateType => "int" case TimestampType => "java.sql.Timestamp" + case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName + case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" } + /** + * Return the boxed type in Java + */ + def boxedType(dt: DataType): Term = dt match { + case IntegerType => "Integer" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case DateType => "Integer" + case _ => primitiveType(dt) + } + /** * Return the representation of default value for given DataType */ @@ -138,30 +155,9 @@ class CodeGenContext { case DoubleType => "-1.0" case IntegerType => "-1" case DateType => "-1" - case dt: DecimalType => "null" - case StringType => "null" case _ => "null" } - /** - * Return the boxed type in Java - */ - def boxedType(dt: DataType): Term = dt match { - case IntegerType => "Integer" - case LongType => "Long" - case ShortType => "Short" - case ByteType => "Byte" - case DoubleType => "Double" - case FloatType => "Float" - case BooleanType => "Boolean" - case dt: DecimalType => decimalType - case BinaryType => "byte[]" - case StringType => stringType - case DateType => "Integer" - case TimestampType => "java.sql.Timestamp" - case _ => "Object" - } - /** * Returns a function to generate equal expression in Java */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 02b7d3fae6767..4b641701008c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu if(${evaluationCode.nullTerm}) mutableRow.setNullAt($i); else - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; + mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitiveTerm)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index f1d8313b5f1dc..f7df2340edb6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -67,12 +67,11 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; - org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = ${ctx.defaultValue(DecimalType())}; + ${ctx.decimalType} ${ev.primitiveTerm} = null; if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); - ${ev.primitiveTerm} = - ${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale); + ${ev.primitiveTerm} = (new ${ctx.decimalType}()).setOrNull( + ${eval.primitiveTerm}, $precision, $scale); ${ev.nullTerm} = ${ev.primitiveTerm} == null; } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 1899c47613aae..86cdc7a6e914f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -85,34 +85,21 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres if (value == null) { s""" final boolean ${ev.nullTerm} = true; - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; + final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; """ } else { - // TODO(cg): Add support for more data types. dataType match { - case StringType => - val v = value.asInstanceOf[UTF8String] - val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}" - s""" - final boolean ${ev.nullTerm} = false; - ${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr}); - """ case FloatType => // This must go before NumericType s""" final boolean ${ev.nullTerm} = false; - float ${ev.primitiveTerm} = ${value}f; - """ - case dt: DecimalType => // This must go before NumericType - s""" - final boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(dt)} ${ev.primitiveTerm} = - new ${ctx.primitiveType(dt)}().set($value); + final float ${ev.primitiveTerm} = ${value}f; """ - case dt: NumericType => + case dt: NumericType if !dt.isInstanceOf[DecimalType]=> s""" final boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; + final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; """ + // eval() version may be faster for non-primitive types case other => super.genCode(ctx, ev) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index a0dae40d964e6..ef1c2bc5836e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -66,8 +66,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { case IntegerType | LongType => s""" boolean ${ev.nullTerm} = false; - ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = - new ${ctx.hashSetForType(elementType)}(); + ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}(); """ case _ => super.genCode(ctx, ev) } @@ -110,14 +109,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { case IntegerType | LongType => val itemEval = item.gen(ctx) val setEval = set.gen(ctx) - val htype = ctx.hashSetForType(elementType) + val htype = ctx.primitiveType(dataType) itemEval.code + setEval.code + s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); - } - boolean ${ev.nullTerm} = false; - ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; + if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { + (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); + } + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; """ case _ => super.genCode(ctx, ev) } @@ -163,7 +162,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres case IntegerType | LongType => val leftEval = left.gen(ctx) val rightEval = right.gen(ctx) - val htype = ctx.hashSetForType(elementType) + val htype = ctx.primitiveType(dataType) leftEval.code + rightEval.code + s""" boolean ${ev.nullTerm} = false; From 86fac2c6dc7add53b694ccaab6cf676d312a1da8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Jun 2015 13:42:37 -0700 Subject: [PATCH 11/18] fix style --- .../org/apache/spark/sql/catalyst/expressions/literals.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 86cdc7a6e914f..4f00cb6bec586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -94,7 +94,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres final boolean ${ev.nullTerm} = false; final float ${ev.primitiveTerm} = ${value}f; """ - case dt: NumericType if !dt.isInstanceOf[DecimalType]=> + case dt: NumericType if !dt.isInstanceOf[DecimalType] => s""" final boolean ${ev.nullTerm} = false; final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; From e03edaaf9c235c2acd86e97e8f6c4efb21487437 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Jun 2015 14:30:41 -0700 Subject: [PATCH 12/18] consts fold --- .../sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 11 +++++--- .../expressions/codegen/CodeGenerator.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 26 ++++++++++--------- .../catalyst/expressions/nullFunctions.scala | 14 +++++----- .../sql/catalyst/expressions/predicates.scala | 2 +- .../spark/sql/catalyst/expressions/sets.scala | 6 ++--- 7 files changed, 33 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2df6737adb42b..6866b1182e0da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -184,7 +184,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express if (!${ev.nullTerm}) { ${eval2.code} if(!${eval2.nullTerm}) { - ${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode); + ${ev.primitiveTerm} = $resultCode; } else { ${ev.nullTerm} = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a049f8878ed32..0923ab6f59564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -118,12 +118,15 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - if (left.dataType.isInstanceOf[DecimalType]) { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") - } else { + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + defineCodeGen(ctx, ev, (eval1, eval2) => + s"(${ctx.primitiveType(dataType)})($eval1 $symbol $eval2)") + case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } protected def evalInternal(evalE1: Any, evalE2: Any): Any = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c963971d28cb1..f6a2a2be1c89f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -40,7 +40,7 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not * valid if `nullTerm` is set to `true`. */ -case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, primitiveTerm: Term) +case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 4f00cb6bec586..e121d39e1d9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -82,23 +82,25 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def eval(input: Row): Any = value override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + // change the nullTerm and primitiveTerm to consts, to inline them if (value == null) { - s""" - final boolean ${ev.nullTerm} = true; - final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - """ + ev.nullTerm = "true" + ev.primitiveTerm = ctx.defaultValue(dataType) + "" } else { dataType match { + case BooleanType => + ev.nullTerm = "false" + ev.primitiveTerm = value.toString + "" case FloatType => // This must go before NumericType - s""" - final boolean ${ev.nullTerm} = false; - final float ${ev.primitiveTerm} = ${value}f; - """ + ev.nullTerm = "false" + ev.primitiveTerm = s"${value}f" + "" case dt: NumericType if !dt.isInstanceOf[DecimalType] => - s""" - final boolean ${ev.nullTerm} = false; - final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; - """ + ev.nullTerm = "false" + ev.primitiveTerm = value.toString + "" // eval() version may be faster for non-primitive types case other => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index e380eafc3fc2a..e3c3489d11aea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -83,10 +83,9 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) - eval.code + s""" - final boolean ${ev.nullTerm} = false; - final boolean ${ev.primitiveTerm} = ${eval.nullTerm}; - """ + ev.nullTerm = "false" + ev.primitiveTerm = eval.nullTerm + eval.code } override def toString: String = s"IS NULL $child" @@ -103,10 +102,9 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = !${eval.nullTerm}; - """ + ev.nullTerm = "false" + ev.primitiveTerm = s"(!(${eval.nullTerm}))" + eval.code } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 67cac26fd0d55..846fc9d90a86b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -304,8 +304,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) + ev.nullTerm = "false" eval1.code + eval2.code + s""" - final boolean ${ev.nullTerm} = false; final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) || (!${eval1.nullTerm} && $equalCode); """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index ef1c2bc5836e0..a0c81473ec050 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -64,8 +64,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { elementType match { case IntegerType | LongType => + ev.nullTerm = "false" s""" - boolean ${ev.nullTerm} = false; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}(); """ case _ => super.genCode(ctx, ev) @@ -111,11 +111,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { val setEval = set.gen(ctx) val htype = ctx.primitiveType(dataType) + ev.nullTerm = "false" itemEval.code + setEval.code + s""" if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); } - boolean ${ev.nullTerm} = false; ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; """ case _ => super.genCode(ctx, ev) @@ -164,8 +164,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightEval = right.gen(ctx) val htype = ctx.primitiveType(dataType) + ev.nullTerm = "false" leftEval.code + rightEval.code + s""" - boolean ${ev.nullTerm} = false; ${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm}; ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); """ From bad682819a2013eff8c9e3c9a8c106dc6f8b924f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Jun 2015 15:01:26 -0700 Subject: [PATCH 13/18] address comments --- .../expressions/decimalFunctions.scala | 6 +-- .../sql/catalyst/expressions/predicates.scala | 39 +++++++++++-------- .../spark/sql/catalyst/expressions/sets.scala | 14 +++---- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index f7df2340edb6c..21f8c812c9ce5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -38,11 +38,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.nullTerm} = ${eval.nullTerm}; - long ${ev.primitiveTerm} = ${ev.nullTerm} ? -1 : ${eval.primitiveTerm}.toUnscaledLong(); - """ + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 846fc9d90a86b..d69324acf0e5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -146,9 +146,12 @@ case class And(left: Expression, right: Expression) } } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) + + // The result should be `false`, if any of them is `false` whenever the other is null or not. s""" ${eval1.code} boolean ${ev.nullTerm} = false; @@ -192,20 +195,21 @@ case class Or(left: Expression, right: Expression) } } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) + + // The result should be `true`, if any of them is `true` whenever the other is null or not. s""" ${eval1.code} boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = false; + boolean ${ev.primitiveTerm} = true; if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - ${ev.primitiveTerm} = true; } else { ${eval2.code} if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - ${ev.primitiveTerm} = true; } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { ${ev.primitiveTerm} = false; } else { @@ -218,19 +222,6 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - left.dataType match { - case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { - (c1, c3) => s"$c1 $symbol $c3" - }) - case TimestampType => - // java.sql.Timestamp does not have compare() - super.genCode(ctx, ev) - case other => defineCodeGen (ctx, ev, { - (c1, c2) => s"$c1.compare($c2) $symbol 0" - }) - } - } override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { @@ -258,6 +249,20 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + left.dataType match { + case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" + }) + case TimestampType => + // java.sql.Timestamp does not have compare() + super.genCode(ctx, ev) + case other => defineCodeGen (ctx, ev, { + (c1, c2) => s"$c1.compare($c2) $symbol 0" + }) + } + } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryComparisons must override either eval or evalInternal") } @@ -389,9 +394,9 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.gen(ctx) s""" + ${condEval.code} boolean ${ev.nullTerm} = false; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - ${condEval.code} if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) { ${trueEval.code} ${ev.nullTerm} = ${trueEval.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index a0c81473ec050..40107c5985481 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -112,11 +112,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { val htype = ctx.primitiveType(dataType) ev.nullTerm = "false" + ev.primitiveTerm = setEval.primitiveTerm itemEval.code + setEval.code + s""" if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); } - ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; """ case _ => super.genCode(ctx, ev) } @@ -147,10 +147,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightValue = iterator.next() leftEval.add(rightValue) } - leftEval - } else { - null } + leftEval } else { null } @@ -164,10 +162,12 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightEval = right.gen(ctx) val htype = ctx.primitiveType(dataType) - ev.nullTerm = "false" + ev.nullTerm = leftEval.nullTerm + ev.primitiveTerm = leftEval.primitiveTerm leftEval.code + rightEval.code + s""" - ${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm}; - ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); + if (!${leftEval.nullTerm} && !${rightEval.nullTerm}) { + ${leftEval.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); + } """ case _ => super.genCode(ctx, ev) } From f42c732febd5fa720ab72dfa0fc442c847012b8c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 6 Jun 2015 00:23:31 -0700 Subject: [PATCH 14/18] improve coverage and tests --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 40 ++--- .../sql/catalyst/expressions/arithmetic.scala | 96 +++++++----- .../expressions/codegen/CodeGenerator.scala | 14 +- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateOrdering.scala | 16 +- .../codegen/GeneratePredicate.scala | 2 +- .../codegen/GenerateProjection.scala | 8 +- .../expressions/decimalFunctions.scala | 12 +- .../sql/catalyst/expressions/literals.scala | 44 ++++-- .../expressions/mathfuncs/binary.scala | 24 ++- .../expressions/mathfuncs/unary.scala | 30 +++- .../expressions/namedExpressions.scala | 6 +- .../catalyst/expressions/nullFunctions.scala | 26 ++-- .../sql/catalyst/expressions/predicates.scala | 139 ++++++++++++++---- .../spark/sql/catalyst/expressions/sets.scala | 20 +-- .../expressions/stringOperations.scala | 18 +++ .../ExpressionEvaluationSuite.scala | 86 ++++++++++- .../GeneratedEvaluationSuite.scala | 27 +--- .../GeneratedMutableEvaluationSuite.scala | 61 -------- 20 files changed, 440 insertions(+), 237 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 478ee997a96a2..00fd7294a8966 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -45,8 +45,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { s""" - final boolean ${ev.nullTerm} = i.isNullAt($ordinal); - final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? + boolean ${ev.isNull} = i.isNullAt($ordinal); + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6866b1182e0da..87f864a7c0d9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -60,9 +60,9 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - val nullTerm = ctx.freshName("nullTerm") - val primitiveTerm = ctx.freshName("primitiveTerm") - val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm) + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) ve } @@ -82,11 +82,11 @@ abstract class Expression extends TreeNode[Expression] { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - final Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); - final boolean ${ev.nullTerm} = ${objectTerm} == null; - ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)}; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${objectTerm}; + Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = ${objectTerm} == null; + ${ctx.primitiveType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(e.dataType)})${objectTerm}; } """ } @@ -175,18 +175,18 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + val resultCode = f(eval1.primitive, eval2.primitive) s""" ${eval1.code} - boolean ${ev.nullTerm} = ${eval1.nullTerm}; - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - if (!${ev.nullTerm}) { + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { ${eval2.code} - if(!${eval2.nullTerm}) { - ${ev.primitiveTerm} = $resultCode; + if(!${eval2.isNull}) { + ${ev.primitive} = $resultCode; } else { - ${ev.nullTerm} = true; + ${ev.isNull} = true; } } """ @@ -216,12 +216,12 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio ev: GeneratedExpressionCode, f: Term => Code): Code = { val eval = child.gen(ctx) - // reuse the previous nullTerm - ev.nullTerm = eval.nullTerm + // reuse the previous isNull + ev.isNull = eval.isNull eval.code + s""" - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${f(eval.primitive)}; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 0923ab6f59564..c161a514fcd4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -50,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") + case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") + } + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) } @@ -68,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { if (value < 0) null else math.sqrt(value) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + if (${eval.primitive} < 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive}); + } + } + """ + } } /** @@ -216,9 +236,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { - s"${eval2.primitiveTerm}.isZero()" + s"${eval2.primitive}.isZero()" } else { - s"${eval2.primitiveTerm} == 0" + s"${eval2.primitive} == 0" } val method = if (left.dataType.isInstanceOf[DecimalType]) { s".$decimalMethod" @@ -227,12 +247,12 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } eval1.code + eval2.code + s""" - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { - ${ev.nullTerm} = true; + boolean ${ev.isNull} = false; + ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; } else { - ${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); } """ } @@ -276,9 +296,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { - s"${eval2.primitiveTerm}.isZero()" + s"${eval2.primitive}.isZero()" } else { - s"${eval2.primitiveTerm} == 0" + s"${eval2.primitive} == 0" } val method = if (left.dataType.isInstanceOf[DecimalType]) { s".$decimalMethod" @@ -287,12 +307,12 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } eval1.code + eval2.code + s""" - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { - ${ev.nullTerm} = true; + boolean ${ev.isNull} = false; + ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; } else { - ${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); } """ } @@ -387,6 +407,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dataType)})~($c)") + } + protected override def evalInternal(evalE: Any) = not(evalE) } @@ -419,21 +443,21 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) eval1.code + eval2.code + s""" - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + boolean ${ev.isNull} = false; + ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.nullTerm}) { - ${ev.nullTerm} = ${eval2.nullTerm}; - ${ev.primitiveTerm} = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - ${ev.nullTerm} = ${eval1.nullTerm}; - ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; } else { - if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + if (${eval1.primitive} > ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; } else { - ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + ${ev.primitive} = ${eval2.primitive}; } } """ @@ -475,21 +499,21 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { val eval2 = right.gen(ctx) eval1.code + eval2.code + s""" - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = + boolean ${ev.isNull} = false; + ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.nullTerm}) { - ${ev.nullTerm} = ${eval2.nullTerm}; - ${ev.primitiveTerm} = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - ${ev.nullTerm} = ${eval1.nullTerm}; - ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; } else { - if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + if (${eval1.primitive} < ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; } else { - ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + ${ev.primitive} = ${eval2.primitive}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f6a2a2be1c89f..94b1b4808d759 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -35,12 +35,12 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * Java source for evaluating an [[Expression]] given a [[Row]] of input. * * @param code The sequence of statements required to evaluate the expression. - * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * @param isNull A term that holds a boolean value representing whether the expression evaluated * to null. - * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `true`. + * @param primitive A term for a possible primitive value of the result of the evaluation. Not + * valid if `isNull` is set to `true`. */ -case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term) +case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported @@ -149,9 +149,9 @@ class CodeGenContext { def defaultValue(dt: DataType): Term = dt match { case BooleanType => "false" case FloatType => "-1.0f" - case ShortType => "-1" - case LongType => "-1" - case ByteType => "-1" + case ShortType => "(short)-1" + case LongType => "-1L" + case ByteType => "(byte)-1" case DoubleType => "-1.0" case IntegerType => "-1" case DateType => "-1" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b641701008c3..e5ee2accd8a84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -40,10 +40,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val evaluationCode = e.gen(ctx) evaluationCode.code + s""" - if(${evaluationCode.nullTerm}) + if(${evaluationCode.isNull}) mutableRow.setNullAt($i); else - mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitiveTerm)}; + mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index d3c219fddc53c..36e155d164a40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -59,8 +59,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit case BinaryType => s""" { - byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm}; - byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm}; + byte[] x = ${if (asc) evalA.primitive else evalB.primitive}; + byte[] y = ${if (!asc) evalB.primitive else evalA.primitive}; int j = 0; while (j < x.length && j < y.length) { if (x[j] != y[j]) return x[j] - y[j]; @@ -73,8 +73,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit }""" case _: NumericType => s""" - if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) { - if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) { + if (${evalA.primitive} != ${evalB.primitive}) { + if (${evalA.primitive} > ${evalB.primitive}) { return ${if (asc) "1" else "-1"}; } else { return ${if (asc) "-1" else "1"}; @@ -82,7 +82,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit }""" case _ => s""" - int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm}); + int comp = ${evalA.primitive}.compare(${evalB.primitive}); if (comp != 0) { return ${if (asc) "comp" else "-comp"}; }""" @@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit ${evalA.code} i = $b; ${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { + if (${evalA.isNull} && ${evalB.isNull}) { // Nothing - } else if (${evalA.nullTerm}) { + } else if (${evalA.isNull}) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.nullTerm}) { + } else if (${evalB.isNull}) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { $compare diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dd4474de05df9..4a547b5ce9543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { @Override public boolean eval(Row i) { ${eval.code} - return !${eval.nullTerm} && ${eval.primitiveTerm}; + return !${eval.isNull} && ${eval.primitive}; } }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 00c856dc02ba1..f621c894833c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -55,9 +55,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { { // column$i ${eval.code} - nullBits[$i] = ${eval.nullTerm}; - if(!${eval.nullTerm}) { - c$i = ${eval.primitiveTerm}; + nullBits[$i] = ${eval.isNull}; + if (!${eval.isNull}) { + c$i = ${eval.primitive}; } } """ @@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case LongType => s"$col ^ ($col >>> 32)" case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)" + s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 21f8c812c9ce5..ddfadf314f838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -62,13 +62,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) eval.code + s""" - boolean ${ev.nullTerm} = ${eval.nullTerm}; - ${ctx.decimalType} ${ev.primitiveTerm} = null; + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.decimalType} ${ev.primitive} = null; - if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = (new ${ctx.decimalType}()).setOrNull( - ${eval.primitiveTerm}, $precision, $scale); - ${ev.nullTerm} = ${ev.primitiveTerm} == null; + if (!${ev.isNull}) { + ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull( + ${eval.primitive}, $precision, $scale); + ${ev.isNull} = ${ev.primitive} == null; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e121d39e1d9b4..bce96bd3c1309 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -79,27 +79,53 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" + override def equals(other: Any): Boolean = other match { + case o: Literal => + dataType.equals(o.dataType) && + (value == null && null == o.value || value != null && value.equals(o.value)) + case _ => false + } + override def eval(input: Row): Any = value override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - // change the nullTerm and primitiveTerm to consts, to inline them + // change the isNull and primitive to consts, to inline them if (value == null) { - ev.nullTerm = "true" - ev.primitiveTerm = ctx.defaultValue(dataType) + ev.isNull = "true" + ev.primitive = ctx.defaultValue(dataType) "" } else { dataType match { case BooleanType => - ev.nullTerm = "false" - ev.primitiveTerm = value.toString + ev.isNull = "false" + ev.primitive = value.toString "" case FloatType => // This must go before NumericType - ev.nullTerm = "false" - ev.primitiveTerm = s"${value}f" + val v = value.asInstanceOf[Float] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}f" + "" + } + case DoubleType => // This must go before NumericType + val v = value.asInstanceOf[Double] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}" + "" + } + + case ByteType | ShortType => // This must go before NumericType + ev.isNull = "false" + ev.primitive = s"(${ctx.primitiveType(dataType)})$value" "" case dt: NumericType if !dt.isInstanceOf[DecimalType] => - ev.nullTerm = "false" - ev.primitiveTerm = value.toString + ev.isNull = "false" + ev.primitive = value.toString "" // eval() version may be faster for non-primitive types case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index db853a2b97fad..88211acd7713c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} import org.apache.spark.sql.types._ @@ -49,6 +50,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + } } case class Atan2(left: Expression, right: Expression) @@ -70,9 +75,26 @@ case class Atan2(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } } case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 41b422346a02d..ad49c376e981e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} import org.apache.spark.sql.types._ @@ -44,6 +45,23 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) if (result.isNaN) null else result } } + + // name of function in java.lang.Math + def funcName: String = name.toLowerCase + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } } case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") @@ -72,7 +90,9 @@ case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") -case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") +case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { + override def funcName: String = "rint" +} case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @@ -84,6 +104,10 @@ case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") -case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") +case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { + override def funcName: String = "toDegrees" +} -case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") +case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { + override def funcName: String = "toRadians" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00565ec651a59..2e4b9ba678433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ object NamedExpression { @@ -116,6 +116,8 @@ case class Alias(child: Expression, name: String)( override def eval(input: Row): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable override def metadata: Metadata = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index e3c3489d11aea..ea216b1d0d9f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -55,17 +55,17 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { s""" - boolean ${ev.nullTerm} = true; - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; + boolean ${ev.isNull} = true; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; """ + children.map { e => val eval = e.gen(ctx) s""" - if (${ev.nullTerm}) { + if (${ev.isNull}) { ${eval.code} - if (!${eval.nullTerm}) { - ${ev.nullTerm} = false; - ${ev.primitiveTerm} = ${eval.primitiveTerm}; + if (!${eval.isNull}) { + ${ev.isNull} = false; + ${ev.primitive} = ${eval.primitive}; } } """ @@ -83,8 +83,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) - ev.nullTerm = "false" - ev.primitiveTerm = eval.nullTerm + ev.isNull = "false" + ev.primitive = eval.isNull eval.code } @@ -102,8 +102,8 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval = child.gen(ctx) - ev.nullTerm = "false" - ev.primitiveTerm = s"(!(${eval.nullTerm}))" + ev.isNull = "false" + ev.primitive = s"(!(${eval.isNull}))" eval.code } } @@ -137,7 +137,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate s""" if ($nonnull < $n) { ${eval.code} - if (!${eval.nullTerm}) { + if (!${eval.isNull}) { $nonnull += 1; } } @@ -146,8 +146,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate s""" int $nonnull = 0; $code - boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = $nonnull >= $n; + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $nonnull >= $n; """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d69324acf0e5a..75af8d71dbd31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -154,17 +154,17 @@ case class And(left: Expression, right: Expression) // The result should be `false`, if any of them is `false` whenever the other is null or not. s""" ${eval1.code} - boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = false; + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; - if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { + if (!${eval1.isNull} && !${eval1.primitive}) { } else { ${eval2.code} - if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - ${ev.primitiveTerm} = true; + if (!${eval2.isNull} && !${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = true; } else { - ${ev.nullTerm} = true; + ${ev.isNull} = true; } } """ @@ -203,17 +203,17 @@ case class Or(left: Expression, right: Expression) // The result should be `true`, if any of them is `true` whenever the other is null or not. s""" ${eval1.code} - boolean ${ev.nullTerm} = false; - boolean ${ev.primitiveTerm} = true; + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = true; - if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { + if (!${eval1.isNull} && ${eval1.primitive}) { } else { ${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - ${ev.primitiveTerm} = false; + if (!${eval2.isNull} && ${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = false; } else { - ${ev.nullTerm} = true; + ${ev.isNull} = true; } } """ @@ -308,11 +308,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) - ev.nullTerm = "false" + val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) + ev.isNull = "false" eval1.code + eval2.code + s""" - final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) || - (!${eval1.nullTerm} && $equalCode); + boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || + (!${eval1.isNull} && $equalCode); """ } } @@ -388,6 +388,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi falseValue.eval(input) } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) @@ -395,16 +396,16 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" ${condEval.code} - boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; - if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + boolean ${ev.isNull} = false; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.primitive}) { ${trueEval.code} - ${ev.nullTerm} = ${trueEval.nullTerm}; - ${ev.primitiveTerm} = ${trueEval.primitiveTerm}; + ${ev.isNull} = ${trueEval.isNull}; + ${ev.primitive} = ${trueEval.primitive}; } else { ${falseEval.code} - ${ev.nullTerm} = ${falseEval.nullTerm}; - ${ev.primitiveTerm} = ${falseEval.primitiveTerm}; + ${ev.isNull} = ${falseEval.isNull}; + ${ev.primitive} = ${falseEval.primitive}; } """ } @@ -493,6 +494,48 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { return res } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${cond.primitive}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + $cases + $other + """ + } + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" @@ -544,6 +587,52 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW return res } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val keyEval = key.gen(ctx) + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (${keyEval.isNull} && ${cond.isNull} || + !${keyEval.isNull} && !${cond.isNull} + && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${keyEval.code} + $cases + $other + """ + } + private def equalNullSafe(l: Any, r: Any) = { if (l == null && r == null) { true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 40107c5985481..1038e7a653358 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -64,9 +64,9 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { elementType match { case IntegerType | LongType => - ev.nullTerm = "false" + ev.isNull = "false" s""" - ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}(); + ${ctx.primitiveType(dataType)} ${ev.primitive} = new ${ctx.primitiveType(dataType)}(); """ case _ => super.genCode(ctx, ev) } @@ -111,11 +111,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { val setEval = set.gen(ctx) val htype = ctx.primitiveType(dataType) - ev.nullTerm = "false" - ev.primitiveTerm = setEval.primitiveTerm + ev.isNull = "false" + ev.primitive = setEval.primitive itemEval.code + setEval.code + s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); + if (!${itemEval.isNull} && !${setEval.isNull}) { + (($htype)${setEval.primitive}).add(${itemEval.primitive}); } """ case _ => super.genCode(ctx, ev) @@ -162,11 +162,11 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightEval = right.gen(ctx) val htype = ctx.primitiveType(dataType) - ev.nullTerm = leftEval.nullTerm - ev.primitiveTerm = leftEval.primitiveTerm + ev.isNull = leftEval.isNull + ev.primitive = leftEval.primitive leftEval.code + rightEval.code + s""" - if (!${leftEval.nullTerm} && !${rightEval.nullTerm}) { - ${leftEval.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); + if (!${leftEval.isNull} && !${rightEval.isNull}) { + ${leftEval.primitive}.union((${htype})${rightEval.primitive}); } """ case _ => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index c4ef9c30907f1..78adb509b470b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ trait StringRegexExpression extends ExpectsInputTypes { @@ -137,6 +138,10 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") + } } /** @@ -147,6 +152,10 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") + } } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -181,6 +190,9 @@ trait StringComparison extends ExpectsInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") + } } /** @@ -189,6 +201,9 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") + } } /** @@ -197,6 +212,9 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5df528770ca6e..bc29f80dede19 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -35,11 +36,20 @@ import org.apache.spark.sql.types._ class ExpressionEvaluationBaseSuite extends SparkFunSuite { + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) + checkEvaluationWithGeneratedProjection(expression, expected, inputRow) + } + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { expression.eval(inputRow) } - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + def checkEvaluationWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -49,6 +59,68 @@ class ExpressionEvaluationBaseSuite extends SparkFunSuite { } } + def checkEvaluationWithGeneratedMutableProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkEvaluationWithGeneratedProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |Expressions: ${expression} + |Code: ${evaluated} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + def checkDoubleEvaluation( expression: Expression, expected: Spread[Double], @@ -69,8 +141,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("literals") { checkEvaluation(Literal(1), 1) checkEvaluation(Literal(true), true) + checkEvaluation(Literal(false), false) checkEvaluation(Literal(0L), 0L) + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { + d => { + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) + } + } checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal(1) + Literal(1), 2) } @@ -1367,6 +1447,10 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { // TODO: Make the tests work with codegen. class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { + override def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow) = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + } + test("CreateStruct") { val row = Row(1, 2, 3) val c1 = 'a.int.at(0).as("a") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index b577de1d5aab9..371a73181dad7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -21,34 +21,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ /** - * Overrides our expression evaluation tests to use code generation for evaluation. + * Additional tests for code generation. */ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = expression.gen(ctx) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow).apply(0) - if (actual != expected) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - test("multithreaded eval") { import scala.concurrent._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala deleted file mode 100644 index 9da72521ec3ec..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen._ - -/** - * Overrides our expression evaluation tests to use generated code on mutable rows. - */ -class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = expression.gen(ctx) - - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code} - """.stripMargin) - } - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } -} From 9adaeaf9963afa42c0c9aa2ddf562df5d6a8d9ba Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 6 Jun 2015 00:38:28 -0700 Subject: [PATCH 15/18] address comments --- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../sql/catalyst/expressions/Expression.scala | 9 ++++----- .../sql/catalyst/expressions/arithmetic.scala | 14 +++++++------- .../expressions/codegen/CodeGenerator.scala | 16 ++++++++-------- .../expressions/codegen/GenerateProjection.scala | 6 +++--- .../sql/catalyst/expressions/literals.scala | 2 +- .../catalyst/expressions/mathfuncs/unary.scala | 2 +- .../sql/catalyst/expressions/nullFunctions.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 6 +++--- .../spark/sql/catalyst/expressions/sets.scala | 6 +++--- 11 files changed, 34 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 00fd7294a8966..005de3166095f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -46,7 +46,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { s""" boolean ${ev.isNull} = i.isNullAt($ordinal); - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ev.isNull} ? + ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a6f805da242ab..5f76a512679a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -458,7 +458,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w super.genCode(ctx, ev) case (BooleanType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)") + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") case (dt: DecimalType, BooleanType) => defineCodeGen(ctx, ev, c => s"$c.isZero()") case (dt: NumericType, BooleanType) => @@ -469,7 +469,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (_: DecimalType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") case (_: NumericType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") case other => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 87f864a7c0d9c..ba489f7dde59d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -77,14 +77,13 @@ abstract class Expression extends TreeNode[Expression] { * @return Java source code */ protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - val e = this.asInstanceOf[Expression] - ctx.references += e + ctx.references += this val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); boolean ${ev.isNull} = ${objectTerm} == null; - ${ctx.primitiveType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)}; + ${ctx.javaType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)}; if (!${ev.isNull}) { ${ev.primitive} = (${ctx.boxedType(e.dataType)})${objectTerm}; } @@ -180,7 +179,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${eval2.code} if(!${eval2.isNull}) { @@ -219,7 +218,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio // reuse the previous isNull ev.isNull = eval.isNull eval.code + s""" - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.primitive} = ${f(eval.primitive)}; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c161a514fcd4e..c983898660c0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -78,7 +78,7 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { if (${eval.primitive} < 0.0) { ${ev.isNull} = true; @@ -144,7 +144,7 @@ abstract class BinaryArithmetic extends BinaryExpression { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => - s"(${ctx.primitiveType(dataType)})($eval1 $symbol $eval2)") + s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -248,7 +248,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { @@ -308,7 +308,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { @@ -408,7 +408,7 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dataType)})~($c)") + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)") } protected override def evalInternal(evalE: Any) = not(evalE) @@ -444,7 +444,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { val eval2 = right.gen(ctx) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitive} = + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; if (${eval1.isNull}) { @@ -500,7 +500,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitive} = + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; if (${eval1.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 94b1b4808d759..c8d0aaf79f5f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -54,8 +54,8 @@ class CodeGenContext { */ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - val stringType = classOf[UTF8String].getName - val decimalType = classOf[Decimal].getName + val stringType: String = classOf[UTF8String].getName + val decimalType: String = classOf[Decimal].getName private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -108,9 +108,9 @@ class CodeGenContext { } /** - * Return the primitive type for a DataType + * Return the Java type for a DataType */ - def primitiveType(dt: DataType): Term = dt match { + def javaType(dt: DataType): Term = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -140,7 +140,7 @@ class CodeGenContext { case FloatType => "Float" case BooleanType => "Boolean" case DateType => "Integer" - case _ => primitiveType(dt) + case _ => javaType(dt) } /** @@ -189,9 +189,9 @@ class CodeGenContext { */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val exprType = classOf[Expression].getName - protected val mutableRowType = classOf[MutableRow].getName - protected val genericMutableRowType = classOf[GenericMutableRow].getName + protected val exprType: String = classOf[Expression].getName + protected val mutableRowType: String = classOf[MutableRow].getName + protected val genericMutableRowType: String = classOf[GenericMutableRow].getName /** * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index f621c894833c9..7caf4aaab88bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -45,7 +45,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val ctx = newCodeGenContext() val columns = expressions.zipWithIndex.map { case (e, i) => - s"private ${ctx.primitiveType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" + s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" }.mkString("\n ") val initColumns = expressions.zipWithIndex.map { @@ -80,7 +80,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public ${ctx.primitiveType(dataType)} ${ctx.accessorForType(dataType)}(int i) { + public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) { if (isNullAt(i)) { return ${ctx.defaultValue(dataType)}; } @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveType(dataType)} value) { + public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) { nullBits[i] = false; switch (i) { $cases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index bce96bd3c1309..3a9271678bc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -121,7 +121,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres case ByteType | ShortType => // This must go before NumericType ev.isNull = "false" - ev.primitive = s"(${ctx.primitiveType(dataType)})$value" + ev.primitive = s"(${ctx.javaType(dataType)})$value" "" case dt: NumericType if !dt.isInstanceOf[DecimalType] => ev.isNull = "false" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index ad49c376e981e..5563cd94bf86d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -53,7 +53,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) val eval = child.gen(ctx) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); if (Double.valueOf(${ev.primitive}).isNaN()) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index ea216b1d0d9f4..9ecfb3ccc262f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -56,7 +56,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { s""" boolean ${ev.isNull} = true; - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; """ + children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 75af8d71dbd31..57486640c90f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -397,7 +397,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" ${condEval.code} boolean ${ev.isNull} = false; - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${condEval.isNull} && ${condEval.primitive}) { ${trueEval.code} ${ev.isNull} = ${trueEval.isNull}; @@ -530,7 +530,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { s""" boolean $got = false; boolean ${ev.isNull} = true; - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; $cases $other """ @@ -626,7 +626,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW s""" boolean $got = false; boolean ${ev.isNull} = true; - ${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; ${keyEval.code} $cases $other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 1038e7a653358..b39349b988389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -66,7 +66,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { case IntegerType | LongType => ev.isNull = "false" s""" - ${ctx.primitiveType(dataType)} ${ev.primitive} = new ${ctx.primitiveType(dataType)}(); + ${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}(); """ case _ => super.genCode(ctx, ev) } @@ -109,7 +109,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { case IntegerType | LongType => val itemEval = item.gen(ctx) val setEval = set.gen(ctx) - val htype = ctx.primitiveType(dataType) + val htype = ctx.javaType(dataType) ev.isNull = "false" ev.primitive = setEval.primitive @@ -160,7 +160,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres case IntegerType | LongType => val leftEval = left.gen(ctx) val rightEval = right.gen(ctx) - val htype = ctx.primitiveType(dataType) + val htype = ctx.javaType(dataType) ev.isNull = leftEval.isNull ev.primitive = leftEval.primitive From 19d643580ba45211ea21248d86b9c59d0ffa4c7f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 6 Jun 2015 23:13:30 -0700 Subject: [PATCH 16/18] Fixed style violation. --- .../sql/catalyst/expressions/ExpressionEvaluationSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index bc29f80dede19..eea2edc323eea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1447,7 +1447,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { // TODO: Make the tests work with codegen. class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { - override def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow) = { + override def checkEvaluation( + expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { checkEvaluationWithoutCodegen(expression, expected, inputRow) } From 73db80ee0cc4f7f5ab8aaaceeb527e2ea81a2d04 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 7 Jun 2015 01:13:56 -0700 Subject: [PATCH 17/18] Fixed compilation failure. --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ba489f7dde59d..92cbb79a2a3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -83,9 +83,9 @@ abstract class Expression extends TreeNode[Expression] { /* expression: ${this} */ Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); boolean ${ev.isNull} = ${objectTerm} == null; - ${ctx.javaType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)}; + ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = (${ctx.boxedType(e.dataType)})${objectTerm}; + ${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm}; } """ } From e1368c2ceafb6d7638cf6eed5ede30a9025df913 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 7 Jun 2015 12:11:04 -0700 Subject: [PATCH 18/18] Fixed tests. --- .../spark/sql/parquet/ParquetPartitionDiscoverySuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 3b29979452ad9..2df178dac1b51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -53,7 +53,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { check("10", Literal.create(10, IntegerType)) check("1000000000000000", Literal.create(1000000000000000L, LongType)) - check("1.5", Literal.create(1.5, FloatType)) + check("1.5", Literal.create(1.5f, FloatType)) check("hello", Literal.create("hello", StringType)) check(defaultPartitionName, Literal.create(null, NullType)) } @@ -83,13 +83,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { ArrayBuffer( Literal.create(10, IntegerType), Literal.create("hello", StringType), - Literal.create(1.5, FloatType))) + Literal.create(1.5f, FloatType))) }) check("file://path/a=10/b_hello/c=1.5", Some { PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal.create(1.5, FloatType))) + ArrayBuffer(Literal.create(1.5f, FloatType))) }) check("file:///", None)