-
Notifications
You must be signed in to change notification settings - Fork 359
/
Copy pathremove_dropout.cpp
96 lines (80 loc) · 3.58 KB
/
remove_dropout.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include "core/util/prelude.h"
namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
std::string dropout_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::dropout(%input, %4, %5)
return (%6))IR";
std::string no_dropout_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";
torch::jit::SubgraphRewriter remove_dropout;
remove_dropout.RegisterRewritePattern(dropout_pattern, no_dropout_pattern);
remove_dropout.runOnGraph(graph);
std::string dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::dropout_(%input, %4, %5)
return (%6))IR";
std::string no_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";
torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
remove_dropout_inplace_pattern.RegisterRewritePattern(dropout_inplace_pattern, no_dropout_inplace_pattern);
remove_dropout_inplace_pattern.runOnGraph(graph);
// remove feature_dropout
std::string feature_dropout_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::feature_dropout(%input, %4, %5)
return (%6))IR";
std::string no_feature_dropout_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";
torch::jit::SubgraphRewriter remove_feature_dropout_pattern;
remove_feature_dropout_pattern.RegisterRewritePattern(feature_dropout_pattern, no_feature_dropout_pattern);
remove_feature_dropout_pattern.runOnGraph(graph);
// remove feature_dropout inplace
std::string feature_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::feature_dropout_(%input, %4, %5)
return (%6))IR";
std::string no_feature_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";
torch::jit::SubgraphRewriter remove_feature_dropout_inplace_pattern;
remove_feature_dropout_inplace_pattern.RegisterRewritePattern(
feature_dropout_inplace_pattern, no_feature_dropout_inplace_pattern);
remove_feature_dropout_inplace_pattern.runOnGraph(graph);
// remove feature_alpha_dropout
std::string feature_alpha_dropout_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::feature_alpha_dropout(%input, %4, %5)
return (%6))IR";
std::string no_feature_alpha_dropout_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";
torch::jit::SubgraphRewriter remove_feature_alpha_dropout_pattern;
remove_feature_alpha_dropout_pattern.RegisterRewritePattern(
feature_alpha_dropout_pattern, no_feature_alpha_dropout_pattern);
remove_feature_alpha_dropout_pattern.runOnGraph(graph);
// remove feature_alpha_dropout inplace
std::string feature_alpha_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::feature_alpha_dropout_(%input, %4, %5)
return (%6))IR";
std::string no_feature_alpha_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";
torch::jit::SubgraphRewriter remove_feature_alpha_dropout_inplace_pattern;
remove_feature_alpha_dropout_inplace_pattern.RegisterRewritePattern(
feature_alpha_dropout_inplace_pattern, no_feature_alpha_dropout_inplace_pattern);
remove_feature_alpha_dropout_inplace_pattern.runOnGraph(graph);
LOG_GRAPH("Post remove dropout: " << *graph);
}
} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt