Skip to content

Commit

Permalink
Add support for torch arange float module (llvm#2749)
Browse files Browse the repository at this point in the history
Added Support for float dtype in in torch.arange in TOSA Dialect

This resolves the following issue :- 
llvm#2762

The following test cases are passing after this change

1. ArangeDtypeIntModule_basic
2. ArangeFloatModule_basic
3. ArangeNegativeStartFloatModule_basic
4. ArangeStartFloatModule_basic
5. ArangeStartNegativeStepFloatModule_basic
6. ArangeStartOutDtypeModule_basic
7. ArangeStartStepFloatModule_basic

---------

Co-authored-by: James Newling <james.newling@gmail.com>
  • Loading branch information
Abhishek-TyRnT and newling authored Feb 27, 2024
1 parent 3021254 commit d541779
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 22 deletions.
153 changes: 131 additions & 22 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <optional>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -4067,28 +4066,138 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: pin_memory must be either None or false");
}

int64_t start, step, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
// Stores a range value (a start, end, or step value) and whether or not it
// was initiated with a constant integer, an constant float or neither.
class ConstRangeValue {
public:
explicit ConstRangeValue(double v)
: vDouble(v), fromDouble(true), vInt(static_cast<int64_t>(v)),
fromInt(false) {}

explicit ConstRangeValue(int64_t v)
: vDouble(static_cast<double>(v)), fromDouble(false), vInt(v),
fromInt(true) {}

// Constructor for the case where there is no constant value to use.
ConstRangeValue()
: vDouble(0), fromDouble(false), vInt(0), fromInt(false) {}

static ConstRangeValue fromValue(Value v) {
int64_t intVal{0};
double floatVal{0.0};
if (matchPattern(v, m_TorchConstantFloat(&floatVal))) {
return ConstRangeValue(floatVal);
} else if (matchPattern(v, m_TorchConstantInt(&intVal))) {
return ConstRangeValue(intVal);
}
return ConstRangeValue();
}

bool hasConstInt() const { return fromInt; }
bool hasConstDouble() const { return fromDouble; }
bool hasConst() const { return fromInt || fromDouble; }
double getDouble() const { return vDouble; }
int64_t getInt() const { return vInt; }

private:
double vDouble;
bool fromDouble;
int64_t vInt;
bool fromInt;
};

auto start = ConstRangeValue::fromValue(op.getStart());
if (!start.hasConst()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `start` should be a torch constant int");
op, "unimplemented: case where `start` is not a constant int or float");
}

if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
auto end = ConstRangeValue::fromValue(op.getEnd());
if (!end.hasConst()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `end` should be a torch constant int");
op,
"unimplemented: case where value `end` is not a constant int or float");
}

if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
auto step = ConstRangeValue::fromValue(op.getStep());
if (!step.hasConst()) {
return rewriter.notifyMatchFailure(op,
"unimplemented: case where value `step` "
"is not a constant int or float");
}

auto getRange = [](auto start, auto end, auto step) {
// Initialize a small vector of the same type as start:
using T = decltype(start);
SmallVector<T> values;

uint64_t counter{0};
if (start == end) {
return values;
}
assert(step != T(0));
values.reserve(
1 + static_cast<size_t>(std::abs((end - start) / std::abs(step))));
if (step > 0) {
while (start + T(counter) * step < end) {
values.push_back(start + counter * step);
counter++;
}
} else {
while (start + T(counter) * step > end) {
values.push_back(start + counter * step);
counter++;
}
}
return values;
};

const auto isIntType =
resultType.getElementType().dyn_cast_or_null<mlir::IntegerType>();

const auto isDoubleType =
resultType.getElementType().dyn_cast_or_null<mlir::FloatType>();

auto maybeResult = [&]() -> std::optional<Value> {
// Integer output type, and start / end / range are all integers.
if (isIntType && start.hasConstInt() && end.hasConstInt() &&
step.hasConstInt()) {
auto values = getRange(start.getInt(), end.getInt(), step.getInt());
return tosa::getConstTensor<int64_t>(rewriter, op, values, values.size());
}

// Get a double range.
auto values =
getRange(start.getDouble(), end.getDouble(), step.getDouble());
if (isIntType) {
SmallVector<int64_t> values_i64;
values_i64.reserve(values.size());
for (auto v : values) {
values_i64.push_back(static_cast<int64_t>(v));
}
return tosa::getConstTensor<int64_t>(rewriter, op, values_i64,
values.size());
}

if (!isDoubleType) {
return {};
}

SmallVector<float> values_f32;
values_f32.reserve(values.size());
for (auto v : values) {
values_f32.push_back(static_cast<float>(v));
}
auto vs = tosa::getConstTensor<float>(rewriter, op, values_f32,
values_f32.size());
return vs;
}();

if (!maybeResult.has_value()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `step` should be a torch constant int");

// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
int64_t resultShape = ceil((float)(end - start) / (float)step);
SmallVector<int64_t> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;
Value result =
tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();
op, "failed to generate constant tensor for arange");
}
auto result = maybeResult.value();

rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, result);
return success();
Expand Down
8 changes: 8 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,14 @@
"ArangeStartOutViewModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic",
"ArangeNegativeStartFloatModule_basic",
"ArangeStartFloatModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartOutDtypeModule_basic",
"ArangeStartStepFloatModule_basic",
"ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
Expand Down

0 comments on commit d541779

Please sign in to comment.