Skip to content

Commit

Permalink
Revert "Increase default threshold of TileLargeTensor pass (#19671)" (#…
Browse files Browse the repository at this point in the history
…19693)

This reverts commit 3978ce6.

It may be causing regression in MI250 SDXL not observed on pre-submit
  • Loading branch information
nirvedhmeshram authored Jan 13, 2025
1 parent 3e34e03 commit 158c636
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def TileLargeTensorsPass :
];
let options = [
Option<"maxVectorSize", "max-vector-size", "int64_t",
/*default=*/"256",
/*default=*/"64",
"Maximum static size to tile to (i.e. all remaining ops will be smaller)">,
];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
// RUN: FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @simple_generic(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>, %5: tensor<64x512xf32>) -> tensor<64x512xf32> {
func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32> {
%6 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]
} ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%5 : tensor<64x512xf32>) {
} ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%5 : tensor<64x256xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.addf %in, %in_0 : f32
linalg.yield %7 : f32
} -> tensor<64x512xf32>
return %6 : tensor<64x512xf32>
} -> tensor<64x256xf32>
return %6 : tensor<64x256xf32>
}

// CHECK-LABEL: func.func @simple_generic
// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1
// CHECK: scf.for %{{.*}} = %c0 to %c512 step %c256
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x256xf32>)
// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>)

// -----

Expand Down Expand Up @@ -65,21 +65,21 @@ func.func @in_nested_region(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: te

// -----

func.func @multiple_use_tilable_op(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>) -> (tensor<64x512xf32>, tensor<512x64xf32>) {
%add_empty = tensor.empty() : tensor<64x512xf32>
func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>) -> (tensor<64x256xf32>, tensor<256x64xf32>) {
%add_empty = tensor.empty() : tensor<64x256xf32>
%6 = linalg.add
ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>)
outs(%add_empty : tensor<64x512xf32>) -> tensor<64x512xf32>
%transpose_empty = tensor.empty() : tensor<512x64xf32>
ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>)
outs(%add_empty : tensor<64x256xf32>) -> tensor<64x256xf32>
%transpose_empty = tensor.empty() : tensor<256x64xf32>
%7 = linalg.transpose
ins(%6 : tensor<64x512xf32>)
outs(%transpose_empty : tensor<512x64xf32>) permutation = [1, 0]
return %6, %7 : tensor<64x512xf32>, tensor<512x64xf32>
ins(%6 : tensor<64x256xf32>)
outs(%transpose_empty : tensor<256x64xf32>) permutation = [1, 0]
return %6, %7 : tensor<64x256xf32>, tensor<256x64xf32>
}

// CHECK-LABEL: func.func @multiple_use_tilable_op
// CHECK: %[[ADD_TILING:.+]] = scf.for
// CHECK: linalg.add {{.*}} -> tensor<1x256xf32>
// CHECK: linalg.add {{.*}} -> tensor<1x64xf32>
// CHECK: %[[T_TILING:.+]] = scf.for
// CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32>
// CHECK: linalg.transpose ins(%[[FUSED_ADD]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,8 +1013,9 @@ hal.executable public @main {
// CHECK: scf.yield %[[REDUCE]]

// CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32>
// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>
// CHECK: scf.for
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32>
// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>

// -----

Expand Down

0 comments on commit 158c636

Please sign in to comment.