Skip to content

Commit

Permalink
Increase default threshold of TileLargeTensor pass to 128
Browse files Browse the repository at this point in the history
  • Loading branch information
nirvedhmeshram committed Jan 14, 2025
1 parent 6e52bd6 commit acdbd18
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def TileLargeTensorsPass :
];
let options = [
Option<"maxVectorSize", "max-vector-size", "int64_t",
/*default=*/"64",
/*default=*/"128",
"Maximum static size to tile to (i.e. all remaining ops will be smaller)">,
];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: te

// CHECK-LABEL: func.func @simple_generic
// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1
// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>)
// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c128
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x128xf32>)

// -----

Expand Down Expand Up @@ -79,7 +79,7 @@ func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32

// CHECK-LABEL: func.func @multiple_use_tilable_op
// CHECK: %[[ADD_TILING:.+]] = scf.for
// CHECK: linalg.add {{.*}} -> tensor<1x64xf32>
// CHECK: linalg.add {{.*}} -> tensor<1x128xf32>
// CHECK: %[[T_TILING:.+]] = scf.for
// CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32>
// CHECK: linalg.transpose ins(%[[FUSED_ADD]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,10 +1013,8 @@ hal.executable public @main {
// CHECK: scf.yield %[[REDUCE]]

// CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1
// CHECK: scf.for
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32>
// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>

// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32>
// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>
// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
Expand Down

0 comments on commit acdbd18

Please sign in to comment.