diff --git a/tests/filecheck/backend/csl/print_csl.mlir b/tests/filecheck/backend/csl/print_csl.mlir index c18499ca86..88997b59bb 100644 --- a/tests/filecheck/backend/csl/print_csl.mlir +++ b/tests/filecheck/backend/csl/print_csl.mlir @@ -403,6 +403,8 @@ csl.func @builtins() { %fabin_dsd = "csl.get_fab_dsd"(%i32_value) <{"fabric_color" = 2 : i5 , "queue_id" = 0 : i3}> : (si32) -> !csl %fabout_dsd = "csl.get_fab_dsd"(%i32_value) <{"fabric_color" = 3 : i5 , "queue_id" = 1 : i3, "control"= true, "wavelet_index_offset" = false}>: (si32) -> !csl + %zero_stride_dsd = "csl.get_mem_dsd"(%A, %i16_value, %i16_value, %i16_value) <{"strides" = [0 : si16, 0 : si16, 1 : si16]}> : (memref<24xf32>, si16, si16, si16) -> !csl + "csl.add16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () "csl.addc16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () "csl.and16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () @@ -780,6 +782,9 @@ csl.func @builtins() { // CHECK-NEXT: .wavelet_index_offset = false, // CHECK-NEXT: .control = true, // CHECK-NEXT: }}); +// CHECK-NEXT: const zero_stride_dsd : mem4d_dsd = @get_dsd( mem4d_dsd, .{ +// CHECK-NEXT: .tensor_access = | d0, d1, d2 | { i16_value, i16_value, i16_value } -> A[ d2 ] +// CHECK-NEXT: }); // CHECK-NEXT: @add16(dest_dsd, src_dsd1, src_dsd2); // CHECK-NEXT: @addc16(dest_dsd, i16_value, src_dsd1); // CHECK-NEXT: @and16(dest_dsd, u16_value, src_dsd1); diff --git a/xdsl/backend/csl/print_csl.py b/xdsl/backend/csl/print_csl.py index 2bbc78f943..0a0bfd92e2 100644 --- a/xdsl/backend/csl/print_csl.py +++ b/xdsl/backend/csl/print_csl.py @@ -760,11 +760,24 @@ def print_block(self, body: Block): ind_vars = ["d" + str(i) for i in range(len(sizes))] ind_vars_str = ", ".join(ind_vars) accesses = [ - (f"{str(strides.data[i].value.data)} * " if strides else "") + ( + f"{str(s)} * " + if strides and (s := strides.data[i].value.data) != 1 + else "" + ) + ind_vars[i] + (f" + {str(offsets.data[i].value.data)}" if offsets else "") for i in range(len(ind_vars)) ] + if strides and 0 in ( + strides_data := [s.value.data for s in strides.data] + ): + non_zero_stride_idx = [ + idx for idx, sd in enumerate(strides_data) if sd != 0 + ] + # if all except one strides are 0, print only the non-0 part (default to printing all dims) + if len(non_zero_stride_idx) == 1: + accesses = [accesses[non_zero_stride_idx[0]]] accesses_str = ", ".join(accesses) self.print( f"{self._var_use(result)} = @get_dsd( {self.mlir_type_to_csl_type(result.type)}, .{{"