From 2e64987b187b2a64f36bd09357865f1c688342f7 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 31 Mar 2023 10:21:59 +0200 Subject: [PATCH] fix(python): Use `check_exact` for temporal types in `assert_series_equal` (#7896) --- py-polars/polars/testing/asserts.py | 5 +++-- py-polars/tests/unit/test_testing.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index 7d9d2b04f7cc..b15f103b59b7 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -6,7 +6,6 @@ from polars import functions as F from polars.dataframe import DataFrame from polars.datatypes import ( - Boolean, Categorical, DataTypeClass, Float32, @@ -320,7 +319,9 @@ def _assert_series_inner( except NotImplementedError: can_be_subtracted = False - check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean + check_exact = ( + check_exact or not can_be_subtracted or left.is_boolean() or left.is_temporal() + ) if check_dtype and left.dtype != right.dtype: raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype) diff --git a/py-polars/tests/unit/test_testing.py b/py-polars/tests/unit/test_testing.py index 6c044738a71b..e671aca73134 100644 --- a/py-polars/tests/unit/test_testing.py +++ b/py-polars/tests/unit/test_testing.py @@ -1,5 +1,8 @@ from __future__ import annotations +from datetime import datetime, time, timedelta +from typing import Any + import pytest import polars as pl @@ -284,3 +287,17 @@ def test_assert_series_equal_int_overflow() -> None: assert_series_equal(s0, s0, check_exact=check_exact) with pytest.raises(AssertionError): assert_series_equal(s1, s2, check_exact=check_exact) + + +@pytest.mark.parametrize( + ("data1", "data2"), + [ + ([datetime(2022, 10, 2, 12)], [datetime(2022, 10, 2, 13)]), + ([time(10, 0, 0)], [time(10, 0, 10)]), + ([timedelta(10, 0, 0)], [timedelta(10, 0, 10)]), + ], +) +def test_assert_series_equal_temporal(data1: Any, data2: Any) -> None: + s1 = pl.Series(data1) + s2 = pl.Series(data2) + assert_series_not_equal(s1, s2)