Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType() to handle str type properly in Python 2. #20507

Closed
wants to merge 3 commits into from

Conversation

ueshin
Copy link
Member

@ueshin ueshin commented Feb 5, 2018

What changes were proposed in this pull request?

In Python 2, when pandas_udf tries to return string type value created in the udf with "..", the execution fails. E.g.,

from pyspark.sql.functions import pandas_udf, col
import pandas as pd

df = spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string")
df.select(str_f(col('id'))).show()

raises the following exception:

...

java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType
	at scala.Predef$.assert(Predef.scala:170)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:93)

...

Seems like pyarrow ignores type parameter for pa.Array.from_pandas() and consider it as binary type when the type is string type and the string values are str instead of unicode in Python 2.

This pr adds a workaround for the case.

How was this patch tested?

Added a test and existing tests.

@ueshin
Copy link
Member Author

ueshin commented Feb 5, 2018

cc @BryanCutler @icexelloss @HyukjinKwon
Could you help me double-check this?
Since seems like this happens only in Python 2 environment, Jenkins will skip the tests.
And let me know if you know better workaround.

@SparkQA
Copy link

SparkQA commented Feb 5, 2018

Test build #87063 has finished for PR 20507 at commit 47b8873.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@HyukjinKwon HyukjinKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I don't have a better idea. Just two nits I found while double checking.

import pandas as pd
df = self.spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), StringType())
res = df.select(str_f(col('id')))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about variable names 'expected' and 'actual'?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll update it.

from pyspark.sql.functions import pandas_udf, col
import pandas as pd
df = self.spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), StringType())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal. How about pd.Series(map(str, x))?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'll take it.

@SparkQA
Copy link

SparkQA commented Feb 5, 2018

Test build #87069 has finished for PR 20507 at commit 06ae568.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -230,6 +230,9 @@ def create_array(s, t):
s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
elif t is not None and pa.types.is_string(t) and sys.version < '3':
# TODO: need decode before converting to Arrow in Python 2
return pa.Array.from_pandas(s.str.decode('utf-8'), mask=mask, type=t)
Copy link
Member

@HyukjinKwon HyukjinKwon Feb 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin, actually, how about s.apply(lambda v: v.decode("utf-8") if isinstance(v, str) else v) to allow non-ascii encodable unicodes too like u"아"? I was worried of performance but I ran a simple perf test vs s.str.decode('utf-8') for sure. Seems actually fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I'll take it. Thanks!

@ueshin
Copy link
Member Author

ueshin commented Feb 6, 2018

Seems like pyarrow ignores type parameter for pa.Array.from_pandas() and consider it as binary type when the type is string type and the string values are str instead of unicode in Python 2.

@BryanCutler Btw, do you think this is a bug of pyarrow in Python 2?

@SparkQA
Copy link

SparkQA commented Feb 6, 2018

Test build #87083 has finished for PR 20507 at commit b3d5209.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@ueshin
Copy link
Member Author

ueshin commented Feb 6, 2018

also cc @cloud-fan @gatorsmile @sameeragarwal

@BryanCutler
Copy link
Member

Sorry I've been travelling, but I'll try to look into this soon on the Arrow side to see if it is a bug in pyarrow. The workaround here seems fine to me.

@HyukjinKwon
Copy link
Member

Merged to master and branch-2.3.

asfgit pushed a commit that referenced this pull request Feb 6, 2018
…() to handle str type properly in Python 2.

## What changes were proposed in this pull request?

In Python 2, when `pandas_udf` tries to return string type value created in the udf with `".."`, the execution fails. E.g.,

```python
from pyspark.sql.functions import pandas_udf, col
import pandas as pd

df = spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string")
df.select(str_f(col('id'))).show()
```

raises the following exception:

```
...

java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType
	at scala.Predef$.assert(Predef.scala:170)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:93)

...
```

Seems like pyarrow ignores `type` parameter for `pa.Array.from_pandas()` and consider it as binary type when the type is string type and the string values are `str` instead of `unicode` in Python 2.

This pr adds a workaround for the case.

## How was this patch tested?

Added a test and existing tests.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes #20507 from ueshin/issues/SPARK-23334.

(cherry picked from commit 63c5bf1)
Signed-off-by: hyukjinkwon <gurwls223@gmail.com>
@asfgit asfgit closed this in 63c5bf1 Feb 6, 2018
@ueshin
Copy link
Member Author

ueshin commented Feb 6, 2018

Thanks! @HyukjinKwon @BryanCutler

@BryanCutler
Copy link
Member

I made https://issues.apache.org/jira/browse/ARROW-2101 to track the issue in Arrow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants