Skip to content
This repository has been archived by the owner on Jan 30, 2025. It is now read-only.

Lower torch transposed convolution to a custom TCP op #25

Merged
merged 4 commits into from
Jan 15, 2024

Conversation

srinathava
Copy link
Contributor

@srinathava srinathava commented Jan 10, 2024

As titled, lower torch transposed convolution to a custom TCP op to avoid a mis-compilation in TorchToTosa.

@srinathava srinathava changed the title Lower torch transposed convolution to a custom TCP op to avoid miscompiles/errors in TorchToTosa Lower torch transposed convolution to a custom TCP op Jan 10, 2024
Comment on lines +166 to +169
auto addOperand = [&](std::string name, Value value) {
operandNames.push_back(name);
operands.push_back(value);
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking and could be addressed in a follow-on PR:
This would be nice to factor out as a utility, so they can be reused by other TcpCustomOp conversion patterns.

Comment on lines 182 to 190
auto addListOfIntAttr = [&](const std::string &name, Value value) {
SmallVector<int64_t> valueInt;
if (!matchPattern(adaptor.getStride(),
m_TorchListOfConstantInts(valueInt)))
return rewriter.notifyMatchFailure(op, std::string("non-const") + name +
"list unsupported");
attrs.push_back(
rewriter.getNamedAttr(name, rewriter.getIndexArrayAttr(valueInt)));
return success();
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines 77 to 81
%input2 = torch.vtensor.literal(dense<0.0> : tensor<1x9x16x1600xf32>) : !torch.vtensor<[1,9,16,1600],f32>
%weights2 = torch.vtensor.literal(dense<0.0> : tensor<32x9x3x3xf32>) : !torch.vtensor<[32,9,3,3],f32>
%int0x0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%none = torch.constant.none
%output2 = torch.aten.convolution %input2, %weights2, %none, %int1x1, %int1x1, %int1x1, %false, %int0x0, %int1 : !torch.vtensor<[1,9,16,1600],f32>, !torch.vtensor<[32,9,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,16,1600],f32>
Copy link
Collaborator

@sjain-stanford sjain-stanford Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider splitting this into a separate test (in the same file) for readability and/or ease of maintenance. You may use self-explanatory test names, like

func.func @torch.aten.regular_convolution
...

func.func @torch.aten.transposed_convolution
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion. I split it and also split the CHECK line into multiple lines to make it a bit clearer.


// ---

// CHECK: tcp.custom_op("torch.aten.convolution") %{{.*}}, %{{.*}}, %{{.*}} {dilation = [2 : index, 2 : index], groups = 1 : i64, output_padding = [2 : index, 2 : index], padding = [2 : index, 2 : index], stride = [2 : index, 2 : index], torch_operand_names = ["input", "weight", "bias"], transposed = true} : tensor<1x64x1x100xf32>, tensor<64x64x3x3xf32>, tensor<64xf32> -> tensor<1x64x2x200xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering how the test passes despite the bug above. It seems to be because padding is being populated from stride and the check statement propagates the bug.

// This should've been padding = [1 : index, 1 : index]
padding = [2 : index, 2 : index]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:( This PR is turning out to be a bit of typo central. Thanks for the careful review.

@@ -19,6 +19,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking, is this still needed?

Copy link
Collaborator

@sjain-stanford sjain-stanford left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me. Thanks for contributing and persevering through the iterations!

@srinathava srinathava merged commit 85c1e73 into cruise-automation:main Jan 15, 2024
1 check passed
srinathava added a commit to srinathava/mlir-tcp that referenced this pull request Jun 24, 2024
…ion#25)

As titled, lower torch transposed convolution to a custom TCP op to
avoid a mis-compilation in `TorchToTosa`.

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
srinathava pushed a commit to srinathava/mlir-tcp that referenced this pull request Jun 24, 2024
…ruise-automation#25) (cruise-automation#10)

As titled, lower torch transposed convolution to a custom TCP op to
avoid a mis-compilation in `TorchToTosa`.

Cherry-pick from upstream: cruise-automation#25

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
srinathava pushed a commit to srinathava/mlir-tcp that referenced this pull request Jun 24, 2024
Remove `LRDPrefilterPrediction` related content

Test from c/c:
`bazel test @mlir-tcp//test/...`
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Development

Successfully merging this pull request may close these issues.

2 participants