Skip to content

Commit

Permalink
object expressions cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Apr 14, 2016
1 parent b481940 commit 43644ef
Showing 1 changed file with 73 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,33 +64,28 @@ case class StaticInvoke(
val argGen = arguments.map(_.gen(ctx))
val argString = argGen.map(_.value).mkString(", ")

if (propagateNull) {
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}

val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
s"""
${argGen.map(_.code).mkString("\n")}
val callFunc = s"$objectName.$functionName($argString)"

boolean ${ev.isNull} = !$argsNonNull;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};

if ($argsNonNull) {
${ev.value} = $objectName.$functionName($argString);
$objNullCheck
}
"""
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"""
${argGen.map(_.code).mkString("\n")}
s"boolean ${ev.isNull} = false;"
}

$javaType ${ev.value} = $objectName.$functionName($argString);
final boolean ${ev.isNull} = ${ev.value} == null;
"""
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}

s"""
${argGen.map(_.code).mkString("\n")}
$setIsNull
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
$postNullCheck
"""
}
}

Expand All @@ -111,7 +106,8 @@ case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression with NonSQLExpression {

override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject +: arguments
Expand All @@ -130,60 +126,52 @@ case class Invoke(
case _ => None
}

lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
case (IntegerType, "java.lang.Object") => (s: String) =>
s"((java.lang.Integer)$s).intValue()"
case (LongType, "java.lang.Object") => (s: String) =>
s"((java.lang.Long)$s).longValue()"
case (FloatType, "java.lang.Object") => (s: String) =>
s"((java.lang.Float)$s).floatValue()"
case (ShortType, "java.lang.Object") => (s: String) =>
s"((java.lang.Short)$s).shortValue()"
case (ByteType, "java.lang.Object") => (s: String) =>
s"((java.lang.Byte)$s).byteValue()"
case (DoubleType, "java.lang.Object") => (s: String) =>
s"((java.lang.Double)$s).doubleValue()"
case (BooleanType, "java.lang.Object") => (s: String) =>
s"((java.lang.Boolean)$s).booleanValue()"
case _ => identity[String] _
}

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.gen(ctx)
val argGen = arguments.map(_.gen(ctx))
val argString = argGen.map(_.value).mkString(", ")

// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"boolean ${ev.isNull} = ${ev.value} == null;"
val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) {
s"${obj.value}.$functionName($argString)"
} else {
ev.isNull = obj.isNull
""
s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
}

val value = unboxer(s"${obj.value}.$functionName($argString)")
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"boolean ${ev.isNull} = ${obj.isNull};"
}

val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;"
} else {
s"""
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
try {
${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc;
} catch (Exception e) {
org.apache.spark.unsafe.Platform.throwException(e);
}
"""
}

// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}

s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
$setIsNull
$evaluate
$objNullCheck
"""
$postNullCheck
"""
}

override def toString: String = s"$targetObject.$functionName"
Expand Down Expand Up @@ -246,39 +234,25 @@ case class NewInstance(

val outer = outerPointer.map(func => Literal.fromObject(func()).gen(ctx))

val setup =
s"""
${argGen.map(_.code).mkString("\n")}
${outer.map(_.code).getOrElse("")}
""".stripMargin
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"final boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
ev.isNull = "false"
""
}

val constructorCall = outer.map { gen =>
s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
}.getOrElse {
s"new $className($argString)"
}

if (propagateNull && argGen.nonEmpty) {
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"

s"""
$setup

boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if ($argsNonNull) {
${ev.value} = $constructorCall;
${ev.isNull} = false;
}
"""
} else {
s"""
$setup

final $javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = false;
"""
}
s"""
${argGen.map(_.code).mkString("\n")}
${outer.map(_.code).getOrElse("")}
$setIsNull
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
"""
}

override def toString: String = s"newInstance($cls)"
Expand Down Expand Up @@ -309,9 +283,9 @@ case class UnwrapOption(
s"""
${inputObject.code}

boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty();
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} =
${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get();
${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
"""
}
}
Expand All @@ -336,11 +310,11 @@ case class WrapOption(child: Expression, optType: DataType)

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val inputObject = child.gen(ctx)
ev.isNull = "false"

s"""
${inputObject.code}

boolean ${ev.isNull} = false;
scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
Expand Down Expand Up @@ -538,10 +512,12 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
}
"""
}

val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
val schemaField = ctx.addReferenceObj("schema", schema)
ev.isNull = "false"

s"""
boolean ${ev.isNull} = false;
$values = new Object[${children.size}];
$childrenCode
final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
Expand Down Expand Up @@ -577,13 +553,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)

// Code to serialize.
val input = child.gen(ctx)
ev.isNull = input.isNull
val javaType = ctx.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"

s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $serializer.serialize(${input.value}, null).array();
}
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
"""
}

Expand Down Expand Up @@ -614,16 +590,16 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
serializer,
s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")

// Code to serialize.
// Code to deserialize.
val input = child.gen(ctx)
ev.isNull = input.isNull
val javaType = ctx.javaType(dataType)
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"

s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = (${ctx.javaType(dataType)})
$serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
}
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
"""
}

Expand Down Expand Up @@ -693,8 +669,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
val idx = ctx.references.length
ctx.references += errMsg
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)

ev.isNull = "false"
ev.value = childGen.value
Expand All @@ -703,7 +678,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
${childGen.code}

if (${childGen.isNull}) {
throw new RuntimeException((String) references[$idx]);
throw new RuntimeException(this.$errMsgField);
}
"""
}
Expand Down

0 comments on commit 43644ef

Please sign in to comment.