Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21375][PYSPARK][SQL] Add Date and Timestamp support to ArrowConverters for toPandas() Conversion #18664

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5aa8b9e
added date type and started test, still some issue with time difference
BryanCutler Jul 13, 2017
20313f9
DateTimeUtils forces defaultTimeZone
BryanCutler Jul 18, 2017
69e1e21
fix style checks
BryanCutler Jul 18, 2017
dbfbef3
date type java tests passing
BryanCutler Jul 18, 2017
436afff
timestamp type java tests passing
BryanCutler Jul 18, 2017
78119ca
adding date and timestamp data to python tests, not passing
BryanCutler Jul 19, 2017
b709d78
TimestampType is correctly inferred as datetime64[ns]
BryanCutler Jul 19, 2017
399e527
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Jul 24, 2017
e6d8590
Adding DateType and TimestampType to ArrowUtils conversions
BryanCutler Jul 24, 2017
719e77c
using default timezone, fixed tests
BryanCutler Jul 24, 2017
3585520
fixed scala tests for timestamp
BryanCutler Jul 25, 2017
f977d0b
Adding sync between Python and Java default timezones
BryanCutler Jul 26, 2017
b826445
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Jul 27, 2017
3b83d7a
added date timestamp writers, fixed tests
BryanCutler Jul 27, 2017
a6009a5
Modify ArrowUtils to have timeZoneId when convert schema to Arrow sch…
ueshin Jul 28, 2017
2ec98cc
fixed python test tearDownClass
BryanCutler Aug 1, 2017
c29018c
using Date.valueOf for tests instead
BryanCutler Aug 2, 2017
7dbdb1f
Made timezone id required for TimestampType
BryanCutler Aug 14, 2017
c3f4e4d
added test for TimestampType without specifying timezone id
BryanCutler Aug 14, 2017
ddbea24
added date and timestamp to ArrowWriter and tests
BryanCutler Aug 15, 2017
c6b597d
removed unused import
BryanCutler Aug 16, 2017
874f104
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Oct 10, 2017
d8bae0b
added Python timezone converions for working with Pandas
BryanCutler Oct 10, 2017
36f58b1
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Oct 11, 2017
c4fd5ae
fix compilation
BryanCutler Oct 11, 2017
d1617fd
fixed test comp
BryanCutler Oct 11, 2017
d7d9b47
add conversion to Python system local timezone before localize
BryanCutler Oct 11, 2017
efe3e27
timestamps with Arrow almost working for pandas_udfs
BryanCutler Oct 11, 2017
9894519
added workaround for Series to_pandas with timestamps, store os.envir…
BryanCutler Oct 17, 2017
a3ba4ac
change use of xrange for py3
BryanCutler Oct 17, 2017
7266304
remove check for valid timezone in vector for ArrowWriter
BryanCutler Oct 17, 2017
e428cbe
added note for 'us' conversion
BryanCutler Oct 17, 2017
cade921
changed python api for is_datetime64
BryanCutler Oct 19, 2017
f512deb
remove Option for timezoneId
BryanCutler Oct 19, 2017
171d9e1
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Oct 20, 2017
79bb93f
added pandas_udf test for date
BryanCutler Oct 23, 2017
c555207
added workaround for date casting, put back check for timestamp conve…
BryanCutler Oct 24, 2017
4d40893
added fillna for null timestamp values
BryanCutler Oct 25, 2017
addd35f
added check for pandas_udf return is a timestamp with tz, added comme…
BryanCutler Oct 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __repr__(self):


def _create_batch(series):
from pyspark.sql.types import _check_series_convert_timestamps_internal
import pyarrow as pa
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
Expand All @@ -224,12 +225,25 @@ def _create_batch(series):
# If a nullable integer series has been promoted to floating point with NaNs, need to cast
# NOTE: this is not necessary with Arrow >= 0.7
def cast_series(s, t):
if t is None or s.dtype == t.to_pandas_dtype():
if type(t) == pa.TimestampType:
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
return _check_series_convert_timestamps_internal(s.fillna(0))\
.values.astype('datetime64[us]', copy=False)
elif t == pa.date32():
# TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
return s.dt.date
elif t is None or s.dtype == t.to_pandas_dtype():
return s
else:
return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)

arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
# Some object types don't support masks in Arrow, see ARROW-1721
def create_array(s, t):
casted = cast_series(s, t)
mask = None if casted.dtype == 'object' else s.isnull()
return pa.Array.from_pandas(casted, mask=mask, type=t)

arrs = [create_array(s, t) for s, t in series]
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])


Expand Down Expand Up @@ -260,11 +274,13 @@ def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
from pyspark.sql.types import _check_dataframe_localize_timestamps
import pyarrow as pa
reader = pa.open_stream(stream)
for batch in reader:
table = pa.Table.from_batches([batch])
yield [c.to_pandas() for c in table.itercolumns()]
# NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1
pdf = _check_dataframe_localize_timestamps(batch.to_pandas())
yield [c for _, c in pdf.iteritems()]
Copy link
Member Author

@BryanCutler BryanCutler Oct 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After running some tests, this change does not significantly degrade performance, but there seems to be a small difference. cc @ueshin

I ran various columns of random data through a pandas_udf repeatedly with and without this change. Test was in local mode with default Spark conf, looking at min wall clock time of 10 loops

before change: 2.595558
after change: 2.681813

Do you think the difference here is acceptable for now until arrow is upgraded and we can look into again?
pandas_udf_perf.py.txt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran your script in my local, too.

  • before change:
    • mean: 2.605722
    • min: 2.502404
    • max: 3.045294
  • after change:
    • mean: 2.626306
    • min: 2.341781
    • max: 2.742432

I think it's okay to use this workaround.


def __repr__(self):
return "ArrowStreamPandasSerializer"
Expand Down
7 changes: 6 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,11 +1880,13 @@ def toPandas(self):
import pandas as pd
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
try:
from pyspark.sql.types import _check_dataframe_localize_timestamps
import pyarrow
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
return table.to_pandas()
pdf = table.to_pandas()
return _check_dataframe_localize_timestamps(pdf)
else:
return pd.DataFrame.from_records([], columns=self.columns)
except ImportError as e:
Expand Down Expand Up @@ -1952,6 +1954,7 @@ def _to_corrected_pandas_type(dt):
"""
When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong.
This method gets the corrected data type for Pandas if that type may be inferred uncorrectly.
NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns]
"""
import numpy as np
if type(dt) == ByteType:
Expand All @@ -1962,6 +1965,8 @@ def _to_corrected_pandas_type(dt):
return np.int32
elif type(dt) == FloatType:
return np.float32
elif type(dt) == DateType:
return 'datetime64[ns]'
else:
return None

Expand Down
106 changes: 95 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3086,18 +3086,38 @@ class ArrowTests(ReusedPySparkTestCase):

@classmethod
def setUpClass(cls):
from datetime import datetime
ReusedPySparkTestCase.setUpClass()

# Synchronize default timezone between Python and Java
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
tz = "America/Los_Angeles"
os.environ["TZ"] = tz
time.tzset()

cls.spark = SparkSession(cls.sc)
cls.spark.conf.set("spark.sql.session.timeZone", tz)
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
cls.schema = StructType([
StructField("1_str_t", StringType(), True),
StructField("2_int_t", IntegerType(), True),
StructField("3_long_t", LongType(), True),
StructField("4_float_t", FloatType(), True),
StructField("5_double_t", DoubleType(), True)])
cls.data = [("a", 1, 10, 0.2, 2.0),
("b", 2, 20, 0.4, 4.0),
("c", 3, 30, 0.8, 6.0)]
StructField("5_double_t", DoubleType(), True),
StructField("6_date_t", DateType(), True),
StructField("7_timestamp_t", TimestampType(), True)])
cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]

@classmethod
def tearDownClass(cls):
del os.environ["TZ"]
if cls.tz_prev is not None:
os.environ["TZ"] = cls.tz_prev
time.tzset()
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()

def assertFramesEqual(self, df_with_arrow, df_without):
msg = ("DataFrame from Arrow is not equal" +
Expand All @@ -3106,8 +3126,8 @@ def assertFramesEqual(self, df_with_arrow, df_without):
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)

