Skip to content

Commit

Permalink
some cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent 893ee4c commit 964b32e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,85 +36,34 @@ object ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)

/** Converts Scala objects to catalyst rows / types */
/*
def convertToCatalyst(a: Any, dataType: DataType): Any = a match {
// TODO: Why does this not need to flatMap stuff? Does it not support nesting?
case o: Option[_] =>
println(s"convertToCatalyst: option")
o.map(convertToCatalyst(_, dataType)).orNull
case s: Seq[_] =>
println(s"convertToCatalyst: array")
s.map(convertToCatalyst(_, null))
case m: Map[_, _] =>
println(s"convertToCatalyst: map")
m.map { case (k, v) =>
convertToCatalyst(k, null) -> convertToCatalyst(v, null)
}
case p: Product =>
println(s"convertToCatalyst: struct")
new GenericRow(p.productIterator.map(convertToCatalyst(_, null)).toArray)
case other =>
println(s"convertToCatalyst: other")
other
}
*/

def convertToCatalyst(a: Any, dataType: DataType): Any = {
println(s"convertToCatalyst: a = $a, dataType = $dataType")
(a, dataType) match {
// TODO: Why does this not need to flatMap stuff? Does it not support nesting?
case (o: Option[_], _) =>
println(s"convertToCatalyst: option")
o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) =>
println(s"convertToCatalyst: array")
s.map(convertToCatalyst(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) =>
println(s"convertToCatalyst: map")
m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
case (p: Product, structType: StructType) =>
println(s"convertToCatalyst: struct with")
println(s"\t p: $p")
println(s"\t structType: $structType")
new GenericRow(
p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
convertToCatalyst(elem, field.dataType)
}.toArray)
case (udt, udtType: UserDefinedType[_]) =>
println(s"convertToCatalyst: udt with $udtType")
udtType.serialize(udt)
case (d: BigDecimal, _) => Decimal(d)
case (other, _) =>
println(s"convertToCatalyst: other")
other
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
case (p: Product, structType: StructType) =>
new GenericRow(
p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
case (udt, udtType: UserDefinedType[_]) => udtType.serialize(udt)
case (other, _) => other
}

/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// TODO: Why does this not need to flatMap stuff? Does it not support nesting?
// TODO: What about Option and Product?
case (s: Seq[_], arrayType: ArrayType) =>
println("convertToScala: Seq")
s.map(convertToScala(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) =>
println("convertToScala: Map")
m.map { case (k, v) =>
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (d: Decimal, DecimalType) => d.toBigDecimal
case (udt: Row, udtType: UserDefinedType[_]) =>
println("convertToScala: udt")
udtType.deserialize(udt)
case (other, _) =>
println("convertToScala: other")
other
case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt)
case (other, _) => other
}

def convertRowToScala(r: Row, schema: StructType): Row = {
println("convertRowToScala called with schema: $schema")
new GenericRow(
r.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
Expand All @@ -133,63 +82,57 @@ object ScalaReflection {
def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema =
schemaFor(typeOf[T], udtRegistry)

/**
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
* TODO: ADD DOC
*/
def schemaFor(tpe: `Type`, udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = {
println(s"schemaFor: $tpe")
tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType, udtRegistry).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
println(s" --schemaFor matched on Product")
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry)
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry)
Schema(MapType(schemaFor(keyType, udtRegistry).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if udtRegistry.contains(t) =>
println(s" schemaFor T matched udtRegistry")
Schema(udtRegistry(t), nullable = true)
}
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(
tpe: `Type`,
udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType, udtRegistry).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry)
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry)
Schema(MapType(schemaFor(keyType, udtRegistry).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if udtRegistry.contains(t) =>
Schema(udtRegistry(t), nullable = true)
}

def typeOfObject: PartialFunction[Any, DataType] = {
Expand Down Expand Up @@ -227,5 +170,4 @@ object ScalaReflection {
LocalRelation(output, data)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi

// scalastyle:off
override def eval(input: Row): Any = {
println(s"ScalaUdf.eval called")
val result = children.size match {
case 0 => function.asInstanceOf[() => Any]()
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,34 +567,19 @@ case class MapType(
("valueContainsNull" -> valueContainsNull)
}

// TODO: Where should this go?
trait UserDefinedType[T] {
def dataType: StructType
def serialize(obj: T): Row
def deserialize(row: Row): T
}

object UserDefinedType {
/**
* Construct a [[UserDefinedType]] object with the given key type and value type.
* The `valueContainsNull` is true.
*/
//def apply(keyType: DataType, valueType: DataType): MapType =
// MapType(keyType: DataType, valueType: DataType, true)
}

/**
* The data type for User Defined Types.
*/
abstract class UserDefinedType[UserType] extends DataType {

// Used only in regex parser above.
//private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { }

/** Underlying storage type for this UDT used by SparkSQL */
def sqlType: DataType

/** Convert the user type to a Row object */
// TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, where we need to convert Any to UserType.
def serialize(obj: Any): Row

/** Convert a Row object to the user type */
def deserialize(row: Row): UserType

def simpleString: String = "udt"
Expand Down

0 comments on commit 964b32e

Please sign in to comment.