diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 203b6ce371a5c..4915078af0225 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -266,7 +266,8 @@ def __init__(self, value: Any, dataType: DataType) -> None: elif isinstance(dataType, DecimalType): assert isinstance(value, decimal.Decimal) elif isinstance(dataType, StringType): - assert isinstance(value, str) + assert isinstance(value, (str, np.str_)) + value = str(value) elif isinstance(dataType, DateType): assert isinstance(value, (datetime.date, datetime.datetime)) if isinstance(value, datetime.date): @@ -319,7 +320,7 @@ def _infer_type(cls, value: Any) -> DataType: ) elif isinstance(value, float): return DoubleType() - elif isinstance(value, str): + elif isinstance(value, (str, np.str_)): return StringType() elif isinstance(value, decimal.Decimal): return DecimalType() diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 872411b5bb995..caa83bd2e1a57 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -225,7 +225,10 @@ def lit(col: Any) -> Column: if _has_numpy and isinstance(col, np.generic): dt = _from_numpy_type(col.dtype) if dt is not None: - return _invoke_function("lit", _enum_to_value(col)).astype(dt).alias(str(col)) + if isinstance(col, np.number): + return _invoke_function("lit", col).astype(dt).alias(str(col)) + else: + return _invoke_function("lit", col) return _invoke_function("lit", _enum_to_value(col)) diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py index fbeff787a3d18..0a77c5531082a 100644 --- a/python/pyspark/sql/tests/connect/test_parity_functions.py +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -34,10 +34,6 @@ def test_function_parity(self): def test_input_file_name_reset_for_rdd(self): super().test_input_file_name_reset_for_rdd() - @unittest.skip("SPARK-50050: Spark Connect should support str ndarray.") - def test_str_ndarray(self): - super().test_str_ndarray() - if __name__ == "__main__": from pyspark.sql.tests.connect.test_parity_functions import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f6f54aee6283d..c83300a4d4fbe 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1277,6 +1277,20 @@ def test_ndarray_input(self): }, ) + @unittest.skipIf(not have_numpy, "NumPy not installed") + def test_bool_ndarray(self): + import numpy as np + + for arr in [ + np.array([], np.bool_), + np.array([True, False], np.bool_), + np.array([1, 0, 3], np.bool_), + ]: + self.assertEqual( + [("a", "array")], + self.spark.range(1).select(F.lit(arr).alias("a")).dtypes, + ) + @unittest.skipIf(not have_numpy, "NumPy not installed") def test_str_ndarray(self): import numpy as np diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1773e0e9604ab..1f3558c37d09d 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -2167,7 +2167,9 @@ def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]: """Convert NumPy type to Spark data type.""" import numpy as np - if nt == np.dtype("int8"): + if nt == np.dtype("bool"): + return BooleanType() + elif nt == np.dtype("int8"): return ByteType() elif nt == np.dtype("int16"): return ShortType() @@ -2179,6 +2181,8 @@ def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]: return FloatType() elif nt == np.dtype("float64"): return DoubleType() + elif nt.type == np.dtype("str"): + return StringType() return None