diff --git a/CMakeLists.txt b/CMakeLists.txt index 04d6fd3e52a..b3aa0e841b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -157,6 +157,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/id_model/validation_utils.cpp ${NVFUSER_SRCS_DIR}/index_compute.cpp ${NVFUSER_SRCS_DIR}/instrumentation.cpp + ${NVFUSER_SRCS_DIR}/interval_analysis.cpp ${NVFUSER_SRCS_DIR}/ir/base_nodes.cpp ${NVFUSER_SRCS_DIR}/ir/builder.cpp ${NVFUSER_SRCS_DIR}/ir/cloner.cpp @@ -580,6 +581,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_indexing.cpp ${NVFUSER_ROOT}/tests/cpp/test_indexing_advanced.cpp ${NVFUSER_ROOT}/tests/cpp/test_inlining.cpp + ${NVFUSER_ROOT}/tests/cpp/test_interval_analysis.cpp ${NVFUSER_ROOT}/tests/cpp/test_iter_visitor.cpp ${NVFUSER_ROOT}/tests/cpp/test_linked_hash_map.cpp ${NVFUSER_ROOT}/tests/cpp/test_loop_domain_scheduling.cpp diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 036c573bbfc..c47d62d1c82 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -2718,6 +2719,14 @@ std::pair Index::getCpAsyncBulkGmemIndex( auto indices_inner_to_outer = indexer.getIndexFor(ldst, !is_load, ids_to_index, loops); + // These are the box coordinates of the TMA box, which must be of type + // int32_t. Possible overflow in each of these dims should be checked + // elsewhere. + for (size_t i : c10::irange(indices_inner_to_outer.size())) { + indices_inner_to_outer[i] = + IrBuilder::maybeCastExpr(DataType::Int32, indices_inner_to_outer[i]); + } + auto coordinate = IrBuilder::arrayExpr(indices_inner_to_outer); auto descriptor = tma_info.tensorMap(); if (is_load) { diff --git a/csrc/interval_analysis.cpp b/csrc/interval_analysis.cpp new file mode 100644 index 00000000000..cb830a40eb7 --- /dev/null +++ b/csrc/interval_analysis.cpp @@ -0,0 +1,646 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include + +#include +#include +#include + +#include +#include +#include + +namespace nvfuser { + +BoundedInt BoundedInt::operator+(const BoundedInt& other) const { + return BoundedInt{min + other.min, max + other.max}; +} + +BoundedInt BoundedInt::operator+(const int64_t other) const { + return BoundedInt{min + other, max + other}; +} + +BoundedInt BoundedInt::operator-(const BoundedInt& other) const { + return BoundedInt{min - other.max, max - other.min}; +} + +BoundedInt BoundedInt::operator-(const int64_t other) const { + return BoundedInt{min - other, max - other}; +} + +BoundedInt BoundedInt::operator-() const { + return BoundedInt{-max, -min}; +} + +BoundedInt BoundedInt::operator*(const BoundedInt& other) const { + // TODO: How should we handle overflow here? + std::vector xs{ + min * other.min, min * other.max, max * other.min, max * other.max}; + return BoundedInt{ + *std::min_element(xs.begin(), xs.end()), + *std::max_element(xs.begin(), xs.end())}; +} + +BoundedInt BoundedInt::operator*(const int64_t other) const { + if (other < 0L) { + return BoundedInt{max * other, min * other}; + } + return BoundedInt{min * other, max * other}; +} + +// Division ranges are computed differently based on whether the numerator and +// denominator are positive or negative. Because of this, we split the numerator +// into a negative and a non-positive range and we split the denominator into a +// negative and a positive range. Then we compute the bounds for every non-empty +// combination of ranges, of which there are at most four. The final bound is +// the union of those intervals. +// +// For example, if we have -2 <= a <= 1 and -1 <= b <= 2 and we want to compute +// bounds for a / b, we have the following cases to handle +// +// -2 / -1 = 2 +// -1 / -1 = 1 +// 0 / -1 = 0 +// 1 / -1 = -1 +// -2 / 0 = ERROR +// -1 / 0 = ERROR +// 0 / 0 = ERROR +// 1 / 0 = ERROR +// -2 / 1 = -2 +// -1 / 1 = -1 +// 0 / 1 = 0 +// 1 / 1 = 1 +// -2 / 2 = -1 +// -1 / 2 = 0 +// 0 / 2 = 0 +// 1 / 2 = 0 +// +// We split a into intervals -2 <= a <= -1 and 0 <= a <= 1 which includes zero. +// We split b, on the other hand, into -1 <= b <= -1 and 1 <= b <= 2, excluding +// the error cases. Then for all four combinations we compute a single interval +// before computing the union of those four intervals. +// +// -2 / -1 = 2 +// -1 / -1 = 1 => [1, 2] +// +// -2 / 1 = -2 +// -1 / 1 = -1 +// -2 / 2 = -1 +// -1 / 2 = 0 => [-2, 0] +// +// 0 / -1 = 0 +// 1 / -1 = -1 => [-1, 0] +// +// 0 / 1 = 0 +// 1 / 1 = 1 +// 0 / 2 = 0 +// 1 / 2 = 0 => [0, 1] +// +// The result we return in this case is the union of these four intervals which +// is [-2, 2] +#define DEFINE_DIVISION_LIKE_OP(a, b, pospos, posneg, negpos, negneg) \ + NVF_ERROR( \ + b.min != 0L || b.max != 0L, \ + "Found denominator that cannot be non-zero: ", \ + b); \ + const auto split_ranges_around_zero = [](const BoundedInt& b, \ + bool include_zero) { \ + std::vector ranges; \ + if (b.min < 0L) { \ + ranges.push_back({b.min, std::min(b.max, -1L)}); \ + } \ + int64_t min_nonneg_val = include_zero ? 0L : 1L; \ + if (b.max >= min_nonneg_val) { \ + ranges.push_back({std::max(b.min, min_nonneg_val), b.max}); \ + } \ + return ranges; \ + }; \ + const std::vector numer_ranges = \ + split_ranges_around_zero(a, /*include_zero=*/true); \ + const std::vector denom_ranges = \ + split_ranges_around_zero(b, /*include_zero=*/false); \ + \ + BoundedInt result; \ + bool first = true; \ + for (const BoundedInt& numer : numer_ranges) { \ + for (const BoundedInt& denom : denom_ranges) { \ + BoundedInt simple_range; \ + /* numer and denom are each either only negative or only positive */ \ + if (numer.min >= 0) { \ + if (denom.min > 0) { \ + simple_range = pospos(numer, denom); \ + } else { \ + simple_range = posneg(numer, denom); \ + } \ + } else { \ + if (denom.min > 0) { \ + simple_range = negpos(numer, denom); \ + } else { \ + simple_range = negneg(numer, denom); \ + } \ + } \ + /* Result is the union over all of the simple ranges */ \ + if (first) { \ + result = simple_range; \ + } else { \ + result.min = std::min(result.min, simple_range.min); \ + result.max = std::max(result.max, simple_range.max); \ + } \ + first = false; \ + } \ + } \ + return result; +BoundedInt BoundedInt::operator/(const BoundedInt& other) const { + // positive over positive + const auto pospos = [](const BoundedInt& numer, const BoundedInt& denom) { + return BoundedInt{numer.min / denom.max, numer.max / denom.min}; + }; + // positive over negative + const auto posneg = [](const BoundedInt& numer, const BoundedInt& denom) { + return BoundedInt{numer.max / denom.max, numer.min / denom.min}; + }; + // negative over positive + const auto negpos = [](const BoundedInt& numer, const BoundedInt& denom) { + return BoundedInt{numer.min / denom.min, numer.max / denom.max}; + }; + // negative over negative + const auto negneg = [](const BoundedInt& numer, const BoundedInt& denom) { + return BoundedInt{numer.max / denom.min, numer.min / denom.max}; + }; + DEFINE_DIVISION_LIKE_OP(*this, other, pospos, posneg, negpos, negneg); +} + +BoundedInt ceilDiv(const BoundedInt& a, const BoundedInt& b) { + // positive over positive + const auto pospos = [](const BoundedInt& numer, const BoundedInt& denom) { + return BoundedInt{ + ceilDiv(numer.min, denom.max), ceilDiv(numer.max, denom.min)}; + }; + // positive over negative + const auto posneg = [](const BoundedInt& numer, const BoundedInt& denom) { + return BoundedInt{ + ceilDiv(numer.max, denom.max), ceilDiv(numer.min, denom.min)}; + }; + // negative over positive + const auto negpos = [](const BoundedInt& numer, const BoundedInt& denom) { + return BoundedInt{ + ceilDiv(numer.min, denom.min), ceilDiv(numer.max, denom.max)}; + }; + // negative over negative + const auto negneg = [](const BoundedInt& numer, const BoundedInt& denom) { + return BoundedInt{ + ceilDiv(numer.max, denom.min), ceilDiv(numer.min, denom.max)}; + }; + DEFINE_DIVISION_LIKE_OP(a, b, pospos, posneg, negpos, negneg); +} + +// Modulo is the remainder op and satisfies +// +// a % b = a - (a / b) * b +// +// for any a and b. Since division in C++ is truncdiv and rounds toward zero, +// (a / b) * b is the same as (a / (-b)) * (-b) and never maps a negative value +// to a more negative value. This means the remainder is negative when a is +// negative and non-negative when a is non-negative. +// +// Note that like for division, we ignore b==0. Additionally, if we can +// guarantee 0 <= a < b then a % b = a so we can just use a's bounds. This +// is also the case if b < a < 0. +BoundedInt BoundedInt::operator%(const BoundedInt& other) const { + // positive mod positive + const auto pospos = [](const BoundedInt& numer, const BoundedInt& denom) { + if (numer.max < denom.min) { + // mod op is trivial + return numer; + } else { + return BoundedInt{0L, denom.max - 1L}; + } + }; + // positive mod negative + const auto posneg = [&pospos]( + const BoundedInt& numer, const BoundedInt& denom) { + return pospos(numer, -denom); + }; + // negative mod positive + const auto negpos = [](const BoundedInt& numer, const BoundedInt& denom) { + if (numer.min > -denom.min) { + // mod op is trivial + return numer; + } else { + return BoundedInt{1L - denom.max, 0}; + } + return BoundedInt{ + ceilDiv(numer.min, denom.min), ceilDiv(numer.max, denom.max)}; + }; + // negative mod negative + const auto negneg = [&negpos]( + const BoundedInt& numer, const BoundedInt& denom) { + return negpos(numer, -denom); + }; + DEFINE_DIVISION_LIKE_OP(*this, other, pospos, posneg, negpos, negneg); +} +#undef DEFINE_DIVISION_LIKE_OP + +BoundedInt BoundedInt::operator/(const int64_t other) const { + return *this / BoundedInt{other, other}; +} + +BoundedInt BoundedInt::operator%(const int64_t other) const { + return *this % BoundedInt{other, other}; +} + +//! Returns the number of high bits that must be common among all integers in +//! this interval +//! +//! Example: +//! min = 0b10101010 +//! max = 0b10101100 +//! +//! All numbers in this range are of the form 0b10101XXX +//! different_bits = 0b110 +//! num_common_bits = 61 +int64_t BoundedInt::countCommonHighBits() const { + // Reinterpret integers as unsigned, so that bitwise ops and + // std::countl_zero are well-defined + uint64_t different_bits = (*reinterpret_cast(&max)) ^ + (*reinterpret_cast(&min)); +#if __cplusplus < 202002L + // TODO: add countl_zero to csrc/C++20/ somewhere for C++17 backward + // compatibility + int64_t num_common_bits = 64L; + while (different_bits != 0L) { + different_bits >>= 1; + num_common_bits--; + } + return num_common_bits; +#else + return (int64_t)std::countl_zero(different_bits); +#endif +} + +// For bitwise operations, we consider the range of each bit independently. +// Consider a number x=0bABCDE. If min(x)=max(x), then each of the bits A, B, +// C, D, and E are fixed. However, if there is a range of values possible then +// a subset of these bits could take on either 0 or 1. Suppose the range of x +// is [0b01010, 0b01100]. Then we know that A=0, B=1, and C, D, and E can have +// either value. Generally speaking, for numbers lying between two positive +// integers, we know the lower-most K many bits are not fixed, where K is +// PRECISION-(number of high bits in common). We can compute the largest K +// between this and other, then we know that the XOR between these two values +// can have any value for that many lower bits and all the higher bits are +// determined by XORing the two min (or max) bounds with one another. +// +// [Note on twos-complement negative integers] +// Since twos-complement negative integers can be envisioned as simply +// stacking (without flipping) the negative values at the right side of the +// positive values, we can apply the same algorithm regardless of signedness. +#define DEFINE_BITWISE_BINARY_OP(op) \ + BoundedInt BoundedInt::operator op(const BoundedInt & other) const { \ + /* New interval has this many fixed bits */ \ + int64_t var_bits = \ + 64L - std::min(countCommonHighBits(), other.countCommonHighBits()); \ + /* Mask everything below the higher fixed_bits */ \ + int64_t low_mask = (1 << var_bits) - 1; \ + int64_t new_min = (min op other.min) & (~low_mask); \ + int64_t new_max = new_min + low_mask; \ + return {new_min, new_max}; \ + } +DEFINE_BITWISE_BINARY_OP(&) +DEFINE_BITWISE_BINARY_OP(|) +DEFINE_BITWISE_BINARY_OP(^) +#undef DEFINE_BITWISE_BINARY_OP + +BoundedInt BoundedInt::operator~() const { + // New interval has this many fixed bits + int64_t var_bits = 64L - countCommonHighBits(); + // Mask everything below the higher fixed_bits + int64_t low_mask = (1 << var_bits) - 1; // 0b00111 + int64_t new_min = (~min) & (~low_mask); // 0b01000 + int64_t new_max = new_min + low_mask; // 0b01111 + return {new_min, new_max}; +} + +// Index types are always signed (always going to be true?). This means that a +// right shift is _arithmetic_ right shift, so if the argument is negative, +// after the shift it stays negative. +BoundedInt BoundedInt::operator>>(const BoundedInt& other) const { + NVF_ERROR(other.min >= 0, "Shift operator must not have negative shift"); + // Note: arithmetic right shift makes negative values closer to zero, as it + // does for positive values + int64_t new_min = (min < 0L) ? (min >> other.min) : (min >> other.max); + int64_t new_max = (max < 0L) ? (max >> other.max) : (max >> other.min); + return {new_min, new_max}; +} + +BoundedInt BoundedInt::operator<<(const BoundedInt& other) const { + NVF_ERROR( + min >= 0, + "Left shift must not be applied to number that can be negative"); + NVF_ERROR(other.min >= 0, "Shift operator must not have negative shift"); + return {min << other.min, max << other.max}; +} + +ScalarBoundsCalculator::ScalarBoundsCalculator( + kir::Kernel* kernel, + ExpressionEvaluator& expr_eval, + const LaunchParams& launch_params) + : expr_eval_(expr_eval), launch_params_(launch_params) { + if (kernel != nullptr) { + // If kernel is given, process all exprs in it + kir::IrVisitor::handle(kernel->topLevelExprs()); + } +} + +//! Look at all casts (T)x where x is of type nvfuser_index_t, to ensure that +//! these casts are safe i.e. that the bounds of x do not overflow those +//! representable by T. +bool ScalarBoundsCalculator::castsFromIndexAreSafe() const { + return std::all_of( + casts_from_index_.begin(), casts_from_index_.end(), [&](UnaryOp* cast) { + const BoundedInt& bounds = bounds_.at(cast->in()); + DataType out_dtype = cast->out()->dtype(); + NVF_ERROR( + std::holds_alternative(out_dtype.type), + "Expected PrimDataType but found ", + out_dtype); + switch (std::get(out_dtype.type)) { + case DataType::Int: + return true; + case DataType::Int32: + return bounds.min >= std::numeric_limits::min() && + bounds.max <= std::numeric_limits::max(); + case DataType::Short: + return bounds.min >= std::numeric_limits::min() && + bounds.max <= std::numeric_limits::max(); + case DataType::Char: + return bounds.min >= std::numeric_limits::min() && + bounds.max <= std::numeric_limits::max(); + case DataType::UInt64: + // upper limit is above that of int64_t, which is the type of + // bounds.max + return bounds.min >= 0L; + case DataType::UInt32: + return bounds.min >= std::numeric_limits::min() && + bounds.max <= std::numeric_limits::max(); + case DataType::UInt16: + return bounds.min >= std::numeric_limits::min() && + bounds.max <= std::numeric_limits::max(); + case DataType::Byte: + return bounds.min >= std::numeric_limits::min() && + bounds.max <= std::numeric_limits::max(); + return true; + default: + NVF_THROW("Unhandled DataType ", out_dtype); + return false; + } + }); +} + +std::ostream& operator<<(std::ostream& out, const BoundedInt& b) { + out << "BoundedInt[" << b.min << ", " << b.max << "]"; + return out; +} + +void ScalarBoundsCalculator::setBounds(Val* val, const BoundedInt& bounds) { + bounds_[val] = bounds; +} + +void ScalarBoundsCalculator::setBounds(Val* val, int64_t min, int64_t max) { + setBounds(val, {min, max}); +} + +void ScalarBoundsCalculator::setAsUnbounded(Val* val) { + setBounds( + val, + std::numeric_limits::min(), + std::numeric_limits::max()); +} + +void ScalarBoundsCalculator::setBoundsForNamedScalar(NamedScalar* scalar) { + if (std::optional ptype = scalar->getParallelDim(); + ptype.has_value()) { + // scalar is the extent of a parallel dim, so evaluate it + int64_t dim_int = launch_params_.getDim(ptype.value()); + setBounds(scalar, dim_int, dim_int); + } else if (std::optional ptype = scalar->getParallelIndex(); + ptype.has_value()) { + // scalar is the index of a parallel dim, so bound it by [0, dim-1] + int64_t dim_int = launch_params_.getDim(ptype.value()); + setBounds(scalar, 0L, dim_int - 1L); + } else { + // We do not know how to bound other NamedScalars + setAsUnbounded(scalar); + } +} + +// Non-recursive function to look up bounds if they have been recorded +// already. For NamedScalars, also look in parallel dimension map. Finally, +// try and evaluate. If all this fails, return nullopt. +std::optional ScalarBoundsCalculator::maybeGetBounds(Val* val) { + if (auto it = bounds_.find(val); it != bounds_.end()) { + return it->second; + } else if (auto* scalar = dynamic_cast(val)) { + setBoundsForNamedScalar(scalar); + return bounds_.at(val); + } else if (PolymorphicValue pv = expr_eval_.evaluate(val, known_scalars_); + pv.hasValue()) { + setBounds(val, pv.as(), pv.as()); + return bounds_.at(val); + } else { + return std::nullopt; + } +} + +void ScalarBoundsCalculator::dispatch(Statement* stmt) { + kir::IrVisitor::dispatch(stmt); +} +void ScalarBoundsCalculator::dispatch(Val* val) { + if (val->isIntegralScalar() && val->definition() != nullptr) { + // This will kick off recursive dispatch + dispatch(val->definition()); + } + kir::IrVisitor::dispatch(val); +} + +void ScalarBoundsCalculator::dispatch(Expr* expr) { + if (auto* uop = dynamic_cast(expr)) { + if (uop->getUnaryOpType() == UnaryOpType::ToUnsignedSmemAddr) { + // This is a workaround for a limitation in being able to evaluate + // metadata for tensors with swizzles. + // TODO: is there a better workaround? + int64_t max_smem_addr = + (int64_t)at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - + 1L; + known_scalars_[uop->out()] = max_smem_addr; + setBounds(uop->out(), 0L, max_smem_addr); + return; + } + if (uop->getUnaryOpType() == UnaryOpType::Cast && + uop->in()->dtype() == DataType::Index && + uop->out()->isIntegralScalar()) { + // Collect casts _from_ Index scalars, so that we can check that these are + // safe. + casts_from_index_.push_back(uop); + } + } + + if (!expr->isA() && + std::all_of( + expr->outputs().begin(), expr->outputs().end(), [](Val* outp) { + return !outp->isIntegralScalar(); + })) { + // We don't need to process expressions that do not produce integers. + // Note that for loops do "produce" their index variables for our + // purposes. + // It is possible that the expression outputs are constant scalars, so + // try and compute them here. + for (Val* outp : expr->outputs()) { + if (outp->isIntegralScalar()) { + PolymorphicValue pv = expr_eval_.evaluate(outp, known_scalars_); + if (pv.hasValue()) { + setBounds(outp, pv.as(), pv.as()); + } + } + } + return; + } + // Inline scalar expressions might not have their inputs processed yet + // The following loop ensures that all inputs to expr have recorded bounds. + std::vector immediate_inputs = expr->inputs(); + if (auto* loop = dynamic_cast(expr)) { + immediate_inputs.push_back(loop->start()); + immediate_inputs.push_back(loop->stop()); + immediate_inputs.push_back(loop->step()); + } + for (Val* inp : immediate_inputs) { + if (!inp->isIntegralScalar()) { + continue; + } + std::optional inp_bounds = maybeGetBounds(inp); + if (!inp_bounds.has_value()) { + // If inp is not constant, then we can try bounding its inputs, if + // they are int scalars. If it has no producers that are int scalars, + // and it is unbound, then we cannot provide any bounds for it. + if (Expr* def = inp->definition(); def && + std::any_of(def->inputs().begin(), + def->inputs().end(), + [](Val* definp) { return definp->isIntegralScalar(); })) { + // Recursively dispatch definitions + dispatch(def); + } else { + setAsUnbounded(inp); + } + } + } + kir::IrVisitor::dispatch(expr); +} + +int64_t ScalarBoundsCalculator::evalInt(Val* val) { + return expr_eval_.evaluate(val).as(); +} + +void ScalarBoundsCalculator::handle(ForLoop* loop) { + // Set bounds for the loop variable + BoundedInt start = bounds_.at(loop->start()); + BoundedInt stop = bounds_.at(loop->stop()); + setBounds(loop->index(), start.min, stop.max - 1L); + kir::IrVisitor::handle(loop); +} + +void ScalarBoundsCalculator::handle(LoadStoreOp* lsop) { + if (lsop->in()->isIntegralScalar()) { + setBounds(lsop->out(), bounds_.at(lsop->in())); + } +} + +void ScalarBoundsCalculator::handle(UnaryOp* uop) { + BoundedInt a = bounds_.at(uop->in()); + BoundedInt result; + switch (uop->getUnaryOpType()) { + case UnaryOpType::Abs: + result = { + std::min(std::abs(a.min), std::abs(a.max)), + std::max(std::abs(a.min), std::abs(a.max))}; + break; + case UnaryOpType::BitwiseNot: + result = ~a; + break; + case UnaryOpType::Cast: + // This assumes there is no loss or overflow, since those should not + // occur in our kernels. We can check that later for index types using + // castsFromIndexAreSafe(). + result = a; + break; + case UnaryOpType::Neg: + result = {-a.max, -a.min}; + break; + default: + NVF_THROW( + "Propagation of integer bounds is not yet implemented for ", + uop->toString()); + } + setBounds(uop->out(), result); +} + +void ScalarBoundsCalculator::handle(BinaryOp* bop) { + BoundedInt a = bounds_.at(bop->lhs()); + BoundedInt b = bounds_.at(bop->rhs()); + BoundedInt result; + switch (bop->getBinaryOpType()) { + case BinaryOpType::Add: + result = a + b; + break; + case BinaryOpType::BitwiseAnd: + result = a & b; + break; + case BinaryOpType::BitwiseOr: + result = a | b; + break; + case BinaryOpType::BitwiseXor: + result = a ^ b; + break; + case BinaryOpType::CeilDiv: + result = ceilDiv(a, b); + break; + case BinaryOpType::Div: + result = a / b; + break; + case BinaryOpType::Mod: + result = a % b; + break; + case BinaryOpType::Mul: + result = a * b; + break; + case BinaryOpType::Lshift: + result = a << b; + break; + case BinaryOpType::Rshift: + result = a >> b; + break; + case BinaryOpType::Sub: + result = a - b; + break; + default: + NVF_THROW( + "Propagation of integer bounds is not yet implemented for ", + bop->toString()); + } + setBounds(bop->out(), result); +} + +void ScalarBoundsCalculator::handle(TernaryOp* top) { + switch (top->getTernaryOpType()) { + default: + NVF_THROW( + "Propagation of integer bounds is not yet implemented for ", + top->toString()); + } +} + +} // namespace nvfuser diff --git a/csrc/interval_analysis.h b/csrc/interval_analysis.h new file mode 100644 index 00000000000..05284a0fdfe --- /dev/null +++ b/csrc/interval_analysis.h @@ -0,0 +1,153 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace nvfuser { + +//! This holds inclusive bounds for a particular integer Val. We will propagate +//! one of these for each integer scalar in the lowered kernel. That propagation +//! makes use of the operators defined in this class. +//! +//! Note that this class does not necessarily represent tight bounds on +//! complicated expressions. For example: +//! +//! for a in iS0{n} +//! b = a * 2 +//! c = b % 8 +//! +//! In our analysis, we will define the following ranges as BoundedInt values: +//! +//! a \in [0, n-1] +//! b \in [0, (n-1) * 2] +//! c \in [0, 7] (assuming 7 is not in the range of n) +//! +//! These bounds are correct even though we could use a tighter bound for c of +//! [0, 6] since we know that b must be a multiple of 2, so c must be 0, 2, 4, +//! or 6 only. This kind of analysis is not provided by the simplistic +//! propagation using a BoundedInt interval at each stage. +struct BoundedInt { + int64_t min = 0L; + int64_t max = 0L; + + //! Returns the number of high bits that must be common among all integers in + //! this interval + //! + //! Example: + //! min = 0b10101010 + //! max = 0b10101100 + //! + //! All numbers in this range are of the form 0b10101XXX + //! different_bits = 0b110 + //! num_common_bits = 61 + int64_t countCommonHighBits() const; + + BoundedInt operator+(const BoundedInt& other) const; + BoundedInt operator+(const int64_t other) const; + BoundedInt operator-(const BoundedInt& other) const; + BoundedInt operator-(const int64_t other) const; + BoundedInt operator-() const; + BoundedInt operator*(const BoundedInt& other) const; + BoundedInt operator*(const int64_t other) const; + BoundedInt operator/(const BoundedInt& other) const; + BoundedInt operator/(const int64_t other) const; + BoundedInt operator%(const BoundedInt& other) const; + BoundedInt operator%(const int64_t other) const; + + BoundedInt operator^(const BoundedInt& other) const; + BoundedInt operator&(const BoundedInt& other) const; + BoundedInt operator|(const BoundedInt& other) const; + BoundedInt operator~() const; + BoundedInt operator>>(const BoundedInt& other) const; + BoundedInt operator<<(const BoundedInt& other) const; + + bool operator==(const BoundedInt& other) const { + return min == other.min && max == other.max; + } + bool operator!=(const BoundedInt& other) const { + return !(*this == other); + } +}; + +BoundedInt ceilDiv(const BoundedInt& numer, const BoundedInt& denom); + +std::ostream& operator<<(std::ostream& out, const BoundedInt& b); + +//! This class traverses the expressions in a kir::Kernel and defines a +//! BoundedInt for each integer scalar encountered. The range is determined by +//! the scalar's definition along with the rules defined in BoundedInt. +class ScalarBoundsCalculator : kir::IrVisitor { + public: + ScalarBoundsCalculator( + kir::Kernel* kernel, + ExpressionEvaluator& expr_eval, + const LaunchParams& launch_params); + + ~ScalarBoundsCalculator() override = default; + + //! Look at all casts (T)x where x is of type nvfuser_index_t, to ensure that + //! these casts are safe i.e. that the bounds of x do not overflow those + //! representable by T. + bool castsFromIndexAreSafe() const; + + //! NamedScalar bounds are set using the launch_params_. For example + //! `blockDim.x` is set to the [blockDim.x, blockDim.x] and `threadIdx.x` is + //! set to [0, blockDim.x - 1]. + void setBoundsForNamedScalar(NamedScalar* scalar); + + using kir::IrVisitor::dispatch; + //! These public methods are useful for processing an individual statement to + //! get bounds of all its producers + void dispatch(Statement* statement) final; + void dispatch(Expr* expr) final; + void dispatch(Val* val) final; + + void setBounds(Val* val, const BoundedInt& bounds); + void setBounds(Val* val, int64_t min, int64_t max); + void setAsUnbounded(Val* val); + + //! Non-recursive function to look up bounds if they have been recorded + //! already. For NamedScalars, also look in parallel dimension map and bound + //! if it has not already been bounded. Finally, try and evaluate constants. + //! If all this fails, return nullopt. + std::optional maybeGetBounds(Val* val); + + private: + //! Evaluate val using our ExpressionEvaluator + int64_t evalInt(Val* val); + + using kir::IrVisitor::handle; + + void handle(ForLoop* loop) final; + void handle(LoadStoreOp* lsop) final; + void handle(UnaryOp* uop) final; + void handle(BinaryOp* bop) final; + void handle(TernaryOp* top) final; + + private: + ExpressionEvaluator& expr_eval_; + const LaunchParams& launch_params_; + std::unordered_map bounds_; + std::unordered_map known_scalars_; + std::vector casts_from_index_; +}; + +} // namespace nvfuser diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index bcca2ac66dc..63c9aa4d6d0 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -599,18 +599,20 @@ std::vector BinaryOp::evaluate( return {lhs * rhs}; break; case BinaryOpType::Div: + NVF_CHECK( + !rhs.is() || rhs != 0, "Integer division by zero detected"); return {lhs / rhs}; break; case BinaryOpType::Mod: - NVF_CHECK(rhs != 0); + NVF_CHECK(rhs != 0, "Modulo zero detected"); return {lhs % rhs}; break; case BinaryOpType::Fmod: - NVF_CHECK(rhs != 0); + NVF_CHECK(rhs != 0, "Float modulo zero detected"); return {fmod(lhs, rhs)}; break; case BinaryOpType::CeilDiv: - NVF_CHECK(rhs != 0); + NVF_CHECK(rhs != 0, "CeilDiv by zero detected"); return {ceildiv(lhs, rhs)}; break; case BinaryOpType::LogicalAnd: diff --git a/csrc/kernel.cpp b/csrc/kernel.cpp index 6574172a66e..dd4d8313ebd 100644 --- a/csrc/kernel.cpp +++ b/csrc/kernel.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include @@ -70,6 +71,11 @@ class KernelIrScanner : private IrVisitor { // Do we have any elect sync predicates? if (uop->getUnaryOpType() == UnaryOpType::ElectSync) { summary_.has_elect_sync_predicate = true; + } else if ( + uop->getUnaryOpType() == UnaryOpType::Cast && + uop->in()->dtype() == DataType::Index && + uop->out()->dtype() != DataType::Int) { + summary_.has_narrowing_index_casts = true; } } diff --git a/csrc/kernel.h b/csrc/kernel.h index 58470692133..f568ae48fa4 100644 --- a/csrc/kernel.h +++ b/csrc/kernel.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -127,6 +128,10 @@ struct KernelSummary { //! Reason: At runtime, we check that at least a single warp along TIDx axis //! exists. bool has_elect_sync_predicate = false; + + //! Do we have any possibly narrowing casts from DataType::Index variables? + //! These need to be validated to prevent overflow. + bool has_narrowing_index_casts = false; }; class KernelPerformanceProfile { @@ -220,6 +225,10 @@ class NVF_API Kernel final : public Fusion { return index_type_; } + void setIndexType(PrimDataType new_index_type) { + index_type_ = new_index_type; + } + //! Checks if parallel type is padded bool isParallelTypePadded(ParallelType ptype) const { return ptype == ParallelType::TIDx && diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index d22582b556a..9f19b0bd935 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -214,27 +214,13 @@ void KernelExecutor::compile( !(compile_params.index_type.value() == PrimDataType::Int32 && arg_index_type == PrimDataType::Int), "Compilation with int32 is requested but int64 is required for the arguments"); - NVF_ERROR( - !has_cp_async_bulk || - (compile_params.index_type.value() == PrimDataType::Int32), - "Compilation with int64 is requested but int32 is required because ", - "of TMA operations."); - - } else if (arg_index_type == PrimDataType::Int) { + } else { // If the given compile option doesn't specify the index type, and // the arguments require 64-bit indexing, we need to use 64-bit // indexing. Note that if the arg type is 32-bit, it doesn't mean // it's safe to use 32-bit for the whole kernel, so unless it's // specified through CompileParams, we do not use 32-bit indexing. compile_params.index_type = arg_index_type; - NVF_ERROR( - !has_cp_async_bulk, - "Compilation with int64 is required based on input arguments, but ", - "int32 is required because of TMA operations."); - } else if (has_cp_async_bulk) { - // TMA operations require 32-bit indexing. - compile_params.index_type = PrimDataType::Int32; - } else { compile_params.index_type = arg_index_type; } @@ -655,6 +641,9 @@ void KernelExecutor::initializeExecutorEntry( executor_utils::validateCircularBuffering( compiled_kernel_->kernel(), expr_eval); + executor_utils::validateIndexCasts( + compiled_kernel_->kernel(), expr_eval, launch_params); + // Check that a full warp exists in blockDim.x if the kernel contains // ElectSync predicate. constexpr int64_t warp_size = 32; diff --git a/csrc/runtime/executor_utils.cpp b/csrc/runtime/executor_utils.cpp index ebbc57fd1ec..b750337cbe8 100644 --- a/csrc/runtime/executor_utils.cpp +++ b/csrc/runtime/executor_utils.cpp @@ -13,8 +13,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -721,5 +723,19 @@ std::unique_ptr getParallelIterExtents( return parallel_iter_extents_ptr; } +void validateIndexCasts( + kir::Kernel* kernel, + ExpressionEvaluator& expr_eval, + const LaunchParams& launch_params) { + if (!kernel->summary().has_narrowing_index_casts) { + return; + } + ScalarBoundsCalculator calc(kernel, expr_eval, launch_params); + NVF_ERROR( + calc.castsFromIndexAreSafe(), + "Found unsafe casts from DataType::Index. ", + "This is likely because one coordinate of a TMA instruction overflowed Int32"); +} + } // namespace executor_utils } // namespace nvfuser diff --git a/csrc/runtime/executor_utils.h b/csrc/runtime/executor_utils.h index aebeedfcda5..2ff4880456a 100644 --- a/csrc/runtime/executor_utils.h +++ b/csrc/runtime/executor_utils.h @@ -233,5 +233,14 @@ void validateCircularBuffering( kir::Kernel* kernel, ExpressionEvaluator& expr_eval); +//! Check that any narrowing casts from DataType::Index do not overflow. +//! In particular, if TMA expressions are present in the kernel, compute bounds +//! for integer expressions in order to validate that the 32-bit coordinates +//! passed to the TMA PTX instructions do not overflow. +void validateIndexCasts( + kir::Kernel* kernel, + ExpressionEvaluator& expr_eval, + const LaunchParams& launch_params); + } // namespace executor_utils } // namespace nvfuser diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 9f98dd8160d..70a40072a7e 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -989,12 +989,17 @@ std::unique_ptr getMatmulHeuristics( mparams->circular_buffer_options.smem_circular_buffer_stage, tensor_roles, /*ignore_occupancy_drop=*/true); - if (isHopper(mparams->mma_macro)) { + if (isHopper(mparams->mma_macro) && mparams->use_smem_epilogue) { // Always promote smem reuse for Hopper. This is needed because we use TMA // which has higher alignment requirements, so it's important that we place // our TMA buffers at an offset that's a multiple of 64 (like 0) if // possible. mparams->promote_prologue_smem_reuse = true; + + // TMA allows us to avoid linear indexing + // TODO: verify here that we will be able to use Int32 indexing. If not, + // then disable use_smem_epilogue. + // mparams->cparams.index_type = PrimDataType::Int32; } if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { @@ -1139,14 +1144,24 @@ std::string getMatmulRunTimeRejectReason( Fusion* fusion, HeuristicDataCache* data_cache, SchedulerRuntimeInfo& runtime_info) { + // On Hopper, we use TMA to load operands. Since TMA requires each coordinate + // of the input to be represented with a 32-bit signed int, we will encounter + // overflow if any dimension of an operand is larger than that. const auto device_prop = at::cuda::getCurrentDeviceProperties(); - - if (device_prop->major >= 9 && - runtime_info.getIndexType() != DataType::Int32) { - // See https://github.com/NVIDIA/Fuser/issues/3595 - return "Hopper matmul is not yet supported with problem sizes requiring 64-bit indexing"; + if (device_prop->major == 9) { + for (Val* inp : fusion->inputs()) { + if (auto* tv = dynamic_cast(inp)) { + for (int64_t extent : runtime_info.getInputAllocationSizes(tv)) { + if (extent >= (1L << 31)) { + std::stringstream ss; + ss << "Cannot schedule Hopper matmul with dims larger than 2^31-1, but found " + << extent; + return ss.str(); + } + } + } + } } - return ""; } diff --git a/tests/cpp/test_interval_analysis.cpp b/tests/cpp/test_interval_analysis.cpp new file mode 100644 index 00000000000..07c75eac5c2 --- /dev/null +++ b/tests/cpp/test_interval_analysis.cpp @@ -0,0 +1,390 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace nvfuser { + +class IntervalAnalysisTest : public NVFuserTest { + std::unique_ptr fusion_ptr; + std::unique_ptr fusion_guard_ptr; + + void SetUp() override { + NVFuserTest::SetUp(); + fusion_ptr = std::make_unique(); + fusion_guard_ptr = std::make_unique(fusion_ptr.get()); + } +}; + +namespace { + +// This class lets us test that our computed ranges match our expectation and +// that they are correct. We provide a val, a mapping from input vals to +// bounds, and the expected range. Then the actual range is computed by +// exhaustively checking all valid input combinations. This is checked against +// the expected range, and the range computed by ScalarBoundsCalculator. +class RangeChecker { + public: + static void check( + Val* output_val, + const std::unordered_map& input_bounds, + const BoundedInt& expected_range, + const bool bound_is_tight = true, + const LaunchParams& launch_params = LaunchParams()) { + RangeChecker checker( + output_val, + input_bounds, + expected_range, + bound_is_tight, + launch_params); + checker.checkAllInputs(); + } + + private: + RangeChecker( + Val* output_val, + const std::unordered_map& input_bounds, + const BoundedInt& expected_range, + const bool bound_is_tight, + const LaunchParams& launch_params) + : output_val_(output_val), + input_bounds_(input_bounds), + expected_range_(expected_range), + bound_is_tight_(bound_is_tight) { + ExpressionEvaluator expr_eval; + // Compute the range using ScalarBoundsCalculator and check that it matches + // expected + ScalarBoundsCalculator calc(/*kernel=*/nullptr, expr_eval, launch_params); + for (auto& [v, b] : input_bounds_) { + calc.setBounds(v, b); + } + // Check that the computed range is correct + calc.dispatch(output_val_); + auto bound_opt = calc.maybeGetBounds(output_val_); + // Cannot use ASSERT_* in constructor + NVF_ERROR( + bound_opt.has_value(), + "Expected bounds to be computed following call to dispatch"); + EXPECT_EQ(bound_opt.value(), expected_range); + } + + // Evaluate output_val_ exhaustively for every possible combination of inputs + void checkAllInputs() { + std::vector inputs; + inputs.reserve(input_bounds_.size()); + // Number of valid combinations of input values + int64_t num_combos = 1; + for (auto& [v, b] : input_bounds_) { + inputs.push_back(v); + NVF_ERROR(b.max >= b.min); + num_combos *= b.max - b.min + 1; + } + // Sort inputs by name so that test deterministically traverses inputs + std::stable_sort(inputs.begin(), inputs.end(), [](Val* v1, Val* v2) { + return v1->name() < v2->name(); + }); + + // Iterate over all input combinations + for (size_t i : c10::irange(num_combos)) { + ExpressionEvaluator expr_eval; + + // All the input combinations are enumerated + // For example if there are three inputs with the following bounds: + // x: [min_x, max_x] + // y: [min_y, max_y] + // z: [min_z, max_z] + // Then there are nx*ny*nz=(max_x-min_x+1)*(max_y-min_y+1)*(max_z-min_z+1) + // combinations of valid inputs. The jth input is determined by + // x = j / (ny*nz) + min_x + // y = (j % (ny*nz)) / nz + min_y + // z = j % nz + min_z + int64_t num_inner_combos = num_combos; + for (size_t inp_num : c10::irange(inputs.size())) { + const BoundedInt& inp_bound = input_bounds_.at(inputs.at(inp_num)); + int64_t next_offset = i % num_inner_combos; + num_inner_combos /= inp_bound.max - inp_bound.min + 1L; + int64_t this_input_value = + inp_bound.min + (next_offset / num_inner_combos); + expr_eval.bind(inputs.at(inp_num), this_input_value); + } + + PolymorphicValue pv; + try { + pv = expr_eval.evaluate(output_val_); + } catch (const std::exception& ex) { + // Floating point exception due to division or modulo by zero avoided + if (std::string(ex.what()).find("zero detected") != std::string::npos) { + continue; + } else { + throw; + } + } + ASSERT_TRUE(pv.hasValue()); + ASSERT_TRUE(pv.is()); + int64_t eval = pv.as(); + EXPECT_GE(eval, expected_range_.min); + EXPECT_LE(eval, expected_range_.max); + + eval_min_ = std::min(eval_min_, eval); + eval_max_ = std::max(eval_max_, eval); + } + + if (bound_is_tight_) { + EXPECT_EQ(eval_min_, expected_range_.min); + EXPECT_EQ(eval_max_, expected_range_.max); + } + } + + private: + Val* output_val_; + const std::unordered_map& input_bounds_; + const BoundedInt& expected_range_; + bool bound_is_tight_; + + int64_t eval_min_ = std::numeric_limits::max(); + int64_t eval_max_ = std::numeric_limits::min(); +}; + +} // namespace + +TEST_F(IntervalAnalysisTest, UnaryOps) { + Val* x = IrBuilder::create(DataType::Index); + RangeChecker::check( + x, /*input_bounds=*/{{x, {-1, 5}}}, /*expected_range=*/{-1, 5}); + RangeChecker::check( + neg(x), /*input_bounds=*/{{x, {-1, 5}}}, /*expected_range=*/{-5, 1}); + // TODO: fix evaluate function for BitwiseNot + // RangeChecker::check(bitwise_not(x), /*input_bounds=*/{{-1, 5}}, {-5, 1}); +} + +TEST_F(IntervalAnalysisTest, BinaryOps) { + Val* x = IrBuilder::create(DataType::Index); + Val* y = IrBuilder::create(DataType::Index); + RangeChecker::check( + x, + /*input_bounds=*/{{x, {-1, 5}}, {y, {-3, 2}}}, + /*expected_range=*/{-1, 5}); + RangeChecker::check( + y, + /*input_bounds=*/{{x, {-1, 5}}, {y, {-3, 2}}}, + /*expected_range=*/{-3, 2}); + RangeChecker::check( + add(x, y), + /*input_bounds=*/{{x, {-1, 5}}, {y, {-3, 2}}}, + /*expected_range=*/{-4, 7}); + RangeChecker::check( + sub(x, y), + /*input_bounds=*/{{x, {-1, 5}}, {y, {-3, 2}}}, + /*expected_range=*/{-3, 8}); + + // Check multiple scenarios for mul + RangeChecker::check( + mul(x, y), + /*input_bounds=*/{{x, {3, 5}}, {y, {4, 6}}}, + /*expected_range=*/{12, 30}); + RangeChecker::check( + mul(x, y), + /*input_bounds=*/{{x, {-1, 5}}, {y, {-3, 2}}}, + /*expected_range=*/{-15, 10}); + RangeChecker::check( + mul(x, y), + /*input_bounds=*/{{x, {0, 1}}, {y, {-2, 1}}}, + /*expected_range=*/{-2, 1}); + RangeChecker::check( + mul(x, y), + /*input_bounds=*/{{x, {-2, 1}}, {y, {-2, 3}}}, + /*expected_range=*/{-6, 4}); + + // Check scenarios for div and ceilDiv where each input contains zero, is + // only positive, or is only negative + RangeChecker::check( + div(x, y), + /*input_bounds=*/{{x, {1, 4}}, {y, {3, 3}}}, + /*expected_range=*/{0, 1}); + RangeChecker::check( + ceilDiv(x, y), + /*input_bounds=*/{{x, {1, 4}}, {y, {3, 3}}}, + /*expected_range=*/{1, 2}); + RangeChecker::check( + div(x, y), + /*input_bounds=*/{{x, {-3, 4}}, {y, {1, 3}}}, + /*expected_range=*/{-3, 4}); + RangeChecker::check( + ceilDiv(x, y), + /*input_bounds=*/{{x, {-3, 4}}, {y, {1, 3}}}, + /*expected_range=*/{-3, 4}); + RangeChecker::check( + div(x, y), + /*input_bounds=*/{{x, {-3, -1}}, {y, {1, 3}}}, + /*expected_range=*/{-3, 0}); + RangeChecker::check( + div(x, y), + /*input_bounds=*/{{x, {-3, -1}}, {y, {-3, -1}}}, + /*expected_range=*/{0, 3}); + RangeChecker::check( + div(x, y), + /*input_bounds=*/{{x, {-3, -1}}, {y, {-3, 2}}}, + /*expected_range=*/{-3, 3}); + RangeChecker::check( + div(x, y), + /*input_bounds=*/{{x, {-2, 1}}, {y, {-1, 2}}}, + /*expected_range=*/{-2, 2}); + RangeChecker::check( + ceilDiv(x, y), + /*input_bounds=*/{{x, {-3, -1}}, {y, {-3, 2}}}, + /*expected_range=*/{-3, 5}, + // NOTE: ceilDiv(-3, -1) = (-3 + (-1) - 1) / (-1) = 5 is what is computed + // in-kernel, but ExpressionEvaluator computes (numer + denom + 1) / denom + // when denom < 0. The bound above is actually tight for the in-kernel + // code but that does not currently match our ExpressionEvaluator + /*bound_is_tight=*/false); + RangeChecker::check( + div(x, y), + /*input_bounds=*/{{x, {0, 0}}, {y, {2, 3}}}, + /*expected_range=*/{0, 0}); + RangeChecker::check( + div(x, y), + /*input_bounds=*/{{x, {0, 1}}, {y, {1, 1}}}, + /*expected_range=*/{0, 1}); + RangeChecker::check( + ceilDiv(x, y), + /*input_bounds=*/{{x, {0, 1}}, {y, {2, 3}}}, + /*expected_range=*/{0, 1}); + + RangeChecker::check( + mod(x, y), + /*input_bounds=*/{{x, {2, 3}}, {y, {3, 3}}}, + /*expected_range=*/{0, 2}); + RangeChecker::check( + mod(x, y), + /*input_bounds=*/{{x, {2, 4}}, {y, {7, 8}}}, + /*expected_range=*/{2, 4}); + RangeChecker::check( + mod(x, y), + /*input_bounds=*/{{x, {2, 4}}, {y, {2, 5}}}, + /*expected_range=*/{0, 4}); + RangeChecker::check( + mod(x, y), + /*input_bounds=*/{{x, {2, 4}}, {y, {-8, -7}}}, + /*expected_range=*/{2, 4}); + + // We do not generally place the tightest bounds on bitwise ops because it is + // difficult to do without exhaustively trying input combinations. + RangeChecker::check( + bitwise_and(x, y), + /*input_bounds=*/{{x, {0b1001, 0b1011}}, {y, {0b1010, 0b1100}}}, + // NOTE: this bound is not tight because we assume all variable bits can + // take any combination of values, but since there is only one y value + // with high third bit, the highest we can actually get is 0b1011=11 + /*expected_range=*/{0b1000, 0b1111}, + /*bound_is_tight=*/false); + RangeChecker::check( + bitwise_or(x, y), + /*input_bounds=*/{{x, {0b1001, 0b1011}}, {y, {0b1010, 0b1100}}}, + /*expected_range=*/{0b1000, 0b1111}, + // NOTE: this bound is not tight because we assume all variable bits can + // take any combination of values, but since there is only one y value + // with high third bit, the lowest we can actually get is 0b1010=10, not + // 0b1000=8 + /*bound_is_tight=*/false); + RangeChecker::check( + bitwise_xor(x, y), + /*input_bounds=*/{{x, {0b1001, 0b1011}}, {y, {0b1010, 0b1100}}}, + /*expected_range=*/{0b0000, 0b0111}); + + RangeChecker::check( + bitwise_left_shift(x, y), + /*input_bounds=*/{{x, {0b1001, 0b1011}}, {y, {1, 5}}}, + /*expected_range=*/{0b10010, 0b101100000}); + RangeChecker::check( + bitwise_right_shift(x, y), + /*input_bounds=*/{{x, {0b100100, 0b101100}}, {y, {1, 5}}}, + /*expected_range=*/{0b1, 0b10110}); +} + +// Test that loop indices are properly bounded, as are expressions derived from +// them +TEST_F(IntervalAnalysisTest, SerialLoops) { + kir::Kernel kernel(FusionGuard::getCurFusion()); + FusionGuard fg(&kernel); + + Val* ext = IrBuilder::create(DataType::Index); + Val* start = kernel.zeroVal(); + auto* id = IterDomainBuilder(start, ext).extent(ext).build(); + Val* index = IrBuilder::create(DataType::Index); + auto* loop = IrBuilder::create( + id, + index, + /*circular_buffer_loop_stage=*/CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + Val* offset = IrBuilder::create(DataType::Index); + Val* index_plus_offset = add(index, offset); + // Compute index + offset inside the "for index in id" loop + loop->body().push_back(index_plus_offset->definition()); + + ExpressionEvaluator expr_eval; + LaunchParams launch_params; + ScalarBoundsCalculator calc(/*kernel=*/nullptr, expr_eval, launch_params); + calc.setBounds(ext, {4, 7}); + calc.setBounds(offset, {2, 5}); + calc.dispatch(loop); + calc.dispatch(index_plus_offset); + auto bound_opt = calc.maybeGetBounds(index_plus_offset); + NVF_ERROR(bound_opt.has_value()); + BoundedInt true_bound{2, 11}; + EXPECT_EQ(bound_opt.value(), true_bound); +} + +// Test that parallelized loop indices are properly bounded, as are expressions +// derived from them +TEST_F(IntervalAnalysisTest, ParallelLoops) { + kir::Kernel kernel(FusionGuard::getCurFusion()); + FusionGuard fg(&kernel); + + Val* ext = NamedScalar::getParallelDim(ParallelType::TIDx); + Val* start = kernel.zeroVal(); + auto* id = IterDomainBuilder(start, ext) + .extent(ext) + .parallel_type(ParallelType::TIDx) + .build(); + Val* index = IrBuilder::create(DataType::Index); + auto* loop = IrBuilder::create( + id, + index, + /*circular_buffer_loop_stage=*/CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + Val* offset = IrBuilder::create(DataType::Index); + Val* index_plus_offset = add(index, offset); + // Compute index + offset inside the "for index in id" loop + loop->body().push_back(index_plus_offset->definition()); + + ExpressionEvaluator expr_eval; + LaunchParams launch_params; + launch_params.bind(128, ParallelType::TIDx); + ScalarBoundsCalculator calc(/*kernel=*/nullptr, expr_eval, launch_params); + calc.setBounds(offset, {2, 5}); + calc.dispatch(loop); + calc.dispatch(index_plus_offset); + auto bound_opt = calc.maybeGetBounds(index_plus_offset); + NVF_ERROR(bound_opt.has_value()); + BoundedInt true_bound{2, 132}; + EXPECT_EQ(bound_opt.value(), true_bound); +} + +} // namespace nvfuser diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index cfbca9223fa..af59a9524fe 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4622,4 +4622,107 @@ TEST_F(HopperMatmulTest, ScheduleWithTranslation) { EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); } +// Test that we can compile matmul kernels with both 32-bit and 64-bit indexing, +// and that if we pass arguments for which this is unsafe (meaning there is +// overflow), that the appropriate exception is raised +TEST_F(HopperMatmulTest, IndexTypeValidation) { + Fusion fusion; + FusionGuard fg(&fusion); + + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // M, K + auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {1}); + + // Reorder the accumulator as [M, N, K] + // [M, K, N] -> [M, N, K] + tv2->reorder({{-2, -1}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = false; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 1; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {1, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + constexpr int64_t M = 1 << 17, N = 256, K = 1 << 17; + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + + { // This scope is to help us reclaim memory later + auto a_ref = at::randn({M, K, 1}, options); + auto b_ref = at::randn({1, K, N}, options); + auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf); + const std::vector inputs = {a_ref, b_ref}; + + mparams.cparams.index_type = DataType::Int32; + + at::Tensor int32_output; + { + Fusion fusion_clone; + Fusion::copy(&fusion, &fusion_clone); + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion_clone, &mparams); + + KernelExecutor ke; + ke.compile(&fusion_clone, inputs); + int32_output = ke.run(inputs).at(0); + } + + mparams.cparams.index_type = DataType::Int; + + Fusion fusion_clone; + Fusion::copy(&fusion, &fusion_clone); + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion_clone, &mparams); + + KernelExecutor ke; + ke.compile(&fusion_clone, inputs); + auto int64_output = ke.run(inputs).at(0); + EXPECT_TRUE(int64_output.equal(int32_output)); + } + + // Test that passing inputs that are too large in one dimension lead to error + maybeClearAllocator(/*max_bytes=*/0); + { + mparams.cparams.index_type = DataType::Int; + + Fusion fusion_clone; + Fusion::copy(&fusion, &fusion_clone); + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion_clone, &mparams); + + constexpr int64_t M_big = 1L << 32, N_big = 2, K_big = 2; + auto a_big = at::randn({M_big, K_big, 1}, options); + auto b_big = at::randn({1, K_big, N_big}, options); + const std::vector inputs_big{a_big, b_big}; + + KernelExecutor ke; + ke.compile(&fusion_clone, inputs_big); + EXPECT_THAT( + [&]() { ke.run(inputs_big); }, + ::testing::ThrowsMessage( + ::testing::HasSubstr("Found unsafe casts from DataType::Index"))); + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 8fac9b9e5c3..bd24200d56d 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2763,18 +2763,6 @@ TEST_F(MatmulSchedulerTest, EpilogueFusionInt64Indexing) { testValidate( executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); - - if (!cudaArchGuardShouldSkip(9, 0)) { - // The Hopper matmul scheduler should reject this fusion since it requires - // 64-bit indexing. - // See https://github.com/NVIDIA/Fuser/issues/3595 - // TODO: Lift this temporary restriction and remove this check - for (const auto& heur : executor_cache.getMostRecentKernelRuntime() - ->schedulerHeuristics() - ->heuristicsList()) { - EXPECT_NE(heur->scheduler_type, SchedulerType::Matmul); - } - } } class MatmulFusionTest