Skip to content
This repository has been archived by the owner on Jan 30, 2025. It is now read-only.

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Srinath Avadhanula authored and srinathava committed Jan 15, 2024
1 parent 8f366f8 commit 5a40ca5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

auto addListOfIntAttr = [&](const std::string &name, Value value) {
SmallVector<int64_t> valueInt;
if (!matchPattern(adaptor.getStride(),
m_TorchListOfConstantInts(valueInt)))
if (!matchPattern(value, m_TorchListOfConstantInts(valueInt)))
return rewriter.notifyMatchFailure(op, std::string("non-const") + name +
"list unsupported");
attrs.push_back(
Expand Down
36 changes: 26 additions & 10 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,19 @@ func.func @torch.aten.index_put_impl_op(%arg0: !torch.vtensor<[25],f32>, %arg1:
return %1 : !torch.vtensor<[25],f32>
}

// ---

// CHECK: tcp.custom_op("torch.aten.convolution") %{{.*}}, %{{.*}}, %{{.*}} {dilation = [2 : index, 2 : index], groups = 1 : i64, output_padding = [2 : index, 2 : index], padding = [2 : index, 2 : index], stride = [2 : index, 2 : index], torch_operand_names = ["input", "weight", "bias"], transposed = true} : tensor<1x64x1x100xf32>, tensor<64x64x3x3xf32>, tensor<64xf32> -> tensor<1x64x2x200xf32>
// CHECK: torch.aten.convolution %{{.*}}
func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> (!torch.vtensor<[1,64,2,200],f32>, !torch.vtensor<[1,32,16,1600],f32>) {
// -----

// CHECK: tcp.custom_op("torch.aten.convolution") %{{.*}}, %{{.*}}, %{{.*}} {
// CHECK-SAME: dilation = [1 : index, 1 : index],
// CHECK-SAME: groups = 1 : i64,
// CHECK-SAME: output_padding = [1 : index, 1 : index],
// CHECK-SAME: padding = [1 : index, 1 : index],
// CHECK-SAME: stride = [2 : index, 2 : index],
// CHECK-SAME: torch_operand_names = ["input", "weight", "bias"],
// CHECK-SAME: transposed = true} : tensor<1x64x1x100xf32>, tensor<64x64x3x3xf32>, tensor<64xf32> -> tensor<1x64x2x200xf32>
func.func @torcn.aten.transposed_convolution(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
%true = torch.constant.bool true
%false = torch.constant.bool false
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32>
Expand All @@ -74,11 +79,22 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> (!torch.vtensor<
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>

%input2 = torch.vtensor.literal(dense<0.0> : tensor<1x9x16x1600xf32>) : !torch.vtensor<[1,9,16,1600],f32>
%weights2 = torch.vtensor.literal(dense<0.0> : tensor<32x9x3x3xf32>) : !torch.vtensor<[32,9,3,3],f32>
return %output : !torch.vtensor<[1,64,2,200],f32>
}

// -----

// CHECK: torch.aten.convolution %{{.*}}
func.func @torch.aten.regular_convolution() -> !torch.vtensor<[1,32,16,1600],f32> {
%false = torch.constant.bool false
%input = torch.vtensor.literal(dense<0.0> : tensor<1x9x16x1600xf32>) : !torch.vtensor<[1,9,16,1600],f32>
%weights = torch.vtensor.literal(dense<0.0> : tensor<32x9x3x3xf32>) : !torch.vtensor<[32,9,3,3],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int0x0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%none = torch.constant.none
%output2 = torch.aten.convolution %input2, %weights2, %none, %int1x1, %int1x1, %int1x1, %false, %int0x0, %int1 : !torch.vtensor<[1,9,16,1600],f32>, !torch.vtensor<[32,9,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,16,1600],f32>
%output = torch.aten.convolution %input, %weights, %none, %int1x1, %int1x1, %int1x1, %false, %int0x0, %int1 : !torch.vtensor<[1,9,16,1600],f32>, !torch.vtensor<[32,9,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,16,1600],f32>

return %output, %output2 : !torch.vtensor<[1,64,2,200],f32>, !torch.vtensor<[1,32,16,1600],f32>
return %output : !torch.vtensor<[1,32,16,1600],f32>
}

0 comments on commit 5a40ca5

Please sign in to comment.