From cd60cb48d36142a152cd02a263212f5c041e6c23 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Oct 2014 11:57:14 -0700 Subject: [PATCH] Trying to get other SQL tests to run --- .../spark/sql/catalyst/ScalaReflection.scala | 110 ++++++++++-------- .../spark/sql/catalyst/UDTRegistry.scala | 2 +- .../annotation/SQLUserDefinedType.java | 5 + .../org/apache/spark/sql/SQLQuerySuite.scala | 40 ++++--- 4 files changed, 88 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8989dafa2b2a5..75eb3b4872475 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -86,56 +87,65 @@ object ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = tpe match { - case t if getClass.getClassLoader.loadClass(t.typeSymbol.asClass.fullName) - .isAnnotationPresent(classOf[SQLUserDefinedType]) => - UDTRegistry.registerType(t) - Schema(UDTRegistry.udtRegistry(t), nullable = true) - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + def schemaFor(tpe: `Type`): Schema = { + val className: String = tpe.erasure.typeSymbol.asClass.fullName + tpe match { + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, + // whereas className is from Scala reflection. This can make it hard to find classes + // in some cases, such as when a class is enclosed in an object (in which case + // Java appends a '$' to the object name but Scala does not). + UDTRegistry.registerType(t) + Schema(UDTRegistry.udtRegistry(t), nullable = true) + case t if UDTRegistry.udtRegistry.contains(t) => + Schema(UDTRegistry.udtRegistry(t), nullable = true) + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType).dataType, nullable = true) + case t if t <:< typeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + } } def typeOfObject: PartialFunction[Any, DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala index b82f0b5f3eb02..a9be187ded96e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/UDTRegistry.scala @@ -35,7 +35,7 @@ private[sql] object UDTRegistry { * RDDs of user types and SchemaRDDs. * If this type has already been registered, this does nothing. */ - def registerType[UserType](implicit userType: Type): Unit = { + def registerType(userType: Type): Unit = { // TODO: Check to see if type is built-in. Throw exception? if (!UDTRegistry.udtRegistry.contains(userType)) { val udt = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java index 1ecb0ac00bb09..fa909a9eb1b3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -24,6 +24,11 @@ /** * A user-defined type which can be automatically recognized by a SQLContext and registered. + * + * WARNING: This annotation will only work if both Java and Scala reflection return the same class + * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class + * is enclosed in an object (a singleton). In these cases, the UDT must be registered + * manually. */ // TODO: Should I used @Documented ? @DeveloperApi diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6befe1b755cc6..73dac52452f23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -737,28 +737,32 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("throw errors for non-aggregate attributes with aggregation") { - def checkAggregation(query: String, isInvalidQuery: Boolean = true) { - val logicalPlan = sql(query).queryExecution.logical - - if (isInvalidQuery) { - val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) - assert( - e.getMessage.startsWith("Expression not in GROUP BY"), - "Non-aggregate attribute(s) not detected\n" + logicalPlan) - } else { - // Should not throw - sql(query).queryExecution.analyzed + try { + def checkAggregation(query: String, isInvalidQuery: Boolean = true) { + val logicalPlan = sql(query).queryExecution.logical + + if (isInvalidQuery) { + val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) + assert( + e.getMessage.startsWith("Expression not in GROUP BY"), + "Non-aggregate attribute(s) not detected\n" + logicalPlan) + } else { + // Should not throw + sql(query).queryExecution.analyzed + } } - } - checkAggregation("SELECT key, COUNT(*) FROM testData") - checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) + checkAggregation("SELECT key, COUNT(*) FROM testData") + checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) - checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") - checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) + checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") + checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) - checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") - checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) + checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") + checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) + } catch { + case e: Exception => println(e.getStackTraceString) + } } test("Test to check we can use Long.MinValue") {