Skip to content

Commit

Permalink
[SPARK-16097][SQL] Encoders.tuple should handle null object correctly
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Although the top level input object can not be null, but when we use `Encoders.tuple` to combine 2 encoders, their input objects are not top level anymore and can be null. We should handle this case.

## How was this patch tested?

new test in DatasetSuite

Author: Wenchen Fan <wenchen@databricks.com>

Closes #13807 from cloud-fan/bug.
  • Loading branch information
cloud-fan authored and liancheng committed Jun 22, 2016
1 parent 39ad53f commit 01277d4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -110,16 +110,34 @@ object ExpressionEncoder {

val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

val serializer = encoders.map {
case e if e.flat => e.serializer.head
case other => CreateStruct(other.serializer)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t, _) =>
Invoke(
BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
t)
val serializer = encoders.zipWithIndex.map { case (enc, index) =>
val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head
val newInputObject = Invoke(
BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
originalInputObject.dataType)

val newSerializer = enc.serializer.map(_.transformUp {
case b: BoundReference if b == originalInputObject => newInputObject
})

if (enc.flat) {
newSerializer.head
} else {
// For non-flat encoder, the input object is not top level anymore after being combined to
// a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and
// null check to handle null case correctly.
// e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is
// not able to handle the case when the input tuple is null. This is not a problem as there
// is a check to make sure the input object won't be null. However, if this encoder is used
// to create a bigger tuple encoder, the original input object becomes a filed of the new
// input tuple and can be null. So instead of creating a struct directly here, we should add
// a null/None check and return a null struct if the null/None check fails.
val struct = CreateStruct(newSerializer)
val nullCheck = Or(
IsNull(newInputObject),
Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil))
If(nullCheck, Literal.create(null, struct.dataType), struct)
}
}

Expand Down Expand Up @@ -203,8 +221,12 @@ case class ExpressionEncoder[T](
// (intermediate value is not an attribute). We assume that all serializer expressions use a same
// `BoundReference` to refer to the object, and throw exception if they don't.
assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.")
assert(serializer.flatMap(_.collect { case b: BoundReference => b}).distinct.length <= 1,
"all serializer expressions must use the same BoundReference.")
assert(serializer.flatMap { ser =>
val boundRefs = ser.collect { case b: BoundReference => b }
assert(boundRefs.nonEmpty,
"each serializer expression should contains at least one `BoundReference`")
boundRefs
}.distinct.length <= 1, "all serializer expressions must use the same BoundReference.")

/**
* Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ds.dropDuplicates("_1", "_2"),
("a", 1), ("a", 2), ("b", 1))
}

test("SPARK-16097: Encoders.tuple should handle null object correctly") {
val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING)
val data = Seq((("a", "b"), "c"), (null, "d"))
val ds = spark.createDataset(data)(enc)
checkDataset(ds, (("a", "b"), "c"), (null, "d"))
}
}

case class Generic[T](id: T, value: Double)
Expand Down

0 comments on commit 01277d4

Please sign in to comment.