From b028675714ba5178f7a1e233eeb35f399ac19ee4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Oct 2014 12:49:45 -0700 Subject: [PATCH] allow any type in UDT --- .../apache/spark/mllib/linalg/Vectors.scala | 59 +++++++++++-------- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../spark/sql/catalyst/types/dataTypes.scala | 10 ++-- .../spark/sql/UserDefinedTypeSuite.scala | 22 +++---- 4 files changed, 50 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 17d7684b1ddf5..9aaafa34f8c03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -285,15 +285,18 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { row } - override def deserialize(row: Row): Vector = { - require(row.length == 3, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3") - val vectorType = row.getByte(0) - vectorType match { - case 0 => - new DenseVectorUDT().deserialize(row.getAs[Row](1)) - case 1 => - new SparseVectorUDT().deserialize(row.getAs[Row](2)) + override def deserialize(datum: Any): Vector = { + datum match { + case row: Row => + require(row.length == 3, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 3") + val vectorType = row.getByte(0) + vectorType match { + case 0 => + new DenseVectorUDT().deserialize(row.getAs[Row](1)) + case 1 => + new SparseVectorUDT().deserialize(row.getAs[Row](2)) + } } } } @@ -304,19 +307,20 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { */ private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] { - override def sqlType: StructType = StructType(Seq( - StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false))) + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Row = obj match { - case v: DenseVector => - val row: GenericMutableRow = new GenericMutableRow(1) - row.update(0, v.values.toSeq) - row + override def serialize(obj: Any): Seq[Double] = { + obj match { + case v: DenseVector => + v.values.toSeq + } } - override def deserialize(row: Row): DenseVector = { - val values = row.getAs[Seq[Double]](0).toArray - new DenseVector(values) + override def deserialize(datum: Any): DenseVector = { + datum match { + case values: Seq[_] => + new DenseVector(values.asInstanceOf[Seq[Double]].toArray) + } } } @@ -340,12 +344,15 @@ private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { row } - override def deserialize(row: Row): SparseVector = { - require(row.length >= 1, - s"SparseVectorUDT.deserialize given row with length ${row.length} but requires length >= 1") - val vSize = row.getInt(0) - val indices = row.getAs[Seq[Int]](1).toArray - val values = row.getAs[Seq[Double]](2).toArray - new SparseVector(vSize, indices, values) + override def deserialize(datum: Any): SparseVector = { + datum match { + case row: Row => + require(row.length == 3, + s"SparseVectorUDT.deserialize given row with length ${row.length} but expect 3.") + val vSize = row.getInt(0) + val indices = row.getAs[Seq[Int]](1).toArray + val values = row.getAs[Seq[Double]](2).toArray + new SparseVector(vSize, indices, values) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index db1a8924c008e..de409b8c376b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -65,7 +65,7 @@ object ScalaReflection { convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) } case (d: Decimal, _: DecimalType) => d.toBigDecimal - case (r: Row, udt: UserDefinedType[_]) => udt.deserialize(r) + case (d, udt: UserDefinedType[_]) => udt.deserialize(d) case (other, _) => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index b1d90dba16ce7..220a347af5c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -594,15 +594,15 @@ case class MapType( abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT used by SparkSQL */ - def sqlType: StructType + def sqlType: DataType - /** Convert the user type to a Row object */ + /** Convert the user type to a SQL datum */ // TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, // where we need to convert Any to UserType. - def serialize(obj: Any): Row + def serialize(obj: Any): Any - /** Convert a Row object to the user type */ - def deserialize(row: Row): UserType + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType override private[sql] def jsonValue: JValue = { ("type" -> "udt") ~ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 3208c910a5bc4..cf793ccbd0c02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ @@ -36,19 +35,20 @@ case class MyLabeledPoint(label: Double, features: MyDenseVector) class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - override def sqlType: StructType = StructType(Seq( - StructField("data", ArrayType(DoubleType, containsNull = false), nullable = false))) + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Row = obj match { - case features: MyDenseVector => - val row: GenericMutableRow = new GenericMutableRow(1) - row.update(0, features.data.toSeq) - row + override def serialize(obj: Any): Seq[Double] = { + obj match { + case features: MyDenseVector => + features.data.toSeq + } } - override def deserialize(row: Row): MyDenseVector = { - val data = row.getAs[Seq[Double]](0).toArray - new MyDenseVector(data) + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: Seq[_] => + new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + } } }