From f45dcfb551a66bbacb0c2b694ea2b7ce5a1cbdf2 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 16 Jan 2025 15:36:50 +0000 Subject: [PATCH] test: Make `skip_requires_pyarrow` compatible w/ `pytest.param` (#3772) --- tests/__init__.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 617cfca80..17a33e91e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,14 +5,14 @@ import sys from importlib.util import find_spec from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, overload import pytest from tests import examples_arguments_syntax, examples_methods_syntax if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterator, Mapping + from collections.abc import Collection, Iterator, Mapping from re import Pattern if sys.version_info >= (3, 11): @@ -20,6 +20,7 @@ else: from typing_extensions import TypeAlias from _pytest.mark import ParameterSet + from _pytest.mark.structures import Markable MarksType: TypeAlias = ( "pytest.MarkDecorator | Collection[pytest.MarkDecorator | pytest.Mark]" @@ -96,9 +97,21 @@ def windows_has_tzdata() -> bool: """ +@overload def skip_requires_pyarrow( - fn: Callable[..., Any] | None = None, /, *, requires_tzdata: bool = False -) -> Callable[..., Any]: + fn: None = ..., /, *, requires_tzdata: bool = ... +) -> pytest.MarkDecorator: ... + + +@overload +def skip_requires_pyarrow( + fn: Markable, /, *, requires_tzdata: bool = ... +) -> Markable: ... + + +def skip_requires_pyarrow( + fn: Markable | None = None, /, *, requires_tzdata: bool = False +) -> pytest.MarkDecorator | Markable: """ ``pytest.mark.skipif`` decorator. @@ -109,7 +122,7 @@ def skip_requires_pyarrow( https://github.com/vega/altair/issues/3050 .. _pyarrow: - https://pypi.org/project/pyarrow/ + https://pypi.org/project/pyarrow/ """ composed = pytest.mark.skipif( find_spec("pyarrow") is None, reason="`pyarrow` not installed." @@ -120,13 +133,7 @@ def skip_requires_pyarrow( reason="Timezone database is not installed on Windows", )(composed) - def wrap(test_fn: Callable[..., Any], /) -> Callable[..., Any]: - return composed(test_fn) - - if fn is None: - return wrap - else: - return wrap(fn) + return composed if fn is None else composed(fn) def id_func_str_only(val) -> str: