Skip to content

Commit

Permalink
fix: Adapt torch JIT pass for removeDropout
Browse files Browse the repository at this point in the history
- Adapt JIT pass to remove dropout to accommodate multiple dropout
schemas
- Include additional test cases to verify new removal code
  • Loading branch information
gs-olive committed Jan 24, 2023
1 parent ae8b569 commit e7a469d
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 4 deletions.
50 changes: 46 additions & 4 deletions core/lowering/passes/remove_dropout.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,57 @@
#include "torch/csrc/jit/passes/remove_dropout.h"
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

#include "core/util/prelude.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

// Schemas for dropout variants
const std::unordered_set<c10::Symbol> DropoutNodeKinds = {
c10::Symbol::fromQualString("aten::dropout"),
c10::Symbol::fromQualString("aten::dropout_"),
c10::Symbol::fromQualString("aten::feature_dropout"),
c10::Symbol::fromQualString("aten::feature_dropout_"),
c10::Symbol::fromQualString("aten::feature_alpha_dropout"),
c10::Symbol::fromQualString("aten::feature_alpha_dropout_"),
};

void removeDropoutInBlock(torch::jit::Block* block) {
/*
Function adapted from:
torch/csrc/jit/passes/remove_dropout.cpp
Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added
*/
std::vector<torch::jit::Node*> dropout_nodes_to_remove;

for (auto node : block->nodes()) {
// Remove dropout for each member block within a node
for (auto block : node->blocks()) {
removeDropoutInBlock(block);
}

// For each node having a dropout-variant Schema, remove the node
if (DropoutNodeKinds.find(node->kind()) != DropoutNodeKinds.end()) {
// Extract input and output tensors of dropout operator
auto input_value = node->inputs()[0];
auto output_value = node->outputs()[0];

output_value->replaceAllUsesWith(input_value);
dropout_nodes_to_remove.push_back(node);
}
}

// Delete dropout nodes
for (auto del_node : dropout_nodes_to_remove) {
del_node->destroy();
}
}

void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::removeDropout(graph);
// Remove all instances of dropout variants from graph
removeDropoutInBlock(graph->block());
torch::jit::EliminateDeadCode(graph);
LOG_GRAPH("Post remove dropout: " << *graph);
}

Expand Down
76 changes: 76 additions & 0 deletions tests/core/lowering/test_remove_dropout_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,79 @@ TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
std::string source_graph = R"IR(
graph(%x.1):
%3 : float = prim::Constant[value=0.5]()
%4 : bool = prim::Constant[value=0]()
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
%11 : Tensor = aten::relu(%y.1)
return (%11))IR";
std::string target_graph = R"IR(
graph(%x.1):
%11 : Tensor = aten::relu(%x.1)
return (%11))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveFeatureAlphaDropoutNestedLowersCorrectly) {
std::string source_graph = R"IR(
graph(%x.1):
%3 : float = prim::Constant[value=0.5]()
%4 : bool = prim::Constant[value=0]()
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
%z.1 : Tensor = aten::feature_alpha_dropout(%y.1, %3, %4)
%12 : Tensor = aten::relu(%z.1)
return (%12))IR";
std::string target_graph = R"IR(
graph(%x.1):
%11 : Tensor = aten::relu(%x.1)
return (%11))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) {
std::string source_graph = R"IR(
graph(%x.1):
%3 : float = prim::Constant[value=0.5]()
%4 : bool = prim::Constant[value=0]()
%y.1 : Tensor = aten::feature_alpha_dropout_(%x.1, %3, %4)
%11 : Tensor = aten::relu(%y.1)
return (%11))IR";
std::string target_graph = R"IR(
graph(%x.1):
%11 : Tensor = aten::relu(%x.1)
return (%11))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

0 comments on commit e7a469d

Please sign in to comment.