Skip to content

Commit

Permalink
Better support multidevice placement with stream.async.barrier (#19651
Browse files Browse the repository at this point in the history
)

Barriers / transfers should have semantics that attempt to parallelize
partitioning. If a value has a barrier placed it should divide
partitions to avoid spaning behavior with cross device dependencies.

Intermediate and ending transfers we want to place on the producing
partition so that any produced operator ends by producing the value at
the needed desetination

For incoming transfers we place in the destination partition as these
will not add a dependency on the incoming data.
  • Loading branch information
rsuderman authored Jan 13, 2025
1 parent 88d5f59 commit 40c19e3
Show file tree
Hide file tree
Showing 13 changed files with 333 additions and 59 deletions.
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,13 @@ void TensorCloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.insert<ReplaceOpIfTensorOperandEmpty<TensorCloneOp, 0, 0>>(context);
}

//===----------------------------------------------------------------------===//
// flow.tensor.barrier
//===----------------------------------------------------------------------===//

void TensorBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// flow.tensor.transfer
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1836,6 +1836,12 @@ LogicalResult TensorCloneOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// flow.tensor.barrier
//===----------------------------------------------------------------------===//

LogicalResult TensorBarrierOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// flow.tensor.transfer
//===----------------------------------------------------------------------===//
Expand Down
46 changes: 46 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,52 @@ def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [
let hasFolder = 1;
}

def FLOW_TensorBarrierOp : FLOW_PureOp<"tensor.barrier", [
AllTypesMatch<["operand", "result"]>,
DeclareOpInterfaceMethods<Util_HoistableOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{}];
let description = [{
}];

let arguments = (ins
FLOW_Tensor:$operand,
FLOW_ShapeDynamicDims:$argument_dims,
AnyAttr:$target
);
let results = (outs
FLOW_Tensor:$result
);

let assemblyFormat = [{
$operand `:` type($result) (`{` $argument_dims^ `}`)?
`on` $target
attr-dict-with-keyword
}];

let builders = [
OpBuilder<(ins "Value":$operand, "Attribute":$target),
[{
build($_builder, $_state,
operand.getType(),
operand,
IREE::Util::buildDynamicDimsForValue($_state.location, operand, $_builder),
target);
}]>,
];

let extraClassDeclaration = [{
bool isHoistableLeafOp() { return false; }

ValueRange getOperandDynamicDims(unsigned idx) { return getArgumentDims(); }
ValueRange getResultDynamicDims(unsigned idx) { return getArgumentDims(); }
}];

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def FLOW_TensorTransferOp : FLOW_PureOp<"tensor.transfer", [
AllTypesMatch<["operand", "result"]>,
DeclareOpInterfaceMethods<Util_HoistableOpInterface>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h"
#include "iree/compiler/Dialect/Stream/Analysis/ResourceHazards.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/PatternMatch.h"

Expand Down Expand Up @@ -138,6 +140,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,

auto asmState = getRootAsmState(block);

llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> syncOps;

for (auto &op : llvm::reverse(*block)) {
// Skip constants; they just add noise (and since they are heavily CSE'd
// they have lots of users to test).
Expand All @@ -163,6 +167,21 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
// Even though not a streamable op we still want to track it below.
}

// Synchronizing operations should join with their producers if the producer
// is streamable.
if (dyn_cast<IREE::Stream::AsyncBarrierOp>(op) ||
dyn_cast<IREE::Stream::AsyncTransferOp>(op)) {
auto producer = op.getOperand(0).getDefiningOp();
auto streamable =
dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
if (streamable) {
if (!syncOps.contains(producer))
syncOps[producer] = llvm::SmallVector<Operation *>();
syncOps[producer].push_back(&op);
continue;
}
}

// Initialize op info for this op - whether streamable or not. We track
// transitive hazards on each op. Note that thanks to the ordering of ops
// in SSA form (_reversed here!_) we know that once we visit this op no
Expand Down Expand Up @@ -202,6 +221,21 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
opInfo.hazards |= userInfo.membership;
opInfo.hazards |= userInfo.hazards;
}

for (auto syncOp : syncOps[&op]) {
for (auto user : syncOp->getUsers()) {
auto userInfoIt = opInfos.find(user);
if (userInfoIt == opInfos.end())
continue;
auto &userInfo = userInfoIt->second;
opInfo.hazards |= userInfo.membership;
opInfo.hazards |= userInfo.hazards;
consumers.reset();
}
}

// For any sync ops not use this ops results we need to put in a
// non-consumer block:
llvm::BitVector candidates(builders.size(), /*t=*/true);
candidates ^= opInfo.hazards;
candidates |= consumers;
Expand All @@ -216,6 +250,16 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
}
}

