Skip to content

Commit

Permalink
Fix WAR async issue with pipelined circular buffering
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Feb 10, 2025
1 parent 5c774fb commit 7bba9a4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 28 deletions.
27 changes: 22 additions & 5 deletions csrc/device_lower/pass/insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator {
auto tv = dynamic_cast<TensorView*>(inp);
if (tv == nullptr) {
continue;
};
}
if (!tv->isCircularBuffered()) {
return 0;
}
Expand Down Expand Up @@ -1205,11 +1205,28 @@ class WarAsyncWaitInserter : private kir::ExprMutator {
// Actually insert these wait expressions.
for (auto [type, pending_ops] : types_and_pending_ops_to_protect) {
auto sync_exprs = lower_utils::getSyncExprs(type, pending_ops);
NVF_ERROR(!for_loop->body().exprs().empty());

// Default position is last expression in for loop
size_t num_exprs = for_loop->body().exprs().size();
int64_t pos = num_exprs - 1;

// The sync qualifier in the `wgmma.wait_group` ptx instruction only
// guarantees that a warp executes the instruction. The entire warp
// group must complete `wgmma` instruction before loading next circular
// buffer stage, so place the wgmma sync expressions before the
// kir::BlockSync to avoid incorrect results.
if (type == AsyncOpType::WgMma &&
for_loop->circularBufferLoopStage() ==
CircularBufferLoopStage::Main) {
NVF_ERROR(num_exprs > 1);
NVF_ERROR(for_loop->body().exprs().back()->isA<kir::BlockSync>());
--pos;
}

Expr* expr = for_loop->body().exprs().at(pos);
while (!sync_exprs.empty()) {
registerInsertAfter(
for_loop->body().exprs().back(),
sync_exprs.back(),
&for_loop->body());
registerInsertAfter(expr, sync_exprs.back(), &for_loop->body());
sync_exprs.pop_back();
}
}
Expand Down
31 changes: 8 additions & 23 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4029,8 +4029,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) {
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 32);
gemm_tile.warp_tile = GemmTile(64, 256, 32);
gemm_tile.cta_tile = GemmTile(128, 256, 64);
gemm_tile.warp_tile = GemmTile(64, 256, 64);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand Down Expand Up @@ -4086,8 +4086,8 @@ TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) {
auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze().t()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 32);
gemm_tile.warp_tile = GemmTile(64, 256, 32);
gemm_tile.cta_tile = GemmTile(128, 256, 64);
gemm_tile.warp_tile = GemmTile(64, 256, 64);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand Down Expand Up @@ -4149,8 +4149,8 @@ TEST_F(HopperMatmulTest, HSH_NN_UseScheduler) {
at::matmul(a_ref.squeeze().t(), b_ref.squeeze().t()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 32);
gemm_tile.warp_tile = GemmTile(64, 256, 32);
gemm_tile.cta_tile = GemmTile(128, 256, 64);
gemm_tile.warp_tile = GemmTile(64, 256, 64);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand Down Expand Up @@ -4211,8 +4211,8 @@ TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) {
auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 32);
gemm_tile.warp_tile = GemmTile(64, 256, 32);
gemm_tile.cta_tile = GemmTile(128, 256, 64);
gemm_tile.warp_tile = GemmTile(64, 256, 64);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand Down Expand Up @@ -4324,11 +4324,6 @@ TEST_P(MLPBenchmarkTest, FwdGEMM) {
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
ke.compiledKernel()->kernel()));

if (!test_params.warp_specialization) {
GTEST_SKIP()
<< "Sync error with pipelined circular buffering causes incorrect results";
}

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}
Expand Down Expand Up @@ -4406,11 +4401,6 @@ TEST_P(MLPBenchmarkTest, FwdEpilogueFusion) {
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
ke.compiledKernel()->kernel()));

if (!test_params.warp_specialization) {
GTEST_SKIP()
<< "Sync error with pipelined circular buffering causes incorrect results";
}

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(tv3_ref, 1e-6 * K, 1e-6 * K));
EXPECT_TRUE(cg_outputs[1].allclose(tv11_ref, 1e-2, 1e-2));
Expand Down Expand Up @@ -4498,11 +4488,6 @@ TEST_P(MLPBenchmarkTest, FwdHorizontalFusion) {
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
ke.compiledKernel()->kernel()));

if (!test_params.warp_specialization) {
GTEST_SKIP()
<< "Sync error with pipelined circular buffering causes incorrect results";
}

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(tv3_ref, 1e-6 * K, 1e-6 * K));
EXPECT_TRUE(cg_outputs[1].allclose(tv10_ref, 1e-6 * K, 1e-6 * K));
Expand Down

0 comments on commit 7bba9a4

Please sign in to comment.