Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (csl) Switch dsds to use affine maps #3657

Merged
merged 7 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ csl.func @builtins() {
%u32_pointer = "csl.addressof"(%u32_value) : (ui32) -> !csl.ptr<ui32, #csl<ptr_kind single>, #csl<ptr_const var>>

%A = memref.get_global @A : memref<24xf32>
%dsd_2d = "csl.get_mem_dsd"(%A, %i32_value, %i32_value) <{"strides" = [3, 4], "offsets" = [1, 2]}> : (memref<24xf32>, si32, si32) -> !csl<dsd mem4d_dsd>
%dsd_2d = "csl.get_mem_dsd"(%A, %i32_value, %i32_value) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<24xf32>, si32, si32) -> !csl<dsd mem4d_dsd>
%dest_dsd = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl<dsd mem1d_dsd>
%src_dsd1 = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl<dsd mem1d_dsd>
%src_dsd2 = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl<dsd mem1d_dsd>
Expand All @@ -426,7 +426,9 @@ csl.func @builtins() {
%fabin_dsd = "csl.get_fab_dsd"(%i32_value) <{"fabric_color" = 2 : ui5 , "queue_id" = 0 : i3}> : (si32) -> !csl<dsd fabin_dsd>
%fabout_dsd = "csl.get_fab_dsd"(%i32_value) <{"fabric_color" = 3 : ui5 , "queue_id" = 1 : i3, "control"= true, "wavelet_index_offset" = false}>: (si32) -> !csl<dsd fabout_dsd>

%zero_stride_dsd = "csl.get_mem_dsd"(%A, %i16_value, %i16_value, %i16_value) <{"strides" = [0 : si16, 0 : si16, 1 : si16]}> : (memref<24xf32>, si16, si16, si16) -> !csl<dsd mem4d_dsd>
%zero_stride_dsd = "csl.get_mem_dsd"(%A, %i16_value, %i16_value, %i16_value) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<24xf32>, si16, si16, si16) -> !csl<dsd mem4d_dsd>
%B = memref.get_global @B : memref<3x64xf32>
%oned_access_into_twod = "csl.get_mem_dsd"(%B, %i16_value) <{"tensor_access" = affine_map<(d0) -> (1, d0)>}> : (memref<3x64xf32>, si16) -> !csl<dsd mem1d_dsd>

