Skip to content

Commit

Permalink
[SPARK-11654][SQL] add reduce to GroupedDataset
Browse files Browse the repository at this point in the history
This PR adds a new method, `reduce`, to `GroupedDataset`, which allows similar operations to `reduceByKey` on a traditional `PairRDD`.

```scala
val ds = Seq("abc", "xyz", "hello").toDS()
ds.groupBy(_.length).reduce(_ + _).collect()  // not actually commutative :P

res0: Array(3 -> "abcxyz", 5 -> "hello")
```

While implementing this method and its test cases several more deficiencies were found in our encoder handling.  Specifically, in order to support positional resolution, named resolution and tuple composition, it is important to keep the unresolved encoder around and to use it when constructing new `Datasets` with the same object type but different output attributes.  We now divide the encoder lifecycle into three phases (that mirror the lifecycle of standard expressions) and have checks at various boundaries:

 - Unresoved Encoders: all users facing encoders (those constructed by implicits, static methods, or tuple composition) are unresolved, meaning they have only `UnresolvedAttributes` for named fields and `BoundReferences` for fields accessed by ordinal.
 - Resolved Encoders: internal to a `[Grouped]Dataset` the encoder is resolved, meaning all input has been resolved to a specific `AttributeReference`.  Any encoders that are placed into a logical plan for use in object construction should be resolved.
 - BoundEncoder: Are constructed by physical plans, right before actual conversion from row -> object is performed.

It is left to future work to add explicit checks for resolution and provide good error messages when it fails.  We might also consider enforcing the above constraints in the type system (i.e. `fromRow` only exists on a `ResolvedEncoder`), but we should probably wait before spending too much time on this.

Author: Michael Armbrust <michael@databricks.com>
Author: Wenchen Fan <wenchen@databricks.com>

Closes apache#9673 from marmbrus/pr/9628.
  • Loading branch information
marmbrus authored and dskrvk committed Nov 13, 2015
1 parent c3be3b6 commit 8e2e7ac
Show file tree
Hide file tree
Showing 15 changed files with 309 additions and 197 deletions.
10 changes: 5 additions & 5 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object Encoders {
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
assert(encoders.length > 1)
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))

