From d12a5e31564ba3dd16d0a328feb67964c07da1e7 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 9 Jan 2025 11:41:28 +0100 Subject: [PATCH] feat: add scalar support to `where` PR-URL: https://github.com/data-apis/array-api/pull/860 Ref: https://github.com/data-apis/array-api/issues/807 Co-authored-by: Athan Reines Reviewed-by: Athan Reines Reviewed-by: Evgeni Burovski Reviewed-by: Lucas Colley --- .../API_specification/type_promotion.rst | 3 +++ .../_draft/searching_functions.py | 20 ++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/spec/draft/API_specification/type_promotion.rst b/spec/draft/API_specification/type_promotion.rst index 339b90e45..4b3791aca 100644 --- a/spec/draft/API_specification/type_promotion.rst +++ b/spec/draft/API_specification/type_promotion.rst @@ -120,6 +120,9 @@ Notes .. note:: Mixed integer and floating-point type promotion rules are not specified because behavior varies between implementations. + +.. _mixing-scalars-and-arrays: + Mixing arrays with Python scalars ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 9e0053825..04d4fd818 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -168,7 +168,12 @@ def searchsorted( """ -def where(condition: array, x1: array, x2: array, /) -> array: +def where( + condition: array, + x1: Union[array, int, float, complex, bool], + x2: Union[array, int, float, complex, bool], + /, +) -> array: """ Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``. @@ -176,13 +181,22 @@ def where(condition: array, x1: array, x2: array, /) -> array: ---------- condition: array when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Should have a boolean data type. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`). - x1: array + x1: Union[array, int, float, complex, bool] first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). - x2: array + x2: Union[array, int, float, complex, bool] second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). Returns ------- out: array an array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The returned array must have a data type determined by :ref:`type-promotion` rules with the arrays ``x1`` and ``x2``. + + Notes + ----- + + - At least one of ``x1`` and ``x2`` must be an array. + - If either ``x1`` or ``x2`` is a scalar value, the returned array must have a data type determined according to :ref:`mixing-scalars-and-arrays`. + + .. versionchanged:: 2024.12 + Added support for scalar arguments. """