Skip to content

Commit

Permalink
Trying to get other SQL tests to run
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent 34a5831 commit cd60cb4
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}

import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
Expand Down Expand Up @@ -86,56 +87,65 @@ object ScalaReflection {
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = tpe match {
case t if getClass.getClassLoader.loadClass(t.typeSymbol.asClass.fullName)
.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
UDTRegistry.registerType(t)
Schema(UDTRegistry.udtRegistry(t), nullable = true)
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
def schemaFor(tpe: `Type`): Schema = {
val className: String = tpe.erasure.typeSymbol.asClass.fullName
tpe match {
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
// Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
// whereas className is from Scala reflection. This can make it hard to find classes
// in some cases, such as when a class is enclosed in an object (in which case
// Java appends a '$' to the object name but Scala does not).
UDTRegistry.registerType(t)
Schema(UDTRegistry.udtRegistry(t), nullable = true)
case t if UDTRegistry.udtRegistry.contains(t) =>
Schema(UDTRegistry.udtRegistry(t), nullable = true)
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
}
}

def typeOfObject: PartialFunction[Any, DataType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private[sql] object UDTRegistry {
* RDDs of user types and SchemaRDDs.
* If this type has already been registered, this does nothing.
*/
def registerType[UserType](implicit userType: Type): Unit = {
def registerType(userType: Type): Unit = {
// TODO: Check to see if type is built-in. Throw exception?
if (!UDTRegistry.udtRegistry.contains(userType)) {
val udt =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

/**
* A user-defined type which can be automatically recognized by a SQLContext and registered.
*
* WARNING: This annotation will only work if both Java and Scala reflection return the same class
* names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class
* is enclosed in an object (a singleton). In these cases, the UDT must be registered
* manually.
*/
// TODO: Should I used @Documented ?
@DeveloperApi
Expand Down
40 changes: 22 additions & 18 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -737,28 +737,32 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}

test("throw errors for non-aggregate attributes with aggregation") {
def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
val logicalPlan = sql(query).queryExecution.logical

if (isInvalidQuery) {
val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
assert(
e.getMessage.startsWith("Expression not in GROUP BY"),
"Non-aggregate attribute(s) not detected\n" + logicalPlan)
} else {
// Should not throw
sql(query).queryExecution.analyzed
try {
def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
val logicalPlan = sql(query).queryExecution.logical

if (isInvalidQuery) {
val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
assert(
e.getMessage.startsWith("Expression not in GROUP BY"),
"Non-aggregate attribute(s) not detected\n" + logicalPlan)
} else {
// Should not throw
sql(query).queryExecution.analyzed
}
}
}

checkAggregation("SELECT key, COUNT(*) FROM testData")
checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)
checkAggregation("SELECT key, COUNT(*) FROM testData")
checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)

checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)
checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)

checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
} catch {
case e: Exception => println(e.getStackTraceString)
}
}

test("Test to check we can use Long.MinValue") {
Expand Down

0 comments on commit cd60cb4

Please sign in to comment.