Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend circular buffer tests to test 1d TMA and fix index for 1dtma #3859

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Feb 9, 2025

Two changes in this PR:
(1) circular buffer tests are extended to test both LoadStoreOpType::CpAsyncBulkTensorTile and LoadStoreOpType::CpAsyncBulk
(2) use IdModel indexing for 1D TMA, avoid offset bug when using warp specilization with prefetch

@liqiangxl
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 9, 2025

Review updated until commit d99a60f

Description

  • Extend circular buffer tests for CpAsyncBulk and CpAsyncBulkTensorTile

  • Fix index calculation for 1D TMA with warp specialization

  • Add checks to skip tests causing source address overflow

  • Update Index::getProducerIndex to handle CpAsyncBulk operations


Changes walkthrough 📝

Relevant files
Enhancement
device_version.cpp
Extend device version check                                                           

csrc/device_lower/analysis/device_version.cpp

  • Extend device version check to include LoadStoreOpType::CpAsyncBulk
  • +4/-2     
    index_compute.cpp
    Update index computation for CpAsyncBulk                                 

    csrc/index_compute.cpp

  • Update getProducerIndex to handle CpAsyncBulk operations
  • Add isCpAsyncBulkLoad utility function check
  • Remove redundant error check for 1D TMA
  • +5/-4     
    test_circular_buffering.cpp
    Extend and fix circular buffering tests                                   

    tests/cpp/test_circular_buffering.cpp

  • Extend test parameters to include LoadStoreOpType
  • Add checks to skip tests causing source address overflow
  • Update test cases to use tma_load_type
  • Skip CpAsyncBulk for multi-dimensional TMA
  • +88/-22 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Version Check

    Ensure that the version checks for LoadStoreOpType::CpAsyncBulk and LoadStoreOpType::CpAsyncBulkTensorTile are correct and that they cover all necessary cases.

    } else if (
        ls_op->opType() == LoadStoreOpType::CpAsyncBulkTensorTile ||
        ls_op->opType() == LoadStoreOpType::CpAsyncBulk) {
      ensureVersion(
          {9, 0},
          "LoadStoreOpType::CpAsyncBulk{TensorTile} requires Hopper (9.0) or newer");
    Index Calculation

    Verify that the index calculation logic for LoadStoreOpType::CpAsyncBulk and LoadStoreOpType::CpAsyncBulkTensorTile is correct and handles all edge cases.

        ir_utils::isCpAsyncBulkLoad(producer->definition());
    bool is_consumer_tma_op = consumer->definition() != nullptr &&
        consumer->definition()->isA<LoadStoreOp>() &&
        ir_utils::isCpAsyncBulkLoad(consumer->definition());
    Test Coverage

    Ensure that the new tests cover all scenarios, including edge cases, for both LoadStoreOpType::CpAsyncBulk and LoadStoreOpType::CpAsyncBulkTensorTile.

    using TmaCircularBufferingParams = std::tuple<
        int64_t,
        int64_t,
        int64_t,
        int64_t,
        CircularBufferType,
        LoadStoreOpType>;
    
    class TmaCircularBufferingTest
        : public NVFuserFixtureParamTest<TmaCircularBufferingParams> {
     protected:
      int64_t number_of_stages = 1;
      int64_t prefetch_distance = 1;
      int64_t tensor_outer_dim = 1;
      int64_t tensor_inner_dim = 1;
      CircularBufferType circular_buffer_type;
      LoadStoreOpType tma_load_type;
    
      void SetUp() override {
        number_of_stages = std::get<0>(GetParam());
        prefetch_distance = std::get<1>(GetParam());
        tensor_outer_dim = std::get<2>(GetParam());
        tensor_inner_dim = std::get<3>(GetParam());
        circular_buffer_type = std::get<4>(GetParam());
        tma_load_type = std::get<5>(GetParam());
    
        // NOTE: Multiple of 16 required for inner dimension
        NVF_ERROR(tensor_inner_dim % 16 == 0);
        NVFuserTest::SetUp();
      }
    
      bool testEnablesRegisterSharing() {
        return std::holds_alternative<WarpSpecialized>(circular_buffer_type) &&
            std::get<WarpSpecialized>(circular_buffer_type)
                .num_registers.has_value();
      }
    
      // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk
      // the memory range [srcMem, srcMem + size - 1] must not overflow the source
      // memory space. Otherwise, the behavior is undefined.
      bool tma1dSrcAddressOverflow(int64_t bulk_inner_dim) {
        return tensor_inner_dim % bulk_inner_dim != 0 &&
            tma_load_type == LoadStoreOpType::CpAsyncBulk;
      }
    
      template <typename data_type>
      void compare(int64_t tensor_dim, at::Tensor result, at::Tensor reference) {
        at::Tensor reference_cpu_data = reference.cpu();
        at::Tensor result_cpu_data = result.cpu();
    
        auto reference_cpu = reference_cpu_data.accessor<data_type, 1>();
        auto result_cpu = result_cpu_data.accessor<data_type, 1>();
    
        constexpr double abs_tolerance = 1e-3;
        constexpr double rel_tolerance = 1e-3;
        for (int64_t pos = 0; pos < tensor_dim; ++pos) {
          double tolerance =
              abs_tolerance + rel_tolerance * fabs((double)reference_cpu[pos]);
          if (fabs((double)result_cpu[pos] - (double)reference_cpu[pos]) >
              tolerance) {
            std::cout << "[" << pos << "] - result: " << result_cpu[pos]
                      << " | reference: " << reference_cpu[pos] << std::endl;
          }
        }
      }
    
      template <typename data_type>
      void compare(
          int64_t tensor_outer_dim,
          int64_t tensor_inner_dim,
          at::Tensor result,
          at::Tensor reference) {
        at::Tensor reference_cpu_data = reference.cpu();
        at::Tensor result_cpu_data = result.cpu();
    
        auto reference_cpu = reference_cpu_data.accessor<data_type, 2>();
        auto result_cpu = result_cpu_data.accessor<data_type, 2>();
    
        constexpr double abs_tolerance = 1e-3;
        constexpr double rel_tolerance = 1e-3;
        for (int64_t out_pos = 0; out_pos < tensor_outer_dim; ++out_pos) {
          for (int64_t in_pos = 0; in_pos < tensor_inner_dim; ++in_pos) {
            double tolerance = abs_tolerance +
                rel_tolerance * fabs((double)reference_cpu[out_pos][in_pos]);
            if (fabs(
                    (double)reference_cpu[out_pos][in_pos] -
                    (double)result_cpu[out_pos][in_pos]) > tolerance) {
              std::cout << "[" << out_pos << ", " << in_pos
                        << "] - result: " << result_cpu[out_pos][in_pos]
                        << " | ref: " << reference_cpu[out_pos][in_pos]
                        << std::endl;
            }
          }
        }
      }
    };
    
    TEST_F(NVFuserTest, ElectSyncCompatibility) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* input = makeContigTensor(3);
      fusion->addInput(input);
      TensorView* output = set(input);
      fusion->addOutput(output);
    
      TensorView* smem_cache =
          input->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      smem_cache->setMemoryType(MemoryType::Shared);
    
      // For TMA load, both the shared memory layout and the loop nest and
      // parallelization of TMA are specified by the consumer: smem_cache
    
      // Step 1: define TMA domain
      // Because we want to treat the entire tensor as 1D, we define the TMA
      // domain as [I0*I1*I2]
      smem_cache->merge(0);
      smem_cache->merge(0);
      // Note that the TMA domain only exist in people's mind, there is no need to
      // set anything here.
    
      // Step 2: define box
      smem_cache->split(0, 256);
      // [I0*I1*I2/256, 256]
      // partitioned IterDomain: I0*I1*I2
      // coordinate IterDomain: I0*I1*I2/256
      // box IterDomain: 256
    
      // Step 3: define tile
      // We use dense tile here, so tile == box. Nothing to do here.
    
      // Step 4: schedule the shared memory tensor
      // By default, the allocation domain is the logical domain, which is already
      // in good shape for this case.
    
      constexpr int64_t number_of_stages = 2;
      // Step 5: schedule the consumer tensor
      smem_cache->split(0, 4);
      // [I0*I1*I2/256/4, 4, 256]
      smem_cache->split(0, number_of_stages);
      // [I0*I1*I2/256/4/2, 2, 4, 256]
    
      // [BIDx, 2, TIDx, Bulk]
      smem_cache->axis(0)->parallelize(ParallelType::BIDx);
      smem_cache->axis(2)->parallelize(ParallelType::TIDx);
      smem_cache->axis(3)->parallelize(ParallelType::Bulk);
    
      // Schedule the smem->gmem part
      output->merge(0);
      output->merge(0);
      output->split(0, 256);
      output->split(0, 4);
      output->split(0, number_of_stages);
      output->axis(0)->parallelize(ParallelType::BIDx);
      output->axis(3)->parallelize(ParallelType::TIDx);
    
      inlineAllAt(output, /*pos=*/2);
      smem_cache->circularBuffer(number_of_stages);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      std::vector<int64_t> shape(3, 300);
      auto t0 = at::randn(shape, options);
    
      // IterDomain 2 for the TMA load is parallelized with TIDx, so we generate
      // (threadIdx.x < 4) predicate. This thread predicate is incompatible with
      // circular buffering because we generate an ElectSync predicate that uses
      // a single thread.
      KernelExecutor ke;
      try {
        ke.compile(fusion.get(), {t0});
      } catch (const std::exception& e) {
        const char* reference =
            R"(This thread-parallelized TensorView T2_s_float[ iblockIdx.x15{( ceilDiv(( ceilDiv(( ceilDiv(( ( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ) * ( (( (( getMetaData(T0) )).logical_size ))[2] ) ), 256) ), 4) ), 2) )}, iS16{2}, ithreadIdx.x14{4}, iB12{256} ] ca_pos( 2 ) is incorrectly contained within a If-Then-Else with the ElectSync predicate.)";
        const char* str_match_pointer = strstr(e.what(), reference);

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    Hi @xwang233 , jit_binary_tests_sanitizer_20_H100_{1,2,3}/3 failed multiple times due to timeout. Is there a way to fix it? e.g. extend the allowed time for these tests?

    @liqiangxl liqiangxl marked this pull request as ready for review February 11, 2025 19:03
    @xwang233
    Copy link
    Collaborator

    xwang233 commented Feb 11, 2025

    Hi @xwang233 , jit_binary_tests_sanitizer_20_H100_{1,2,3}/3 failed multiple times due to timeout. Is there a way to fix it? e.g. extend the allowed time for these tests?

    The dashboard shows the H100 binary sanitizer jobs are mostly stable that take about 35 minutes for most other PR pipelines. Are there any changes from your PR that could cause the compute sanitizer to stuck?

    It could also be a runner issue that wrongly measures the time for each CI job. We can restart those jobs and see if it helps.

    @liqiangxl
    Copy link
    Collaborator Author

    Hi @xwang233 , jit_binary_tests_sanitizer_20_H100_{1,2,3}/3 failed multiple times due to timeout. Is there a way to fix it? e.g. extend the allowed time for these tests?

    The dashboard shows the H100 binary sanitizer jobs are mostly stable that take about 35 minutes for most other PR pipelines. Are there any changes from your PR that could cause the compute sanitizer to stuck?

    It could also be a runner issue that wrongly measures the time for each CI job. We can restart those jobs and see if it helps.

    Good point, I didn't check whether it stucks due to dead lock. Let me try it on a local node.

    Copy link
    Collaborator

    @zasdfgbnm zasdfgbnm left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM, but please check if there are hangs

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @xwang233
    Copy link
    Collaborator

    !test --dev

    @liqiangxl
    Copy link
    Collaborator Author

    !test --dev

    @liqiangxl
    Copy link
    Collaborator Author

    LGTM, but please check if there are hangs

    The hangs are real. It comes from some tests with small inner dim size e.g. 128 and large TMA load length, e.g. 256. It caused src address overflow which is not supported for 1D TMA. I skipped these tests, will add a check during the lowering of cp.async.bulk. Thanks @xwang233 for helping CI debug.
    ref
    the memory range [srcMem, srcMem + size - 1] must not overflow the source memory space. Otherwise, the behavior is undefined.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    4 participants