Skip to content

Commit

Permalink
allow any type in UDT
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr authored and jkbradley committed Nov 2, 2014
1 parent 4500d8a commit b028675
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 43 deletions.
59 changes: 33 additions & 26 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,18 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
row
}

override def deserialize(row: Row): Vector = {
require(row.length == 3,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3")
val vectorType = row.getByte(0)
vectorType match {
case 0 =>
new DenseVectorUDT().deserialize(row.getAs[Row](1))
case 1 =>
new SparseVectorUDT().deserialize(row.getAs[Row](2))
override def deserialize(datum: Any): Vector = {
datum match {
case row: Row =>
require(row.length == 3,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3")
val vectorType = row.getByte(0)
vectorType match {
case 0 =>
new DenseVectorUDT().deserialize(row.getAs[Row](1))
case 1 =>
new SparseVectorUDT().deserialize(row.getAs[Row](2))
}
}
}
}
Expand All @@ -304,19 +307,20 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
*/
private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] {

override def sqlType: StructType = StructType(Seq(
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false)))
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)

override def serialize(obj: Any): Row = obj match {
case v: DenseVector =>
val row: GenericMutableRow = new GenericMutableRow(1)
row.update(0, v.values.toSeq)
row
override def serialize(obj: Any): Seq[Double] = {
obj match {
case v: DenseVector =>
v.values.toSeq
}
}

override def deserialize(row: Row): DenseVector = {
val values = row.getAs[Seq[Double]](0).toArray
new DenseVector(values)
override def deserialize(datum: Any): DenseVector = {
datum match {
case values: Seq[_] =>
new DenseVector(values.asInstanceOf[Seq[Double]].toArray)
}
}
}

Expand All @@ -340,12 +344,15 @@ private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] {
row
}

override def deserialize(row: Row): SparseVector = {
require(row.length >= 1,
s"SparseVectorUDT.deserialize given row with length ${row.length} but requires length >= 1")
val vSize = row.getInt(0)
val indices = row.getAs[Seq[Int]](1).toArray
val values = row.getAs[Seq[Double]](2).toArray
new SparseVector(vSize, indices, values)
override def deserialize(datum: Any): SparseVector = {
datum match {
case row: Row =>
require(row.length == 3,
s"SparseVectorUDT.deserialize given row with length ${row.length} but expect 3.")
val vSize = row.getInt(0)
val indices = row.getAs[Seq[Int]](1).toArray
val values = row.getAs[Seq[Double]](2).toArray
new SparseVector(vSize, indices, values)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ object ScalaReflection {
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (d: Decimal, _: DecimalType) => d.toBigDecimal
case (r: Row, udt: UserDefinedType[_]) => udt.deserialize(r)
case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
case (other, _) => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,15 +594,15 @@ case class MapType(
abstract class UserDefinedType[UserType] extends DataType with Serializable {

/** Underlying storage type for this UDT used by SparkSQL */
def sqlType: StructType
def sqlType: DataType

/** Convert the user type to a Row object */
/** Convert the user type to a SQL datum */
// TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
// where we need to convert Any to UserType.
def serialize(obj: Any): Row
def serialize(obj: Any): Any

/** Convert a Row object to the user type */
def deserialize(row: Row): UserType
/** Convert a SQL datum to the user type */
def deserialize(datum: Any): UserType

override private[sql] def jsonValue: JValue = {
("type" -> "udt") ~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.test.TestSQLContext._

Expand All @@ -36,19 +35,20 @@ case class MyLabeledPoint(label: Double, features: MyDenseVector)

class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {

override def sqlType: StructType = StructType(Seq(
StructField("data", ArrayType(DoubleType, containsNull = false), nullable = false)))
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)

override def serialize(obj: Any): Row = obj match {
case features: MyDenseVector =>
val row: GenericMutableRow = new GenericMutableRow(1)
row.update(0, features.data.toSeq)
row
override def serialize(obj: Any): Seq[Double] = {
obj match {
case features: MyDenseVector =>
features.data.toSeq
}
}

override def deserialize(row: Row): MyDenseVector = {
val data = row.getAs[Seq[Double]](0).toArray
new MyDenseVector(data)
override def deserialize(datum: Any): MyDenseVector = {
datum match {
case data: Seq[_] =>
new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray)
}
}
}

Expand Down

0 comments on commit b028675

Please sign in to comment.