Skip to content

Commit

Permalink
feat: Rsqrt lowering pass (#1394)
Browse files Browse the repository at this point in the history
* feat: Add lowering pass for rsqrt operator

- Add unpack rsqrt lowering pass
- Add test cases for positive inputs, int and float
- Add references to new function in headers and BUILD files

* Added UnpackRsqrt to lowering passes list
  • Loading branch information
gs-olive authored Oct 12, 2022
1 parent e27103a commit 85e5e99
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::UnpackAddMM(g);
// passes::UnpackBatchNorm(g);
passes::UnpackLogSoftmax(g);
passes::UnpackRsqrt(g);
passes::UnpackStd(g);
passes::UnpackVar(g);
passes::RemoveNOPs(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cc_library(
"unpack_hardsigmoid.cpp",
"unpack_hardswish.cpp",
"unpack_log_softmax.cpp",
"unpack_rsqrt.cpp",
"unpack_std.cpp",
"unpack_var.cpp",
"view_to_reshape.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ target_sources(${lib_name}
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_rsqrt.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp"
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackRsqrt(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackStd(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
30 changes: 30 additions & 0 deletions core/lowering/passes/unpack_rsqrt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

void UnpackRsqrt(std::shared_ptr<torch::jit::Graph>& graph) {
std::string rsqrt_pattern = R"IR(
graph(%1):
%out: Tensor = aten::rsqrt(%1)
return (%out))IR";
std::string unpacked_pattern = R"IR(
graph(%1):
%intermediate: Tensor = aten::sqrt(%1)
%out: Tensor = aten::reciprocal(%intermediate)
return (%out))IR";

torch::jit::SubgraphRewriter rsqrt_rewriter;
rsqrt_rewriter.RegisterRewritePattern(rsqrt_pattern, unpacked_pattern);
rsqrt_rewriter.runOnGraph(graph);
LOG_GRAPH("Post unpack rsqrt: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
42 changes: 42 additions & 0 deletions tests/core/lowering/test_unpack_reduce_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,45 @@ TEST(LoweringPasses, UnpackStdUnbiasedKeepDimsLowersCorrectly) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
}

TEST(LoweringPasses, UnpackRsqrtLowersCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : Tensor = aten::rsqrt(%x.1)
return (%2))IR";

// Make range [0.01, 1.01] to ensure positives / avoid NaN with negative sqrt
auto in = at::rand({2, 3, 5, 7}, {at::kCUDA}) + 0.01;

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
torch_tensorrt::core::lowering::passes::UnpackRsqrt(g);
torch::jit::EliminateCommonSubexpression(g);
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
}

TEST(LoweringPasses, UnpackRsqrtIntLowersCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : Tensor = aten::rsqrt(%x.1)
return (%2))IR";

// Make range of ints [1, 10]
auto in = at::randint(1, 11, {2, 3, 5, 7}, {at::kCUDA});

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
torch_tensorrt::core::lowering::passes::UnpackRsqrt(g);
torch::jit::EliminateCommonSubexpression(g);
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
}

0 comments on commit 85e5e99

Please sign in to comment.