Skip to content

Commit

Permalink
remove unnecessary changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr authored and jkbradley committed Nov 2, 2014
1 parent cfbc321 commit 3143ac3
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 65 deletions.
23 changes: 11 additions & 12 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
@SQLUserDefinedType(serdes = classOf[DenseVectorUDT])
@SQLUserDefinedType(udt = classOf[DenseVectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {

override def size: Int = values.length
Expand Down Expand Up @@ -259,16 +259,16 @@ class SparseVector(
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class VectorUDT extends UserDefinedTypeSerDes[Vector] {
private[spark] class VectorUDT extends UserDefinedType[Vector] {

/**
* vectorType: 0 = dense, 1 = sparse.
* dense, sparse: One element holds the vector, and the other is null.
*/
override def sqlType: StructType = StructType(Seq(
StructField("vectorType", ByteType, nullable = false),
StructField("dense", new UserDefinedType(new DenseVectorUDT), nullable = true),
StructField("sparse", new UserDefinedType(new SparseVectorUDT), nullable = true)))
StructField("dense", new DenseVectorUDT, nullable = true),
StructField("sparse", new SparseVectorUDT, nullable = true)))

override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(3)
Expand Down Expand Up @@ -297,16 +297,17 @@ private[spark] class VectorUDT extends UserDefinedTypeSerDes[Vector] {
}
}

override def userType: Class[Vector] = classOf[Vector]
// override def userType: Class[Vector] = classOf[Vector]
}

/**
* User-defined type for [[DenseVector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class DenseVectorUDT extends UserDefinedTypeSerDes[DenseVector] {
private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] {

override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false)
override def sqlType: StructType = StructType(Seq(
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false)))

override def serialize(obj: Any): Row = obj match {
case v: DenseVector =>
Expand All @@ -320,14 +321,14 @@ private[spark] class DenseVectorUDT extends UserDefinedTypeSerDes[DenseVector] {
new DenseVector(values)
}

override def userType: Class[DenseVector] = classOf[DenseVector]
// override def userType: Class[DenseVector] = classOf[DenseVector]
}

/**
* User-defined type for [[SparseVector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class SparseVectorUDT extends UserDefinedTypeSerDes[SparseVector] {
private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] {

override def sqlType: StructType = StructType(Seq(
StructField("size", IntegerType, nullable = false),
Expand All @@ -341,8 +342,6 @@ private[spark] class SparseVectorUDT extends UserDefinedTypeSerDes[SparseVector]
row.update(1, v.indices.toSeq)
row.update(2, v.values.toSeq)
row
case row: Row =>
row
}

override def deserialize(row: Row): SparseVector = {
Expand All @@ -354,5 +353,5 @@ private[spark] class SparseVectorUDT extends UserDefinedTypeSerDes[SparseVector]
new SparseVector(vSize, indices, values)
}

override def userType: Class[SparseVector] = classOf[SparseVector]
// override def userType: Class[SparseVector] = classOf[SparseVector]
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object ScalaReflection {
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
case (udt, udtType: UserDefinedTypeSerDes[_]) => udtType.serialize(udt)
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (other, _) => other
}

Expand All @@ -65,7 +65,7 @@ object ScalaReflection {
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (d: Decimal, _: DecimalType) => d.toBigDecimal
case (r: Row, udt: UserDefinedType[_]) => udt.serdes.deserialize(r)
case (r: Row, udt: UserDefinedType[_]) => udt.deserialize(r)
case (other, _) => other
}

Expand Down Expand Up @@ -94,12 +94,12 @@ object ScalaReflection {
// whereas className is from Scala reflection. This can make it hard to find classes
// in some cases, such as when a class is enclosed in an object (in which case
// Java appends a '$' to the object name but Scala does not).
val serdes = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).serdes().newInstance()
UDTRegistry.registerType(t, serdes)
Schema(new UserDefinedType(serdes), nullable = true)
val udt = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
UDTRegistry.registerType(t, udt)
Schema(udt, nullable = true)
case t if UDTRegistry.udtRegistry.contains(t) =>
Schema(new UserDefinedType(UDTRegistry.udtRegistry(t)), nullable = true)
Schema(UDTRegistry.udtRegistry(t), nullable = true)
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.reflect.runtime.universe._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes
import org.apache.spark.sql.catalyst.types.UserDefinedType

/**
* ::DeveloperApi::
Expand All @@ -32,14 +32,14 @@ import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes
@DeveloperApi
object UDTRegistry {
/** Map: UserType --> UserDefinedType */
val udtRegistry = new mutable.HashMap[Any, UserDefinedTypeSerDes[_]]()
val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]()

/**
* Register a user-defined type and its serializer, to allow automatic conversion between
* RDDs of user types and SchemaRDDs.
* If this type has already been registered, this does nothing.
*/
def registerType(userType: Type, udt: UserDefinedTypeSerDes[_]): Unit = {
def registerType(userType: Type, udt: UserDefinedType[_]): Unit = {
// TODO: Check to see if type is built-in. Throw exception?
UDTRegistry.udtRegistry(userType) = udt
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.lang.annotation.*;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes;
import org.apache.spark.sql.catalyst.types.UserDefinedType;

/**
* ::DeveloperApi::
Expand All @@ -38,5 +38,5 @@
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface SQLUserDefinedType {
Class<? extends UserDefinedTypeSerDes<?> > serdes();
Class<? extends UserDefinedType<?> > udt();
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,9 @@ object DataType {
StructType(fields.map(parseStructField))

case JSortedObject(
("serdes", JString(serdesClass)),
("type", JString("udt"))) => {
val serdes = Class.forName(serdesClass).newInstance().asInstanceOf[UserDefinedTypeSerDes[_]]
new UserDefinedType(serdes)
}
("class", JString(udtClass)),
("type", JString("udt"))) =>
Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
}

private def parseStructField(json: JValue): StructField = json match {
Expand Down Expand Up @@ -580,7 +578,7 @@ case class MapType(
* The data type for User Defined Types (UDTs).
*
* This interface allows a user to make their own classes more interoperable with SparkSQL;
* e.g., by creating a [[UserDefinedTypeSerDes]] for a class X, it becomes possible to create a SchemaRDD
* e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create a SchemaRDD
* which has class X in the schema.
*
* For SparkSQL to recognize UDTs, the UDT must be registered in
Expand All @@ -593,12 +591,12 @@ case class MapType(
* The conversion via `deserialize` occurs when reading from a `SchemaRDD`.
*/
@DeveloperApi
abstract class UserDefinedTypeSerDes[UserType] extends Serializable {
abstract class UserDefinedType[UserType] extends DataType with Serializable {

def userType: Class[UserType]
// def userType: Class[UserType]

/** Underlying storage type for this UDT used by SparkSQL */
def sqlType: DataType
def sqlType: StructType

/** Convert the user type to a Row object */
// TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
Expand All @@ -607,12 +605,9 @@ abstract class UserDefinedTypeSerDes[UserType] extends Serializable {

/** Convert a Row object to the user type */
def deserialize(row: Row): UserType
}

case class UserDefinedType[UserType](serdes: UserDefinedTypeSerDes[UserType])
extends DataType with Serializable {
override private[sql] def jsonValue: JValue = {
("type" -> "udt") ~
("serdes" -> serdes.getClass.getName)
("class" -> this.getClass.getName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ private[sql] object CatalystConverter {
fieldIndex,
parent)
}
case UserDefinedType(serdes) => {
createConverter(field.copy(dataType = serdes.sqlType), fieldIndex, parent)
case udt: UserDefinedType[_] => {
createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent)
}
// Strings, Shorts and Bytes do not have a corresponding type in Parquet
// so we need to treat them separately
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
case t @ StructType(_) => writeStruct(
t,
value.asInstanceOf[CatalystConverter.StructScalaType[_]])
case UserDefinedType(serdes) => {
println(value.getClass)
writeValue(serdes.sqlType, serdes.serialize(value))
}
case t: UserDefinedType[_] => writeValue(t.sqlType, value)
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ private[parquet] object ParquetTypesConverter extends Logging {
parquetKeyType,
parquetValueType)
}
case UserDefinedType(serdes) => {
fromDataType(serdes.sqlType, name, nullable, inArray)
case udt: UserDefinedType[_] => {
fromDataType(udt.sqlType, name, nullable, inArray)
}
case _ => sys.error(s"Unsupported datatype $ctype")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.UDTRegistry
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.test.TestSQLContext._

@SQLUserDefinedType(serdes = classOf[MyDenseVectorUDT])
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
class MyDenseVector(val data: Array[Double]) extends Serializable {
override def equals(other: Any): Boolean = other match {
case v: MyDenseVector =>
Expand All @@ -35,31 +34,21 @@ class MyDenseVector(val data: Array[Double]) extends Serializable {

case class MyLabeledPoint(label: Double, features: MyDenseVector)

class MyDenseVectorUDT extends UserDefinedTypeSerDes[MyDenseVector] {
class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {

override def userType: Class[MyDenseVector] = classOf[MyDenseVector]

override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false)
override def sqlType: StructType = StructType(Seq(
StructField("data", ArrayType(DoubleType, containsNull = false), nullable = false)))

override def serialize(obj: Any): Row = obj match {
case features: MyDenseVector =>
val row: GenericMutableRow = new GenericMutableRow(features.data.length)
var i = 0
while (i < features.data.length) {
row.setDouble(i, features.data(i))
i += 1
}
val row: GenericMutableRow = new GenericMutableRow(1)
row.update(0, features.data.toSeq)
row
}

override def deserialize(row: Row): MyDenseVector = {
val features = new MyDenseVector(new Array[Double](row.length))
var i = 0
while (i < row.length) {
features.data(i) = row.getDouble(i)
i += 1
}
features
val features = row.getAs[Seq[Double]](0).toArray
new MyDenseVector(features)
}
}

Expand Down

0 comments on commit 3143ac3

Please sign in to comment.