Skip to content

Commit

Permalink
[FLOW] Move InitializeEmptyTensors before CaptureDynamicDims (#19563)
Browse files Browse the repository at this point in the history
Some patterns take in a `tensor.empty` op as the `iter_arg` for
`scf.for` ops, but these ops don't get handled when capturing dynamic
dims. Moving the conversion of `tensor.empty` to `flow.tensor.empty`
before `CaptureDynamicDims` allows these dynamic dims to properly
propagate through the loop.

---------

Signed-off-by: zjgarvey <zjgarvey@gmail.com>
  • Loading branch information
zjgarvey authored Jan 23, 2025
1 parent 2cd88d5 commit 710c22a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 4 deletions.
8 changes: 4 additions & 4 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,
passManager.addPass(IREE::Flow::createVerifyInputLegalityPass());

FunctionLikeNest(passManager)
.addPass(IREE::Flow::createCaptureDynamicDimsPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
.addPass([&]() {
return IREE::Flow::createInitializeEmptyTensorsPass(
InitializeEmptyTensorsPassOptions{clZeroFillEmptyTensors});
});
})
.addPass(IREE::Flow::createCaptureDynamicDimsPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);

// Module pass to outline dispatch regions (and similar ops) into their own
// functions wrapped in executables.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ iree_lit_test_suite(
"outline_constants.mlir",
"outline_dispatch_externs.mlir",
"outline_dispatch_regions.mlir",
"pipeline_tests.mlir",
"top_level_scf_to_cfg.mlir",
"verify_input_ir.mlir",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_lit_test_suite(
"outline_constants.mlir"
"outline_dispatch_externs.mlir"
"outline_dispatch_regions.mlir"
"pipeline_tests.mlir"
"top_level_scf_to_cfg.mlir"
"verify_input_ir.mlir"
TOOLS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: iree-opt --iree-flow-transformation-pipeline --split-input-file --mlir-print-local-scope %s | FileCheck %s

util.func public @scf_for_with_empty_tensor$dynamic_dim_resolution(
%arg0: !hal.buffer_view) -> !hal.buffer_view {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%input_dim = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%input = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?xi64>{%input_dim}
%empty = tensor.empty(%input_dim) : tensor<?xi64>
%loop = scf.for %arg1 = %c0 to %input_dim step %c1 iter_args(%arg2 = %empty) -> (tensor<?xi64>) : index {
%extracted = flow.tensor.load %input[%arg1] : tensor<?xi64>{%input_dim}
%dim_0 = tensor.dim %arg2, %c0 : tensor<?xi64>
%6 = flow.tensor.store %extracted, %arg2[%arg1] : tensor<?xi64>{%dim_0}
scf.yield %6 : tensor<?xi64>
}
%output_dim = tensor.dim %loop, %c0 : tensor<?xi64>
%output = hal.tensor.export %loop "output0" : tensor<?xi64>{%output_dim} -> !hal.buffer_view
util.return %output : !hal.buffer_view
}

// CHECK-LABEL: util.func public @scf_for_with_empty_tensor$dynamic_dim_resolution(
// CHECK-SAME: %[[IN_BUFFER:.*]]: !hal.buffer_view
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[IN_DIM:.*]] = hal.buffer_view.dim<%[[IN_BUFFER]] : !hal.buffer_view>[0] : index
// CHECK: %[[IMPORT:.*]] = hal.tensor.import %[[IN_BUFFER]]
// CHECK-SAME: !hal.buffer_view -> tensor<?xi64>{%[[IN_DIM]]}
// CHECK: %[[EMPTY:.*]] = flow.tensor.empty : tensor<?xi64>{%[[IN_DIM]]}
// CHECK: %[[LOOP:.*]] = scf.for %[[INDEX:.*]] = %[[CST0]]
// CHECK-SAME: iter_args(%[[ITER_TENSOR:.*]] = %[[EMPTY]]) -> (tensor<?xi64>)
// CHECK: %[[LOAD:.*]] = flow.tensor.load %[[IMPORT]][%[[INDEX]]] : tensor<?xi64>{%[[IN_DIM]]}
// CHECK: %[[STORE:.*]] = flow.tensor.store %[[LOAD]], %[[ITER_TENSOR]][%[[INDEX]]] : tensor<?xi64>{%[[IN_DIM]]}
// CHECK: %[[EXPORT:.*]] = hal.tensor.export %[[LOOP]]
// CHECK-SAME: tensor<?xi64>{%[[IN_DIM]]} -> !hal.buffer_view
// CHECK: util.return %[[EXPORT]] : !hal.buffer_view

0 comments on commit 710c22a

Please sign in to comment.