Skip to content

Commit

Permalink
fix(python): Use check_exact for temporal types in `assert_series_e…
Browse files Browse the repository at this point in the history
…qual` (#7896)
  • Loading branch information
stinodego authored Mar 31, 2023
1 parent 76dc3a2 commit 2e64987
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
5 changes: 3 additions & 2 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from polars import functions as F
from polars.dataframe import DataFrame
from polars.datatypes import (
Boolean,
Categorical,
DataTypeClass,
Float32,
Expand Down Expand Up @@ -320,7 +319,9 @@ def _assert_series_inner(
except NotImplementedError:
can_be_subtracted = False

check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean
check_exact = (
check_exact or not can_be_subtracted or left.is_boolean() or left.is_temporal()
)
if check_dtype and left.dtype != right.dtype:
raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype)

Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from datetime import datetime, time, timedelta
from typing import Any

import pytest

import polars as pl
Expand Down Expand Up @@ -284,3 +287,17 @@ def test_assert_series_equal_int_overflow() -> None:
assert_series_equal(s0, s0, check_exact=check_exact)
with pytest.raises(AssertionError):
assert_series_equal(s1, s2, check_exact=check_exact)


@pytest.mark.parametrize(
("data1", "data2"),
[
([datetime(2022, 10, 2, 12)], [datetime(2022, 10, 2, 13)]),
([time(10, 0, 0)], [time(10, 0, 10)]),
([timedelta(10, 0, 0)], [timedelta(10, 0, 10)]),
],
)
def test_assert_series_equal_temporal(data1: Any, data2: Any) -> None:
s1 = pl.Series(data1)
s2 = pl.Series(data2)
assert_series_not_equal(s1, s2)

0 comments on commit 2e64987

Please sign in to comment.