diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 6134f9e036638..5f619d6c339e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -84,7 +84,7 @@ object Encoders { private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { assert(encoders.length > 1) // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. - assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) + assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty)) val schema = StructType(encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) @@ -93,8 +93,8 @@ object Encoders { val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) + case e if e.flat => e.toRowExpressions.head + case other => CreateStruct(other.toRowExpressions) }.zipWithIndex.map { case (expr, index) => expr.transformUp { case BoundReference(0, t: ObjectType, _) => @@ -107,11 +107,11 @@ object Encoders { val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { - enc.constructExpression.transform { + enc.fromRowExpression.transform { case b: BoundReference => b.copy(ordinal = index) } } else { - enc.constructExpression.transformUp { + enc.fromRowExpression.transformUp { case BoundReference(ordinal, dt, _) => GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 294afde5347e2..0d3e4aafb0af4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitves to and from the @@ -61,20 +61,39 @@ object ExpressionEncoder { /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an - * N-tuple. Note that these encoders should first be bound correctly to the combined input - * schema. + * N-tuple. Note that these encoders should be unresolved so that information about + * name/positional binding is preserved. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + encoders.foreach(_.assertUnresolved()) + val schema = StructType( - encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)}) + encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) + + // Rebind the encoders to the nested schema. + val newConstructExpressions = encoders.zipWithIndex.map { + case (e, i) if !e.flat => e.nested(i).fromRowExpression + case (e, i) => e.shift(i).fromRowExpression } + val constructExpression = - NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls)) + NewInstance(cls, newConstructExpressions, false, ObjectType(cls)) + + val input = BoundReference(0, ObjectType(cls), false) + val extractExpressions = encoders.zipWithIndex.map { + case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp { + case b: BoundReference => + Invoke(input, s"_${i + 1}", b.dataType, Nil) + })) + case (e, i) => e.toRowExpressions.head transformUp { + case b: BoundReference => + Invoke(input, s"_${i + 1}", b.dataType, Nil) + } + } new ExpressionEncoder[Any]( schema, @@ -95,35 +114,40 @@ object ExpressionEncoder { * A generic encoder for JVM objects. * * @param schema The schema after converting `T` to a Spark SQL row. - * @param extractExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object. + * @param toRowExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]. + * @param fromRowExpression An expression that will construct an object given an [[InternalRow]]. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( schema: StructType, flat: Boolean, - extractExpressions: Seq[Expression], - constructExpression: Expression, + toRowExpressions: Seq[Expression], + fromRowExpression: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(extractExpressions.size == 1) + if (flat) require(toRowExpressions.size == 1) @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) private val inputRow = new GenericMutableRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should * copy the result before making another call if required. */ - def toRow(t: T): InternalRow = { + def toRow(t: T): InternalRow = try { inputRow(0) = t extractProjection(inputRow) + } catch { + case e: Exception => + throw new RuntimeException( + s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e) } /** @@ -135,7 +159,20 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e) + throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e) + } + + /** + * The process of resolution to a given schema throws away information about where a given field + * is being bound by ordinal instead of by name. This method checks to make sure this process + * has not been done already in places where we plan to do later composition of encoders. + */ + def assertUnresolved(): Unit = { + (fromRowExpression +: toRowExpressions).foreach(_.foreach { + case a: AttributeReference => + sys.error(s"Unresolved encoder expected, but $a was found.") + case _ => + }) } /** @@ -143,9 +180,14 @@ case class ExpressionEncoder[T]( * given schema. */ def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema)) + val positionToAttribute = AttributeMap.toIndex(schema) + val unbound = fromRowExpression transform { + case b: BoundReference => positionToAttribute(b.ordinal) + } + + val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(constructExpression = analyzedPlan.expressions.head.children.head) + copy(fromRowExpression = analyzedPlan.expressions.head.children.head) } /** @@ -154,39 +196,14 @@ case class ExpressionEncoder[T]( * resolve before bind. */ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(constructExpression = BindReferences.bindReference(constructExpression, schema)) - } - - /** - * Replaces any bound references in the schema with the attributes at the corresponding ordinal - * in the provided schema. This can be used to "relocate" a given encoder to pull values from - * a different schema than it was initially bound to. It can also be used to assign attributes - * to ordinal based extraction (i.e. because the input data was a tuple). - */ - def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) - copy(constructExpression = constructExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) - }) + copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema)) } /** - * Given an encoder that has already been bound to a given schema, returns a new encoder - * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, - * when you are trying to use an encoder on grouping keys that were originally part of a larger - * row, but now you have projected out only the key expressions. + * Returns a new encoder with input columns shifted by `delta` ordinals */ - def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(oldSchema) - val attributeToNewPosition = AttributeMap.byIndex(newSchema) - copy(constructExpression = constructExpression transform { - case r: BoundReference => - r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) - }) - } - def shift(delta: Int): ExpressionEncoder[T] = { - copy(constructExpression = constructExpression transform { + copy(fromRowExpression = fromRowExpression transform { case r: BoundReference => r.copy(ordinal = r.ordinal + delta) }) } @@ -196,11 +213,14 @@ case class ExpressionEncoder[T]( * input row have been modified to pull the object out from a nested struct, instead of the * top level fields. */ - def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = { - copy(constructExpression = constructExpression transform { - case u: Attribute if u != input => + private def nested(i: Int): ExpressionEncoder[T] = { + // We don't always know our input type at this point since it might be unresolved. + // We fill in null and it will get unbound to the actual attribute at this position. + val input = BoundReference(i, NullType, nullable = true) + copy(fromRowExpression = fromRowExpression transformUp { + case u: Attribute => UnresolvedExtractValue(input, Literal(u.name)) - case b: BoundReference if b != input => + case b: BoundReference => GetStructField( input, StructField(s"i[${b.ordinal}]", b.dataType), @@ -208,7 +228,7 @@ case class ExpressionEncoder[T]( }) } - protected val attrs = extractExpressions.flatMap(_.collect { + protected val attrs = toRowExpressions.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" case b: BoundReference => s"[${b.ordinal}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 2c35adca9c925..9e283f5eb6342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -18,10 +18,19 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference package object encoders { + /** + * Returns an internal encoder object that can be used to serialize / deserialize JVM objects + * into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute + * references from a specific schema.) This requirement allows us to preserve whether a given + * object type is being bound by name or by ordinal when doing resolution. + */ private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { - case e: ExpressionEncoder[A] => e + case e: ExpressionEncoder[A] => + e.assertUnresolved() + e case _ => sys.error(s"Only expression encoders are supported today") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 41cd0a104a1f5..f871b737fff3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -97,11 +97,16 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. + * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]]. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { - override def dataType: DataType = field.dataType + override def dataType: DataType = child.dataType match { + case s: StructType => s(ordinal).dataType + // This is a hack to avoid breaking existing code until we remove the need for the struct field + case _ => field.dataType + } override def nullable: Boolean = child.nullable || field.nullable override def toString: String = s"$child.${field.name}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 32b09b59af436..d9f046efce0bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -483,9 +483,12 @@ case class MapPartitions[T, U]( /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumn { - def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { + def apply[T, U : Encoder]( + func: T => U, + tEncoder: ExpressionEncoder[T], + child: LogicalPlan): AppendColumn[T, U] = { val attrs = encoderFor[U].schema.toAttributes - new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child) + new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child) } } @@ -506,14 +509,16 @@ case class AppendColumn[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { - def apply[K : Encoder, T : Encoder, U : Encoder]( + def apply[K, T, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], groupingAttributes: Seq[Attribute], child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( func, - encoderFor[K], - encoderFor[T], + kEncoder, + tEncoder, encoderFor[U], groupingAttributes, encoderFor[U].schema.toAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f0f275e91f1a3..929224460dc09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression + import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ @@ -45,7 +47,25 @@ private[sql] object Column { * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. */ -class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr) +class TypedColumn[-T, U]( + expr: Expression, + private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) { + + /** + * Inserts the specific input type and schema into any expressions that are expected to operate + * on a decoded object. + */ + private[sql] def withInputType( + inputEncoder: ExpressionEncoder[_], + schema: Seq[Attribute]): TypedColumn[T, U] = { + new TypedColumn[T, U] (expr transform { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]), + children = schema) + }, encoder) + } +} /** * :: Experimental :: @@ -73,6 +93,25 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** Creates a column based on the given expression. */ private def withExpr(newExpr: Expression): Column = new Column(newExpr) + /** + * Returns the expression for this column either with an existing or auto assigned name. + */ + private[sql] def named: NamedExpression = expr match { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + + case expr: NamedExpression => expr + + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. + case explode: Explode => MultiAlias(explode, Nil) + case jt: JsonTuple => MultiAlias(jt, Nil) + + case expr: Expression => Alias(expr, expr.prettyString)() + } + override def toString: String = expr.prettyString override def equals(that: Any): Boolean = that match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index a492099b9392b..3ba4ba18d2122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -735,22 +735,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - val namedExpressions = cols.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) - - case Column(expr: NamedExpression) => expr - - // Leave an unaliased generator with an empty list of names since the analyzer will generate - // the correct defaults after the nested expression's type has been resolved. - case Column(explode: Explode) => MultiAlias(explode, Nil) - case Column(jt: JsonTuple) => MultiAlias(jt, Nil) - - case Column(expr: Expression) => Alias(expr, expr.prettyString)() - } - Project(namedExpressions.toSeq, logicalPlan) + Project(cols.map(_.named), logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 87dae6b331593..b930e4661c1a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.types.StructType /** @@ -63,15 +62,20 @@ import org.apache.spark.sql.types.StructType class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, - unresolvedEncoder: Encoder[T]) extends Queryable with Serializable { + tEncoder: Encoder[T]) extends Queryable with Serializable { + + /** + * An unresolved version of the internal encoder for the type of this dataset. This one is marked + * implicit so that we can use it when constructing new [[Dataset]] objects that have the same + * object type (that will be possibly resolved to a different schema). + */ + private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { - case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output) - case _ => throw new IllegalArgumentException("Only expression encoders are currently supported") - } + private[sql] val resolvedTEncoder: ExpressionEncoder[T] = + unresolvedTEncoder.resolve(queryExecution.analyzed.output) - private implicit def classTag = encoder.clsTag + private implicit def classTag = resolvedTEncoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) @@ -81,7 +85,7 @@ class Dataset[T] private[sql]( * * @since 1.6.0 */ - def schema: StructType = encoder.schema + def schema: StructType = resolvedTEncoder.schema /* ************* * * Conversions * @@ -134,7 +138,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = encoderFor[T] + val tEnc = resolvedTEncoder val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => val bound = tEnc.bind(input) @@ -195,7 +199,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - new Dataset( + new Dataset[U]( sqlContext, MapPartitions[T, U]( func, @@ -295,12 +299,12 @@ class Dataset[T] private[sql]( */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = queryExecution.analyzed - val withGroupingKey = AppendColumn(func, inputPlan) + val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( - encoderFor[K].resolve(withGroupingKey.newColumns), - encoderFor[T].bind(inputPlan.output), + encoderFor[K], + encoderFor[T], executed, inputPlan.output, withGroupingKey.newColumns) @@ -360,7 +364,15 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan)) + // We use an unbound encoder since the expression will make up its own schema. + // TODO: This probably doesn't work if we are relying on reordering of the input class fields. + new Dataset[U1]( + sqlContext, + Project( + c1.withInputType( + resolvedTEncoder.bind(queryExecution.analyzed.output), + queryExecution.analyzed.output).named :: Nil, + logicalPlan)) } /** @@ -369,28 +381,14 @@ class Dataset[T] private[sql]( * that cast appropriately for the user facing interface. */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val withEncoders = columns.map(withEncoder) - val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } - val unresolvedPlan = Project(aliases, logicalPlan) - val execution = new QueryExecution(sqlContext, unresolvedPlan) - // Rebind the encoders to the nested schema that will be produced by the select. - val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a.toAttribute).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output) - } - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) - } + val encoders = columns.map(_.encoder) + // We use an unbound encoder since the expression will make up its own schema. + // TODO: This probably doesn't work if we are relying on reordering of the input class fields. + val namedColumns = + columns.map(_.withInputType(unresolvedTEncoder, queryExecution.analyzed.output).named) + val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) - private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = { - val e = c.expr transform { - case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy( - aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]), - children = queryExecution.analyzed.output) - } - new TypedColumn(e, c.encoder) + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } /** @@ -497,23 +495,18 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan - val leftData = this.encoder match { + val leftData = this.unresolvedTEncoder match { case e if e.flat => Alias(left.output.head, "_1")() case _ => Alias(CreateStruct(left.output), "_1")() } - val rightData = other.encoder match { + val rightData = other.unresolvedTEncoder match { case e if e.flat => Alias(right.output.head, "_2")() case _ => Alias(CreateStruct(right.output), "_2")() } - val leftEncoder = - if (encoder.flat) encoder else encoder.nested(leftData.toAttribute) - val rightEncoder = - if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple( - leftEncoder, - rightEncoder.rebind(right.output, left.output ++ right.output)) + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, @@ -580,7 +573,7 @@ class Dataset[T] private[sql]( private[sql] def logicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder) + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) private[sql] def withPlan[R : Encoder]( other: Dataset[_])( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 61e2a9545069b..ae1272ae531fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,20 +17,16 @@ package org.apache.spark.sql -import java.util.{Iterator => JIterator} import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} +import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.QueryExecution - /** * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -44,23 +40,21 @@ import org.apache.spark.sql.execution.QueryExecution */ @Experimental class GroupedDataset[K, T] private[sql]( - private val kEncoder: Encoder[K], - private val tEncoder: Encoder[T], - queryExecution: QueryExecution, + kEncoder: Encoder[K], + tEncoder: Encoder[T], + val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { - private implicit val kEnc = kEncoder match { - case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } + // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders + // when constructing new logical plans that will operate on the output of the current + // queryexecution. - private implicit val tEnc = tEncoder match { - case e: ExpressionEncoder[T] => e.resolve(dataAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } + private implicit val unresolvedKEncoder = encoderFor(kEncoder) + private implicit val unresolvedTEncoder = encoderFor(tEncoder) + + private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) + private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) /** Encoders for built in aggregations. */ private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) @@ -79,7 +73,7 @@ class GroupedDataset[K, T] private[sql]( def asKey[L : Encoder]: GroupedDataset[L, T] = new GroupedDataset( encoderFor[L], - tEncoder, + unresolvedTEncoder, queryExecution, dataAttributes, groupingAttributes) @@ -95,7 +89,7 @@ class GroupedDataset[K, T] private[sql]( } /** - * Applies the given function to each group of data. For each unique group, the function will + * Applies the given function to each group of data. For each unique group, the function will * be passed the group key and an iterator that contains all of the elements in the group. The * function can return an iterator containing elements of an arbitrary type which will be returned * as a new [[Dataset]]. @@ -108,7 +102,12 @@ class GroupedDataset[K, T] private[sql]( def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, - MapGroups(f, groupingAttributes, logicalPlan)) + MapGroups( + f, + resolvedKEncoder, + resolvedTEncoder, + groupingAttributes, + logicalPlan)) } def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { @@ -127,15 +126,28 @@ class GroupedDataset[K, T] private[sql]( */ def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) - new Dataset[U]( - sqlContext, - MapGroups(func, groupingAttributes, logicalPlan)) + flatMap(func) } def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { map((key, data) => f.call(key, data.asJava))(encoder) } + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + */ + def reduce(f: (T, T) => T): Dataset[(K, T)] = { + val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) + + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) + flatMap(func) + } + + def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { + reduce(f.call _) + } + // To ensure valid overloading. protected def agg(expr: Column, exprs: Column*): DataFrame = groupedData.agg(expr, exprs: _*) @@ -147,37 +159,17 @@ class GroupedDataset[K, T] private[sql]( * TODO: does not handle aggrecations that return nonflat results, */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val aliases = (groupingAttributes ++ columns.map(_.expr)).map { - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - - val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) - - // Fill in the input encoders for any aggregators in the plan. - val withEncoders = unresolvedPlan transformAllExpressions { - case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy( - aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]), - children = dataAttributes) - } - val execution = new QueryExecution(sqlContext, withEncoders) - - val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) - - // Rebind the encoders to the nested schema that will be produced by the aggregation. - val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a :: Nil).resolve(execution.analyzed.output) - } + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map( + _.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes).named) + val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val execution = new QueryExecution(sqlContext, aggregate) new Dataset( sqlContext, execution, - ExpressionEncoder.tuple(encoders)) + ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) } /** @@ -230,7 +222,7 @@ class GroupedDataset[K, T] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.tEncoder + implicit def uEnc: Encoder[U] = other.unresolvedTEncoder new Dataset[R]( sqlContext, CoGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index dfcbac8687b3e..3f2775896bb8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -55,7 +55,7 @@ case class TypedAggregateExpression( aEncoder: Option[ExpressionEncoder[Any]], bEncoder: ExpressionEncoder[Any], cEncoder: ExpressionEncoder[Any], - children: Seq[Expression], + children: Seq[Attribute], mutableAggBufferOffset: Int, inputAggBufferOffset: Int) extends ImperativeAggregate with Logging { @@ -78,8 +78,7 @@ case class TypedAggregateExpression( override lazy val resolved: Boolean = aEncoder.isDefined - override lazy val inputTypes: Seq[DataType] = - aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil) + override lazy val inputTypes: Seq[DataType] = Nil override val aggBufferSchema: StructType = bEncoder.schema @@ -90,12 +89,8 @@ case class TypedAggregateExpression( override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) - lazy val inputAttributes = aEncoder.get.schema.toAttributes - lazy val inputMapping = AttributeMap(inputAttributes.zip(children)) - lazy val boundA = - aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform { - case a: AttributeReference => inputMapping(a) - }) + // We let the dataset do the binding for us. + lazy val boundA = aEncoder.get val bAttributes = bEncoder.schema.toAttributes lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ae08fb71bf4cb..ed82c9a6a3770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -311,6 +311,10 @@ case class AppendColumns[T, U]( newColumns: Seq[Attribute], child: SparkPlan) extends UnaryNode { + // We are using an unsafe combiner. + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def output: Seq[Attribute] = child.output ++ newColumns override protected def doExecute(): RDD[InternalRow] = { @@ -349,11 +353,12 @@ case class MapGroups[K, T, U]( child.execute().mapPartitions { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) val groupKeyEncoder = kEncoder.bind(groupingAttributes) + val groupDataEncoder = tEncoder.bind(child.output) grouped.flatMap { case (key, rowIter) => val result = func( groupKeyEncoder.fromRow(key), - rowIter.map(tEncoder.fromRow)) + rowIter.map(groupDataEncoder.fromRow)) result.map(uEncoder.toRow) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 33d8388f615ae..46169ca07d715 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -157,7 +157,6 @@ public Integer call(Integer v1, Integer v2) throws Exception { Assert.assertEquals(6, reduced); } - @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); @@ -196,6 +195,17 @@ public Iterable call(Integer key, Iterator values) throws Except Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); + Dataset> reduced = grouped.reduce(new ReduceFunction() { + @Override + public String call(String v1, String v2) throws Exception { + return v1 + v2; + } + }); + + Assert.assertEquals( + Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")), + reduced.collectAsList()); + List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, Encoders.INT()); GroupedDataset grouped2 = ds2.groupBy(new MapFunction() { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 378cd365276b3..20896efdfec16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -67,6 +67,28 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L override def finish(reduction: (Long, Long)): (Long, Long) = reduction } +case class AggData(a: Int, b: String) +object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: Int = 0 + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: Int, a: AggData): Int = b + a.a + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: Int): Int = reduction + + /** + * Merge two intermediate values + */ + override def merge(b1: Int, b2: Int): Int = b1 + b2 +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -123,4 +145,24 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), 11 -> 22) } + + test("typed aggregation: class input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ClassInputAgg.toColumn), + 3) + } + + test("typed aggregation: class input with reordering") { + val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData] + + checkAnswer( + ds.select(ClassInputAgg.toColumn), + 1) + + checkAnswer( + ds.groupBy(_.b).agg(ClassInputAgg.toColumn), + ("one", 1)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 621148528714f..c23dd46d3767b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -218,6 +218,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "a", "30", "b", "3", "c", "1") } + test("groupBy function, reduce") { + val ds = Seq("abc", "xyz", "hello").toDS() + val agged = ds.groupBy(_.length).reduce(_ + _) + + checkAnswer( + agged, + 3 -> "abcxyz", 5 -> "hello") + } + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 7a8b7ae5bf265..b5417b195f396 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -82,18 +82,21 @@ abstract class QueryTest extends PlanTest { fail( s""" |Exception collecting dataset as objects - |${ds.encoder} - |${ds.encoder.constructExpression.treeString} + |${ds.resolvedTEncoder} + |${ds.resolvedTEncoder.fromRowExpression.treeString} |${ds.queryExecution} """.stripMargin, e) } if (decoded != expectedAnswer.toSet) { + val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted + val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted + + val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") fail( s"""Decoded objects do not match expected objects: - |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted} - |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted} - |${ds.encoder.constructExpression.treeString} + |$comparision + |${ds.resolvedTEncoder.fromRowExpression.treeString} """.stripMargin) } }