From f54205e45467f38f19ea07bd61a412ad09e71850 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 13 Feb 2024 13:10:16 -0600 Subject: [PATCH 1/2] Adds overloads for group indexing functions --- .../_index_space_id_overloads.py | 23 +++++ numba_dpex/kernel_api/index_space_ids.py | 18 ++-- .../experimental/test_index_space_ids.py | 86 ++++++++++++++++++- 3 files changed, 118 insertions(+), 9 deletions(-) diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py index c8ff48c750..d574f6aced 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py @@ -123,6 +123,26 @@ def _intrinsic_spirv_workgroup_size( ) +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_spirv_workgroup_id( + ty_context, ty_dim # pylint: disable=unused-argument +): + """Generates instruction to get index from BuiltInWorkgroupId.""" + return _intrinsic_spirv_global_index_const( + ty_context, ty_dim, "BuiltInWorkgroupId" + ) + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_spirv_numworkgroups( + ty_context, ty_dim # pylint: disable=unused-argument +): + """Generates instruction to get index from BuiltInNumWorkgroups.""" + return _intrinsic_spirv_global_index_const( + ty_context, ty_dim, "BuiltInNumWorkgroups" + ) + + def generate_index_overload(_type, _intrinsic): """Generates overload for the index method that generates specific IR from provided intrinsic.""" @@ -167,6 +187,9 @@ def ol_item_get_index_impl(item, dim): (NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id), (NdItemType, "get_global_range", _intrinsic_spirv_global_size), (NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size), + (GroupType, "get_group_id", _intrinsic_spirv_workgroup_id), + (GroupType, "get_group_range", _intrinsic_spirv_numworkgroups), + (GroupType, "get_local_range", _intrinsic_spirv_workgroup_size), ] for index_overload in _index_const_overload_methods: diff --git a/numba_dpex/kernel_api/index_space_ids.py b/numba_dpex/kernel_api/index_space_ids.py index fba6ea1a51..4e10bfc688 100644 --- a/numba_dpex/kernel_api/index_space_ids.py +++ b/numba_dpex/kernel_api/index_space_ids.py @@ -52,9 +52,11 @@ def get_group_linear_id(self): + (self._index[2]) ) - def get_group_range(self): - """Returns a range representing the number of groups in the nd-range.""" - return self._group_range + def get_group_range(self, dim): + """Returns a the extent of the range representing the number of groups + in the nd-range for a specified dimension. + """ + return self._group_range[dim] def get_group_linear_range(self): """Return the total number of work-groups in the nd_range.""" @@ -64,12 +66,12 @@ def get_group_linear_range(self): return num_wg - def get_local_range(self): - """Returns a SYCL range representing all dimensions of the local - range. This local range may have been provided by the programmer, or - chosen by the SYCL runtime. + def get_local_range(self, dim): + """Returns the extent of the SYCL range representing all dimensions + of the local range for a specified dimension. This local range may + have been provided by the programmer, or chosen by the SYCL runtime. """ - return self._local_range + return self._local_range[dim] def get_local_linear_range(self): """Return the total number of work-items in the work-group.""" diff --git a/numba_dpex/tests/experimental/test_index_space_ids.py b/numba_dpex/tests/experimental/test_index_space_ids.py index fd7dd3e646..faacb0c29f 100644 --- a/numba_dpex/tests/experimental/test_index_space_ids.py +++ b/numba_dpex/tests/experimental/test_index_space_ids.py @@ -11,7 +11,8 @@ import numba_dpex as dpex import numba_dpex.experimental as dpex_exp -from numba_dpex.kernel_api import Item, NdItem +from numba_dpex.kernel_api import Item, NdItem, NdRange +from numba_dpex.kernel_api import call_kernel as kapi_call_kernel from numba_dpex.tests._helper import skip_windows _SIZE = 16 @@ -63,6 +64,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a): a[i] = 1 +def _get_group_id_driver(nditem: NdItem, a): + i = nditem.get_global_id(0) + g = nditem.get_group() + a[i] = g.get_group_id(0) + + +def _get_group_range_driver(nditem: NdItem, a): + i = nditem.get_global_id(0) + g = nditem.get_group() + a[i] = g.get_group_range(0) + + +def _get_group_local_range_driver(nditem: NdItem, a): + i = nditem.get_global_id(0) + g = nditem.get_group() + a[i] = g.get_local_range(0) + + def test_item_get_id(): a = dpnp.zeros(_SIZE, dtype=dpnp.float32) dpex_exp.call_kernel(set_ones_item, dpex.Range(a.size), a) @@ -155,6 +174,71 @@ def test_no_item(): ) +def test_get_group_id(): + + global_size = 100 + group_size = 20 + num_groups = global_size // group_size + + a = dpnp.empty(global_size, dtype=dpnp.int32) + ka = dpnp.empty(global_size, dtype=dpnp.int32) + expected = np.empty(global_size, dtype=np.int32) + ndrange = NdRange((global_size,), (group_size,)) + dpex_exp.call_kernel(dpex_exp.kernel(_get_group_id_driver), ndrange, a) + kapi_call_kernel(_get_group_id_driver, ndrange, ka) + + for gid in range(num_groups): + for lid in range(group_size): + expected[gid * group_size + lid] = gid + + assert np.array_equal(a.asnumpy(), expected) + assert np.array_equal(ka.asnumpy(), expected) + + +def test_get_group_range(): + + global_size = 100 + group_size = 20 + num_groups = global_size // group_size + + a = dpnp.empty(global_size, dtype=dpnp.int32) + ka = dpnp.empty(global_size, dtype=dpnp.int32) + expected = np.empty(global_size, dtype=np.int32) + ndrange = NdRange((global_size,), (group_size,)) + dpex_exp.call_kernel(dpex_exp.kernel(_get_group_range_driver), ndrange, a) + kapi_call_kernel(_get_group_range_driver, ndrange, ka) + + for gid in range(num_groups): + for lid in range(group_size): + expected[gid * group_size + lid] = num_groups + + assert np.array_equal(a.asnumpy(), expected) + assert np.array_equal(ka.asnumpy(), expected) + + +def test_get_group_local_range(): + + global_size = 100 + group_size = 20 + num_groups = global_size // group_size + + a = dpnp.empty(global_size, dtype=dpnp.int32) + ka = dpnp.empty(global_size, dtype=dpnp.int32) + expected = np.empty(global_size, dtype=np.int32) + ndrange = NdRange((global_size,), (group_size,)) + dpex_exp.call_kernel( + dpex_exp.kernel(_get_group_local_range_driver), ndrange, a + ) + kapi_call_kernel(_get_group_local_range_driver, ndrange, ka) + + for gid in range(num_groups): + for lid in range(group_size): + expected[gid * group_size + lid] = group_size + + assert np.array_equal(a.asnumpy(), expected) + assert np.array_equal(ka.asnumpy(), expected) + + I_SIZE, J_SIZE, K_SIZE = 2, 3, 4 From a106444863ac3510d848eff38ec064c017769bfe Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 13 Feb 2024 14:09:40 -0600 Subject: [PATCH 2/2] Skip failing test cases on Windows for now. --- numba_dpex/tests/experimental/test_index_space_ids.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/numba_dpex/tests/experimental/test_index_space_ids.py b/numba_dpex/tests/experimental/test_index_space_ids.py index faacb0c29f..54ceca8807 100644 --- a/numba_dpex/tests/experimental/test_index_space_ids.py +++ b/numba_dpex/tests/experimental/test_index_space_ids.py @@ -174,8 +174,9 @@ def test_no_item(): ) +# TODO: https://github.com/IntelPython/numba-dpex/issues/1308 +@skip_windows def test_get_group_id(): - global_size = 100 group_size = 20 num_groups = global_size // group_size @@ -195,8 +196,9 @@ def test_get_group_id(): assert np.array_equal(ka.asnumpy(), expected) +# TODO: https://github.com/IntelPython/numba-dpex/issues/1308 +@skip_windows def test_get_group_range(): - global_size = 100 group_size = 20 num_groups = global_size // group_size @@ -216,8 +218,9 @@ def test_get_group_range(): assert np.array_equal(ka.asnumpy(), expected) +# TODO: https://github.com/IntelPython/numba-dpex/issues/1308 +@skip_windows def test_get_group_local_range(): - global_size = 100 group_size = 20 num_groups = global_size // group_size