for (auto syncOp : syncOps[&op]) {
for (auto ordinal : candidates.set_bits()) {
if (!canAddOpToPartition(*syncOp, opInfo, ordinal)) {
LLVM_DEBUG(llvm::dbgs()
<< "Candidate partition " << ordinal << " incompatible\n");
candidates.reset(ordinal);
}
}
}

// If this op is not streamable then bail here; we've still setup the hazard
// map for following iteration.
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
Expand All @@ -227,63 +271,60 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
// First see which partitions are consuming this that we can also safely
// move in to.
consumers &= candidates;
if (consumers.any())
candidates = consumers;

opInfo.membership.reserve(builders.size() + 1);
opInfo.membership.resize(builders.size(), /*t=*/false);

// If we have one or more consumers we should go into those first.
if (consumers.any()) {
// If we are a clonable op (like splat) clone us into every partition.
// Otherwise we just pick the first we find (probably a bad heuristic).
if (streamableOp.preferCloneToConsumers() && consumers.count() > 1) {
for (auto consumerOrdinal : consumers.set_bits()) {
LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op, opInfo);
consumerBuilder->clonedOps.insert(&op);
}
} else {
int consumerOrdinal = consumers.find_last();
LLVM_DEBUG(llvm::dbgs() << "Moving into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op, opInfo);
}
LLVM_DEBUG(llvm::dbgs() << "Handled streamable (continue)\n");
continue;
}

// No consumers - if there's any candidate then we'll go into that.
int firstCandidateOrdinal = candidates.find_first();
if (firstCandidateOrdinal != -1) {
LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition "
<< firstCandidateOrdinal << " (continue)\n");
builders[firstCandidateOrdinal]->insert(&op, opInfo);
continue;
if (firstCandidateOrdinal == -1) {
// Mark the op as having hazards against all other partitions.
// It is better to be safe than incorrect, especially with our current
// minimal test coverage. It's not always safe to reorder things - if
// anything we are unlikely to be conservative enough here - for example,
// if there's a stream.resource.load of a resource or a global we can't
// move anything that may affect that resource or global. This
// partitioning was designed to be conservative because debugging such
// issues is really difficult.
if (!builders.empty()) {
opInfo.hazards.set(0, builders.size() - 1);
}

// Create a new partition just for this op.
opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
auto builder = std::make_unique<PartitionBuilder>();
builder->ordinal = builders.size();
builders.push_back(std::move(builder));
usableBuilders.resize(builders.size(), /*t=*/true);
LLVM_DEBUG(llvm::dbgs()
<< "Created partition " << builder->ordinal << "\n");
firstCandidateOrdinal = builders.size() - 1;
}

// Mark the op as having hazards against all other partitions.
// It is better to be safe than incorrect, especially with our current
// minimal test coverage. It's not always safe to reorder things - if
// anything we are unlikely to be conservative enough here - for example,
// if there's a stream.resource.load of a resource or a global we can't
// move anything that may affect that resource or global. This partitioning
// was designed to be conservative because debugging such issues is really
// difficult.
if (!builders.empty()) {
opInfo.hazards.set(0, builders.size() - 1);
auto &builder = builders[firstCandidateOrdinal];

// If we have synchronization operations we can place in the last block:
for (auto syncOp : syncOps[&op]) {
builder->insert(syncOp, opInfo);
}

// Create a new partition just for this op.
opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
auto builder = std::make_unique<PartitionBuilder>();
builder->ordinal = builders.size();
builder->insert(&op, opInfo);
LLVM_DEBUG(llvm::dbgs()
<< "Created partition " << builder->ordinal << "\n");
builders.push_back(std::move(builder));
usableBuilders.resize(builders.size(), /*t=*/true);
LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition "
<< firstCandidateOrdinal << " (continue)\n");
// If we are a clonable op (like splat) clone us into every partition.
// Otherwise we just pick the first we find (probably a bad heuristic).
if (consumers.count() > 1 && streamableOp.preferCloneToConsumers()) {
for (auto consumerOrdinal : consumers.set_bits()) {
LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op, opInfo);
consumerBuilder->clonedOps.insert(&op);
}
} else {
builder->insert(&op, opInfo);
}
}

// Ops cloned into multiple partitions may still escape if there are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,24 @@ struct ConvertTensorCloneOp
}
};

