Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50050][PYTHON][CONNECT] Make lit accept str and bool type numpy ndarray #48591

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def __init__(self, value: Any, dataType: DataType) -> None:
elif isinstance(dataType, DecimalType):
assert isinstance(value, decimal.Decimal)
elif isinstance(dataType, StringType):
assert isinstance(value, str)
assert isinstance(value, (str, np.str_))
value = str(value)
elif isinstance(dataType, DateType):
assert isinstance(value, (datetime.date, datetime.datetime))
if isinstance(value, datetime.date):
Expand Down Expand Up @@ -319,7 +320,7 @@ def _infer_type(cls, value: Any) -> DataType:
)
elif isinstance(value, float):
return DoubleType()
elif isinstance(value, str):
elif isinstance(value, (str, np.str_)):
return StringType()
elif isinstance(value, decimal.Decimal):
return DecimalType()
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ def lit(col: Any) -> Column:
if _has_numpy and isinstance(col, np.generic):
dt = _from_numpy_type(col.dtype)
if dt is not None:
return _invoke_function("lit", _enum_to_value(col)).astype(dt).alias(str(col))
if isinstance(col, np.number):
return _invoke_function("lit", col).astype(dt).alias(str(col))
else:
return _invoke_function("lit", col)
return _invoke_function("lit", _enum_to_value(col))


Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/sql/tests/connect/test_parity_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def test_function_parity(self):
def test_input_file_name_reset_for_rdd(self):
super().test_input_file_name_reset_for_rdd()

@unittest.skip("SPARK-50050: Spark Connect should support str ndarray.")
def test_str_ndarray(self):
super().test_str_ndarray()


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_functions import * # noqa: F401
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,20 @@ def test_ndarray_input(self):
},
)

@unittest.skipIf(not have_numpy, "NumPy not installed")
def test_bool_ndarray(self):
import numpy as np

for arr in [
np.array([], np.bool_),
np.array([True, False], np.bool_),
np.array([1, 0, 3], np.bool_),
]:
self.assertEqual(
[("a", "array<boolean>")],
self.spark.range(1).select(F.lit(arr).alias("a")).dtypes,
)

@unittest.skipIf(not have_numpy, "NumPy not installed")
def test_str_ndarray(self):
import numpy as np
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,7 +2167,9 @@ def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]:
"""Convert NumPy type to Spark data type."""
import numpy as np

if nt == np.dtype("int8"):
if nt == np.dtype("bool"):
return BooleanType()
elif nt == np.dtype("int8"):
return ByteType()
elif nt == np.dtype("int16"):
return ShortType()
Expand All @@ -2179,6 +2181,8 @@ def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]:
return FloatType()
elif nt == np.dtype("float64"):
return DoubleType()
elif nt.type == np.dtype("str"):
return StringType()

return None

Expand Down