Skip to content

Commit

Permalink
renamed UDT types
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent 3579035 commit 2f40c02
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.annotation.UserDefinedType
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
Expand Down Expand Up @@ -53,7 +53,7 @@ object ScalaReflection {
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
case (udt, udtType: UserDefinedTypeType[_]) => udtType.serialize(udt)
case (udt, udtType: UserDefinedType[_]) => udtType.serialize(udt)
case (other, _) => other
}

Expand All @@ -64,7 +64,7 @@ object ScalaReflection {
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (d: Decimal, DecimalType) => d.toBigDecimal
case (udt: Row, udtType: UserDefinedTypeType[_]) => udtType.deserialize(udt)
case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt)
case (other, _) => other
}

Expand All @@ -86,58 +86,56 @@ object ScalaReflection {
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = {
tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if getClass.getClassLoader.loadClass(t.typeSymbol.asClass.fullName)
.isAnnotationPresent(classOf[UserDefinedType]) =>
UDTRegistry.registerType(t)
Schema(UDTRegistry.udtRegistry(t), nullable = true)
}
def schemaFor(tpe: `Type`): Schema = tpe match {
case t if getClass.getClassLoader.loadClass(t.typeSymbol.asClass.fullName)
.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
UDTRegistry.registerType(t)
Schema(UDTRegistry.udtRegistry(t), nullable = true)
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
}

def typeOfObject: PartialFunction[Any, DataType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.annotation.UserDefinedType
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType

import scala.collection.mutable

import org.apache.spark.sql.catalyst.types.UserDefinedTypeType
import org.apache.spark.sql.catalyst.types.UserDefinedType

import scala.reflect.runtime.universe._

Expand All @@ -30,7 +30,7 @@ import scala.reflect.runtime.universe._
*/
private[sql] object UDTRegistry {
/** Map: UserType --> UserDefinedType */
val udtRegistry = new mutable.HashMap[Any, UserDefinedTypeType[_]]()
val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]()

/**
* Register a user-defined type and its serializer, to allow automatic conversion between
Expand All @@ -42,7 +42,7 @@ private[sql] object UDTRegistry {
if (!UDTRegistry.udtRegistry.contains(userType)) {
val udt =
getClass.getClassLoader.loadClass(userType.typeSymbol.asClass.fullName)
.getAnnotation(classOf[UserDefinedType]).udt().newInstance()
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
UDTRegistry.udtRegistry(userType) = udt
}
// TODO: Else: Should we check (assert) that udt is the same as what is in the registry?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.annotation;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.sql.catalyst.types.UserDefinedTypeType;
import org.apache.spark.sql.catalyst.types.UserDefinedType;

import java.lang.annotation.*;

Expand All @@ -29,6 +29,6 @@
@DeveloperApi
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface UserDefinedType {
Class<? extends UserDefinedTypeType<?> > udt();
public @interface SQLUserDefinedType {
Class<? extends UserDefinedType<?> > udt();
}
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ case class MapType(
* The data type for User Defined Types.
*/
@DeveloperApi
abstract class UserDefinedTypeType[UserType] extends DataType with Serializable {
abstract class UserDefinedType[UserType] extends DataType with Serializable {

/** Underlying storage type for this UDT used by SparkSQL */
def sqlType: DataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.annotation.UserDefinedType
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.UserDefinedTypeType
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.test.TestSQLContext._

@UserDefinedType(udt = classOf[DenseVectorUDT])
@SQLUserDefinedType(udt = classOf[DenseVectorUDT])
class DenseVector(val data: Array[Double]) extends Serializable {
override def equals(other: Any): Boolean = other match {
case v: DenseVector =>
Expand All @@ -35,7 +34,7 @@ class DenseVector(val data: Array[Double]) extends Serializable {

case class LabeledPoint(label: Double, features: DenseVector)

class DenseVectorUDT extends UserDefinedTypeType[DenseVector] {
class DenseVectorUDT extends UserDefinedType[DenseVector] {

override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false)

Expand Down Expand Up @@ -83,14 +82,4 @@ class UserDefinedTypeSuite extends QueryTest {
assert(featuresArrays.contains(new DenseVector(Array(0.2, 2.0))))
}

/*
test("UDTs can be registered twice, overriding previous registration") {
// TODO
}
test("UDTs cannot override built-in types") {
// TODO
}
*/

}

0 comments on commit 2f40c02

Please sign in to comment.