Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14637][SQL] object expressions cleanup #12399

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,33 +64,29 @@ case class StaticInvoke(
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")

if (propagateNull) {
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}

val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
ev.copy(code = s"""
${argGen.map(_.code).mkString("\n")}

boolean ${ev.isNull} = !$argsNonNull;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
val callFunc = s"$objectName.$functionName($argString)"

if ($argsNonNull) {
${ev.value} = $objectName.$functionName($argString);
$objNullCheck
}
""")
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
ev.copy(code = s"""
${argGen.map(_.code).mkString("\n")}
s"boolean ${ev.isNull} = false;"
}

$javaType ${ev.value} = $objectName.$functionName($argString);
final boolean ${ev.isNull} = ${ev.value} == null;
""")
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}

val code = s"""
${argGen.map(_.code).mkString("\n")}
$setIsNull
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
$postNullCheck
"""
ev.copy(code = code)
}
}

Expand All @@ -111,7 +107,8 @@ case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression with NonSQLExpression {

override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject +: arguments
Expand All @@ -130,60 +127,53 @@ case class Invoke(
case _ => None
}

lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
case (IntegerType, "java.lang.Object") => (s: String) =>
s"((java.lang.Integer)$s).intValue()"
case (LongType, "java.lang.Object") => (s: String) =>
s"((java.lang.Long)$s).longValue()"
case (FloatType, "java.lang.Object") => (s: String) =>
s"((java.lang.Float)$s).floatValue()"
case (ShortType, "java.lang.Object") => (s: String) =>
s"((java.lang.Short)$s).shortValue()"
case (ByteType, "java.lang.Object") => (s: String) =>
s"((java.lang.Byte)$s).byteValue()"
case (DoubleType, "java.lang.Object") => (s: String) =>
s"((java.lang.Double)$s).doubleValue()"
case (BooleanType, "java.lang.Object") => (s: String) =>
s"((java.lang.Boolean)$s).booleanValue()"
case _ => identity[String] _
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.genCode(ctx)
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")

// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"boolean ${ev.isNull} = ${ev.value} == null;"
val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) {
s"${obj.value}.$functionName($argString)"
} else {
ev.isNull = obj.isNull
""
s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
}

val value = unboxer(s"${obj.value}.$functionName($argString)")
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"boolean ${ev.isNull} = ${obj.isNull};"
}

val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;"
} else {
s"""
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
try {
${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc;
} catch (Exception e) {
org.apache.spark.unsafe.Platform.throwException(e);
}
"""
}

ev.copy(code = s"""
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to do unboxing anymore here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, but I simplified it, see the callFunc above.

}

val code = s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
$setIsNull
$evaluate
$objNullCheck
""")
$postNullCheck
"""
ev.copy(code = code)
}

override def toString: String = s"$targetObject.$functionName"
Expand Down Expand Up @@ -246,39 +236,27 @@ case class NewInstance(

val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))

val setup =
s"""
${argGen.map(_.code).mkString("\n")}
${outer.map(_.code).getOrElse("")}
""".stripMargin
var isNull = ev.isNull
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
isNull = "false"
""
}

val constructorCall = outer.map { gen =>
s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
}.getOrElse {
s"new $className($argString)"
}

if (propagateNull && argGen.nonEmpty) {
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"

ev.copy(code = s"""
$setup

boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if ($argsNonNull) {
${ev.value} = $constructorCall;
${ev.isNull} = false;
}
""")
} else {
ev.copy(code = s"""
$setup

final $javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = false;
""")
}
val code = s"""
${argGen.map(_.code).mkString("\n")}
${outer.map(_.code).getOrElse("")}
$setIsNull
final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
"""
ev.copy(code = code, isNull = isNull)
}

override def toString: String = s"newInstance($cls)"
Expand Down Expand Up @@ -306,13 +284,14 @@ case class UnwrapOption(
val javaType = ctx.javaType(dataType)
val inputObject = child.genCode(ctx)

ev.copy(code = s"""
val code = s"""
${inputObject.code}

boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty();
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} =
${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get();
""")
${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
"""
ev.copy(code = code)
}
}

Expand All @@ -338,14 +317,14 @@ case class WrapOption(child: Expression, optType: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputObject = child.genCode(ctx)

ev.copy(code = s"""
val code = s"""
${inputObject.code}

boolean ${ev.isNull} = false;
scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
""")
"""
ev.copy(code = code, isNull = "false")
}
}

Expand Down Expand Up @@ -474,7 +453,7 @@ case class MapObjects private(
s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
}

ev.copy(code = s"""
val code = s"""
${genInputData.code}

boolean ${ev.isNull} = ${genInputData.value} == null;
Expand Down Expand Up @@ -504,7 +483,8 @@ case class MapObjects private(
${ev.isNull} = false;
${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
}
""")
"""
ev.copy(code = code)
}
}

Expand Down Expand Up @@ -539,14 +519,16 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
}
"""
}

val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
val schemaField = ctx.addReferenceObj("schema", schema)
ev.copy(code = s"""
boolean ${ev.isNull} = false;

val code = s"""
$values = new Object[${children.size}];
$childrenCode
final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
""")
"""
ev.copy(code = code, isNull = "false")
}
}

Expand Down Expand Up @@ -579,14 +561,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)

// Code to serialize.
val input = child.genCode(ctx)
ev.copy(code = s"""
val javaType = ctx.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"

val code = s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $serializer.serialize(${input.value}, null).array();
}
""")
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
"""
ev.copy(code = code, isNull = input.isNull)
}

override def dataType: DataType = BinaryType
Expand Down Expand Up @@ -617,17 +599,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
serializer,
s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")

// Code to serialize.
// Code to deserialize.
val input = child.genCode(ctx)
ev.copy(code = s"""
val javaType = ctx.javaType(dataType)
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"

val code = s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = (${ctx.javaType(dataType)})
$serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
}
""")
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
"""
ev.copy(code = code, isNull = input.isNull)
}

override def dataType: DataType = ObjectType(tag.runtimeClass)
Expand Down Expand Up @@ -658,15 +640,13 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
"""
}

ev.isNull = instanceGen.isNull
ev.value = instanceGen.value

ev.copy(code = s"""
val code = s"""
${instanceGen.code}
if (!${instanceGen.isNull}) {
${initialize.mkString("\n")}
}
""")
"""
ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
}
}

Expand Down Expand Up @@ -696,13 +676,15 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
val idx = ctx.references.length
ctx.references += errMsg
ExprCode(code = s"""
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)

val code = s"""
${childGen.code}

if (${childGen.isNull}) {
throw new RuntimeException((String) references[$idx]);
}""", isNull = "false", value = childGen.value)
throw new RuntimeException(this.$errMsgField);
}
"""
ev.copy(code = code, isNull = "false", value = childGen.value)
}
}