Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49960][SQL] Custom ExpressionEncoder support and TransformingEncoder fixes #50023

Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,14 @@ object AgnosticEncoders {
* another encoder. This is fallback for scenarios where objects can't be represented using
* standard encoders, an example of this is where we use a different (opaque) serialization
* format (i.e. java serialization, kryo serialization, or protobuf).
* @param nullable defaults to false indicating the codec guarantees
* decode / encode results are non-nullable
*/
case class TransformingEncoder[I, O](
clsTag: ClassTag[I],
transformed: AgnosticEncoder[O],
codecProvider: () => Codec[_ >: I, O])
codecProvider: () => Codec[_ >: I, O],
override val nullable: Boolean = false)
extends AgnosticEncoder[I] {
override def isPrimitive: Boolean = transformed.isPrimitive
override def dataType: DataType = transformed.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
Expand Down Expand Up @@ -270,6 +270,8 @@ object DeserializerBuildHelper {
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] =>
ae.fromCatalyst(path)
case _ if isNativeEncoder(enc) =>
path
case _: BoxedLeafEncoder[_, _] =>
Expand Down Expand Up @@ -447,13 +449,13 @@ object DeserializerBuildHelper {
val result = InitializeJavaBean(newInstance, setters.toMap)
exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)

case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec =>
case TransformingEncoder(tag, _, codec, _) if codec == JavaSerializationCodec =>
DecodeUsingSerializer(path, tag, kryo = false)

case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec =>
case TransformingEncoder(tag, _, codec, _) if codec == KryoSerializationCodec =>
DecodeUsingSerializer(path, tag, kryo = true)

case TransformingEncoder(tag, encoder, provider) =>
case TransformingEncoder(tag, encoder, provider, _) =>
Invoke(
Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
"decode",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.language.existentials

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
Expand Down Expand Up @@ -306,6 +306,7 @@ object SerializerBuildHelper {
* by encoder `enc`.
*/
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input)
case _ if isNativeEncoder(enc) => input
case BoxedBooleanEncoder => createSerializerForBoolean(input)
case BoxedByteEncoder => createSerializerForByte(input)
Expand Down Expand Up @@ -418,18 +419,21 @@ object SerializerBuildHelper {
}
createSerializerForObject(input, serializedFields)

case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec =>
case TransformingEncoder(_, _, codec, _) if codec == JavaSerializationCodec =>
EncodeUsingSerializer(input, kryo = false)

case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec =>
case TransformingEncoder(_, _, codec, _) if codec == KryoSerializationCodec =>
EncodeUsingSerializer(input, kryo = true)

case TransformingEncoder(_, encoder, codecProvider) =>
case TransformingEncoder(_, encoder, codecProvider, _) =>
val encoded = Invoke(
Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])),
"encode",
externalDataTypeFor(encoder),
input :: Nil)
input :: Nil,
propagateNull = input.nullable,
returnNullable = input.nullable
)
createSerializer(encoder, encoded)
}

Expand Down Expand Up @@ -486,6 +490,7 @@ object SerializerBuildHelper {
nullable: Boolean): Expression => Expression = { input =>
val expected = enc match {
case OptionEncoder(_) => lenientExternalDataTypeFor(enc)
case TransformingEncoder(_, transformed, _, _) => lenientExternalDataTypeFor(transformed)
case _ => enc.dataType
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders

import scala.collection.Map

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder}
import org.apache.spark.sql.catalyst.expressions.Expression
Expand All @@ -26,6 +27,30 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}

/**
* :: DeveloperApi ::
* Extensible [[AgnosticEncoder]] providing conversion extension points over type T
* @tparam T over T
*/
@DeveloperApi
@deprecated("This trait is intended only as a migration tool and will be removed in 4.1")
trait AgnosticExpressionPathEncoder[T]
extends AgnosticEncoder[T] {
/**
* Converts from T to InternalRow
* @param input the starting input path
* @return
*/
def toCatalyst(input: Expression): Expression

/**
* Converts from InternalRow to T
* @param inputPath path expression from InternalRow
* @return
*/
def fromCatalyst(inputPath: Expression): Expression
}

/**
* Helper class for Generating [[ExpressionEncoder]]s.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, JavaTypeInference, ScalaReflection, SerializerBuildHelper}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{OptionEncoder, TransformingEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
Expand Down Expand Up @@ -215,6 +216,13 @@ case class ExpressionEncoder[T](
StructField(s.name, s.dataType, s.nullable)
})

private def transformerOfOption(enc: AgnosticEncoder[_]): Boolean =
enc match {
case t: TransformingEncoder[_, _] => transformerOfOption(t.transformed)
case _: OptionEncoder[_] => true
case _ => false
}

/**
* Returns true if the type `T` is serialized as a struct by `objSerializer`.
*/
Expand All @@ -228,7 +236,8 @@ case class ExpressionEncoder[T](
* returns true if `T` is serialized as struct and is not `Option` type.
*/
def isSerializedAsStructForTopLevel: Boolean = {
isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)
isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isSerializedAsStruct && !transformerOfOption(encoder)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we should make these checks part of the AgnosticEncoder api.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd make sense, for the path encoder backwards compat logic I can embed / document that in shim. The Builders could embed that. I can take a stab at that post rc2.

!transformerOfOption(encoder)
}

// serializer expressions are used to encode an object to a row, while the object is usually an
Expand Down
Loading