Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for torch arange float module #2749

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6ae9b32
ADDED SUPPORT FLOAT VALUE IN ARANGE
Abhishek-TyRnT Jan 13, 2024
4650040
Merge branch 'Added-support-for-torch-arange-float-module' of github.…
Abhishek-TyRnT Jan 13, 2024
42fac70
got rid of extra tosa tests
Abhishek-TyRnT Jan 13, 2024
b85c84e
git rid of iostream import
Abhishek-TyRnT Jan 16, 2024
6047cc0
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Jan 19, 2024
a544ed5
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Jan 31, 2024
c357abf
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 3, 2024
1530802
using int in result shape
Abhishek-TyRnT Feb 3, 2024
b6e1bcf
got rid of resultshape for int case
Abhishek-TyRnT Feb 5, 2024
5b59626
got rid of result shape in all int case
Abhishek-TyRnT Feb 6, 2024
7f51909
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 6, 2024
b2a541c
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 9, 2024
ac606f6
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 14, 2024
510a6de
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 15, 2024
cb4ed3e
using static cast instead of dynamic cast
Abhishek-TyRnT Feb 15, 2024
5d3194b
typecasting for int64type
Abhishek-TyRnT Feb 19, 2024
314ec60
Merge branch 'main' into Added-support-for-torch-arange-float-module
Abhishek-TyRnT Feb 19, 2024
c7e6780
Merge branch 'llvm:main' into Added-support-for-torch-arange-float-mo…
Abhishek-TyRnT Feb 21, 2024
0ee752b
ADDED SUPPORT FLOAT VALUE IN ARANGE
Abhishek-TyRnT Jan 13, 2024
6b26100
got rid of extra tosa tests
Abhishek-TyRnT Jan 13, 2024
ef559c5
git rid of iostream import
Abhishek-TyRnT Jan 16, 2024
08a289f
using int in result shape
Abhishek-TyRnT Feb 3, 2024
7f3caa8
got rid of resultshape for int case
Abhishek-TyRnT Feb 5, 2024
0f6ef1f
got rid of result shape in all int case
Abhishek-TyRnT Feb 6, 2024
8b57a51
using static cast instead of dynamic cast
Abhishek-TyRnT Feb 15, 2024
3140ab1
typecasting for int64type
Abhishek-TyRnT Feb 19, 2024
4c185db
git format, add some stylistic changes
newling Feb 26, 2024
9b4ae1e
update
newling Feb 26, 2024
4cd1632
Merge pull request #1 from newling/newling-update-added-support-for-t…
Abhishek-TyRnT Feb 26, 2024
ba6ba92
Merge branch 'llvm:main' into main
Abhishek-TyRnT Feb 27, 2024
142d14e
Merge branch 'main' into Added-support-for-torch-arange-float-module
Abhishek-TyRnT Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4067,28 +4067,60 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: pin_memory must be either None or false");
}

int64_t start, step, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
double start, step, end;
int64_t start_int, step_int, end_int;
bool is_all_inp_int; //Flag to check whether all inputs are integer
Abhishek-TyRnT marked this conversation as resolved.
Show resolved Hide resolved
is_all_inp_int = op.getStart().getType().isa<Torch::IntType>() && op.getEnd().getType().isa<Torch::IntType>() && op.getStep().getType().isa<Torch::IntType>();

if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int)))
{
start = (double)(start_int);
Abhishek-TyRnT marked this conversation as resolved.
Show resolved Hide resolved
}

else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `start` should be a torch constant int");
op, "unimplemented: value `start` should be a torch constant int or float");

if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int)))
{
end = (double)(end_int);
}
else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `end` should be a torch constant int");
op, "unimplemented: value `end` should be a torch constant int or float");

if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int)))
{

step = (double)(step_int);
}

else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `step` should be a torch constant int");
op, "unimplemented: value `step` should be a torch constant int or float");

// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
int64_t resultShape = ceil((float)(end - start) / (float)step);
SmallVector<int64_t> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;
Value result =
tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();
int64_t resultShape = ceil((end - start) / step);
Value result;
if (is_all_inp_int)
{
SmallVector<int64_t> values(resultShape, start);
Abhishek-TyRnT marked this conversation as resolved.
Show resolved Hide resolved
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;

result = tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();
}

else
{
SmallVector<float> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += (i * step);

result = tosa::getConstTensor<float>(rewriter, op, values, resultShape).value();
}

rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, result);
return success();
Expand Down
8 changes: 8 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,14 @@
"ArangeStartOutViewModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic",
"ArangeNegativeStartFloatModule_basic",
"ArangeStartFloatModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartOutDtypeModule_basic",
"ArangeStartStepFloatModule_basic",
"ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
Expand Down