Skip to content

Commit

Permalink
feat: improve error message when casting to invalid type (#1429)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Nov 23, 2024
1 parent 35c34f4 commit 8973d50
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 57 deletions.
12 changes: 0 additions & 12 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any
from typing import Sequence

from narwhals.dependencies import get_polars
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -77,17 +76,6 @@ def native_to_narwhals_dtype(dtype: pa.DataType, dtypes: DTypes) -> DType:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

import pyarrow as pa # ignore-banned-import

if isinstance_or_issubclass(dtype, dtypes.Float64):
Expand Down
12 changes: 0 additions & 12 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any

from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_polars
from narwhals.dependencies import get_pyarrow
from narwhals.exceptions import InvalidIntoExprError
from narwhals.utils import isinstance_or_issubclass
Expand Down Expand Up @@ -86,17 +85,6 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
Expand Down
12 changes: 0 additions & 12 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from narwhals._arrow.utils import (
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)
from narwhals.dependencies import get_polars
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import isinstance_or_issubclass
Expand Down Expand Up @@ -384,17 +383,6 @@ def narwhals_to_native_dtype( # noqa: PLR0915
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

dtype_backend = get_dtype_backend(starting_dtype, implementation)
if isinstance_or_issubclass(dtype, dtypes.Float64):
if dtype_backend == "pyarrow-nullable":
Expand Down
9 changes: 0 additions & 9 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,6 @@ def native_to_narwhals_dtype(dtype: pl.DataType, dtypes: DTypes) -> DType:
def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> pl.DataType:
import polars as pl # ignore-banned-import()

if isinstance(dtype, (pl.DataType, pl.DataType.__class__)): # type: ignore[arg-type]
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

if dtype == dtypes.Float64:
return pl.Float64()
if dtype == dtypes.Float32:
Expand Down
11 changes: 11 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import TYPE_CHECKING
from typing import Mapping

from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
from typing import Iterator
from typing import Literal
Expand All @@ -13,6 +15,15 @@
from typing_extensions import Self


def _validate_dtype(dtype: DType | type[DType]) -> None:
if not isinstance_or_issubclass(dtype, DType):
msg = (
f"Expected Narwhals dtype, got: {type(dtype)}.\n\n"
"Hint: if you were trying to cast to a type, use e.g. nw.Int64 instead of 'int64'."
)
raise TypeError(msg)


class DType:
def __repr__(self) -> str: # pragma: no cover
return self.__class__.__qualname__
Expand Down
2 changes: 2 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TypeVar

from narwhals.dependencies import is_numpy_array
from narwhals.dtypes import _validate_dtype
from narwhals.utils import _validate_rolling_arguments
from narwhals.utils import flatten

Expand Down Expand Up @@ -202,6 +203,7 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
foo: [[1,2,3]]
bar: [[6,7,8]]
"""
_validate_dtype(dtype)
return self.__class__(
lambda plx: self._call(plx).cast(dtype),
)
Expand Down
2 changes: 2 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TypeVar
from typing import overload

from narwhals.dtypes import _validate_dtype
from narwhals.utils import _validate_rolling_arguments
from narwhals.utils import parse_version

Expand Down Expand Up @@ -516,6 +517,7 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
1
]
"""
_validate_dtype(dtype)
return self._from_compliant_series(self._compliant_series.cast(dtype))

def to_frame(self) -> DataFrame[Any]:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def isinstance_or_issubclass(obj: Any, cls: Any) -> bool:

if isinstance(obj, DType):
return isinstance(obj, cls)
return isinstance(obj, cls) or issubclass(obj, cls)
return isinstance(obj, cls) or (isinstance(obj, type) and issubclass(obj, cls))


def validate_laziness(items: Iterable[Any]) -> None:
Expand Down
15 changes: 5 additions & 10 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,24 +175,19 @@ def test_cast_string() -> None:


def test_cast_raises_for_unknown_dtype(
constructor: Constructor,
request: pytest.FixtureRequest,
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= (
15,
): # pragma: no cover
if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,):
# Unsupported cast from string to dictionary using function cast_dictionary
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data)).select(
nw.col(key).cast(value) for key, value in schema.items()
)

class Banana:
pass

with pytest.raises(AssertionError, match=r"Unknown dtype"):
with pytest.raises(TypeError, match="Expected Narwhals dtype"):
df.select(nw.col("a").cast(Banana)) # type: ignore[arg-type]


Expand Down Expand Up @@ -229,5 +224,5 @@ def test_cast_datetime_tz_aware(
@pytest.mark.parametrize("dtype", [pl.String, pl.String()])
def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]}))
with pytest.raises(TypeError, match="Expected Narwhals object, got:"):
with pytest.raises(TypeError, match="Expected Narwhals dtype, got:"):
df.select(nw.col("a").cast(dtype))
2 changes: 1 addition & 1 deletion tests/frame/invalid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_invalid() -> None:
df.select(nw.all() + nw.all())
with pytest.raises(TypeError, match="Perhaps you:"):
df.select([pl.col("a")]) # type: ignore[list-item]
with pytest.raises(TypeError, match="Perhaps you:"):
with pytest.raises(TypeError, match="Expected Narwhals dtype"):
df.select([nw.col("a").cast(pl.Int64)]) # type: ignore[arg-type]


Expand Down

0 comments on commit 8973d50

Please sign in to comment.