From 6ebcdd961ff2e3aef7e4c8ac247b15d70eff985d Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 25 Nov 2024 15:56:54 +0100 Subject: [PATCH 1/8] Allow scalar arguments to where() --- .../_draft/searching_functions.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 029459b9a..800f1d4c5 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -1,7 +1,7 @@ __all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"] -from ._types import Optional, Tuple, Literal, array +from ._types import Optional, Tuple, Literal, Union, array def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: @@ -139,21 +139,24 @@ def searchsorted( """ -def where(condition: array, x1: array, x2: array, /) -> array: +def where(condition: array, x1: Union[array, int, float, bool], x2: Union[array, int, float, bool], /) -> array: """ Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``. Parameters ---------- condition: array - when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`). - x1: array - first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). - x2: array - second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). + when ``True``, yield ``x1_i`` (scalar ``x1``); otherwise, yield ``x2_i`` (scalar ``x2``). Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`). + x1: Union[array, int, float, bool] + first input array or scalar. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). + x2: Union[array, int, float, bool] + second input array or scalar. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). Returns ------- out: array an array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The returned array must have a data type determined by :ref:`type-promotion` rules with the arrays ``x1`` and ``x2``. + + .. versionchanged:: 2024.12 + ``x1`` and ``x2`` may be scalars. """ From 252331a8e05434867a4656a14ceac759364bd362 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 25 Nov 2024 16:10:56 +0100 Subject: [PATCH 2/8] Appease formatting gods --- src/array_api_stubs/_draft/searching_functions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 800f1d4c5..9d9096f48 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -139,7 +139,12 @@ def searchsorted( """ -def where(condition: array, x1: Union[array, int, float, bool], x2: Union[array, int, float, bool], /) -> array: +def where( + condition: array, + x1: Union[array, int, float, bool], + x2: Union[array, int, float, bool], + /, +) -> array: """ Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``. From c8048d3c38ca5372a7c5be985d27c40d3f64cd49 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 26 Nov 2024 16:13:42 +0100 Subject: [PATCH 3/8] Refer to existing "mix scalars and arrays" section --- spec/draft/API_specification/type_promotion.rst | 2 ++ src/array_api_stubs/_draft/searching_functions.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/spec/draft/API_specification/type_promotion.rst b/spec/draft/API_specification/type_promotion.rst index 339b90e45..63c2581a3 100644 --- a/spec/draft/API_specification/type_promotion.rst +++ b/spec/draft/API_specification/type_promotion.rst @@ -120,6 +120,8 @@ Notes .. note:: Mixed integer and floating-point type promotion rules are not specified because behavior varies between implementations. + +.. _mixing-scalars-and-arrays: Mixing arrays with Python scalars ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 9d9096f48..e76b95811 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -151,17 +151,21 @@ def where( Parameters ---------- condition: array - when ``True``, yield ``x1_i`` (scalar ``x1``); otherwise, yield ``x2_i`` (scalar ``x2``). Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`). - x1: Union[array, int, float, bool] - first input array or scalar. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). - x2: Union[array, int, float, bool] - second input array or scalar. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). + when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`). + x1: Union[array, int, float, complex, bool] + first input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). + x2: Union[array, int, float, complex, bool] + second input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). Returns ------- out: array an array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The returned array must have a data type determined by :ref:`type-promotion` rules with the arrays ``x1`` and ``x2``. + Notes + ----- + See :ref:`mixing-scalars-and-arrays` on compatibility requirements and handling of scalar arguments for ``x1`` and ``x2``. + .. versionchanged:: 2024.12 ``x1`` and ``x2`` may be scalars. """ From da3395b40cf34d747837119929b39729245ce016 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 26 Nov 2024 16:23:17 +0100 Subject: [PATCH 4/8] Fix syntax --- spec/draft/API_specification/type_promotion.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/spec/draft/API_specification/type_promotion.rst b/spec/draft/API_specification/type_promotion.rst index 63c2581a3..4b3791aca 100644 --- a/spec/draft/API_specification/type_promotion.rst +++ b/spec/draft/API_specification/type_promotion.rst @@ -122,6 +122,7 @@ Notes .. _mixing-scalars-and-arrays: + Mixing arrays with Python scalars ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 1344c3504e877b549b46883adfd374ccc84fa668 Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 9 Jan 2025 02:30:50 -0800 Subject: [PATCH 5/8] fix: add missing complex dtypes --- src/array_api_stubs/_draft/searching_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index e76b95811..c491c8e8c 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -141,8 +141,8 @@ def searchsorted( def where( condition: array, - x1: Union[array, int, float, bool], - x2: Union[array, int, float, bool], + x1: Union[array, int, float, complex, bool], + x2: Union[array, int, float, complex, bool], /, ) -> array: """ From 37b577f4c13c15443d41c361f592bd918a48cc0f Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 9 Jan 2025 02:31:38 -0800 Subject: [PATCH 6/8] docs: update copy --- src/array_api_stubs/_draft/searching_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index c491c8e8c..41596664a 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -153,9 +153,9 @@ def where( condition: array when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`). x1: Union[array, int, float, complex, bool] - first input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). + first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). x2: Union[array, int, float, complex, bool] - second input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). + second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). Returns ------- From 114fe06179858a5c556dbd43a4c4d2023f234fd6 Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 9 Jan 2025 02:33:47 -0800 Subject: [PATCH 7/8] docs: remove obsolete note now that `condition` must be an array --- src/array_api_stubs/_draft/searching_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 41596664a..11489cdd4 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -164,7 +164,9 @@ def where( Notes ----- - See :ref:`mixing-scalars-and-arrays` on compatibility requirements and handling of scalar arguments for ``x1`` and ``x2``. + + - At least one of ``x1`` and ``x2`` must be an array. + - If either ``x1`` or ``x2`` is a scalar value, the returned array must have a data type determined according to :ref:`mixing-scalars-and-arrays`. .. versionchanged:: 2024.12 ``x1`` and ``x2`` may be scalars. From 4fa6d9803ea49ed29aa7197423f251aa6b5512d6 Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 9 Jan 2025 02:34:09 -0800 Subject: [PATCH 8/8] docs: update copy --- src/array_api_stubs/_draft/searching_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 11489cdd4..edaabc03a 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -169,5 +169,5 @@ def where( - If either ``x1`` or ``x2`` is a scalar value, the returned array must have a data type determined according to :ref:`mixing-scalars-and-arrays`. .. versionchanged:: 2024.12 - ``x1`` and ``x2`` may be scalars. + Added support for scalar arguments. """