Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the normalization scheduler to accept DID loop split. #3853

Merged
merged 4 commits into from
Feb 10, 2025
Merged

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Feb 8, 2025

I'm sure we'll need more tests to be confident, but this incremental PR feels good!

For #2563

@@ -32,16 +32,6 @@ NVF_API bool distributedEnabled() {

namespace {

std::unordered_set<IterDomain*> getShardedIterDomains(TensorView* tv) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used

@wujingyue wujingyue requested a review from naoyam February 8, 2025 01:03
Copy link

github-actions bot commented Feb 8, 2025

Review updated until commit 5723c20

Description

  • Added support for DID loop split in normalization scheduler.

  • Introduced getShardedLoopAxis function for loop axis retrieval.

  • Enhanced scheduleReductionTV to handle DID loop split.

  • Added a new test case DivideBySum for multidevice sharding.


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Add getShardedLoopAxis and remove getShardedIterDomains   

csrc/multidevice/utils.cpp

  • Removed getShardedIterDomains function.
  • Added getShardedLoopAxis function.
  • +15/-10 
    reduction_utils.cpp
    Update scheduleReductionTV for DID loop split                       

    csrc/scheduler/reduction_utils.cpp

  • Updated scheduleReductionTV to use getShardedLoopAxis.
  • Added error checks for DID loop split.
  • +11/-4   
    utils.h
    Declare getShardedLoopAxis                                                             

    csrc/multidevice/utils.h

    • Added declaration for getShardedLoopAxis.
    +4/-0     
    Tests
    test_multidevice_sharding.cpp
    Add DivideBySum test case                                                               

    tests/cpp/test_multidevice_sharding.cpp

    • Added DivideBySum test case for multidevice sharding.
    +42/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Assumption Check

    The PR assumes that the DIDx domain is always the outermost domain in the loop. This assumption should be validated with more test cases to ensure correctness.

    int64_t sharded_axis = getShardedLoopAxis(reduction_tv, ParallelType::DIDx);
    if (sharded_axis >= 0) {
      NVF_ERROR(
          sharded_axis == 0,
          "Expect 1D mesh and DIDx only appear outermost in loop, but found: ",
          reduction_tv->getLoopDomain());
    }
    Error Handling

    The error handling in getShardedLoopAxis could be improved by providing more context in the error message, such as the specific parallel type that caused the failure.

    isParallelTypeDeviceDim(parallel_type),
    "Expect a DID but found: ",
    parallel_type);
    Test Coverage

    While a new test DivideBySum is added, it would be beneficial to add more test cases to cover different scenarios and edge cases for the new functionality.

    TEST_F(MultiDeviceTest, DivideBySum) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int64_t d = communicator_->size();
    
      // [b, h, s, s]
      TensorView* x = makeContigTensor(4);
      TensorView* sum_x = sum(x, {-1});
      TensorView* sum_x_broadcasted = broadcast(sum_x, {false, false, false, true});
      TensorView* y = div(x, sum_x_broadcasted);
      fusion->addInput(x);
      fusion->addOutput(y);
    
      auto mesh = DeviceMesh::createForNumDevices(d);
      for (auto* tv : {x, sum_x, sum_x_broadcasted, y}) {
        tv->setDeviceMesh(mesh);
        tv->split(1, d, /*inner_split=*/false);
        tv->axis(1)->parallelize(ParallelType::DIDx);
        tv->reorder({{1, 0}});
      }
      for (auto* tv : {x, y}) {
        tv->setAllocationDomain(tv->getLoopDomain(), true);
      }
    
      const int64_t b = 2;
      const int64_t h = d * 3;
      const int64_t s = 5;
      at::Tensor unsharded_x_tensor = at::randint(5, {b, h, s, s}, tensor_options);
      at::Tensor x_tensor = shardTensor(unsharded_x_tensor, x);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor y_tensor = executor_cache.runFusionWithInputs({x_tensor})[0];
      testValidate(
          executor_cache.fusion(),
          {y_tensor},
          {x_tensor},
          {x_tensor / x_tensor.sum(-1, true)},
          __LINE__,
          __FILE__);
    }

    @wujingyue wujingyue requested a review from Priya2698 February 8, 2025 01:03
    @wujingyue
    Copy link
    Collaborator Author

    !test

    in the same way as ExpressionEvaluator::bindTensorDomain and several
    other places. Caveat: having to fix multiple places in the same way
    probably indicates a pre-existing duplication of logic.
    @wujingyue wujingyue changed the base branch from wjy/gdb to bug3817 February 8, 2025 07:50
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @naoyam naoyam mentioned this pull request Feb 10, 2025
    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    Base automatically changed from bug3817 to main February 10, 2025 18:45
    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @Priya2698 Priya2698 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    @wujingyue
    Copy link
    Collaborator Author

    CI failures are due to http://nv/exg

    @wujingyue wujingyue merged commit fd96f84 into main Feb 10, 2025
    49 of 52 checks passed
    @wujingyue wujingyue deleted the wjy/norm branch February 10, 2025 23:54
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    3 participants