diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 7e2e12ad72f..ece2e8d34ca 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -1069,7 +1069,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { auto tv = dynamic_cast(inp); if (tv == nullptr) { continue; - }; + } if (!tv->isCircularBuffered()) { return 0; } @@ -1205,11 +1205,28 @@ class WarAsyncWaitInserter : private kir::ExprMutator { // Actually insert these wait expressions. for (auto [type, pending_ops] : types_and_pending_ops_to_protect) { auto sync_exprs = lower_utils::getSyncExprs(type, pending_ops); + NVF_ERROR(!for_loop->body().exprs().empty()); + + // Default position is last expression in for loop + size_t num_exprs = for_loop->body().exprs().size(); + int64_t pos = num_exprs - 1; + + // The sync qualifier in the `wgmma.wait_group` ptx instruction only + // guarantees that a warp executes the instruction. The entire warp + // group must complete `wgmma` instruction before loading next circular + // buffer stage, so place the wgmma sync expressions before the + // kir::BlockSync to avoid incorrect results. + if (type == AsyncOpType::WgMma && + for_loop->circularBufferLoopStage() == + CircularBufferLoopStage::Main) { + NVF_ERROR(num_exprs > 1); + NVF_ERROR(for_loop->body().exprs().back()->isA()); + --pos; + } + + Expr* expr = for_loop->body().exprs().at(pos); while (!sync_exprs.empty()) { - registerInsertAfter( - for_loop->body().exprs().back(), - sync_exprs.back(), - &for_loop->body()); + registerInsertAfter(expr, sync_exprs.back(), &for_loop->body()); sync_exprs.pop_back(); } } diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 1aab9983746..7007488224b 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4029,8 +4029,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) { auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 256, 32); - gemm_tile.warp_tile = GemmTile(64, 256, 32); + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4086,8 +4086,8 @@ TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) { auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze().t()).to(at::kHalf); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 256, 32); - gemm_tile.warp_tile = GemmTile(64, 256, 32); + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4149,8 +4149,8 @@ TEST_F(HopperMatmulTest, HSH_NN_UseScheduler) { at::matmul(a_ref.squeeze().t(), b_ref.squeeze().t()).to(at::kHalf); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 256, 32); - gemm_tile.warp_tile = GemmTile(64, 256, 32); + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4211,8 +4211,8 @@ TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) { auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 256, 32); - gemm_tile.warp_tile = GemmTile(64, 256, 32); + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4324,11 +4324,6 @@ TEST_P(MLPBenchmarkTest, FwdGEMM) { ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse( ke.compiledKernel()->kernel())); - if (!test_params.warp_specialization) { - GTEST_SKIP() - << "Sync error with pipelined circular buffering causes incorrect results"; - } - // Relax tolerance for larger sum due to large K EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); } @@ -4406,11 +4401,6 @@ TEST_P(MLPBenchmarkTest, FwdEpilogueFusion) { ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse( ke.compiledKernel()->kernel())); - if (!test_params.warp_specialization) { - GTEST_SKIP() - << "Sync error with pipelined circular buffering causes incorrect results"; - } - // Relax tolerance for larger sum due to large K EXPECT_TRUE(cg_outputs[0].allclose(tv3_ref, 1e-6 * K, 1e-6 * K)); EXPECT_TRUE(cg_outputs[1].allclose(tv11_ref, 1e-2, 1e-2)); @@ -4498,11 +4488,6 @@ TEST_P(MLPBenchmarkTest, FwdHorizontalFusion) { ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse( ke.compiledKernel()->kernel())); - if (!test_params.warp_specialization) { - GTEST_SKIP() - << "Sync error with pipelined circular buffering causes incorrect results"; - } - // Relax tolerance for larger sum due to large K EXPECT_TRUE(cg_outputs[0].allclose(tv3_ref, 1e-6 * K, 1e-6 * K)); EXPECT_TRUE(cg_outputs[1].allclose(tv10_ref, 1e-6 * K, 1e-6 * K));