Skip to content

Commit

Permalink
Merge pull request #1330 from IntelPython/feature/group_id_overloads
Browse files Browse the repository at this point in the history
Adds overloads for group indexing functions
  • Loading branch information
Diptorup Deb authored Feb 13, 2024
2 parents 6290156 + a106444 commit e998f72
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,26 @@ def _intrinsic_spirv_workgroup_size(
)


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_workgroup_id(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInWorkgroupId."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInWorkgroupId"
)


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_spirv_numworkgroups(
ty_context, ty_dim # pylint: disable=unused-argument
):
"""Generates instruction to get index from BuiltInNumWorkgroups."""
return _intrinsic_spirv_global_index_const(
ty_context, ty_dim, "BuiltInNumWorkgroups"
)


def generate_index_overload(_type, _intrinsic):
"""Generates overload for the index method that generates specific IR from
provided intrinsic."""
Expand Down Expand Up @@ -167,6 +187,9 @@ def ol_item_get_index_impl(item, dim):
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
]

for index_overload in _index_const_overload_methods:
Expand Down
18 changes: 10 additions & 8 deletions numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def get_group_linear_id(self):
+ (self._index[2])
)

def get_group_range(self):
"""Returns a range representing the number of groups in the nd-range."""
return self._group_range
def get_group_range(self, dim):
"""Returns a the extent of the range representing the number of groups
in the nd-range for a specified dimension.
"""
return self._group_range[dim]

def get_group_linear_range(self):
"""Return the total number of work-groups in the nd_range."""
Expand All @@ -64,12 +66,12 @@ def get_group_linear_range(self):

return num_wg

def get_local_range(self):
"""Returns a SYCL range representing all dimensions of the local
range. This local range may have been provided by the programmer, or
chosen by the SYCL runtime.
def get_local_range(self, dim):
"""Returns the extent of the SYCL range representing all dimensions
of the local range for a specified dimension. This local range may
have been provided by the programmer, or chosen by the SYCL runtime.
"""
return self._local_range
return self._local_range[dim]

def get_local_linear_range(self):
"""Return the total number of work-items in the work-group."""
Expand Down
89 changes: 88 additions & 1 deletion numba_dpex/tests/experimental/test_index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
from numba_dpex.kernel_api import Item, NdItem
from numba_dpex.kernel_api import Item, NdItem, NdRange
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel
from numba_dpex.tests._helper import skip_windows

_SIZE = 16
Expand Down Expand Up @@ -63,6 +64,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a):
a[i] = 1


def _get_group_id_driver(nditem: NdItem, a):
i = nditem.get_global_id(0)
g = nditem.get_group()
a[i] = g.get_group_id(0)


def _get_group_range_driver(nditem: NdItem, a):
i = nditem.get_global_id(0)
g = nditem.get_group()
a[i] = g.get_group_range(0)


def _get_group_local_range_driver(nditem: NdItem, a):
i = nditem.get_global_id(0)
g = nditem.get_group()
a[i] = g.get_local_range(0)


def test_item_get_id():
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
dpex_exp.call_kernel(set_ones_item, dpex.Range(a.size), a)
Expand Down Expand Up @@ -155,6 +174,74 @@ def test_no_item():
)


# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
@skip_windows
def test_get_group_id():
global_size = 100
group_size = 20
num_groups = global_size // group_size

a = dpnp.empty(global_size, dtype=dpnp.int32)
ka = dpnp.empty(global_size, dtype=dpnp.int32)
expected = np.empty(global_size, dtype=np.int32)
ndrange = NdRange((global_size,), (group_size,))
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_id_driver), ndrange, a)
kapi_call_kernel(_get_group_id_driver, ndrange, ka)

for gid in range(num_groups):
for lid in range(group_size):
expected[gid * group_size + lid] = gid

assert np.array_equal(a.asnumpy(), expected)
assert np.array_equal(ka.asnumpy(), expected)


# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
@skip_windows
def test_get_group_range():
global_size = 100
group_size = 20
num_groups = global_size // group_size

a = dpnp.empty(global_size, dtype=dpnp.int32)
ka = dpnp.empty(global_size, dtype=dpnp.int32)
expected = np.empty(global_size, dtype=np.int32)
ndrange = NdRange((global_size,), (group_size,))
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_range_driver), ndrange, a)
kapi_call_kernel(_get_group_range_driver, ndrange, ka)

for gid in range(num_groups):
for lid in range(group_size):
expected[gid * group_size + lid] = num_groups

assert np.array_equal(a.asnumpy(), expected)
assert np.array_equal(ka.asnumpy(), expected)


# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
@skip_windows
def test_get_group_local_range():
global_size = 100
group_size = 20
num_groups = global_size // group_size

a = dpnp.empty(global_size, dtype=dpnp.int32)
ka = dpnp.empty(global_size, dtype=dpnp.int32)
expected = np.empty(global_size, dtype=np.int32)
ndrange = NdRange((global_size,), (group_size,))
dpex_exp.call_kernel(
dpex_exp.kernel(_get_group_local_range_driver), ndrange, a
)
kapi_call_kernel(_get_group_local_range_driver, ndrange, ka)

for gid in range(num_groups):
for lid in range(group_size):
expected[gid * group_size + lid] = group_size

assert np.array_equal(a.asnumpy(), expected)
assert np.array_equal(ka.asnumpy(), expected)


I_SIZE, J_SIZE, K_SIZE = 2, 3, 4


Expand Down

0 comments on commit e998f72

Please sign in to comment.