From 79d1883b5d41f22636488a18235c5d27270cccbf Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:59:43 -0700 Subject: [PATCH 1/4] feat: Add converter for einsum operator - Add einsum converter - Add test cases to ensure common einsum patterns are correctly converted - Reflect updated operators in BUILD files --- core/conversion/converters/BUILD | 1 + core/conversion/converters/impl/einsum.cpp | 62 +++++++++++ tests/core/conversion/converters/BUILD | 5 + .../conversion/converters/test_einsum.cpp | 104 ++++++++++++++++++ 4 files changed, 172 insertions(+) create mode 100644 core/conversion/converters/impl/einsum.cpp create mode 100644 tests/core/conversion/converters/test_einsum.cpp diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 95dde838dc..354f17a734 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -62,6 +62,7 @@ cc_library( "impl/constant_pad.cpp", "impl/conv_deconv.cpp", "impl/cumsum.cpp", + "impl/einsum.cpp", "impl/element_wise.cpp", "impl/expand.cpp", "impl/interpolate.cpp", diff --git a/core/conversion/converters/impl/einsum.cpp b/core/conversion/converters/impl/einsum.cpp new file mode 100644 index 0000000000..a8f08ae1e1 --- /dev/null +++ b/core/conversion/converters/impl/einsum.cpp @@ -0,0 +1,62 @@ +#include "NvInfer.h" +#include "core/conversion/converters/converters.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" +#include "core/util/prelude.h" +#include "torch/torch.h" + +#include +#include + +namespace torch_tensorrt { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::einsum(str equation, Tensor[] tensors) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Extract equation and list of tensors + auto equation = args[0].unwrapToString(); + auto in = args[1].IValue()->toListRef(); + + std::vector tensors; + + // Populate vector of ITensor pointers + for (auto t : in) { + nvinfer1::ITensor* itensor; + + // Tensor is either an ITensor (wrapped) or PyTorch Tensor + if (t.isTensor()) { + auto weight = Weights(ctx, t.toTensor()); + + auto const_layer = ctx->net->addConstant(weight.shape, weight.data); + TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n); + + itensor = const_layer->getOutput(0); + } else { + auto cont = t.toCustomClass(); + itensor = cont->tensor(); + } + + tensors.push_back(itensor); + } + + // Add Tensor-RT Einsum layer + auto einsum_layer = ctx->net->addEinsum(tensors.data(), tensors.size(), equation.c_str()); + TORCHTRT_CHECK(einsum_layer, "Unable to create einsum layer from node: " << *n); + + einsum_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], einsum_layer->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + }}); + +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 5246de4cf1..98630b6dc1 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -51,6 +51,10 @@ converter_test( name = "test_cumsum", ) +converter_test( + name = "test_einsum", +) + converter_test( name = "test_element_wise", ) @@ -152,6 +156,7 @@ test_suite( ":test_conv_deconv", ":test_copy", ":test_cumsum", + ":test_einsum", ":test_element_wise", ":test_expand", ":test_instance_norm", diff --git a/tests/core/conversion/converters/test_einsum.cpp b/tests/core/conversion/converters/test_einsum.cpp new file mode 100644 index 0000000000..5702f610e7 --- /dev/null +++ b/tests/core/conversion/converters/test_einsum.cpp @@ -0,0 +1,104 @@ +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenEinsumConvertsMatMulCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, %x.2 : Tensor): + %0 : str = prim::Constant[value="ij,jk->ik"]() + %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2) + %4 : Tensor = aten::einsum(%0, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Test matrix multiplication via einsum + auto in_0 = at::rand({12, 17}, {at::kCUDA}); + auto in_1 = at::rand({17, 35}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenEinsumConvertsElementwiseProdCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, %x.2 : Tensor): + %0 : str = prim::Constant[value="abcd,abcd->abcd"]() + %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2) + %4 : Tensor = aten::einsum(%0, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Test elementwise tensor product via einsum + auto in_0 = at::rand({7, 5, 2, 8}, {at::kCUDA}); + auto in_1 = at::rand({7, 5, 2, 8}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenEinsumConvertsTransposeCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %0 : str = prim::Constant[value="jk->kj"]() + %3 : Tensor[] = prim::ListConstruct(%x.1) + %4 : Tensor = aten::einsum(%0, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Test single-matrix transpose via einsum + auto in_0 = at::rand({25, 28}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenEinsumConvertsVectorsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, %x.2 : Tensor): + %0 : str = prim::Constant[value="a,b->ab"]() + %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2) + %4 : Tensor = aten::einsum(%0, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Test vector outer product via einsum + auto in_0 = at::rand({25}, {at::kCUDA}); + auto in_1 = at::rand({4}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} \ No newline at end of file From 498ac590087806e822028603dcc83eee3f48681e Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 28 Sep 2022 16:49:32 -0700 Subject: [PATCH 2/4] Remove redundant imports --- core/conversion/converters/impl/einsum.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/conversion/converters/impl/einsum.cpp b/core/conversion/converters/impl/einsum.cpp index a8f08ae1e1..daf14cbefd 100644 --- a/core/conversion/converters/impl/einsum.cpp +++ b/core/conversion/converters/impl/einsum.cpp @@ -1,10 +1,7 @@ -#include "NvInfer.h" #include "core/conversion/converters/converters.h" #include "core/conversion/tensorcontainer/TensorContainer.h" #include "core/util/prelude.h" -#include "torch/torch.h" -#include #include namespace torch_tensorrt { From 3be2f4c21f8a7d5bbdcc61d7f4d844e25aa40b0f Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 28 Sep 2022 16:52:21 -0700 Subject: [PATCH 3/4] Add newline at end of file --- tests/core/conversion/converters/test_einsum.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/conversion/converters/test_einsum.cpp b/tests/core/conversion/converters/test_einsum.cpp index 5702f610e7..ff7ba201ff 100644 --- a/tests/core/conversion/converters/test_einsum.cpp +++ b/tests/core/conversion/converters/test_einsum.cpp @@ -101,4 +101,4 @@ TEST(Converters, ATenEinsumConvertsVectorsCorrectly) { ASSERT_TRUE( torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} \ No newline at end of file +} From 371fc3868e68f0c8ae04a530cc164ac51056835c Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 29 Sep 2022 12:46:43 -0700 Subject: [PATCH 4/4] Renamed registration and updated a comment --- core/conversion/converters/impl/einsum.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/conversion/converters/impl/einsum.cpp b/core/conversion/converters/impl/einsum.cpp index daf14cbefd..fb031f6c38 100644 --- a/core/conversion/converters/impl/einsum.cpp +++ b/core/conversion/converters/impl/einsum.cpp @@ -11,7 +11,7 @@ namespace converters { namespace impl { namespace { -auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( +auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( {"aten::einsum(str equation, Tensor[] tensors) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Extract equation and list of tensors @@ -40,7 +40,7 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt tensors.push_back(itensor); } - // Add Tensor-RT Einsum layer + // Add TensorRT Einsum layer auto einsum_layer = ctx->net->addEinsum(tensors.data(), tensors.size(), equation.c_str()); TORCHTRT_CHECK(einsum_layer, "Unable to create einsum layer from node: " << *n);