Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21255][SQL][WIP] Fixed NPE when creating encoder for enum #18488

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public ExpressionInfo(
assert name != null;
assert arguments != null;
assert examples != null;
assert examples.isEmpty() || examples.startsWith("\n Examples:");
assert examples.isEmpty() || examples.startsWith(System.lineSeparator() + " Examples:");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this one is not related?

Copy link
Contributor Author

@mike0sv mike0sv Aug 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but without this it's not possible to run tests if you have different line separators (on windows for example)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we support Windows for dev. This assertion should probably be weakened anyway but that's a separate issue from this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I got rid of it

assert note != null;
assert since != null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Type-inference utilities for POJOs and Java collections.
Expand Down Expand Up @@ -118,6 +119,10 @@ object JavaTypeInference {
val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
(MapType(keyDataType, valueDataType, nullable), true)

case other if other.isEnum =>
(StructType(Seq(StructField(typeToken.getRawType.getSimpleName,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we map enum to struct type? shouldn't enum always have a single field?

StringType, nullable = false))), true)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use struct type with string field to store enum type and it's value

case other =>
if (seenTypeSet.contains(other)) {
throw new UnsupportedOperationException(
Expand All @@ -140,6 +145,7 @@ object JavaTypeInference {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
.filterNot(_.getName == "declaringClass")
.filter(_.getReadMethod != null)
}

Expand Down Expand Up @@ -303,6 +309,11 @@ object JavaTypeInference {
keyData :: valueData :: Nil,
returnNullable = false)

case other if other.isEnum =>
StaticInvoke(JavaTypeInference.getClass, ObjectType(other), "deserializeEnumName",
expressions.Literal.create(other.getEnumConstants.apply(0), ObjectType(other))
:: getPath :: Nil)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we pass literal value of first enum constant to resolve type parameter of deserializeEnumName method

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
Expand Down Expand Up @@ -345,6 +356,30 @@ object JavaTypeInference {
}
}

/** Returns a mapping from enum value to int for given enum type */
def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = {
assert(enum.isEnum)
inputObject: T =>
UTF8String.fromString(inputObject.name())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use enum constant name as field value

}

/** Returns value index for given enum type and value */
def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): UTF8String = {
enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utils.classForName delegates to Class.forName, which operates on native level, so additional optimizations like caching are not required


/** Returns a mapping from int to enum value for given enum type */
def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = {
assert(enum.isEnum)
value: InternalRow =>
Enum.valueOf(enum, value.getUTF8String(0).toString)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enum.valueOf uses cached string->value map

}

/** Returns enum value for given enum type and value index */
def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: InternalRow): T = {
enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject)
}

private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {

def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
Expand Down Expand Up @@ -429,6 +464,11 @@ object JavaTypeInference {
valueNullable = true
)

case other if other.isEnum =>
CreateNamedStruct(expressions.Literal("enum") ::
StaticInvoke(JavaTypeInference.getClass, StringType, "serializeEnumName",
expressions.Literal.create(other.getName, StringType) :: inputObject :: Nil) :: Nil)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we pass enum class name via literal to serializer

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, StringType, StructField, StructType}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -81,9 +81,19 @@ object ExpressionEncoder {
ClassTag[T](cls))
}

def javaEnumSchema[T](beanClass: Class[T]): DataType = {
StructType(Seq(StructField("enum",
StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable = false))),
nullable = false)))
}

// TODO: improve error message for java bean encoder.
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
val schema = JavaTypeInference.inferDataType(beanClass)._1
val schema = if (beanClass.isEnum) {
javaEnumSchema(beanClass)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use enum as top level object, we need another level of structType for it to be compatible with our ser/de structure

} else {
JavaTypeInference.inferDataType(beanClass)._1
}
assert(schema.isInstanceOf[StructType])

val serializer = JavaTypeInference.serializerFor(beanClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ case class StaticInvoke(
val evaluate = if (returnNullable) {
if (ctx.defaultValue(dataType) == "null") {
s"""
${ev.value} = $callFunc;
${ev.value} = (($javaType) ($callFunc));
Copy link
Contributor Author

@mike0sv mike0sv Aug 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicitly cast value to needed type, because without this generated code didn't compile with something like "cannot assign value of type Enum to %RealEnumClassName%"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from janino documentation: "Type arguments: Are parsed, but otherwise ignored. The most significant restriction that follows is that you must cast return values from method invocations, e.g. "(String) myMap.get(key)"

${ev.isNull} = ${ev.value} == null;
"""
} else {
val boxedResult = ctx.freshName("boxedResult")
s"""
${ctx.boxedType(dataType)} $boxedResult = $callFunc;
${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc));
${ev.isNull} = $boxedResult == null;
if (!${ev.isNull}) {
${ev.value} = $boxedResult;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,83 @@ public void test() {
ds.collectAsList();
}

public enum EnumBean {
A("www.elgoog.com"),
B("www.google.com");

private String url;

EnumBean(String url) {
this.url = url;
}

public String getUrl() {
return url;
}

public void setUrl(String url) {
this.url = url;
}
}

@Test
public void testEnum() {
List<EnumBean> data = Arrays.asList(EnumBean.B);
Encoder<EnumBean> encoder = Encoders.bean(EnumBean.class);
Dataset<EnumBean> ds = spark.createDataset(data, encoder);
Assert.assertEquals(ds.collectAsList(), data);
}

public static class BeanWithEnum {
EnumBean enumField;
String regularField;

public String getRegularField() {
return regularField;
}

public void setRegularField(String regularField) {
this.regularField = regularField;
}

public EnumBean getEnumField() {
return enumField;
}

public void setEnumField(EnumBean field) {
this.enumField = field;
}

public BeanWithEnum(EnumBean enumField, String regularField) {
this.enumField = enumField;
this.regularField = regularField;
}

public BeanWithEnum() {
}

public String toString() {
return "BeanWithEnum(" + enumField + ", " + regularField + ")";
}

public boolean equals(Object other) {
if (other instanceof BeanWithEnum) {
BeanWithEnum beanWithEnum = (BeanWithEnum) other;
return beanWithEnum.regularField.equals(regularField) && beanWithEnum.enumField.equals(enumField);
}
return false;
}
}

@Test
public void testBeanWithEnum() {
List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira avenue"),
new BeanWithEnum(EnumBean.B, "flower boulevard"));
Encoder<BeanWithEnum> encoder = Encoders.bean(BeanWithEnum.class);
Dataset<BeanWithEnum> ds = spark.createDataset(data, encoder);
Assert.assertEquals(ds.collectAsList(), data);
}

public static class EmptyBean implements Serializable {}

@Test
Expand Down