Skip to content

Commit

Permalink
Merge branch 'main' into wjy/norm
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored Feb 10, 2025
2 parents 5611117 + 0510726 commit 5723c20
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 326 deletions.
57 changes: 55 additions & 2 deletions csrc/device_lower/pass/insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>();
registerInsertBefore(expr, fence_async, scope);
}

// An mma operation is added to async mma pipeline.
fill_async_mma_pipeline_ = true;
// async mma pipeline has not been flushed yet.
flush_async_mma_pipeline_ = false;
}
} else if (ir_utils::isCpAsyncBulkStore(expr)) {
// Add a fence before TMA store so that writes in the generic proxy is
Expand All @@ -420,7 +425,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
registerInsertBefore(expr, fence_async, scope);
}

// Insert sync exprs before async ops. For example, insert
// Insert sync exprs after async ops. For example, insert
// wgmma.commit_group.sync.aligned
// wgmma.wait_group.sync.aligned 0
// before the use of mma results. Note that cp.async is not handled
Expand All @@ -430,13 +435,34 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
for (auto inp : expr->inputs()) {
auto def = inp->definition();
auto async_type = ir_utils::getAsyncOpType(def);

NVF_ERROR(
!flush_async_mma_pipeline_ || !fill_async_mma_pipeline_,
"The async mma pipeline cannot be filled without encountering ",
"another mma op after it is flushed with a RAW sync.");

// Detect a expression that consumes async mma operation.
// The async mma pipeline is already flushed and is empty.
// Adding a RAW wgmma.wait_group(0) is not necessary so skip it.
if (async_type == AsyncOpType::WgMma && !fill_async_mma_pipeline_ &&
flush_async_mma_pipeline_) {
continue;
}

if (async_type != AsyncOpType::NotAsync &&
async_type != AsyncOpType::CpAsync) {
input_async_ops[async_type].insert(def);
// async mma pipeline is flushed.
flush_async_mma_pipeline_ = true;
// No mma operations are active in the async mma pipeline.
fill_async_mma_pipeline_ = false;
}
}
for (const auto& [async_type, ops] : input_async_ops) {
auto sync_exprs = lower_utils::getSyncExprs(async_type, 0);
auto sync_exprs = lower_utils::getSyncExprs(
async_type,
/*keep_stages=*/0,
/*requires_commit=*/async_type != AsyncOpType::WgMma);
for (auto sync_expr : sync_exprs) {
insertSyncExpr(ops, expr, sync_expr, nullptr);
}
Expand Down Expand Up @@ -758,7 +784,31 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
NVF_ERROR(sync_before_.empty(), "Didn't place all required syncs.");
}

bool checkAsyncMmaPipeline() {
return fill_async_mma_pipeline_ == false;
}

private:
//! fill_async_mma_pipeline_ is true when any mma expression is issued. A RAW
//! sync is required before any consumer operations use the results of mma
//! expression.
//!
//! flush_async_mma_pipeline_ is true when a RAW sync is issued for async mma
//! pipeline. The RAW sync for async wgmma is `wgmma.wait_group(0)`. All prior
//! mma operations are completed after this operation. No additional RAW sync
//! are required for other consumer expressions unless another mma expression
//! occurs in the fusion.
//!
//! fill_async_mma_pipeline_ and flush_async_mma_pipeline_ cannot be true
//! simultaneously.
//!
//! fill_async_mma_pipeline_ is always false at end of `ReadAfterWriteSyncs`.
//! Detect mma op in async mma pipeline that require RAW sync.
bool fill_async_mma_pipeline_ = false;

//! Only a single wgmma wait_group to flush async mma pipeline.
bool flush_async_mma_pipeline_ = false;

//! Keep track of expressions that must be followed by syncthreads
std::deque<std::pair<Expr*, ParallelTypeBitmap>> sync_before_;

Expand Down Expand Up @@ -788,6 +838,9 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
public:
static std::vector<Expr*> insert(const std::vector<Expr*>& loop_nests) {
ReadAfterWriteSyncs inserter(loop_nests);
NVF_ERROR(
inserter.checkAsyncMmaPipeline(),
"Async mma pipeline should be empty at end of cuda kernel.");
return inserter.exprs_;
}
};
Expand Down
11 changes: 8 additions & 3 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2022,11 +2022,16 @@ bool allMmaInputsGuardedByMBarrier(const MmaOp* mma) {
ir_utils::isCpAsyncBulkLoad(ir_utils::getTv(mma->inB())->definition());
}

