diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py index d574f6aced..a99781e53f 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py @@ -9,7 +9,7 @@ import llvmlite.ir as llvmir from numba.core import cgutils, types from numba.core.errors import TypingError -from numba.extending import intrinsic, overload_method +from numba.extending import intrinsic, overload_attribute, overload_method from numba_dpex.core.types.kernel_api.index_space_ids import ( GroupType, @@ -248,3 +248,24 @@ def ol_nd_item_get_group_impl(nd_item): return _intrinsic_get_group(nd_item) return ol_nd_item_get_group_impl + + +@overload_attribute(GroupType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME) +@overload_attribute(ItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME) +@overload_attribute( + NdItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME +) +def ol_nd_item_dimensions(item): + """ + SPIR-V overload for :meth:`numba_dpex.kernel_api..dimensions`. + + Generates the same LLVM IR instruction as dpcpp for the + `sycl::::dimensions` attribute. + """ + dimensions = item.ndim + + # pylint: disable=unused-argument + def ol_nd_item_get_group_impl(item): + return dimensions + + return ol_nd_item_get_group_impl diff --git a/numba_dpex/experimental/typeof.py b/numba_dpex/experimental/typeof.py index 108e9d3b09..e72c951a0f 100644 --- a/numba_dpex/experimental/typeof.py +++ b/numba_dpex/experimental/typeof.py @@ -68,11 +68,11 @@ def typeof_item(val: Item, c): Returns: A numba_dpex.experimental.core.types.kernel_api.items.ItemType instance. """ - return ItemType(val.ndim) + return ItemType(val.dimensions) @typeof_impl.register(NdItem) -def typeof_nditem(val, c): +def typeof_nditem(val: NdItem, c): """Registers the type inference implementation function for a numba_dpex.kernel_api.NdItem PyObject. @@ -83,4 +83,4 @@ def typeof_nditem(val, c): Returns: A numba_dpex.experimental.core.types.kernel_api.items.NdItemType instance. """ - return NdItemType(val.ndim) + return NdItemType(val.dimensions) diff --git a/numba_dpex/kernel_api/index_space_ids.py b/numba_dpex/kernel_api/index_space_ids.py index 4e10bfc688..faa78ab4cd 100644 --- a/numba_dpex/kernel_api/index_space_ids.py +++ b/numba_dpex/kernel_api/index_space_ids.py @@ -98,6 +98,14 @@ def leader(self): """ return self._leader + @property + def dimensions(self) -> int: + """Returns the rank of a Group object. + Returns: + int: Number of dimensions in the Group object + """ + return self._global_range.ndim + @leader.setter def leader(self, work_item_id): """Sets the leader attribute for the group.""" @@ -147,7 +155,7 @@ def get_range(self, idx): return self._extent[idx] @property - def ndim(self) -> int: + def dimensions(self) -> int: """Returns the rank of a Item object. Returns: @@ -228,10 +236,10 @@ def get_group(self): return self._group @property - def ndim(self) -> int: + def dimensions(self) -> int: """Returns the rank of a NdItem object. Returns: int: Number of dimensions in the NdItem object """ - return self._global_item.ndim + return self._global_item.dimensions diff --git a/numba_dpex/tests/experimental/test_index_space_ids.py b/numba_dpex/tests/experimental/test_index_space_ids.py index 2d1edb54f2..887ce6584e 100644 --- a/numba_dpex/tests/experimental/test_index_space_ids.py +++ b/numba_dpex/tests/experimental/test_index_space_ids.py @@ -63,6 +63,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a): a[i] = 1 +@dpex_exp.kernel +def set_dimensions_item(item: Item, a): + i = item.get_id(0) + a[i] = item.dimensions + + +@dpex_exp.kernel +def set_dimensions_nd_item(nd_item: NdItem, a): + i = nd_item.get_global_id(0) + a[i] = nd_item.dimensions + + +@dpex_exp.kernel +def set_dimensions_group(nd_item: NdItem, a): + i = nd_item.get_global_id(0) + a[i] = nd_item.get_group().dimensions + + def _get_group_id_driver(nditem: NdItem, a): i = nditem.get_global_id(0) g = nditem.get_group() @@ -149,6 +167,29 @@ def test_nd_item_get_local_id(): ) +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_item_dimensions(dims): + a = dpnp.zeros(_SIZE, dtype=dpnp.float32) + rng = [1] * dims + rng[0] = a.size + dpex_exp.call_kernel(set_dimensions_item, dpex.Range(*rng), a) + + assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32)) + + +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize( + "kernel", [set_dimensions_nd_item, set_dimensions_group] +) +def test_nd_item_dimensions(dims, kernel): + a = dpnp.zeros(_SIZE, dtype=dpnp.float32) + rng, grp = [1] * dims, [1] * dims + rng[0], grp[0] = a.size, _GROUP_SIZE + dpex_exp.call_kernel(kernel, dpex.NdRange(rng, grp), a) + + assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32)) + + def test_error_item_get_global_id(): a = dpnp.zeros(_SIZE, dtype=dpnp.float32)