Skip to content

Commit

Permalink
Fix generation of float4 when float3 requested
Browse files Browse the repository at this point in the history
Closes gh-922
  • Loading branch information
inducer committed Feb 20, 2025
1 parent d32e610 commit d783acc
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 22 deletions.
65 changes: 45 additions & 20 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

if TYPE_CHECKING:
from pymbolic import ArithmeticExpression
from pymbolic.typing import Integer

from loopy.codegen import VectorizationInfo
from loopy.kernel import LoopKernel
Expand Down Expand Up @@ -1103,43 +1104,67 @@ def none_pass_mapper(s: Expression | None) -> Expression | None:
else:
return self

def _vector_axis_index(self) -> int | None:
if self.dim_tags is None or self.shape is None:
return None

vec_axes = [
i for i, dim_tag in enumerate(self.dim_tags)
if isinstance(dim_tag, VectorArrayDimTag)
]
if len(vec_axes) > 1:
raise LoopyError("more than one axis of '{self.name}' is tagged 'vec'")

if not vec_axes:
return None

iaxis, = vec_axes
return iaxis

def vector_length(self) -> Integer:
iaxis = self._vector_axis_index()
if iaxis is None:
return 1

assert isinstance(self.shape, tuple)

shape_i = self.shape[iaxis]
if not is_integer(shape_i):
raise LoopyError("shape of '%s' has non-constant-integer "
"length for vector axis %d (0-based)" % (
self.name, iaxis))

return shape_i

def vector_size(self, target: TargetBase) -> int:
"""Return the size of the vector type used for the array
divided by the basic data type.
Note: For 3-vectors, this will be 4.
"""

if self.dim_tags is None or self.shape is None:
iaxis = self._vector_axis_index()
if iaxis is None:
return 1

assert isinstance(self.shape, tuple)
assert isinstance(self.dtype, LoopyType)

saw_vec_tag = False

for i, dim_tag in enumerate(self.dim_tags):
if isinstance(dim_tag, VectorArrayDimTag):
if saw_vec_tag:
raise LoopyError("more than one axis of '{self.name}' "
"is tagged 'vec'")
saw_vec_tag = True
shape_i = self.shape[iaxis]
if not is_integer(shape_i):
raise LoopyError("shape of '%s' has non-constant-integer "
"length for vector axis %d (0-based)" % (
self.name, iaxis))

shape_i = self.shape[i]
if not is_integer(shape_i):
raise LoopyError("shape of '%s' has non-constant-integer "
"length for vector axis %d (0-based)" % (
self.name, i))

vec_dtype = target.vector_dtype(self.dtype, shape_i)

return int(vec_dtype.itemsize) // int(self.dtype.itemsize)
if self.dim_tags is None or self.shape is None:
return 1

return 1
vec_dtype = target.vector_dtype(self.dtype, shape_i)

return int(vec_dtype.itemsize) // int(self.dtype.itemsize)

# }}}


def drop_vec_dims(
dim_tags: tuple[ArrayDimImplementationTag, ...],
t: tuple[T, ...]) -> tuple[T, ...]:
Expand Down
2 changes: 1 addition & 1 deletion loopy/target/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def wrap_global_constant(self, decl: Declarator) -> Declarator:
def get_array_base_declarator(self, ary: ArrayBase) -> Declarator:
dtype = ary.dtype

vec_size = ary.vector_size(self.target)
vec_size = ary.vector_length()
if vec_size > 1:
dtype = self.target.vector_dtype(dtype, vec_size)

Expand Down
2 changes: 1 addition & 1 deletion loopy/target/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def wrap_global_constant(self, decl: Declarator) -> Declarator:
def get_array_base_declarator(self, ary: ArrayBase) -> Declarator:
dtype = ary.dtype

vec_size = ary.vector_size(self.target)
vec_size = ary.vector_length()
if vec_size > 1:
dtype = self.target.vector_dtype(dtype, vec_size)

Expand Down
23 changes: 23 additions & 0 deletions test/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,29 @@ def test_cl_vectorize_ternary(ctx_factory):
assert np.allclose(result, result_ref)


def test_float3():
# https://github.com/inducer/loopy/issues/922
knl = lp.make_kernel(
"{ [i]: 0<=i<n }",
"""
out[i] = a if i == 0 else b
"""
)
vec_size = 3
knl = lp.split_array_axis(knl, "out", 0, vec_size)
knl = lp.split_iname(knl, "i", vec_size)
knl = lp.tag_inames(knl, {"i_inner": "vec"})
knl = lp.tag_array_axes(knl, "out", "c,vec")
knl = lp.assume(knl, f"n % {vec_size} = 0 and n>0")

knl = lp.add_and_infer_dtypes(knl,
{"a": np.dtype(np.float32), "b": np.dtype(np.float32)})

device_code = lp.generate_code_v2(knl).device_code()

assert "float3" in device_code


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down

0 comments on commit d783acc

Please sign in to comment.