val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
Expand All @@ -93,8 +93,8 @@ object Encoders {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

val extractExpressions = encoders.map {
case e if e.flat => e.extractExpressions.head
case other => CreateStruct(other.extractExpressions)
case e if e.flat => e.toRowExpressions.head
case other => CreateStruct(other.toRowExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
Expand All @@ -107,11 +107,11 @@ object Encoders {

val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
enc.constructExpression.transform {
enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
enc.constructExpression.transformUp {
enc.fromRowExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType}

/**
* A factory for constructing encoders that convert objects and primitves to and from the
Expand Down Expand Up @@ -61,20 +61,39 @@ object ExpressionEncoder {

/**
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
* N-tuple. Note that these encoders should first be bound correctly to the combined input
* schema.
* N-tuple. Note that these encoders should be unresolved so that information about
* name/positional binding is preserved.
*/
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
encoders.foreach(_.assertUnresolved())

val schema =
StructType(
encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
})
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
val extractExpressions = encoders.map {
case e if e.flat => e.extractExpressions.head
case other => CreateStruct(other.extractExpressions)

// Rebind the encoders to the nested schema.
val newConstructExpressions = encoders.zipWithIndex.map {
case (e, i) if !e.flat => e.nested(i).fromRowExpression
case (e, i) => e.shift(i).fromRowExpression
}

val constructExpression =
NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
NewInstance(cls, newConstructExpressions, false, ObjectType(cls))

val input = BoundReference(0, ObjectType(cls), false)
val extractExpressions = encoders.zipWithIndex.map {
case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
case b: BoundReference =>
Invoke(input, s"_${i + 1}", b.dataType, Nil)
}))
case (e, i) => e.toRowExpressions.head transformUp {
case b: BoundReference =>
Invoke(input, s"_${i + 1}", b.dataType, Nil)
}
}

new ExpressionEncoder[Any](
schema,
Expand All @@ -95,35 +114,40 @@ object ExpressionEncoder {
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
* @param extractExpressions A set of expressions, one for each top-level field that can be used to
* extract the values from a raw object.
* @param toRowExpressions A set of expressions, one for each top-level field that can be used to
* extract the values from a raw object into an [[InternalRow]].
* @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
extractExpressions: Seq[Expression],
constructExpression: Expression,
toRowExpressions: Seq[Expression],
fromRowExpression: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {

if (flat) require(extractExpressions.size == 1)
if (flat) require(toRowExpressions.size == 1)

@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
private val inputRow = new GenericMutableRow(1)

@transient
private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)

/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
* copy the result before making another call if required.
*/
def toRow(t: T): InternalRow = {
def toRow(t: T): InternalRow = try {
inputRow(0) = t
extractProjection(inputRow)
} catch {
case e: Exception =>
throw new RuntimeException(
s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
}

/**
Expand All @@ -135,17 +159,35 @@ case class ExpressionEncoder[T](
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
}

/**
* The process of resolution to a given schema throws away information about where a given field
* is being bound by ordinal instead of by name. This method checks to make sure this process
* has not been done already in places where we plan to do later composition of encoders.
*/
def assertUnresolved(): Unit = {
(fromRowExpression +: toRowExpressions).foreach(_.foreach {
case a: AttributeReference =>
sys.error(s"Unresolved encoder expected, but $a was found.")
case _ =>
})
}

/**
* Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
* given schema.
*/
def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
val positionToAttribute = AttributeMap.toIndex(schema)
val unbound = fromRowExpression transform {
case b: BoundReference => positionToAttribute(b.ordinal)
}

val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
copy(constructExpression = analyzedPlan.expressions.head.children.head)
copy(fromRowExpression = analyzedPlan.expressions.head.children.head)
}

/**
Expand All @@ -154,39 +196,14 @@ case class ExpressionEncoder[T](
* resolve before bind.
*/
def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
}

/**
* Replaces any bound references in the schema with the attributes at the corresponding ordinal
* in the provided schema. This can be used to "relocate" a given encoder to pull values from
* a different schema than it was initially bound to. It can also be used to assign attributes
* to ordinal based extraction (i.e. because the input data was a tuple).
*/
def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
val positionToAttribute = AttributeMap.toIndex(schema)
copy(constructExpression = constructExpression transform {
case b: BoundReference => positionToAttribute(b.ordinal)
})
copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
}

/**
* Given an encoder that has already been bound to a given schema, returns a new encoder
* where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
* when you are trying to use an encoder on grouping keys that were originally part of a larger
* row, but now you have projected out only the key expressions.
* Returns a new encoder with input columns shifted by `delta` ordinals
*/
def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
val positionToAttribute = AttributeMap.toIndex(oldSchema)
val attributeToNewPosition = AttributeMap.byIndex(newSchema)
copy(constructExpression = constructExpression transform {
case r: BoundReference =>
r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
})
}

def shift(delta: Int): ExpressionEncoder[T] = {
copy(constructExpression = constructExpression transform {
copy(fromRowExpression = fromRowExpression transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
Expand All @@ -196,19 +213,22 @@ case class ExpressionEncoder[T](
* input row have been modified to pull the object out from a nested struct, instead of the
* top level fields.
*/
def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
copy(constructExpression = constructExpression transform {
case u: Attribute if u != input =>
private def nested(i: Int): ExpressionEncoder[T] = {
// We don't always know our input type at this point since it might be unresolved.
// We fill in null and it will get unbound to the actual attribute at this position.
val input = BoundReference(i, NullType, nullable = true)
copy(fromRowExpression = fromRowExpression transformUp {
case u: Attribute =>
UnresolvedExtractValue(input, Literal(u.name))
case b: BoundReference if b != input =>
case b: BoundReference =>
GetStructField(
input,
StructField(s"i[${b.ordinal}]", b.dataType),
b.ordinal)
})
}

protected val attrs = extractExpressions.flatMap(_.collect {
protected val attrs = toRowExpressions.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
case b: BoundReference => s"[${b.ordinal}]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference

package object encoders {
/**
* Returns an internal encoder object that can be used to serialize / deserialize JVM objects
* into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute
* references from a specific schema.) This requirement allows us to preserve whether a given
* object type is being bound by name or by ordinal when doing resolution.
*/
private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
case e: ExpressionEncoder[A] => e
case e: ExpressionEncoder[A] =>
e.assertUnresolved()
e
case _ => sys.error(s"Only expression encoders are supported today")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ object ExtractValue {
* Returns the value of fields in the Struct `child`.
*
* No need to do type checking since it is handled by [[ExtractValue]].
* TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends UnaryExpression {

override def dataType: DataType = field.dataType
override def dataType: DataType = child.dataType match {
case s: StructType => s(ordinal).dataType
// This is a hack to avoid breaking existing code until we remove the need for the struct field
case _ => field.dataType
}
override def nullable: Boolean = child.nullable || field.nullable
override def toString: String = s"$child.${field.name}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,12 @@ case class MapPartitions[T, U](

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
def apply[T, U : Encoder](
func: T => U,
tEncoder: ExpressionEncoder[T],
child: LogicalPlan): AppendColumn[T, U] = {
val attrs = encoderFor[U].schema.toAttributes
new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
}
}

Expand All @@ -506,14 +509,16 @@ case class AppendColumn[T, U](

/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
def apply[K : Encoder, T : Encoder, U : Encoder](
def apply[K, T, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
kEncoder: ExpressionEncoder[K],
tEncoder: ExpressionEncoder[T],
groupingAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
encoderFor[K],
encoderFor[T],
kEncoder,
tEncoder,
encoderFor[U],
groupingAttributes,
encoderFor[U].schema.toAttributes,
Expand Down
43 changes: 41 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.sql

import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression

import scala.language.implicitConversions

import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._
Expand All @@ -45,7 +47,25 @@ private[sql] object Column {
* checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
* @tparam U The output type of this column.
*/
class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr)
class TypedColumn[-T, U](
expr: Expression,
private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) {

/**
* Inserts the specific input type and schema into any expressions that are expected to operate
* on a decoded object.
*/
private[sql] def withInputType(
inputEncoder: ExpressionEncoder[_],
schema: Seq[Attribute]): TypedColumn[T, U] = {
new TypedColumn[T, U] (expr transform {
case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
ta.copy(
aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]),
children = schema)
}, encoder)
}
}

/**
* :: Experimental ::
Expand Down Expand Up @@ -73,6 +93,25 @@ class Column(protected[sql] val expr: Expression) extends Logging {
/** Creates a column based on the given expression. */
private def withExpr(newExpr: Expression): Column = new Column(newExpr)

/**
* Returns the expression for this column either with an existing or auto assigned name.
*/
private[sql] def named: NamedExpression = expr match {
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)

case expr: NamedExpression => expr

// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
case explode: Explode => MultiAlias(explode, Nil)
case jt: JsonTuple => MultiAlias(jt, Nil)

case expr: Expression => Alias(expr, expr.prettyString)()
}

override def toString: String = expr.prettyString

override def equals(that: Any): Boolean = that match {
Expand Down
Loading

0 comments on commit 8e2e7ac

Please sign in to comment.