-
Notifications
You must be signed in to change notification settings - Fork 7
Lower torch transposed convolution to a custom TCP op #25
Conversation
auto addOperand = [&](std::string name, Value value) { | ||
operandNames.push_back(name); | ||
operands.push_back(value); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-blocking and could be addressed in a follow-on PR:
This would be nice to factor out as a utility, so they can be reused by other TcpCustomOp conversion patterns.
auto addListOfIntAttr = [&](const std::string &name, Value value) { | ||
SmallVector<int64_t> valueInt; | ||
if (!matchPattern(adaptor.getStride(), | ||
m_TorchListOfConstantInts(valueInt))) | ||
return rewriter.notifyMatchFailure(op, std::string("non-const") + name + | ||
"list unsupported"); | ||
attrs.push_back( | ||
rewriter.getNamedAttr(name, rewriter.getIndexArrayAttr(valueInt))); | ||
return success(); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
%input2 = torch.vtensor.literal(dense<0.0> : tensor<1x9x16x1600xf32>) : !torch.vtensor<[1,9,16,1600],f32> | ||
%weights2 = torch.vtensor.literal(dense<0.0> : tensor<32x9x3x3xf32>) : !torch.vtensor<[32,9,3,3],f32> | ||
%int0x0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> | ||
%none = torch.constant.none | ||
%output2 = torch.aten.convolution %input2, %weights2, %none, %int1x1, %int1x1, %int1x1, %false, %int0x0, %int1 : !torch.vtensor<[1,9,16,1600],f32>, !torch.vtensor<[32,9,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,16,1600],f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider splitting this into a separate test (in the same file) for readability and/or ease of maintenance. You may use self-explanatory test names, like
func.func @torch.aten.regular_convolution
...
func.func @torch.aten.transposed_convolution
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good suggestion. I split it and also split the CHECK line into multiple lines to make it a bit clearer.
|
||
// --- | ||
|
||
// CHECK: tcp.custom_op("torch.aten.convolution") %{{.*}}, %{{.*}}, %{{.*}} {dilation = [2 : index, 2 : index], groups = 1 : i64, output_padding = [2 : index, 2 : index], padding = [2 : index, 2 : index], stride = [2 : index, 2 : index], torch_operand_names = ["input", "weight", "bias"], transposed = true} : tensor<1x64x1x100xf32>, tensor<64x64x3x3xf32>, tensor<64xf32> -> tensor<1x64x2x200xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering how the test passes despite the bug above. It seems to be because padding is being populated from stride and the check statement propagates the bug.
// This should've been padding = [1 : index, 1 : index]
padding = [2 : index, 2 : index]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:( This PR is turning out to be a bit of typo central. Thanks for the careful review.
@@ -19,6 +19,7 @@ | |||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | |||
|
|||
#include "llvm/ADT/StringSet.h" | |||
#include "llvm/Support/Debug.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking, is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me. Thanks for contributing and persevering through the iterations!
…ion#25) As titled, lower torch transposed convolution to a custom TCP op to avoid a mis-compilation in `TorchToTosa`. --------- Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
…ruise-automation#25) (cruise-automation#10) As titled, lower torch transposed convolution to a custom TCP op to avoid a mis-compilation in `TorchToTosa`. Cherry-pick from upstream: cruise-automation#25 --------- Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
Remove `LRDPrefilterPrediction` related content Test from c/c: `bazel test @mlir-tcp//test/...`
As titled, lower torch transposed convolution to a custom TCP op to avoid a mis-compilation in
TorchToTosa
.