std::vector<Expr*> getSyncExprs(AsyncOpType async_type, int64_t keep_stages) {
std::vector<Expr*> getSyncExprs(
AsyncOpType async_type,
int64_t keep_stages,
bool requires_commit) {
std::vector<Expr*> sync_exprs;
sync_exprs.reserve(2);
auto commit = IrBuilder::create<kir::AsyncCommit>(async_type);
sync_exprs.push_back(commit);
if (requires_commit) {
auto commit = IrBuilder::create<kir::AsyncCommit>(async_type);
sync_exprs.push_back(commit);
}
auto wait = IrBuilder::create<kir::AsyncWait>(async_type, keep_stages);
sync_exprs.push_back(wait);
return sync_exprs;
Expand Down
3 changes: 2 additions & 1 deletion csrc/device_lower/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ bool allMmaInputsGuardedByMBarrier(const MmaOp* mma);
// wgmma.wait_group.sync.aligned
std::vector<Expr*> getSyncExprs(
AsyncOpType async_type,
int64_t keep_stages = 0);
int64_t keep_stages = 0,
bool requires_commit = true);

} // namespace lower_utils

Expand Down
4 changes: 1 addition & 3 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,8 @@ void HostIrExecutor::compile(Fusion* fusion) {
}
} else {
std::vector<Expr*> exprs = fusion->exprs();
DeviceIdxType my_device_idx = communicator_ ? communicator_->deviceId() : 0;
for (Expr* e : exprs) {
std::vector<Expr*> communications =
HostIrLower::lower(cloner.clone(e), my_device_idx);
std::vector<Expr*> communications = HostIrLower::lower(cloner.clone(e));
for (auto* communication : communications) {
host_ir_container_->pushBackTopLevelExprs(communication);
}
Expand Down
70 changes: 19 additions & 51 deletions csrc/host_ir/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ void lowerToScatter(
std::vector<Expr*>& comms) {
// we arbitrarily choose the first device of the sender mesh to be the root
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR(
receiver_mesh.rank() == 1,
"Gather only supported on a 1D mesh. Given ",
receiver_mesh);
auto root = input_tv->getDeviceMesh().at(0);
Team team = receiver_mesh.vector();
if (!receiver_mesh.has(root)) {
Expand All @@ -74,18 +70,14 @@ void lowerToScatter(
Adds zero or multiple Gather communications to the vector 'comms'
Note that since the root of a Gather collective is a destination, we possibly
need multiple Gathers if the tensor is replicated in the receiver mesh.
need multiple Gather if the tensor is replicated in the receiver mesh.
*/
void lowerToGather(
TensorView* input_tv,
TensorView* output_tv,
std::vector<Expr*>& comms) {
// we create as many 'Gathers' as there are devices in the receiver mesh
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
NVF_ERROR(
sender_mesh.rank() == 1,
"Currently only lower Gather on a 1D mesh. Given ",
sender_mesh);
for (auto root : output_tv->getDeviceMesh().vector()) {
Team team = sender_mesh.vector();
if (!sender_mesh.has(root)) {
Expand All @@ -100,12 +92,10 @@ void lowerToGather(
void lowerToAllgather(
TensorView* input_tv,
TensorView* output_tv,
std::vector<Expr*>& comms,
DeviceIdxType my_device_idx) {
Team team =
input_tv->getDeviceMesh().getSlice(my_device_idx, ParallelType::DIDx);
std::vector<Expr*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::Allgather, output_tv, input_tv, team));
CommunicationType::Allgather, output_tv, input_tv, mesh.vector()));
}

// Adds one or zero Broadcast communication to the vector 'comms'
Expand All @@ -115,8 +105,6 @@ void lowerToBroadcast(
DeviceIdxType root,
std::vector<Expr*>& comms) {
const DeviceMesh& mesh = output_tv->getDeviceMesh();
NVF_ERROR(
mesh.rank() == 1, "Broadcast only supported a 1D mesh. Given ", mesh);
Team team = mesh.vector();
if (!mesh.has(root)) {
team.push_back(root);
Expand All @@ -135,14 +123,6 @@ void lowerToBroadcastOrSendRecv(
std::vector<Expr*>& comms) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR(
sender_mesh.rank() == 1,
"Broadcast only supported a 1D mesh. Given ",
sender_mesh);
NVF_ERROR(
receiver_mesh.rank() == 1,
"Broadcast only supported a 1D mesh. Given ",
receiver_mesh);
if (isSharded(input_tv) && sender_mesh.size() > 1) {
// if the inputs and ouputs are parallelized,
// we create as many Broadcast as that will be handled in parallel
Expand Down Expand Up @@ -184,14 +164,6 @@ void lowerToReduce(
std::vector<Expr*>& comms) {
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
NVF_ERROR(
sender_mesh.rank() == 1,
"Reduce only supported a 1D mesh. Given ",
sender_mesh);
NVF_ERROR(
receiver_mesh.rank() == 1,
"Reduce only supported a 1D mesh. Given ",
receiver_mesh);
const auto reduce_op_type = getC10dReduceOpType(op_type);
// we create as many Reduces as there are devices in the receiver mesh
for (auto root : receiver_mesh.vector()) {
Expand All @@ -213,15 +185,13 @@ void lowerToAllreduce(
TensorView* input_tv,
TensorView* output_tv,
BinaryOpType op_type,
std::vector<Expr*>& comms,
DeviceIdxType my_device_idx) {
Team team =
input_tv->getDeviceMesh().getSlice(my_device_idx, ParallelType::DIDx);
std::vector<Expr*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::Allreduce,
output_tv,
input_tv,
team,
mesh.vector(),
/*root=*/-1,
getC10dReduceOpType(op_type)));
}
Expand All @@ -230,10 +200,8 @@ void lowerToReduceScatter(
TensorView* input_tv,
TensorView* output_tv,
BinaryOpType op_type,
std::vector<Expr*>& comms,
DeviceIdxType my_device_idx) {
Team team =
input_tv->getDeviceMesh().getSlice(my_device_idx, ParallelType::DIDx);
std::vector<Expr*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
auto reduction_axis = output_tv->getReductionAxis().value();
auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx);
// The output tensor is sharded on scattered_axis and needs to be mapped
Expand All @@ -248,7 +216,7 @@ void lowerToReduceScatter(
CommunicationType::ReduceScatter,
output_tv,
input_tv,
/*team=*/team,
/*team=*/mesh.vector(),
/*root=*/-1,
getC10dReduceOpType(op_type),
scattered_axis));
Expand All @@ -265,7 +233,7 @@ void lowerToReduceScatter(
sources
*) Leverage the topology to ensure that the senders and recerivers are close
*/
std::vector<Expr*> HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) {
std::vector<Expr*> HostIrLower::lower(Expr* c) {
FusionGuard fg(c->fusion());

if (c->isOneOf<MatmulOp, LinearOp>()) {
Expand All @@ -288,7 +256,7 @@ std::vector<Expr*> HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) {
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
const bool same_mesh = sender_mesh == receiver_mesh;

// Stores whether the I/O has its first axis parallelized on DIDx
// Stores whether the I/O has its first axis parallelized on Didx
const bool is_input_sharded = isSharded(input_tv) && sender_mesh.size() > 1;
const bool is_output_sharded =
isSharded(output_tv) && receiver_mesh.size() > 1;
Expand All @@ -314,11 +282,11 @@ std::vector<Expr*> HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) {
NVF_ERROR(
same_mesh,
"ReduceScatter operation must have the same sender and receiver device mesh. "
"Insert a Set operation before or after the reduction to reshard to another device mesh");
lowerToReduceScatter(input_tv, output_tv, op_type, comms, my_device_idx);
"Insert a Set operation before or after the reduction to reshard ot another device mesh");
lowerToReduceScatter(input_tv, output_tv, op_type, comms);
} else {
if (same_mesh) {
lowerToAllreduce(input_tv, output_tv, op_type, comms, my_device_idx);
lowerToAllreduce(input_tv, output_tv, op_type, comms);
} else {
lowerToReduce(input_tv, output_tv, op_type, comms);
}
Expand All @@ -328,7 +296,7 @@ std::vector<Expr*> HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) {
lowerToScatter(input_tv, output_tv, comms);
} else if (is_input_sharded && !is_output_sharded) {
if (same_mesh) {
lowerToAllgather(input_tv, output_tv, comms, my_device_idx);
lowerToAllgather(input_tv, output_tv, comms);
} else {
lowerToGather(input_tv, output_tv, comms);
}
Expand Down Expand Up @@ -538,7 +506,7 @@ std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(

std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
std::unique_ptr<Fusion> fusion,
DeviceIdxType my_device_index) {
int64_t my_device_index) {
// Sharding PreSegmenter passes.
// Note: passes run before PreSegmenter optimization passes.
preseg_passes::OptimizationPass<
Expand Down Expand Up @@ -596,8 +564,8 @@ std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
NVF_ERROR(
group->exprs().size() == 1,
"Communication segments must contain only one Expr");
for (auto* expr : HostIrLower::lower(
ir_cloner.clone(group->exprs().at(0)), my_device_index)) {
for (auto* expr :
HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) {
// Allocate the recv buffers of communications
if (expr->isA<Communication>()) {
auto* communication = expr->as<Communication>();
Expand Down
4 changes: 2 additions & 2 deletions csrc/host_ir/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class HostIrLower {
static bool canLower(Expr* expr, bool ignore_inner_resharding = false);

// Lower a sharded Expr into a series of Communication.
static std::vector<Expr*> lower(Expr* c, DeviceIdxType my_device_index);
static std::vector<Expr*> lower(Expr* c);

static std::unique_ptr<hir::HostIrContainer> lower(
std::unique_ptr<Fusion> fusion,
DeviceIdxType my_device_index);
int64_t my_device_index);

private:
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
Expand Down
Loading

0 comments on commit 5723c20

Please sign in to comment.