Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Replace RemoveDropout lowering pass implementation with modified JIT pass #1589

Merged
merged 2 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 46 additions & 80 deletions core/lowering/passes/remove_dropout.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include "torch/csrc/jit/passes/dead_code_elimination.h"

#include "core/util/prelude.h"

Expand All @@ -7,86 +7,52 @@ namespace core {
namespace lowering {
namespace passes {

void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
std::string dropout_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::dropout(%input, %4, %5)
return (%6))IR";
std::string no_dropout_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";

torch::jit::SubgraphRewriter remove_dropout;
remove_dropout.RegisterRewritePattern(dropout_pattern, no_dropout_pattern);
remove_dropout.runOnGraph(graph);

std::string dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::dropout_(%input, %4, %5)
return (%6))IR";
std::string no_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";

torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
remove_dropout_inplace_pattern.RegisterRewritePattern(dropout_inplace_pattern, no_dropout_inplace_pattern);
remove_dropout_inplace_pattern.runOnGraph(graph);

// remove feature_dropout
std::string feature_dropout_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::feature_dropout(%input, %4, %5)
return (%6))IR";
std::string no_feature_dropout_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";

torch::jit::SubgraphRewriter remove_feature_dropout_pattern;
remove_feature_dropout_pattern.RegisterRewritePattern(feature_dropout_pattern, no_feature_dropout_pattern);
remove_feature_dropout_pattern.runOnGraph(graph);

// remove feature_dropout inplace
std::string feature_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::feature_dropout_(%input, %4, %5)
return (%6))IR";
std::string no_feature_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";

torch::jit::SubgraphRewriter remove_feature_dropout_inplace_pattern;
remove_feature_dropout_inplace_pattern.RegisterRewritePattern(
feature_dropout_inplace_pattern, no_feature_dropout_inplace_pattern);
remove_feature_dropout_inplace_pattern.runOnGraph(graph);

// remove feature_alpha_dropout
std::string feature_alpha_dropout_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::feature_alpha_dropout(%input, %4, %5)
return (%6))IR";
std::string no_feature_alpha_dropout_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";

torch::jit::SubgraphRewriter remove_feature_alpha_dropout_pattern;
remove_feature_alpha_dropout_pattern.RegisterRewritePattern(
feature_alpha_dropout_pattern, no_feature_alpha_dropout_pattern);
remove_feature_alpha_dropout_pattern.runOnGraph(graph);

// remove feature_alpha_dropout inplace
std::string feature_alpha_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::feature_alpha_dropout_(%input, %4, %5)
return (%6))IR";
std::string no_feature_alpha_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";

torch::jit::SubgraphRewriter remove_feature_alpha_dropout_inplace_pattern;
remove_feature_alpha_dropout_inplace_pattern.RegisterRewritePattern(
feature_alpha_dropout_inplace_pattern, no_feature_alpha_dropout_inplace_pattern);
remove_feature_alpha_dropout_inplace_pattern.runOnGraph(graph);
// Schemas for dropout variants
const std::unordered_set<c10::Symbol> DropoutNodeKinds = {
c10::Symbol::fromQualString("aten::dropout"),
c10::Symbol::fromQualString("aten::dropout_"),
c10::Symbol::fromQualString("aten::feature_dropout"),
c10::Symbol::fromQualString("aten::feature_dropout_"),
c10::Symbol::fromQualString("aten::feature_alpha_dropout"),
c10::Symbol::fromQualString("aten::feature_alpha_dropout_"),
};

void removeDropoutInBlock(torch::jit::Block* block) {
/*
Function adapted from:
torch/csrc/jit/passes/remove_dropout.cpp

Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added
*/
std::vector<torch::jit::Node*> dropout_nodes_to_remove;

for (auto node : block->nodes()) {
// Remove dropout for each member block within a node
for (auto block : node->blocks()) {
removeDropoutInBlock(block);
}

// For each node having a dropout-variant Schema, remove the node
if (DropoutNodeKinds.find(node->kind()) != DropoutNodeKinds.end()) {
// Extract input and output tensors of dropout operator
auto input_value = node->inputs()[0];
auto output_value = node->outputs()[0];

output_value->replaceAllUsesWith(input_value);
dropout_nodes_to_remove.push_back(node);
}
}

// Delete dropout nodes
for (auto del_node : dropout_nodes_to_remove) {
del_node->destroy();
}
}

void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
// Remove all instances of dropout variants from graph
removeDropoutInBlock(graph->block());
torch::jit::EliminateDeadCode(graph);
LOG_GRAPH("Post remove dropout: " << *graph);
}

Expand Down
52 changes: 52 additions & 0 deletions tests/core/lowering/test_remove_dropout_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,32 @@ TEST(LoweringPasses, RemoveDropoutLowersCorrectly) {
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveDropoutNestedLowersCorrectly) {
std::string source_graph = R"IR(
graph(%x.1):
%3 : float = prim::Constant[value=0.5]()
%4 : bool = prim::Constant[value=0]()
%y.1 : Tensor = aten::dropout(%x.1, %3, %4)
%z.1 : Tensor = aten::dropout(%y.1, %3, %4)
%12 : Tensor = aten::relu(%z.1)
return (%12))IR";
std::string target_graph = R"IR(
graph(%x.1):
%11 : Tensor = aten::relu(%x.1)
return (%11))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveDropoutInplaceLowersCorrectly) {
std::string source_graph = R"IR(
graph(%x.1):
Expand Down Expand Up @@ -132,6 +158,32 @@ TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveFeatureAlphaDropoutNestedLowersCorrectly) {
std::string source_graph = R"IR(
graph(%x.1):
%3 : float = prim::Constant[value=0.5]()
%4 : bool = prim::Constant[value=0]()
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
%z.1 : Tensor = aten::feature_alpha_dropout(%y.1, %3, %4)
%12 : Tensor = aten::relu(%z.1)
return (%12))IR";
std::string target_graph = R"IR(
graph(%x.1):
%11 : Tensor = aten::relu(%x.1)
return (%11))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) {
std::string source_graph = R"IR(
graph(%x.1):
Expand Down