diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index efbe56abf7..0503a38ae6 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -367,14 +367,35 @@ def __iter__(self): def astype( self, - dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype], + dtype: Union[ + bigframes.dtypes.DtypeString, + bigframes.dtypes.Dtype, + dict[str, Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype]], + ], *, errors: Literal["raise", "null"] = "raise", ) -> DataFrame: if errors not in ["raise", "null"]: raise ValueError("Arg 'error' must be one of 'raise' or 'null'") - return self._apply_unary_op( - ops.AsTypeOp(to_type=dtype, safe=(errors == "null")) + + safe_cast = errors == "null" + + # Type strings check + if dtype in bigframes.dtypes.DTYPE_STRINGS: + return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast)) + + # Type instances check + if type(dtype) in bigframes.dtypes.DTYPES: + return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast)) + + if isinstance(dtype, dict): + result = self.copy() + for col, to_type in dtype.items(): + result[col] = result[col].astype(to_type) + return result + + raise TypeError( + f"Invalid type {type(dtype)} for dtype input. {constants.FEEDBACK_LINK}" ) def _to_sql_query( diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index ff3e7a31fb..6e179225ea 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -36,6 +36,8 @@ pd.ArrowDtype, gpd.array.GeometryDtype, ] + +DTYPES = typing.get_args(Dtype) # Represents both column types (dtypes) and local-only types # None represents the type of a None scalar. ExpressionType = typing.Optional[Dtype] @@ -238,6 +240,8 @@ class SimpleDtypeInfo: "binary[pyarrow]", ] +DTYPE_STRINGS = typing.get_args(DtypeString) + BOOL_BIGFRAMES_TYPES = [BOOL_DTYPE] # Corresponds to the pandas concept of numeric type (such as when 'numeric_only' is specified in an operation) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index bae71b33be..4e0e5c2739 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -5199,3 +5199,33 @@ def test__resample_start_time(rule, origin, data): pd.testing.assert_frame_equal( bf_result, pd_result, check_dtype=False, check_index_type=False ) + + +@pytest.mark.parametrize( + "dtype", + [ + pytest.param("string[pyarrow]", id="type-string"), + pytest.param(pd.StringDtype(storage="pyarrow"), id="type-literal"), + pytest.param( + {"bool_col": "string[pyarrow]", "int64_col": pd.Float64Dtype()}, + id="multiple-types", + ), + ], +) +def test_astype(scalars_dfs, dtype): + bf_df, pd_df = scalars_dfs + target_cols = ["bool_col", "int64_col"] + bf_df = bf_df[target_cols] + pd_df = pd_df[target_cols] + + bf_result = bf_df.astype(dtype).to_pandas() + pd_result = pd_df.astype(dtype) + + pd.testing.assert_frame_equal(bf_result, pd_result, check_index_type=False) + + +def test_astype_invalid_type_fail(scalars_dfs): + bf_df, _ = scalars_dfs + + with pytest.raises(TypeError, match=r".*Share your usecase with.*"): + bf_df.astype(123)