Skip to content

Commit

Permalink
[SPARK-50050][PYTHON][CONNECT] Make lit accept str and bool typ…
Browse files Browse the repository at this point in the history
…e numpy ndarray

### What changes were proposed in this pull request?
Make `lit` accept `str` and `bool` type numpy ndarray

### Why are the changes needed?
to be consistent with PySpark Classic
```python
In [4]: spark.range(1).select(sf.lit(np.array(["a", "b"], np.str_))).show()
+---------------+
|ARRAY('a', 'b')|
+---------------+
|         [a, b]|
+---------------+
```

### Does this PR introduce _any_ user-facing change?
yes

before:
```python
In [3]: spark.range(1).select(sf.lit(np.array(["a", "b"], np.str_))).show()
---------------------------------------------------------------------------
PySparkTypeError                          Traceback (most recent call last)
Cell In[3], line 1
----> 1 spark.range(1).select(sf.lit(np.array(["a", "b"], np.str_))).schema

File ~/Dev/spark/python/pyspark/sql/utils.py:272, in try_remote_functions.<locals>.wrapped(*args, **kwargs)
    269 if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
    270     from pyspark.sql.connect import functions
--> 272     return getattr(functions, f.__name__)(*args, **kwargs)
    273 else:
    274     return f(*args, **kwargs)

File ~/Dev/spark/python/pyspark/sql/connect/functions/builtin.py:274, in lit(col)
    272 dt = _from_numpy_type(col.dtype)
    273 if dt is None:
--> 274     raise PySparkTypeError(
    275         errorClass="UNSUPPORTED_NUMPY_ARRAY_SCALAR",
    276         messageParameters={"dtype": col.dtype.name},
    277     )
    279 # NumpyArrayConverter for Py4J can not support ndarray with int8 values.
    280 # Actually this is not a problem for Connect, but here still convert it
    281 # to int16 for compatibility.
    282 if dt == ByteType():

PySparkTypeError: [UNSUPPORTED_NUMPY_ARRAY_SCALAR] The type of array scalar 'str32' is not supported.
```

after:
```python
In [4]: spark.range(1).select(sf.lit(np.array(["a", "b"], np.str_))).show()
+-----------+
|array(a, b)|
+-----------+
|     [a, b]|
+-----------+
```

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48591 from zhengruifeng/connect_lit_bool_str.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Oct 23, 2024
1 parent 0083815 commit 2bf41a6
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 8 deletions.
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def __init__(self, value: Any, dataType: DataType) -> None:
elif isinstance(dataType, DecimalType):
assert isinstance(value, decimal.Decimal)
elif isinstance(dataType, StringType):
assert isinstance(value, str)
assert isinstance(value, (str, np.str_))
value = str(value)
elif isinstance(dataType, DateType):
assert isinstance(value, (datetime.date, datetime.datetime))
if isinstance(value, datetime.date):
Expand Down Expand Up @@ -319,7 +320,7 @@ def _infer_type(cls, value: Any) -> DataType:
)
elif isinstance(value, float):
return DoubleType()
elif isinstance(value, str):
elif isinstance(value, (str, np.str_)):
return StringType()
elif isinstance(value, decimal.Decimal):
return DecimalType()
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ def lit(col: Any) -> Column:
if _has_numpy and isinstance(col, np.generic):
dt = _from_numpy_type(col.dtype)
if dt is not None:
return _invoke_function("lit", _enum_to_value(col)).astype(dt).alias(str(col))
if isinstance(col, np.number):
return _invoke_function("lit", col).astype(dt).alias(str(col))
else:
return _invoke_function("lit", col)
return _invoke_function("lit", _enum_to_value(col))


Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/sql/tests/connect/test_parity_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def test_function_parity(self):
def test_input_file_name_reset_for_rdd(self):
super().test_input_file_name_reset_for_rdd()

@unittest.skip("SPARK-50050: Spark Connect should support str ndarray.")
def test_str_ndarray(self):
super().test_str_ndarray()


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_functions import * # noqa: F401
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,20 @@ def test_ndarray_input(self):
},
)

@unittest.skipIf(not have_numpy, "NumPy not installed")
def test_bool_ndarray(self):
import numpy as np

for arr in [
np.array([], np.bool_),
np.array([True, False], np.bool_),
np.array([1, 0, 3], np.bool_),
]:
self.assertEqual(
[("a", "array<boolean>")],
self.spark.range(1).select(F.lit(arr).alias("a")).dtypes,
)

@unittest.skipIf(not have_numpy, "NumPy not installed")
def test_str_ndarray(self):
import numpy as np
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,7 +2167,9 @@ def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]:
"""Convert NumPy type to Spark data type."""
import numpy as np

if nt == np.dtype("int8"):
if nt == np.dtype("bool"):
return BooleanType()
elif nt == np.dtype("int8"):
return ByteType()
elif nt == np.dtype("int16"):
return ShortType()
Expand All @@ -2179,6 +2181,8 @@ def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]:
return FloatType()
elif nt == np.dtype("float64"):
return DoubleType()
elif nt.type == np.dtype("str"):
return StringType()

return None

Expand Down

0 comments on commit 2bf41a6

Please sign in to comment.