diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 5844143f4..c60f28e84 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -52,6 +52,7 @@ if TYPE_CHECKING: from pymbolic import ArithmeticExpression + from pymbolic.typing import Integer from loopy.codegen import VectorizationInfo from loopy.kernel import LoopKernel @@ -1103,43 +1104,67 @@ def none_pass_mapper(s: Expression | None) -> Expression | None: else: return self + def _vector_axis_index(self) -> int | None: + if self.dim_tags is None or self.shape is None: + return None + + vec_axes = [ + i for i, dim_tag in enumerate(self.dim_tags) + if isinstance(dim_tag, VectorArrayDimTag) + ] + if len(vec_axes) > 1: + raise LoopyError("more than one axis of '{self.name}' is tagged 'vec'") + + if not vec_axes: + return None + + iaxis, = vec_axes + return iaxis + + def vector_length(self) -> Integer: + iaxis = self._vector_axis_index() + if iaxis is None: + return 1 + + assert isinstance(self.shape, tuple) + + shape_i = self.shape[iaxis] + if not is_integer(shape_i): + raise LoopyError("shape of '%s' has non-constant-integer " + "length for vector axis %d (0-based)" % ( + self.name, iaxis)) + + return shape_i + def vector_size(self, target: TargetBase) -> int: """Return the size of the vector type used for the array divided by the basic data type. Note: For 3-vectors, this will be 4. """ - - if self.dim_tags is None or self.shape is None: + iaxis = self._vector_axis_index() + if iaxis is None: return 1 assert isinstance(self.shape, tuple) assert isinstance(self.dtype, LoopyType) - saw_vec_tag = False - - for i, dim_tag in enumerate(self.dim_tags): - if isinstance(dim_tag, VectorArrayDimTag): - if saw_vec_tag: - raise LoopyError("more than one axis of '{self.name}' " - "is tagged 'vec'") - saw_vec_tag = True + shape_i = self.shape[iaxis] + if not is_integer(shape_i): + raise LoopyError("shape of '%s' has non-constant-integer " + "length for vector axis %d (0-based)" % ( + self.name, iaxis)) - shape_i = self.shape[i] - if not is_integer(shape_i): - raise LoopyError("shape of '%s' has non-constant-integer " - "length for vector axis %d (0-based)" % ( - self.name, i)) - - vec_dtype = target.vector_dtype(self.dtype, shape_i) - - return int(vec_dtype.itemsize) // int(self.dtype.itemsize) + if self.dim_tags is None or self.shape is None: + return 1 - return 1 + vec_dtype = target.vector_dtype(self.dtype, shape_i) + return int(vec_dtype.itemsize) // int(self.dtype.itemsize) # }}} + def drop_vec_dims( dim_tags: tuple[ArrayDimImplementationTag, ...], t: tuple[T, ...]) -> tuple[T, ...]: diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py index 50d2ac7fe..7a9b4c11a 100644 --- a/loopy/target/cuda.py +++ b/loopy/target/cuda.py @@ -431,7 +431,7 @@ def wrap_global_constant(self, decl: Declarator) -> Declarator: def get_array_base_declarator(self, ary: ArrayBase) -> Declarator: dtype = ary.dtype - vec_size = ary.vector_size(self.target) + vec_size = ary.vector_length() if vec_size > 1: dtype = self.target.vector_dtype(dtype, vec_size) diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 3fe951c4e..07c5b49d0 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -751,7 +751,7 @@ def wrap_global_constant(self, decl: Declarator) -> Declarator: def get_array_base_declarator(self, ary: ArrayBase) -> Declarator: dtype = ary.dtype - vec_size = ary.vector_size(self.target) + vec_size = ary.vector_length() if vec_size > 1: dtype = self.target.vector_dtype(dtype, vec_size) diff --git a/test/test_target.py b/test/test_target.py index 32ca21ed5..fe2ad1d8a 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -852,6 +852,29 @@ def test_cl_vectorize_ternary(ctx_factory): assert np.allclose(result, result_ref) +def test_float3(): + # https://github.com/inducer/loopy/issues/922 + knl = lp.make_kernel( + "{ [i]: 0<=i0") + + knl = lp.add_and_infer_dtypes(knl, + {"a": np.dtype(np.float32), "b": np.dtype(np.float32)}) + + device_code = lp.generate_code_v2(knl).device_code() + + assert "float3" in device_code + + if __name__ == "__main__": import sys if len(sys.argv) > 1: