Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Compute bounds for Index scalars in lowered kernel #3850

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Feb 7, 2025

This extends #3599 by also computing the minimal dtype required by the expressions in the lowered kernel. Like in #3599, we cast from nvfuser_index_t to int32_t when passing coords to the TMA expression. However, unlike #3599 we actually verify that this is safe to do by checking the bounds of the inputs to those casts. This way we can safely use 64-bit indexing with TMA and know that we will not get silently incorrect results. Also, we will more commonly use 32-bit indexing because with TMA we often do not have extremely large values for index variables since TMA allows us to do multi-dimensional indexing.

Fixes #3601

TODO: add a few tests

Copy link

github-actions bot commented Feb 7, 2025

Review updated until commit 9b8ec99

Description

  • Added bounds-based index type calculation

  • Implemented BoundedInt struct for interval arithmetic

  • Created ScalarBoundsCalculator class for computing bounds of scalars

  • Updated KernelExecutor to compute index type after lowering


Changes walkthrough 📝

Relevant files
Enhancement
index_compute.cpp
Cast TMA box coordinates to int32_t                                           

csrc/index_compute.cpp

  • Included ir/builder.h
  • Added casting of TMA box coordinates to int32_t
  • +9/-0     
    executor.cpp
    Compute bounds for index scalars                                                 

    csrc/runtime/executor.cpp

  • Included additional headers for expression evaluation
  • Defined BoundedInt struct for interval arithmetic
  • Implemented ScalarBoundsCalculator class for bounds computation
  • Updated KernelExecutor::compile to compute index type after lowering
  • +556/-7 
    matmul_utils.cpp
    Remove index type check for Hopper matmul                               

    csrc/scheduler/matmul_utils.cpp

    • Removed index type check for Hopper matmul
    +0/-8     
    kernel.h
    Add setIndexType method to Kernel                                               

    csrc/kernel.h

    • Included type.h
    • Added setIndexType method to Kernel class
    +5/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The recip method in BoundedInt class does not handle the case where min and max are both negative correctly. It should handle the case where both min and max are negative and return the correct bounds.

    BoundedInt recip() const {
      if (!canBeZero()) {
        return BoundedInt{1L / max, 1L / min};
      }
    
      if (min == 0L) {
        if (max == 0) {
          return BoundedInt{};
        }
        return BoundedInt{1 / max, std::numeric_limits<int64_t>::max()};
      } else if (max == 0L) {
        return BoundedInt{std::numeric_limits<int64_t>::min(), 1L / min};
      } else {
        return BoundedInt{};
      }
    }
    Performance Concern

    The boundByDataType method in ScalarBoundsCalculator class does not handle the case where the bounds exceed the limits of the data type correctly. It should provide a more detailed error message or handle the overflow in a more robust way.

    //! Return the bounds, computed over all scalars in the fusion with the given
    //! data type
    BoundedInt boundByDataType(DataType dtype = DataType::Index) {
      BoundedInt ret;
      bool initialized = false;
      for (auto& [val, b] : bounds_) {
        if (val->dtype() != dtype) {
          continue;
        }
        if (!initialized) {
          ret = b;
          initialized = true;
        } else {
          ret.min = std::min(ret.min, b.min);
          ret.max = std::max(ret.max, b.max);
        }
        if (b.min < std::numeric_limits<int32_t>::min() ||
            b.max > std::numeric_limits<int32_t>::max()) {
        }
      }
      return ret;
    }
    Code Quality

    The operator* method in BoundedInt class does not handle overflow correctly. It should handle overflow in a more robust way, possibly by using a larger data type for intermediate calculations.

    BoundedInt operator*(const BoundedInt& other) const {
      // TODO: How should we handle overflow here?
      std::vector<int64_t> xs{
          min * other.min, min * other.max, max * other.min, max * other.max};
      return BoundedInt{
          *std::min_element(xs.begin(), xs.end()),
          *std::max_element(xs.begin(), xs.end())};
    }
    
    BoundedInt operator*(const int64_t other) const {
      if (other < 0L) {
        return BoundedInt{max * other, min * other};
      }
      return BoundedInt{min * other, max * other};
    }

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    Compute index type by bounding index expressions
    1 participant