diff --git a/core/lowering/passes/remove_dropout.cpp b/core/lowering/passes/remove_dropout.cpp index 54baea22b7..80bef0cb0b 100644 --- a/core/lowering/passes/remove_dropout.cpp +++ b/core/lowering/passes/remove_dropout.cpp @@ -1,15 +1,57 @@ -#include "torch/csrc/jit/passes/remove_dropout.h" -#include - #include "core/util/prelude.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" namespace torch_tensorrt { namespace core { namespace lowering { namespace passes { +// Schemas for dropout variants +const std::unordered_set DropoutNodeKinds = { + c10::Symbol::fromQualString("aten::dropout"), + c10::Symbol::fromQualString("aten::dropout_"), + c10::Symbol::fromQualString("aten::feature_dropout"), + c10::Symbol::fromQualString("aten::feature_dropout_"), + c10::Symbol::fromQualString("aten::feature_alpha_dropout"), + c10::Symbol::fromQualString("aten::feature_alpha_dropout_"), +}; + +void removeDropoutInBlock(torch::jit::Block* block) { + /* + Function adapted from: + torch/csrc/jit/passes/remove_dropout.cpp + + Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added + */ + std::vector dropout_nodes_to_remove; + + for (auto node : block->nodes()) { + // Remove dropout for each member block within a node + for (auto block : node->blocks()) { + removeDropoutInBlock(block); + } + + // For each node having a dropout-variant Schema, remove the node + if (DropoutNodeKinds.find(node->kind()) != DropoutNodeKinds.end()) { + // Extract input and output tensors of dropout operator + auto input_value = node->inputs()[0]; + auto output_value = node->outputs()[0]; + + output_value->replaceAllUsesWith(input_value); + dropout_nodes_to_remove.push_back(node); + } + } + + // Delete dropout nodes + for (auto del_node : dropout_nodes_to_remove) { + del_node->destroy(); + } +} + void RemoveDropout(std::shared_ptr& graph) { - torch::jit::removeDropout(graph); + // Remove all instances of dropout variants from graph + removeDropoutInBlock(graph->block()); + torch::jit::EliminateDeadCode(graph); LOG_GRAPH("Post remove dropout: " << *graph); } diff --git a/tests/core/lowering/test_remove_dropout_pass.cpp b/tests/core/lowering/test_remove_dropout_pass.cpp index 615dfdefcc..76e85d661a 100644 --- a/tests/core/lowering/test_remove_dropout_pass.cpp +++ b/tests/core/lowering/test_remove_dropout_pass.cpp @@ -132,3 +132,79 @@ TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) { ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); } + +TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) { + std::string source_graph = R"IR( + graph(%x.1): + %3 : float = prim::Constant[value=0.5]() + %4 : bool = prim::Constant[value=0]() + %y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4) + %11 : Tensor = aten::relu(%y.1) + return (%11))IR"; + std::string target_graph = R"IR( + graph(%x.1): + %11 : Tensor = aten::relu(%x.1) + return (%11))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveDropout(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveFeatureAlphaDropoutNestedLowersCorrectly) { + std::string source_graph = R"IR( + graph(%x.1): + %3 : float = prim::Constant[value=0.5]() + %4 : bool = prim::Constant[value=0]() + %y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4) + %z.1 : Tensor = aten::feature_alpha_dropout(%y.1, %3, %4) + %12 : Tensor = aten::relu(%z.1) + return (%12))IR"; + std::string target_graph = R"IR( + graph(%x.1): + %11 : Tensor = aten::relu(%x.1) + return (%11))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveDropout(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) { + std::string source_graph = R"IR( + graph(%x.1): + %3 : float = prim::Constant[value=0.5]() + %4 : bool = prim::Constant[value=0]() + %y.1 : Tensor = aten::feature_alpha_dropout_(%x.1, %3, %4) + %11 : Tensor = aten::relu(%y.1) + return (%11))IR"; + std::string target_graph = R"IR( + graph(%x.1): + %11 : Tensor = aten::relu(%x.1) + return (%11))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveDropout(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +}