Skip to content

Commit

Permalink
Merge branch 'main' into b338872698-bigframes-v1
Browse files Browse the repository at this point in the history
  • Loading branch information
rey-esp authored Jan 7, 2025
2 parents 80c18d8 + 5934f8e commit 6c30806
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 3 deletions.
27 changes: 24 additions & 3 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,35 @@ def __iter__(self):

def astype(
self,
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
dtype: Union[
bigframes.dtypes.DtypeString,
bigframes.dtypes.Dtype,
dict[str, Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype]],
],
*,
errors: Literal["raise", "null"] = "raise",
) -> DataFrame:
if errors not in ["raise", "null"]:
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
return self._apply_unary_op(
ops.AsTypeOp(to_type=dtype, safe=(errors == "null"))

safe_cast = errors == "null"

# Type strings check
if dtype in bigframes.dtypes.DTYPE_STRINGS:
return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))

# Type instances check
if type(dtype) in bigframes.dtypes.DTYPES:
return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))

if isinstance(dtype, dict):
result = self.copy()
for col, to_type in dtype.items():
result[col] = result[col].astype(to_type)
return result

raise TypeError(
f"Invalid type {type(dtype)} for dtype input. {constants.FEEDBACK_LINK}"
)

def _to_sql_query(
Expand Down
4 changes: 4 additions & 0 deletions bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
pd.ArrowDtype,
gpd.array.GeometryDtype,
]

DTYPES = typing.get_args(Dtype)
# Represents both column types (dtypes) and local-only types
# None represents the type of a None scalar.
ExpressionType = typing.Optional[Dtype]
Expand Down Expand Up @@ -238,6 +240,8 @@ class SimpleDtypeInfo:
"binary[pyarrow]",
]

DTYPE_STRINGS = typing.get_args(DtypeString)

BOOL_BIGFRAMES_TYPES = [BOOL_DTYPE]

# Corresponds to the pandas concept of numeric type (such as when 'numeric_only' is specified in an operation)
Expand Down
30 changes: 30 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5199,3 +5199,33 @@ def test__resample_start_time(rule, origin, data):
pd.testing.assert_frame_equal(
bf_result, pd_result, check_dtype=False, check_index_type=False
)


@pytest.mark.parametrize(
"dtype",
[
pytest.param("string[pyarrow]", id="type-string"),
pytest.param(pd.StringDtype(storage="pyarrow"), id="type-literal"),
pytest.param(
{"bool_col": "string[pyarrow]", "int64_col": pd.Float64Dtype()},
id="multiple-types",
),
],
)
def test_astype(scalars_dfs, dtype):
bf_df, pd_df = scalars_dfs
target_cols = ["bool_col", "int64_col"]
bf_df = bf_df[target_cols]
pd_df = pd_df[target_cols]

bf_result = bf_df.astype(dtype).to_pandas()
pd_result = pd_df.astype(dtype)

pd.testing.assert_frame_equal(bf_result, pd_result, check_index_type=False)


def test_astype_invalid_type_fail(scalars_dfs):
bf_df, _ = scalars_dfs

with pytest.raises(TypeError, match=r".*Share your usecase with.*"):
bf_df.astype(123)

0 comments on commit 6c30806

Please sign in to comment.