struct ConvertTensorBarrierOp
: public AffinityOpConversionPattern<IREE::Flow::TensorBarrierOp> {
using AffinityOpConversionPattern::AffinityOpConversionPattern;
LogicalResult matchAndRewriteOnAffinity(
IREE::Flow::TensorBarrierOp op, OneToNOpAdaptor adaptor,
IREE::Stream::AffinityAttr executionAffinityAttr,
ConversionPatternRewriter &rewriter) const override {
auto operand = resolveTensorOperands(op.getLoc(), op.getOperand(),
adaptor.getOperand(), rewriter);
auto barrierOp = rewriter.create<IREE::Stream::AsyncBarrierOp>(
op.getLoc(), operand.resource.getType(), operand.resource,
operand.resourceSize,
/*affinity=*/operand.affinity);
rewriter.replaceOpWithMultiple(op, {{barrierOp, operand.resourceSize}});
return success();
}
};

struct ConvertTensorTransferOp
: public AffinityOpConversionPattern<IREE::Flow::TensorTransferOp> {
using AffinityOpConversionPattern::AffinityOpConversionPattern;
Expand Down Expand Up @@ -1162,15 +1180,15 @@ void populateFlowToStreamConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
IREE::Stream::AffinityAnalysis *affinityAnalysis,
RewritePatternSet &patterns) {
patterns
.insert<ConvertTensorConstantOp, ConvertTensorDynamicConstantOp,
ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>,
ConvertTensorCastLikeOp<IREE::Flow::TensorBitCastOp>,
ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
ConvertTensorCloneOp, ConvertTensorTransferOp,
ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
ConvertTensorStoreOp, ConvertTensorTraceOp>(
typeConverter, context, affinityAnalysis);
patterns.insert<
ConvertTensorConstantOp, ConvertTensorDynamicConstantOp,
ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>,
ConvertTensorCastLikeOp<IREE::Flow::TensorBitCastOp>,
ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
ConvertTensorCloneOp, ConvertTensorBarrierOp, ConvertTensorTransferOp,
ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
ConvertTensorStoreOp, ConvertTensorTraceOp>(typeConverter, context,
affinityAnalysis);
patterns.insert<ConvertChannelDefaultOp>(typeConverter, context,
affinityAnalysis);
patterns.insert<ConvertChannelSplitOp, ConvertChannelRankOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ util.func public @tensorTransfer(%input: tensor<?x128xi8>, %dim0: index) -> tens

// -----

util.global private @device : !hal.device

// CHECK-LABEL: @tensorBarrier
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM0:.+]]: index)
util.func public @tensorBarrier(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
// CHECK: %[[TRANSFER:.+]] = stream.async.barrier %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> !stream.resource<*>
%transfer = flow.tensor.barrier %input : tensor<?x128xi8>{%dim0} on #hal.device.affinity<@device>
// CHECK: util.return %[[TRANSFER]], %[[INPUT_SIZE]]
util.return %transfer : tensor<?x128xi8>
}

// -----

// CHECK-LABEL: @tensorSlice
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index)
util.func public @tensorSlice(%input : tensor<5x24x48xf32>) -> tensor<3x24x48xf32> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1964,6 +1964,13 @@ void AsyncCollectiveOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.insert<ElideUnusedOp<AsyncCollectiveOp>>(context);
}

//===----------------------------------------------------------------------===//
// stream.async.barrier
//===----------------------------------------------------------------------===//

void AsyncBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// stream.async.transfer
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 13 additions & 3 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2006,6 +2006,14 @@ void AsyncCollectiveOp::getAsyncAccessRanges(
getTargetOffset(), getTargetEnd(), getTargetLength()});
}

//===----------------------------------------------------------------------===//
// stream.async.barrier
//===----------------------------------------------------------------------===//

bool AsyncBarrierOp::isMetadata() { return true; }

LogicalResult AsyncBarrierOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// stream.async.transfer
//===----------------------------------------------------------------------===//
Expand All @@ -2026,15 +2034,17 @@ IREE::Stream::AffinityAttr AsyncTransferOp::getAffinityAttr() {
resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// TODO(multi-device): figure out how to model staging->staging transfers.
return getSourceAffinityAttr();
} else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
} else if (sourceType.getLifetime() == IREE::Stream::Lifetime::External ||
sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// If source is staging then the op should execute on the consumer.
return getResultAffinityAttr();
} else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
} else if (resultType.getLifetime() == IREE::Stream::Lifetime::External ||
resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// If result is staging then the op should execute on the producer.
return getSourceAffinityAttr();
} else {
// Default to result affinity.
return getResultAffinityAttr();
return getSourceAffinityAttr();
}
}

Expand Down
Loading

0 comments on commit 40c19e3

Please sign in to comment.