diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b49c9af8adce..ce0a1af2f834 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8,24 +8,23 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" -#include "torch-mlir/Conversion/Utils/Utils.h" - #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" +#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include using namespace mlir; using namespace mlir::torch; @@ -4067,28 +4066,138 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - int64_t start, step, end; - if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + // Stores a range value (a start, end, or step value) and whether or not it + // was initiated with a constant integer, an constant float or neither. + class ConstRangeValue { + public: + explicit ConstRangeValue(double v) + : vDouble(v), fromDouble(true), vInt(static_cast(v)), + fromInt(false) {} + + explicit ConstRangeValue(int64_t v) + : vDouble(static_cast(v)), fromDouble(false), vInt(v), + fromInt(true) {} + + // Constructor for the case where there is no constant value to use. + ConstRangeValue() + : vDouble(0), fromDouble(false), vInt(0), fromInt(false) {} + + static ConstRangeValue fromValue(Value v) { + int64_t intVal{0}; + double floatVal{0.0}; + if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { + return ConstRangeValue(floatVal); + } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { + return ConstRangeValue(intVal); + } + return ConstRangeValue(); + } + + bool hasConstInt() const { return fromInt; } + bool hasConstDouble() const { return fromDouble; } + bool hasConst() const { return fromInt || fromDouble; } + double getDouble() const { return vDouble; } + int64_t getInt() const { return vInt; } + + private: + double vDouble; + bool fromDouble; + int64_t vInt; + bool fromInt; + }; + + auto start = ConstRangeValue::fromValue(op.getStart()); + if (!start.hasConst()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int"); + op, "unimplemented: case where `start` is not a constant int or float"); + } - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + auto end = ConstRangeValue::fromValue(op.getEnd()); + if (!end.hasConst()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int"); + op, + "unimplemented: case where value `end` is not a constant int or float"); + } - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + auto step = ConstRangeValue::fromValue(op.getStep()); + if (!step.hasConst()) { + return rewriter.notifyMatchFailure(op, + "unimplemented: case where value `step` " + "is not a constant int or float"); + } + + auto getRange = [](auto start, auto end, auto step) { + // Initialize a small vector of the same type as start: + using T = decltype(start); + SmallVector values; + + uint64_t counter{0}; + if (start == end) { + return values; + } + assert(step != T(0)); + values.reserve( + 1 + static_cast(std::abs((end - start) / std::abs(step)))); + if (step > 0) { + while (start + T(counter) * step < end) { + values.push_back(start + counter * step); + counter++; + } + } else { + while (start + T(counter) * step > end) { + values.push_back(start + counter * step); + counter++; + } + } + return values; + }; + + const auto isIntType = + resultType.getElementType().dyn_cast_or_null(); + + const auto isDoubleType = + resultType.getElementType().dyn_cast_or_null(); + + auto maybeResult = [&]() -> std::optional { + // Integer output type, and start / end / range are all integers. + if (isIntType && start.hasConstInt() && end.hasConstInt() && + step.hasConstInt()) { + auto values = getRange(start.getInt(), end.getInt(), step.getInt()); + return tosa::getConstTensor(rewriter, op, values, values.size()); + } + + // Get a double range. + auto values = + getRange(start.getDouble(), end.getDouble(), step.getDouble()); + if (isIntType) { + SmallVector values_i64; + values_i64.reserve(values.size()); + for (auto v : values) { + values_i64.push_back(static_cast(v)); + } + return tosa::getConstTensor(rewriter, op, values_i64, + values.size()); + } + + if (!isDoubleType) { + return {}; + } + + SmallVector values_f32; + values_f32.reserve(values.size()); + for (auto v : values) { + values_f32.push_back(static_cast(v)); + } + auto vs = tosa::getConstTensor(rewriter, op, values_f32, + values_f32.size()); + return vs; + }(); + + if (!maybeResult.has_value()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int"); - - // The result will always be a 1-d tensor. - // The size of the result is calculated as follows: - // ceil((end - start)/step) - int64_t resultShape = ceil((float)(end - start) / (float)step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; - Value result = - tosa::getConstTensor(rewriter, op, values, resultShape).value(); + op, "failed to generate constant tensor for arange"); + } + auto result = maybeResult.value(); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 195a5e42f249..74f7300c9274 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -892,6 +892,14 @@ "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartStepFloatModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic",