Skip to content

Commit

Permalink
Merge pull request #2105 from andi4191/anurag.dixit/aten_tile
Browse files Browse the repository at this point in the history
feat: Added support for aten::tile converter
  • Loading branch information
peri044 authored Aug 5, 2023
2 parents a052cf0 + b7b2725 commit 8c62fca
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
passes::ReplaceScalarImplicit(g);
passes::RewriteInputsWithParams(g, params);
passes::ReplaceAtenPad(g);
passes::ReplaceTileWithRepeat(g);
LOG_GRAPH(*g);
}

Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ cc_library(
"replace_aten_pad.cpp",
"rewrite_inputs_with_params.cpp",
"silu_to_sigmoid_multiplication.cpp",
"tile_to_repeat.cpp",
"unpack_addmm.cpp",
"unpack_batch_norm.cpp",
"unpack_hardsigmoid.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ target_sources(${lib_name}
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/module_fallback.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_aliasing.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tile_to_repeat.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_gelu.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_remainder.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_to.cpp"
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::st
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceAtenPad(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceTileWithRepeat(std::shared_ptr<torch::jit::Graph>& graph);

// utility functions exposed for testing
std::string unmangle_cls_name(const std::string& name);
Expand Down
25 changes: 25 additions & 0 deletions core/lowering/passes/tile_to_repeat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "core/util/prelude.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {
void ReplaceTileWithRepeat(std::shared_ptr<torch::jit::Graph>& graph) {
std::string tile_pattern = R"IR(
graph(%input, %1):
%2 = aten::tile(%input, %1)
return (%2))IR";
std::string repeat_pattern = R"IR(
graph(%input, %1):
%2 = aten::repeat(%input, %1)
return (%2))IR";
torch::jit::SubgraphRewriter tile_to_repeat;
tile_to_repeat.RegisterRewritePattern(tile_pattern, repeat_pattern);
tile_to_repeat.runOnGraph(graph);
LOG_GRAPH("Mapping tile -> repeat: " << *graph);
}
} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
7 changes: 7 additions & 0 deletions docsrc/contributors/lowering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,10 @@ Unroll Loops
`torch/csrc/jit/passes/loop_unrolling.h <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/loop_unrolling.h>`_

Unrolls the operations of compatable loops (e.g. sufficently short) so that you only have to go through the loop once.

Replace Tile with Repeat
***************************************

`Torch-TensorRT/core/lowering/passes/tile_to_repeat.cpp <https://github.com/pytorch/TensorRT/blob/master/core/lowering/passes/tile_to_repeat.cpp>`_

Removes dropout operators since we are doing inference.
126 changes: 126 additions & 0 deletions tests/core/conversion/converters/test_expand.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/torch.h>
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
Expand Down Expand Up @@ -670,6 +671,131 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenTileConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[4, 1]]()
%3 : Tensor = aten::tile(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(jit_in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenTileRepeatRankConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[4, 1, 2]]()
%3 : Tensor = aten::tile(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(jit_in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenTileConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[4, 1]]()
%3 : Tensor = aten::tile(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(jit_in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenTile3dConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[2, 2, 2]]()
%3 : Tensor = aten::tile(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(jit_in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenTile3dConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[2, 2, 2]]()
%3 : Tensor = aten::tile(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(jit_in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenMeshGridConvertsCorrectly) {
const auto graph = R"IR(
graph(%x : Tensor, %y : Tensor, %z : Tensor):
Expand Down
5 changes: 5 additions & 0 deletions tests/core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ lowering_test(
name = "test_replace_aten_pad_pass",
)

lowering_test(
name = "test_tile_to_repeat_pass",
)

test_suite(
name = "lowering_tests",
tests = [
Expand All @@ -122,6 +126,7 @@ test_suite(
":test_remove_unnecessary_casts",
":test_replace_aten_pad_pass",
":test_rewrite_inputs_with_params",
":test_tile_to_repeat_pass",
":test_unpack_hardsigmoid",
":test_unpack_hardswish",
":test_unpack_reduce_ops",
Expand Down
26 changes: 26 additions & 0 deletions tests/core/lowering/test_tile_to_repeat_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"

TEST(LoweringPasses, TileToRepeatCorrectly) {
std::string source_graph = R"IR(
graph(%input, %dim):
%o : Tensor = aten::tile(%input, %dim)
return (%o))IR";
std::string target_graph = R"IR(
graph(%input, %dim):
%o : Tensor = aten::repeat(%input, %dim)
return (%o))IR";
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

0 comments on commit 8c62fca

Please sign in to comment.