def test_unsupported_datatype(self):
schema = StructType([StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
schema = StructType([StructField("decimal", DecimalType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())

Expand Down Expand Up @@ -3385,13 +3405,77 @@ def test_vectorized_udf_varargs(self):

def test_vectorized_udf_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType([StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
f = pandas_udf(lambda x: x, DateType())
schema = StructType([StructField("dt", DecimalType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
f = pandas_udf(lambda x: x, DecimalType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('dt'))).collect()

def test_vectorized_udf_null_date(self):
from pyspark.sql.functions import pandas_udf, col
from datetime import date
schema = StructType().add("date", DateType())
data = [(date(1969, 1, 1),),
(date(2012, 2, 2),),
(None,),
(date(2100, 4, 4),)]
df = self.spark.createDataFrame(data, schema=schema)
date_f = pandas_udf(lambda t: t, returnType=DateType())
res = df.select(date_f(col("date")))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_timestamps(self):
from pyspark.sql.functions import pandas_udf, col
from datetime import datetime
schema = StructType([
StructField("idx", LongType(), True),
StructField("timestamp", TimestampType(), True)])
data = [(0, datetime(1969, 1, 1, 1, 1, 1)),
(1, datetime(2012, 2, 2, 2, 2, 2)),
(2, None),
(3, datetime(2100, 4, 4, 4, 4, 4))]
df = self.spark.createDataFrame(data, schema=schema)

# Check that a timestamp passed through a pandas_udf will not be altered by timezone calc
f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType())
df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))

@pandas_udf(returnType=BooleanType())
def check_data(idx, timestamp, timestamp_copy):
is_equal = timestamp.isnull() # use this array to check values are equal
for i in range(len(idx)):
# Check that timestamps are as expected in the UDF
is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \
timestamp[i].to_pydatetime() == data[idx[i]][1]
return is_equal

result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"),
col("timestamp_copy"))).collect()
# Check that collection values are correct
self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected

def test_vectorized_udf_return_timestamp_tz(self):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
df = self.spark.range(10)

@pandas_udf(returnType=TimestampType())
def gen_timestamps(id):
ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
return pd.Series(ts)

result = df.withColumn("ts", gen_timestamps(col("id"))).collect()
spark_ts_t = TimestampType()
for r in result:
i, ts = r
ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
self.assertEquals(expected, ts)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedPySparkTestCase):
Expand Down Expand Up @@ -3550,8 +3634,8 @@ def test_wrong_args(self):
def test_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType(
[StructField("id", LongType(), True), StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema)
[StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
df = self.spark.createDataFrame([(1, None,)], schema=schema)
f = pandas_udf(lambda x: x, df.schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
Expand Down
36 changes: 36 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,11 +1619,47 @@ def to_arrow_type(dt):
arrow_type = pa.decimal(dt.precision, dt.scale)
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == DateType:
arrow_type = pa.date32()
elif type(dt) == TimestampType:
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type


def _check_dataframe_localize_timestamps(pdf):
"""
Convert timezone aware timestamps to timezone-naive in local time

:param pdf: pandas.DataFrame
:return pandas.DataFrame where any timezone aware columns have be converted to tz-naive
"""
from pandas.api.types import is_datetime64tz_dtype
for column, series in pdf.iteritems():
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(series.dtype):
pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None)
return pdf


def _check_series_convert_timestamps_internal(s):
"""
Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage
:param s: a pandas.Series
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC')
elif is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert('UTC')
else:
return s


def _test():
import doctest
from pyspark.context import SparkContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ public ArrowColumnVector(ValueVector vector) {
accessor = new StringAccessor((NullableVarCharVector) vector);
} else if (vector instanceof NullableVarBinaryVector) {
accessor = new BinaryAccessor((NullableVarBinaryVector) vector);
} else if (vector instanceof NullableDateDayVector) {
accessor = new DateAccessor((NullableDateDayVector) vector);
} else if (vector instanceof NullableTimeStampMicroTZVector) {
accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector);
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
accessor = new ArrayAccessor(listVector);
Expand Down Expand Up @@ -575,6 +579,36 @@ final byte[] getBinary(int rowId) {
}
}

private static class DateAccessor extends ArrowVectorAccessor {

private final NullableDateDayVector.Accessor accessor;

DateAccessor(NullableDateDayVector vector) {
super(vector);
this.accessor = vector.getAccessor();
}

@Override
final int getInt(int rowId) {
return accessor.get(rowId);
}
}

private static class TimestampAccessor extends ArrowVectorAccessor {

private final NullableTimeStampMicroTZVector.Accessor accessor;

TimestampAccessor(NullableTimeStampMicroTZVector vector) {
super(vector);
this.accessor = vector.getAccessor();
}

@Override
final long getLong(int rowId) {
return accessor.get(rowId);
}
}

private static class ArrayAccessor extends ArrowVectorAccessor {

private final UInt4Vector.Accessor accessor;
Expand Down
4 changes: 3 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3143,9 +3143,11 @@ class Dataset[T] private[sql](
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
queryExecution.toRdd.mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context)
ArrowConverters.toPayloadIterator(
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ private[sql] object ArrowConverters {
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Int,
timeZoneId: String,
context: TaskContext): Iterator[ArrowPayload] = {

val arrowSchema = ArrowUtils.toArrowSchema(schema)
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator =
ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue)

Expand Down
Loading