Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for torch arange float module #2749

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6ae9b32
ADDED SUPPORT FLOAT VALUE IN ARANGE
Abhishek-TyRnT Jan 13, 2024
4650040
Merge branch 'Added-support-for-torch-arange-float-module' of github.…
Abhishek-TyRnT Jan 13, 2024
42fac70
got rid of extra tosa tests
Abhishek-TyRnT Jan 13, 2024
b85c84e
git rid of iostream import
Abhishek-TyRnT Jan 16, 2024
6047cc0
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Jan 19, 2024
a544ed5
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Jan 31, 2024
c357abf
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 3, 2024
1530802
using int in result shape
Abhishek-TyRnT Feb 3, 2024
b6e1bcf
got rid of resultshape for int case
Abhishek-TyRnT Feb 5, 2024
5b59626
got rid of result shape in all int case
Abhishek-TyRnT Feb 6, 2024
7f51909
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 6, 2024
b2a541c
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 9, 2024
ac606f6
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 14, 2024
510a6de
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 15, 2024
cb4ed3e
using static cast instead of dynamic cast
Abhishek-TyRnT Feb 15, 2024
5d3194b
typecasting for int64type
Abhishek-TyRnT Feb 19, 2024
314ec60
Merge branch 'main' into Added-support-for-torch-arange-float-module
Abhishek-TyRnT Feb 19, 2024
c7e6780
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 21, 2024
0ee752b
ADDED SUPPORT FLOAT VALUE IN ARANGE
Abhishek-TyRnT Jan 13, 2024
6b26100
got rid of extra tosa tests
Abhishek-TyRnT Jan 13, 2024
ef559c5
git rid of iostream import
Abhishek-TyRnT Jan 16, 2024
08a289f
using int in result shape
Abhishek-TyRnT Feb 3, 2024
7f3caa8
got rid of resultshape for int case
Abhishek-TyRnT Feb 5, 2024
0f6ef1f
got rid of result shape in all int case
Abhishek-TyRnT Feb 6, 2024
8b57a51
using static cast instead of dynamic cast
Abhishek-TyRnT Feb 15, 2024
3140ab1
typecasting for int64type
Abhishek-TyRnT Feb 19, 2024
4c185db
git format, add some stylistic changes
newling Feb 26, 2024
9b4ae1e
update
newling Feb 26, 2024
4cd1632
Merge pull request #1 from newling/newling-update-added-support-for-t…
Abhishek-TyRnT Feb 26, 2024
ba6ba92
Merge branch 'llvm:main' into main
Abhishek-TyRnT Feb 27, 2024
142d14e
Merge branch 'main' into Added-support-for-torch-arange-float-module
Abhishek-TyRnT Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 131 additions & 22 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <optional>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -4067,28 +4066,138 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: pin_memory must be either None or false");
}

int64_t start, step, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
// Stores a range value (a start, end, or step value) and whether or not it
// was initiated with a constant integer, an constant float or neither.
class ConstRangeValue {
public:
explicit ConstRangeValue(double v)
: vDouble(v), fromDouble(true), vInt(static_cast<int64_t>(v)),
fromInt(false) {}

explicit ConstRangeValue(int64_t v)
: vDouble(static_cast<double>(v)), fromDouble(false), vInt(v),
fromInt(true) {}

// Constructor for the case where there is no constant value to use.
ConstRangeValue()
: vDouble(0), fromDouble(false), vInt(0), fromInt(false) {}

static ConstRangeValue fromValue(Value v) {
int64_t intVal{0};
double floatVal{0.0};
if (matchPattern(v, m_TorchConstantFloat(&floatVal))) {
return ConstRangeValue(floatVal);
} else if (matchPattern(v, m_TorchConstantInt(&intVal))) {
return ConstRangeValue(intVal);
}
return ConstRangeValue();
}

bool hasConstInt() const { return fromInt; }
bool hasConstDouble() const { return fromDouble; }
bool hasConst() const { return fromInt || fromDouble; }
double getDouble() const { return vDouble; }
int64_t getInt() const { return vInt; }

private:
double vDouble;
bool fromDouble;
int64_t vInt;
bool fromInt;
};

auto start = ConstRangeValue::fromValue(op.getStart());
if (!start.hasConst()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `start` should be a torch constant int");
op, "unimplemented: case where `start` is not a constant int or float");
}

if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
auto end = ConstRangeValue::fromValue(op.getEnd());
if (!end.hasConst()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `end` should be a torch constant int");
op,
"unimplemented: case where value `end` is not a constant int or float");
}

if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
auto step = ConstRangeValue::fromValue(op.getStep());
if (!step.hasConst()) {
return rewriter.notifyMatchFailure(op,
"unimplemented: case where value `step` "
"is not a constant int or float");
}

auto getRange = [](auto start, auto end, auto step) {
// Initialize a small vector of the same type as start:
using T = decltype(start);
SmallVector<T> values;

uint64_t counter{0};
if (start == end) {
return values;
}
assert(step != T(0));
values.reserve(
1 + static_cast<size_t>(std::abs((end - start) / std::abs(step))));
if (step > 0) {
while (start + T(counter) * step < end) {
values.push_back(start + counter * step);
counter++;
}
} else {
while (start + T(counter) * step > end) {
values.push_back(start + counter * step);
counter++;
}
}
return values;
};

const auto isIntType =
resultType.getElementType().dyn_cast_or_null<mlir::IntegerType>();

const auto isDoubleType =
resultType.getElementType().dyn_cast_or_null<mlir::FloatType>();

auto maybeResult = [&]() -> std::optional<Value> {
// Integer output type, and start / end / range are all integers.
if (isIntType && start.hasConstInt() && end.hasConstInt() &&
step.hasConstInt()) {
auto values = getRange(start.getInt(), end.getInt(), step.getInt());
return tosa::getConstTensor<int64_t>(rewriter, op, values, values.size());
}

// Get a double range.
auto values =
getRange(start.getDouble(), end.getDouble(), step.getDouble());
if (isIntType) {
SmallVector<int64_t> values_i64;
values_i64.reserve(values.size());
for (auto v : values) {
values_i64.push_back(static_cast<int64_t>(v));
}
return tosa::getConstTensor<int64_t>(rewriter, op, values_i64,
values.size());
}

if (!isDoubleType) {
return {};
}

SmallVector<float> values_f32;
values_f32.reserve(values.size());
for (auto v : values) {
values_f32.push_back(static_cast<float>(v));
}
auto vs = tosa::getConstTensor<float>(rewriter, op, values_f32,
values_f32.size());
return vs;
}();

if (!maybeResult.has_value()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `step` should be a torch constant int");

// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
int64_t resultShape = ceil((float)(end - start) / (float)step);
SmallVector<int64_t> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;
Value result =
tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();
op, "failed to generate constant tensor for arange");
}
auto result = maybeResult.value();

rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, result);
return success();
Expand Down
8 changes: 8 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,14 @@
"ArangeStartOutViewModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic",
"ArangeNegativeStartFloatModule_basic",
"ArangeStartFloatModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartOutDtypeModule_basic",
"ArangeStartStepFloatModule_basic",
"ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
Expand Down
Loading