Skip to content

Commit

Permalink
unify GetStructField and GetInternalRowField
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Nov 23, 2015
1 parent 426004a commit ec40d23
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection {

/** Returns the current path with a field at ordinal extracted. */
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
.map(p => GetInternalRowField(p, ordinal, dataType))
.map(p => GetStructField(p, ordinal))
.getOrElse(BoundReference(ordinal, dataType, false))

/** Returns the current path or `BoundReference`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
if (attribute.isDefined) {
// This target resolved to an attribute in child. It must be a struct. Expand it.
attribute.get.dataType match {
case s: StructType => {
s.fields.map( f => {
val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get)
case s: StructType => s.zipWithIndex.map {
case (f, i) =>
val extract = GetStructField(attribute.get, i)
Alias(extract, target.get + "." + f.name)()
})
}

case _ => {
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
target.get + "`")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ object ExpressionEncoder {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object RowEncoder {
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
constructorFor(GetInternalRowField(input, i, f.dataType)))
constructorFor(GetStructField(input, i)))
}
CreateExternalRow(convertedFields)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] {
*/
def prettyString: String = {
transform {
case a: AttributeReference => PrettyAttribute(a.name)
case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object ExtractValue {
case (StructType(fields), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
GetStructField(child, ordinal, Some(fieldName))

case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
Expand Down Expand Up @@ -97,18 +97,15 @@ object ExtractValue {
* Returns the value of fields in the Struct `child`.
*
* No need to do type checking since it is handled by [[ExtractValue]].
* TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
extends UnaryExpression {

override def dataType: DataType = child.dataType match {
case s: StructType => s(ordinal).dataType
// This is a hack to avoid breaking existing code until we remove the need for the struct field
case _ => field.dataType
}
private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)

override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
override def toString: String = s"$child.${field.name}"
override def toString: String = s"$child.${name.getOrElse(field.name)}"

protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow].get(ordinal, field.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ case class AttributeReference(
* A place holder used when printing expressions without debugging information such as the
* expression id or the unresolved indicator.
*/
case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
case class PrettyAttribute(name: String, dataType: DataType = NullType)
extends Attribute with Unevaluable {

override def toString: String = name

Expand All @@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
override def nullable: Boolean = throw new UnsupportedOperationException
override def dataType: DataType = NullType
}

object VirtualColumn {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,27 +522,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
}
}

case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
extends UnaryExpression {

override def nullable: Boolean = true

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val row = child.gen(ctx)
s"""
${row.code}
final boolean ${ev.isNull} = ${row.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)};
}
"""
}
}

/**
* Serializes an input object using a generic serializer (Kryo or Java).
* @param kryo if true, use Kryo. Otherwise, use Java.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
def getStructField(expr: Expression, fieldName: String): GetStructField = {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
GetStructField(expr, field, fields.indexOf(field))
val index = fields.indexWhere(_.name == fieldName)
GetStructField(expr, index)
}
}

Expand Down

0 comments on commit ec40d23

Please sign in to comment.