"csl.add16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()
"csl.addc16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl<dsd mem1d_dsd>, si16, !csl<dsd mem1d_dsd>) -> ()
Expand Down Expand Up @@ -795,7 +797,7 @@ csl.func @builtins() {
// CHECK-NEXT: var u16_pointer : *u16 = &u16_value;
// CHECK-NEXT: var u32_pointer : *u32 = &u32_value;
// CHECK-NEXT: const dsd_2d : mem4d_dsd = @get_dsd( mem4d_dsd, .{
// CHECK-NEXT: .tensor_access = | d0, d1 | { i32_value, i32_value } -> A[ 3 * d0 + 1, 4 * d1 + 2 ]
// CHECK-NEXT: .tensor_access = | d0, d1 | { i32_value, i32_value } -> A[ ((d0 * 3) + 1), ((d1 * 4) + 2) ]
// CHECK-NEXT: });
// CHECK-NEXT: const dest_dsd : mem1d_dsd = @get_dsd( mem1d_dsd, .{
// CHECK-NEXT: .tensor_access = | d0 | { i32_value } -> A[ d0 ]
Expand Down Expand Up @@ -825,6 +827,9 @@ csl.func @builtins() {
// CHECK-NEXT: const zero_stride_dsd : mem4d_dsd = @get_dsd( mem4d_dsd, .{
// CHECK-NEXT: .tensor_access = | d0, d1, d2 | { i16_value, i16_value, i16_value } -> A[ d2 ]
// CHECK-NEXT: });
// CHECK-NEXT: const oned_access_into_twod : mem1d_dsd = @get_dsd( mem1d_dsd, .{
// CHECK-NEXT: .tensor_access = | d0 | { i16_value } -> B[ 1, d0 ]
// CHECK-NEXT: });
// CHECK-NEXT: @add16(dest_dsd, src_dsd1, src_dsd2);
// CHECK-NEXT: @addc16(dest_dsd, i16_value, src_dsd1);
// CHECK-NEXT: @and16(dest_dsd, u16_value, src_dsd1);
Expand Down
14 changes: 7 additions & 7 deletions tests/filecheck/dialects/csl/csl-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ builtin.module {
%1 = "csl.zeros"() : () -> memref<512xf32>
%2 = "csl.get_mem_dsd"(%1, %0) : (memref<512xf32>, i16) -> !csl<dsd mem1d_dsd>

%3 = arith.constant 1 : si16
%4 = "csl.increment_dsd_offset"(%2, %3) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
%int8 = arith.constant 3 : si8
%3 = "csl.set_dsd_stride"(%2, %int8) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>

%5 = arith.constant 510 : ui16
%6 = "csl.set_dsd_length"(%4, %5) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
%4 = arith.constant 1 : si16
%5 = "csl.increment_dsd_offset"(%3, %4) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>

%int8 = arith.constant 1 : si8
%7 = "csl.set_dsd_stride"(%6, %int8) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
%6 = arith.constant 510 : ui16
%7 = "csl.set_dsd_length"(%5, %6) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>

"test.op"(%7) : (!csl<dsd mem1d_dsd>) -> ()

// CHECK-NEXT: %0 = "csl.zeros"() : () -> memref<512xf32>
// CHECK-NEXT: %1 = arith.constant 510 : ui16
// CHECK-NEXT: %2 = "csl.get_mem_dsd"(%0, %1) <{"offsets" = [1 : si16], "strides" = [1 : si8]}> : (memref<512xf32>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %2 = "csl.get_mem_dsd"(%0, %1) <{"tensor_access" = affine_map<(d0) -> (((d0 * 3) + 1))>}> : (memref<512xf32>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "test.op"(%2) : (!csl<dsd mem1d_dsd>) -> ()


Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/csl/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ csl.func @initialize() {
%dir = "csl.get_dir"() <{"dir" = #csl<dir_kind north>}> : () -> !csl.direction

%dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
%dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3, 4], "offsets" = [1, 2]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl<dsd mem1d_dsd>, !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>) -> !csl<dsd mem1d_dsd>
Expand Down Expand Up @@ -392,7 +392,7 @@ csl.func @builtins() {
// CHECK-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %dir = "csl.get_dir"() <{"dir" = #csl<dir_kind north>}> : () -> !csl.direction
// CHECK-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl<dsd mem1d_dsd>, !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>) -> !csl<dsd mem1d_dsd>
Expand Down Expand Up @@ -639,7 +639,7 @@ csl.func @builtins() {
// CHECK-GENERIC-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-GENERIC-NEXT: %dir = "csl.get_dir"() <{"dir" = #csl<dir_kind north>}> : () -> !csl.direction
// CHECK-GENERIC-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
// CHECK-GENERIC-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl<dsd mem1d_dsd>, !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>) -> !csl<dsd mem1d_dsd>
Expand Down
4 changes: 2 additions & 2 deletions tests/filecheck/transforms/lower-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ builtin.module {
// CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index
// CHECK-NEXT: %42 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>>
// CHECK-NEXT: %43 = arith.constant 4 : i16
// CHECK-NEXT: %44 = "csl.get_mem_dsd"(%accumulator, %43, %29, %31) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %44 = "csl.get_mem_dsd"(%accumulator, %43, %29, %31) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<510xf32>, i16, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %45 = arith.index_cast %offset_1 : index to si16
// CHECK-NEXT: %46 = "csl.increment_dsd_offset"(%44, %45) <{"elem_type" = f32}> : (!csl<dsd mem4d_dsd>, si16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %47 = "csl.member_call"(%34) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl<dsd mem4d_dsd>
Expand Down Expand Up @@ -308,7 +308,7 @@ builtin.module {
// CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index
// CHECK-NEXT: %88 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>>
// CHECK-NEXT: %89 = arith.constant 4 : i16
// CHECK-NEXT: %90 = "csl.get_mem_dsd"(%accumulator_1, %89, %arg3_1, %arg5_1) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %90 = "csl.get_mem_dsd"(%accumulator_1, %89, %arg3_1, %arg5_1) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<510xf32>, i16, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %91 = arith.index_cast %offset_3 : index to si16
// CHECK-NEXT: %92 = "csl.increment_dsd_offset"(%90, %91) <{"elem_type" = f32}> : (!csl<dsd mem4d_dsd>, si16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %93 = "csl.member_call"(%69) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl<dsd mem4d_dsd>
Expand Down
31 changes: 9 additions & 22 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
i1,
)
from xdsl.ir import Attribute, Block, Operation, OpResult, Region, SSAValue
from xdsl.ir.affine import AffineMap
from xdsl.irdl import Operand
from xdsl.traits import is_side_effect_free
from xdsl.utils.comparisons import to_unsigned
Expand Down Expand Up @@ -755,36 +756,22 @@ def print_block(self, body: Block):
inner.print(f"@rpc(@get_data_task_id({id}));")
case csl.GetMemDsdOp(
base_addr=base_addr,
offsets=offsets,
strides=strides,
tensor_access=tensor_access,
sizes=sizes,
result=result,
):
sizes_str = ", ".join(
self._get_variable_name_for(size) for size in sizes
)
t_accesses = (
tensor_access.data
if tensor_access
else AffineMap.identity(len(sizes))
)

ind_vars = ["d" + str(i) for i in range(len(sizes))]
ind_vars_str = ", ".join(ind_vars)
accesses = [
(
f"{str(s)} * "
if strides and (s := strides.data[i].value.data) != 1
else ""
)
+ ind_vars[i]
+ (f" + {str(offsets.data[i].value.data)}" if offsets else "")
for i in range(len(ind_vars))
]
if strides and 0 in (
strides_data := [s.value.data for s in strides.data]
):
non_zero_stride_idx = [
idx for idx, sd in enumerate(strides_data) if sd != 0
]
# if all except one strides are 0, print only the non-0 part (default to printing all dims)
if len(non_zero_stride_idx) == 1:
accesses = [accesses[non_zero_stride_idx[0]]]
accesses_str = ", ".join(accesses)
accesses_str = ", ".join(str(expr) for expr in t_accesses.results)
self.print(
f"{self._var_use(result)} = @get_dsd( {self.mlir_type_to_csl_type(result.type)}, .{{"
)
Expand Down
19 changes: 9 additions & 10 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloatAttr,
AnyFloatAttrConstr,
AnyIntegerAttr,
Expand Down Expand Up @@ -1143,8 +1144,7 @@ class GetMemDsdOp(_GetDsdOp):

name = "csl.get_mem_dsd"
base_addr = operand_def(base(MemRefType[Attribute]) | base(TensorType[Attribute]))
offsets = opt_prop_def(ArrayAttr[AnyIntegerAttr])
strides = opt_prop_def(ArrayAttr[AnyIntegerAttr])
tensor_access = opt_prop_def(AffineMapAttr)

traits = traits_def(
Pure(),
Expand All @@ -1166,14 +1166,13 @@ def verify_(self) -> None:
raise VerifyException(
"DSD of type mem4d_dsd must have between 1 and 4 dimensions"
)
if self.offsets is not None and len(self.offsets) != len(self.sizes):
raise VerifyException(
"Dimensions of offsets must match dimensions of sizes"
)
if self.strides is not None and len(self.strides) != len(self.sizes):
raise VerifyException(
"Dimensions of strides must match dimensions of sizes"
)
if self.tensor_access:
if len(self.sizes) != self.tensor_access.data.num_dims:
raise VerifyException(
"Dsd must have sizes specified for each dimension of the affine map"
)
if self.tensor_access.data.num_symbols != 0:
raise VerifyException("Symbols on affine map not supported")


@irdl_op_definition
Expand Down
36 changes: 27 additions & 9 deletions xdsl/transforms/canonicalization_patterns/csl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from xdsl.dialects import arith
from xdsl.dialects.builtin import AnyIntegerAttrConstr, ArrayAttr
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyIntegerAttr,
)
from xdsl.dialects.csl import csl
from xdsl.ir import OpResult
from xdsl.ir.affine import AffineMap
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.utils.isattr import isattr
from xdsl.utils.hints import isa


class GetDsdAndOffsetFolding(RewritePattern):
Expand All @@ -23,20 +27,28 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> N
):
return
# only works on 1d
if op.offsets and len(op.offsets) > 1:
if len(op.sizes) > 1:
return

# check if we can promote arith.const to property
if (
isinstance(offset_op.offset, OpResult)
and isinstance(cnst := offset_op.offset.op, arith.ConstantOp)
and isattr(cnst.value, AnyIntegerAttrConstr)
and isa(attr_val := cnst.value, AnyIntegerAttr)
):
tensor_access = AffineMap.from_callable(
lambda x: (x + attr_val.value.data,)
)
if op.tensor_access:
tensor_access = tensor_access.compose(op.tensor_access.data)
rewriter.replace_matched_op(
new_op := csl.GetMemDsdOp.build(
operands=[op.base_addr, op.sizes],
result_types=op.result_types,
properties={**op.properties, "offsets": ArrayAttr([cnst.value])},
properties={
**op.properties,
"tensor_access": AffineMapAttr(tensor_access),
},
)
)
rewriter.replace_op(offset_op, [], new_results=[new_op.result])
Expand Down Expand Up @@ -81,21 +93,27 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> N
stride_op := next(iter(op.result.uses)).operation, csl.SetDsdStrideOp
):
return
# only works on 1d
if op.offsets and len(op.offsets) > 1:
# only works on 1d and default (unspecified) tensor_access
if len(op.sizes) > 1 or op.tensor_access:
return

dk949 marked this conversation as resolved.
Show resolved Hide resolved
# check if we can promote arith.const to property
if (
isinstance(stride_op.stride, OpResult)
and isinstance(cnst := stride_op.stride.op, arith.ConstantOp)
and isattr(cnst.value, AnyIntegerAttrConstr)
and isa(attr_val := cnst.value, AnyIntegerAttr)
):
tensor_access = AffineMap.from_callable(
lambda x: (x * attr_val.value.data,)
)
rewriter.replace_matched_op(
dk949 marked this conversation as resolved.
Show resolved Hide resolved
new_op := csl.GetMemDsdOp.build(
operands=[op.base_addr, op.sizes],
result_types=op.result_types,
properties={**op.properties, "strides": ArrayAttr([cnst.value])},
properties={
**op.properties,
"tensor_access": AffineMapAttr(tensor_access),
},
)
)
rewriter.replace_op(stride_op, [], new_results=[new_op.result])
Expand Down
9 changes: 7 additions & 2 deletions xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, func, memref, stencil
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloatAttr,
AnyMemRefType,
ArrayAttr,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
Expand All @@ -28,6 +28,7 @@
Region,
SSAValue,
)
from xdsl.ir.affine import AffineMap
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -458,7 +459,11 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
acc_dsd = csl.GetMemDsdOp.build(
operands=[alloc, [direction_count, pattern, chunk_size]],
result_types=[dsd_t],
properties={"strides": ArrayAttr([IntegerAttr(i, 16) for i in [0, 0, 1]])},
properties={
"tensor_access": AffineMapAttr(
AffineMap.from_callable(lambda x, y, z: (z,))
)
},
)
new_acc = acc_dsd

Expand Down
Loading