Skip to content

Commit

Permalink
fixes from comments in PR
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler committed Sep 19, 2017
1 parent 69112a5 commit f451d65
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def dumps(self, series):
series = [series]
series = [(s, None) if not isinstance(s, (list, tuple)) else s for s in series]
arrs = [pa.Array.from_pandas(s[0], type=s[1], mask=s[0].isnull()) for s in series]
batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
return super(ArrowPandasSerializer, self).dumps(batch)

def loads(self, obj):
Expand All @@ -241,9 +241,9 @@ def loads(self, obj):
"""
import pyarrow as pa
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
batches = [reader.get_batch(i) for i in range(reader.num_record_batches)]
batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)]
# NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set
num_rows = sum([batch.num_rows for batch in batches])
num_rows = sum((batch.num_rows for batch in batches))
table = pa.Table.from_batches(batches)
return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}]

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3231,7 +3231,7 @@ def test_vectorized_udf_null_string(self):
def test_vectorized_udf_zero_parameter(self):
from pyspark.sql.functions import pandas_udf
import pandas as pd
df = self.spark.range(100000)
df = self.spark.range(10)
f0 = pandas_udf(lambda **kwargs: pd.Series(1).repeat(kwargs['length']), LongType())
res = df.select(f0())
self.assertEquals(df.select(lit(1)).collect(), res.collect())
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,16 @@ def wrap_udf(f, return_type):


def wrap_pandas_udf(f, return_type):
arrow_return_type = toArrowType(return_type)

def verify_result_length(*a):
kwargs = a[-1]
result = f(*a[:-1], **kwargs)
if len(result) != kwargs["length"]:
raise RuntimeError("Result vector from pandas_udf was not the required length: "
"expected %d, got %d\nUse input vector length or kwarg['length']"
"expected %d, got %d\nUse input vector length or kwargs['length']"
% (kwargs["length"], len(result)))
return result, toArrowType(return_type)
return result, arrow_return_type
return lambda *a: verify_result_length(*a)


Expand Down

0 comments on commit f451d65

Please sign in to comment.