diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 18adec12c6c07..3814108471a9d 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -39,7 +39,6 @@ from array import array from operator import itemgetter from itertools import imap -import importlib from py4j.protocol import Py4JError from py4j.java_collections import ListConverter, MapConverter @@ -416,25 +415,15 @@ class UserDefinedType(DataType): """ @classmethod - def sqlType(self): - """ - Underlying SQL storage type for this UDT. - """ - raise NotImplementedError("UDT must implement sqlType().") - - @classmethod - def serialize(self, obj): - """ - Converts the a user-type object into a SQL datum. - """ - raise NotImplementedError("UDT must implement serialize().") + def typeName(cls): + return cls.__name__.lower() @classmethod - def deserialize(self, datum): + def sqlType(cls): """ - Converts a SQL datum into a user-type object. + Underlying SQL storage type for this UDT. """ - raise NotImplementedError("UDT must implement deserialize().") + raise NotImplementedError("UDT must implement sqlType().") @classmethod def module(cls): @@ -450,25 +439,35 @@ def scalaUDT(cls): """ raise NotImplementedError("UDT must have a paired Scala UDT.") - @classmethod - def json(cls): - return json.dumps(cls.jsonValue(), separators=(',', ':'), sort_keys=True) + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") - @classmethod - def jsonValue(cls): + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): schema = { "type": "udt", - "pyModule": cls.module(), - "pyClass": cls.__name__} - if cls.scalaUDT() is not None: - schema['class'] = cls.scalaUDT() + "pyModule": self.module(), + "pyClass": type(self).__name__, + "class": self.scalaUDT() + } return schema @classmethod def fromJson(cls, json): pyModule = json['pyModule'] pyClass = json['pyClass'] - m = importlib.import_module(pyModule) + m = __import__(pyModule, globals(), locals(), [pyClass], -1) UDT = getattr(m, pyClass) return UDT() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3368881755fc1..f272d8f995e82 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -689,12 +689,6 @@ class ExamplePointUDT(UserDefinedType): def sqlType(self): return ArrayType(DoubleType(), False) - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - @classmethod def module(cls): return 'pyspark.tests' @@ -703,6 +697,12 @@ def module(cls): def scalaUDT(cls): return 'org.apache.spark.sql.test.ExamplePointUDT' + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + class ExamplePoint: """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index bd8d1fc8ce775..83d18daf7a259 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation @@ -483,7 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true - case udt: UserDefinedType[_] => true + case _: UserDefinedType[_] => true case other => false }