Skip to content

Commit

Permalink
Block dynamic dimensions of contraction-like operations. (#19056)
Browse files Browse the repository at this point in the history
Block the dynamic dimensions of contraction-like operations when the
dynamic dimensions are known to be multiples of a static value.

This also fixes a couple of bugs
1) Fix a bug in the integer divisibility propagation function for
   `arith.muli`
2) Add divisiblity propagation for `arith.divui`
3) Add better tests in `integer_divisibility.mlir` to check the computed
divisibility.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Nov 13, 2024
1 parent f3c1467 commit d32ce2f
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 74 deletions.
136 changes: 96 additions & 40 deletions compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis,
/// inverses of each other. The `util.optimization.barrier` avoid these from
/// getting folded away during reshape propagation. Return the result of the
/// `tensor.collapse_shape generated.
static std::optional<Value>
struct ReshapeOps {
tensor::ExpandShapeOp expandShapeOp;
tensor::CollapseShapeOp collapseShapeOp;
};
static std::optional<ReshapeOps>
blockDynamicDimensionsOfValue(RewriterBase &rewriter,
const TensorDivisibilityInfo &divisibilityInfo,
Value v) {
Expand Down Expand Up @@ -154,18 +158,23 @@ blockDynamicDimensionsOfValue(RewriterBase &rewriter,
auto outputType = RankedTensorType::get(
staticOutputShape, tensorType.getElementType(), tensorType.getEncoding());

Value expandShape = rewriter.create<tensor::ExpandShapeOp>(
auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, outputType, v, reassociation, outputShape);
Value barrier =
rewriter.create<IREE::Util::OptimizationBarrierOp>(loc, expandShape)
.getResult(0);
Value collapseShape = rewriter.create<tensor::CollapseShapeOp>(
Value barrier = rewriter
.create<IREE::Util::OptimizationBarrierOp>(
loc, expandShapeOp.getResult())
.getResult(0);
auto collapseShapeOp = rewriter.create<tensor::CollapseShapeOp>(
loc, tensorType, barrier, reassociation);
return collapseShape;
return ReshapeOps{expandShapeOp, collapseShapeOp};
}

//===---------------------------------------------------------------------===//
// Methods for blocking operands of operations
//===---------------------------------------------------------------------===//

/// For an operation, replace the operands at indices specified in
/// `limitToOperandIndices` with the result of
/// `limitToOperandNumbers` with the result of
/// `tensor.expand_shape`/`tensor.collapse_shape` pair to materialize the
/// information about dynamic dimensions that are known to be a multiple of a
/// compile-time static value. For example,
Expand All @@ -186,68 +195,104 @@ blockDynamicDimensionsOfValue(RewriterBase &rewriter,
/// ```
static LogicalResult blockDynamicDimensions(
RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis,
Operation *operation, llvm::SmallDenseSet<int64_t> limitToOperandIndices) {
OpBuilder::InsertionGuard g(rewriter);

Operation *operation, llvm::SmallDenseSet<int64_t> limitToOperandNumbers,
llvm::SmallDenseSet<int64_t> limitToResultNumbers) {
for (OpOperand &operand : operation->getOpOperands()) {
if (!limitToOperandIndices.contains(operand.getOperandNumber()))
if (!limitToOperandNumbers.contains(operand.getOperandNumber()))
continue;
if (operand.get().getDefiningOp<tensor::CollapseShapeOp>())
continue;
TensorDivisibilityInfo operandDivisibilityInfo =
getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get());
if (operandDivisibilityInfo.empty())
continue;
std::optional<Value> newOperand = blockDynamicDimensionsOfValue(
std::optional<ReshapeOps> reshapes = blockDynamicDimensionsOfValue(
rewriter, operandDivisibilityInfo, operand.get());
if (newOperand) {
rewriter.modifyOpInPlace(operation,
[&]() { operand.set(newOperand.value()); });
if (reshapes) {
rewriter.modifyOpInPlace(
operation, [&]() { operand.set(reshapes->collapseShapeOp); });
}
}

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(operation);
for (OpResult result : operation->getResults()) {
if (!limitToResultNumbers.contains(result.getResultNumber()))
continue;
TensorDivisibilityInfo resultDivisibilityInfo =
getTensorDivisibilityInfo(dynamicDimAnalysis, result);
if (resultDivisibilityInfo.empty())
continue;
std::optional<ReshapeOps> reshapes =
blockDynamicDimensionsOfValue(rewriter, resultDivisibilityInfo, result);
if (reshapes) {
llvm::SmallPtrSet<Operation *, 1> ignoreUses;
ignoreUses.insert(reshapes->expandShapeOp);
rewriter.replaceAllUsesExcept(
result, reshapes->collapseShapeOp.getResult(), ignoreUses);
}
}
return success();
}

/// Insert `tensor.expand_shape` operations to materialize in IR information
/// about dynamic dimensions that are known to be a multiple of a compile-time
/// know value, for the operands of `iree_linalg_ext.attention` operation.
/// Generic method for blocking all operands of an operation.
static LogicalResult blockDynamicDimensionsOfAllTensorOperandsAndResults(
RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis,
Operation *op) {
llvm::SmallDenseSet<int64_t> tensorOperandsList, tensorResultsList;
for (OpOperand &opOperand : op->getOpOperands()) {
if (isa<RankedTensorType>(opOperand.get().getType())) {
tensorOperandsList.insert(opOperand.getOperandNumber());
}
}
for (OpResult result : op->getResults()) {
if (isa<RankedTensorType>(result.getType())) {
tensorResultsList.insert(result.getResultNumber());
}
}
return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op,
tensorOperandsList, tensorResultsList);
}

/// Block dynamic dimensions in operands of `LinalgOp`.
static LogicalResult
blockDynamicDimensions(RewriterBase &rewriter,
const TensorDynamicDimAnalysis &dynamicDimAnalysis,
linalg::LinalgOp linalgOp) {
if (linalg::isaContractionOpInterface(linalgOp)) {
return blockDynamicDimensionsOfAllTensorOperandsAndResults(
rewriter, dynamicDimAnalysis, linalgOp);
}
return success();
}

/// Block dynamic dimensions in operands of `AttentionOp`.
static LogicalResult
blockDynamicDimensions(RewriterBase &rewriter,
const TensorDynamicDimAnalysis &dynamicDimAnalysis,
IREE::LinalgExt::AttentionOp attentionOp) {
// Only block the q and k values.
llvm::SmallDenseSet<int64_t> prunedOperandsList;
llvm::SmallDenseSet<int64_t> prunedOperandsList, prunedResultsList;
prunedOperandsList.insert(attentionOp.getQueryMutable().getOperandNumber());
prunedOperandsList.insert(attentionOp.getKeyMutable().getOperandNumber());
return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp,
prunedOperandsList);
prunedOperandsList, prunedResultsList);
}

/// Generic method to block dynamic dimensions for all tensor operands.
/// Only used for testing for now
/// Dispatch to methods that block dynamic dimensions of operations.
static LogicalResult
blockDynamicDimensions(RewriterBase &rewriter,
const TensorDynamicDimAnalysis &dynamicDimAnalysis,
Operation *operation, bool test) {
Operation *operation) {
return TypeSwitch<Operation *, LogicalResult>(operation)
.Case<IREE::LinalgExt::AttentionOp>([&](auto attentionOp) {
return blockDynamicDimensions(rewriter, dynamicDimAnalysis,
attentionOp);
})
.Default([&](Operation *op) {
if (!test) {
return success();
}
// The default path here is for now only for testing.
llvm::SmallDenseSet<int64_t> tensorOperandsList;
for (OpOperand &opOperand : operation->getOpOperands()) {
if (isa<RankedTensorType>(opOperand.get().getType())) {
tensorOperandsList.insert(opOperand.getOperandNumber());
}
}
return blockDynamicDimensions(rewriter, dynamicDimAnalysis, operation,
tensorOperandsList);
});
.Case<linalg::LinalgOp>([&](auto linalgOp) {
return blockDynamicDimensions(rewriter, dynamicDimAnalysis, linalgOp);
})
.Default([&](Operation *op) { return success(); });
}

void BlockDynamicDimensionsPass::runOnOperation() {
Expand All @@ -261,7 +306,7 @@ void BlockDynamicDimensionsPass::runOnOperation() {
IRRewriter rewriter(context);
auto walkResult = operation->walk([&](Operation *op) -> WalkResult {
rewriter.setInsertionPoint(op);
return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op, test);
return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op);
});
if (walkResult.wasInterrupted()) {
return signalPassFailure();
Expand All @@ -278,7 +323,11 @@ void BlockDynamicDimensionsPass::runOnOperation() {
// Add patterns to "push down" the `tensor.collapse_shape` patterns (which
// are the dual of the patterns to "bubble up" `tensor.expand_shape`
// patterns)
linalg::ControlFusionFn controlFn = [](OpOperand *) { return true; };
linalg::ControlFusionFn controlFn = [](OpOperand *opOperand) {
// Avoid fusion with fills/empty using the propagation patterns.
return !isa_and_nonnull<linalg::FillOp, tensor::EmptyOp>(
opOperand->get().getDefiningOp());
};
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
controlFn);
IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
Expand All @@ -288,6 +337,8 @@ void BlockDynamicDimensionsPass::runOnOperation() {
// bindings or `tensor.empty` operations.
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
linalg::FillOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
context);
// Add some additional patterns that can simplify the IR and remove dead
// operations.
memref::populateResolveRankedShapedTypeResultDimsPatterns(
Expand Down Expand Up @@ -315,6 +366,11 @@ void BlockDynamicDimensionsPass::runOnOperation() {
context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(
removeBarrierOpsPatterns, context);
tensor::populateFoldTensorEmptyPatterns(removeBarrierOpsPatterns);
linalg::FillOp::getCanonicalizationPatterns(removeBarrierOpsPatterns,
context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(
removeBarrierOpsPatterns);
if (failed(applyPatternsAndFoldGreedily(
operation, std::move(removeBarrierOpsPatterns)))) {
operation->emitOpError("failed in cleanup patterns");
Expand Down
3 changes: 0 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def BlockDynamicDimensionsPass
: Pass<"iree-codegen-block-dynamic-dimensions"> {
let summary = "Expand dynamic dimensions that are known to be multiples of "
"statically known values.";
let options = [
Option<"test", "test", "bool", /*default=*/"false", "Enable test mode">
];
}

def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> {
Expand Down
128 changes: 112 additions & 16 deletions compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions{test}, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s

#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
Expand Down Expand Up @@ -102,33 +102,129 @@ func.func @block_attention_dims() {

// -----

func.func @basic_blocking_test(%arg0 : index) -> tensor<?xf32> {
func.func @basic_blocking_test(%arg0 : index) -> tensor<?x4096xf32> {
%0 = util.assume.int %arg0<umin = 0, umax = 1024, udiv = 16> : index
%1 = tensor.empty(%0) : tensor<?xf32>
return %1 : tensor<?xf32>
%lhs = tensor.empty(%0) : tensor<?x2048xf32>
%rhs = tensor.empty() : tensor<2048x4096xf32>
%init = tensor.empty(%0) : tensor<?x4096xf32>
%matmul = linalg.matmul ins(%lhs, %rhs : tensor<?x2048xf32>, tensor<2048x4096xf32>)
outs(%init : tensor<?x4096xf32>) -> tensor<?x4096xf32>
return %matmul : tensor<?x4096xf32>
}
// CHECK-LABEL: func @basic_blocking_test(
// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?x16xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EMPTY]]
// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor<?x16x2048xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?x16x4096xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: ins(%[[LHS]],
// CHECK-SAME: outs(%[[INIT]] :
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[MATMUL]]
// CHECK: return %[[COLLAPSE]]

// -----

func.func @no_blocking(%arg0 : index) -> tensor<?xf32> {
%1 = tensor.empty(%arg0) : tensor<?xf32>
return %1 : tensor<?xf32>
func.func @no_blocking(%arg0 : index) -> tensor<?x4096xf32> {
%lhs = tensor.empty(%arg0) : tensor<?x2048xf32>
%rhs = tensor.empty() : tensor<2048x4096xf32>
%init = tensor.empty(%arg0) : tensor<?x4096xf32>
%matmul = linalg.matmul ins(%lhs, %rhs : tensor<?x2048xf32>, tensor<2048x4096xf32>)
outs(%init : tensor<?x4096xf32>) -> tensor<?x4096xf32>
return %matmul : tensor<?x4096xf32>
}
// CHECK-LABEL: func @no_blocking(
// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
// CHECK: return %[[EMPTY]]
// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor<?x2048xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?x4096xf32>
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]],
// CHECK-SAME: outs(%[[INIT]] :
// CHECK: return %[[MATMUL]]

// -----

func.func @no_unit_blocking(%arg0 : index) -> tensor<?xf32> {
func.func @no_unit_blocking(%arg0 : index) -> tensor<?x4096xf32> {
%0 = util.assume.int %arg0<umin = 0, umax = 1024, udiv = 1> : index
%1 = tensor.empty(%0) : tensor<?xf32>
return %1 : tensor<?xf32>
%lhs = tensor.empty(%0) : tensor<?x2048xf32>
%rhs = tensor.empty() : tensor<2048x4096xf32>
%init = tensor.empty(%0) : tensor<?x4096xf32>
%matmul = linalg.matmul ins(%lhs, %rhs : tensor<?x2048xf32>, tensor<2048x4096xf32>)
outs(%init : tensor<?x4096xf32>) -> tensor<?x4096xf32>
return %matmul : tensor<?x4096xf32>
}
// CHECK-LABEL: func @no_unit_blocking(
// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
// CHECK: return %[[EMPTY]]
// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor<?x2048xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?x4096xf32>
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]],
// CHECK-SAME: outs(%[[INIT]] :
// CHECK: return %[[MATMUL]]


// -----

func.func @contract_op_interface_op(%rhs : tensor<2048x4096xf16>, %m : index)
-> tensor<?x2048xf32> {
%0 = util.assume.int %m<udiv = 16> : index
%lhs = tensor.empty(%0) : tensor<?x4096xf16>
%init = tensor.empty(%0) : tensor<?x2048xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%lhs, %rhs : tensor<?x4096xf16>, tensor<2048x4096xf16>)
outs(%init : tensor<?x2048xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%17 = arith.extf %in : f16 to f32
%18 = arith.extf %in_0 : f16 to f32
%19 = arith.mulf %17, %18 : f32
%20 = arith.addf %out, %19 : f32
linalg.yield %20 : f32
} -> tensor<?x2048xf32>
return %1 : tensor<?x2048xf32>
}
// CHECK-LABEL: func @contract_op_interface_op(
// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor<?x16x4096xf16>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?x16x2048xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[LHS]],
// CHECK-SAME: outs(%[[INIT]] :
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1], [2]{{\]}}
// CHECK: return %[[COLLAPSED]]

// -----

func.func @reshape_propagation_test(%rhs : tensor<2048x4096xf16>, %m : index)
-> tensor<?x2048xf16> {
%cst = arith.constant 0.0 : f32
%0 = util.assume.int %m<udiv = 16> : index
%lhs = tensor.empty(%0) : tensor<?x4096xf16>
%init = tensor.empty(%0) : tensor<?x2048xf32>
%init2 = tensor.empty(%0) : tensor<?x2048xf16>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x2048xf32>) -> tensor<?x2048xf32>
%1 = linalg.matmul_transpose_b
ins(%lhs, %rhs : tensor<?x4096xf16>, tensor<2048x4096xf16>)
outs(%fill : tensor<?x2048xf32>) -> tensor<?x2048xf32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%1 : tensor<?x2048xf32>) outs(%init2 : tensor<?x2048xf16>) {
^bb0(%b0 : f32, %b1 : f16):
%3 = arith.truncf %b0 : f32 to f16
linalg.yield %3 : f16
} -> tensor<?x2048xf16>
return %2 : tensor<?x2048xf16>
}
// CHECK-LABEL: func @reshape_propagation_test(
// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor<?x16x4096xf16>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?x16x2048xf32>
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT]] :
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[LHS]],
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?x16x2048xf16>
// CHECK: %[[TRUNC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[TRUNC]]
// CHECK: return %[[COLLAPSED]]
Loading

0 comments on commit d32ce2f

Please sign in to comment.