Skip to content

Commit

Permalink
[SPARK-21365][PYTHON] Deduplicate logics parsing DDL type/schema defi…
Browse files Browse the repository at this point in the history
…nition

## What changes were proposed in this pull request?

This PR deals with four points as below:

- Reuse existing DDL parser APIs rather than reimplementing within PySpark

- Support DDL formatted string, `field type, field type`.

- Support case-insensitivity for parsing.

- Support nested data types as below:

  **Before**
  ```
  >>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show()
  ...
  ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int>
  ```

  ```
  >>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show()
  ...
  ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int>
  ```

  ```
  >>> spark.createDataFrame([[1]], "a int").show()
  ...
  ValueError: Could not parse datatype: a int
  ```

  **After**
  ```
  >>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show()
  +---+
  |  a|
  +---+
  |[1]|
  +---+
  ```

  ```
  >>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show()
  +---+
  |  a|
  +---+
  |[1]|
  +---+
  ```

  ```
  >>> spark.createDataFrame([[1]], "a int").show()
  +---+
  |  a|
  +---+
  |  1|
  +---+
  ```

## How was this patch tested?

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #18590 from HyukjinKwon/deduplicate-python-ddl.
  • Loading branch information
HyukjinKwon authored and cloud-fan committed Jul 11, 2017
1 parent 66d2168 commit ebc124d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 57 deletions.
16 changes: 13 additions & 3 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,15 +2037,25 @@ def __init__(self, func, returnType, name=None):
"{0}".format(type(func)))

self.func = func
self.returnType = (
returnType if isinstance(returnType, DataType)
else _parse_datatype_string(returnType))
self._returnType = returnType
# Stores UserDefinedPythonFunctions jobj, once initialized
self._returnType_placeholder = None
self._judf_placeholder = None
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)

@property
def returnType(self):
# This makes sure this is called after SparkContext is initialized.
# ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
if self._returnType_placeholder is None:
if isinstance(self._returnType, DataType):
self._returnType_placeholder = self._returnType
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)
return self._returnType_placeholder

@property
def _judf(self):
# It is possible that concurrent access, to newly created UDF,
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,31 @@ def test_struct_type(self):
with self.assertRaises(TypeError):
not_a_field = struct1[9.9]

def test_parse_datatype_string(self):
from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
for k, t in _all_atomic_types.items():
if t != NullType:
self.assertEqual(t(), _parse_datatype_string(k))
self.assertEqual(IntegerType(), _parse_datatype_string("int"))
self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
self.assertEqual(
ArrayType(IntegerType()),
_parse_datatype_string("array<int >"))
self.assertEqual(
MapType(IntegerType(), DoubleType()),
_parse_datatype_string("map< int, double >"))
self.assertEqual(
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
_parse_datatype_string("struct<a:int, c:double >"))
self.assertEqual(
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
_parse_datatype_string("a:int, c:double"))
self.assertEqual(
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
_parse_datatype_string("a INT, c DOUBLE"))

def test_metadata_null(self):
from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
Expand Down
88 changes: 34 additions & 54 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass

from pyspark import SparkContext
from pyspark.serializers import CloudPickleSerializer

__all__ = [
Expand Down Expand Up @@ -727,18 +728,6 @@ def __eq__(self, other):
_BRACKETS = {'(': ')', '[': ']', '{': '}'}


def _parse_basic_datatype_string(s):
if s in _all_atomic_types.keys():
return _all_atomic_types[s]()
elif s == "int":
return IntegerType()
elif _FIXED_DECIMAL.match(s):
m = _FIXED_DECIMAL.match(s)
return DecimalType(int(m.group(1)), int(m.group(2)))
else:
raise ValueError("Could not parse datatype: %s" % s)


def _ignore_brackets_split(s, separator):
"""
Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
Expand Down Expand Up @@ -771,32 +760,23 @@ def _ignore_brackets_split(s, separator):
return parts


def _parse_struct_fields_string(s):
parts = _ignore_brackets_split(s, ",")
fields = []
for part in parts:
name_and_type = _ignore_brackets_split(part, ":")
if len(name_and_type) != 2:
raise ValueError("The strcut field string format is: 'field_name:field_type', " +
"but got: %s" % part)
field_name = name_and_type[0].strip()
field_type = _parse_datatype_string(name_and_type[1])
fields.append(StructField(field_name, field_type))
return StructType(fields)


def _parse_datatype_string(s):
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals
to :class:`DataType.simpleString`, except that top level struct type can omit
the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
for :class:`IntegerType`.
for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
string and case-insensitive strings.
>>> _parse_datatype_string("int ")
IntegerType
>>> _parse_datatype_string("INT ")
IntegerType
>>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
>>> _parse_datatype_string("a DOUBLE, b STRING")
StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
>>> _parse_datatype_string("a: array< short>")
StructType(List(StructField(a,ArrayType(ShortType,true),true)))
>>> _parse_datatype_string(" map<string , string > ")
Expand All @@ -806,43 +786,43 @@ def _parse_datatype_string(s):
>>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
ParseException:...
>>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
ParseException:...
>>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
ParseException:...
>>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
ParseException:...
"""
s = s.strip()
if s.startswith("array<"):
if s[-1] != ">":
raise ValueError("'>' should be the last char, but got: %s" % s)
return ArrayType(_parse_datatype_string(s[6:-1]))
elif s.startswith("map<"):
if s[-1] != ">":
raise ValueError("'>' should be the last char, but got: %s" % s)
parts = _ignore_brackets_split(s[4:-1], ",")
if len(parts) != 2:
raise ValueError("The map type string format is: 'map<key_type,value_type>', " +
"but got: %s" % s)
kt = _parse_datatype_string(parts[0])
vt = _parse_datatype_string(parts[1])
return MapType(kt, vt)
elif s.startswith("struct<"):
if s[-1] != ">":
raise ValueError("'>' should be the last char, but got: %s" % s)
return _parse_struct_fields_string(s[7:-1])
elif ":" in s:
return _parse_struct_fields_string(s)
else:
return _parse_basic_datatype_string(s)
sc = SparkContext._active_spark_context

def from_ddl_schema(type_str):
return _parse_datatype_json_string(
sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())

def from_ddl_datatype(type_str):
return _parse_datatype_json_string(
sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())

try:
# DDL format, "fieldname datatype, fieldname datatype".
return from_ddl_schema(s)
except Exception as e:
try:
# For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
return from_ddl_datatype(s)
except:
try:
# For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
return from_ddl_datatype("struct<%s>" % s.strip())
except:
raise e


def _parse_datatype_json_string(json_string):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.api.python

import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types.DataType

private[sql] object PythonSQLUtils {
def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText)
}

0 comments on commit ebc124d

Please sign in to comment.