Skip to content

Commit

Permalink
[PatternMatcher] Support matching tuples, call nodes, and functions w…
Browse files Browse the repository at this point in the history
…ith variable numbers of inputs (apache#7754)

* Allow TuplePattern to have null fields and match any tuple

* support matching functions and call nodes with variable numbers of parameters

* remove development code that was commented out

* add docs for fuzzy matching
  • Loading branch information
Matthew Brookhart authored and trevor-m committed May 11, 2021
1 parent 3a6063a commit 599a640
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 52 deletions.
16 changes: 16 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,22 @@ The final example is matching diamonds with a post-dominator relationship. We em
assert diamond.match(out)


Matching Fuzzy Patterns
=======================

The Dominator analysis above lets one match a subgraph of Relay AST that doesn't correspond to a set of patterns nodes exactly 1-to-1. There are a few other places where we support such "fuzzy" matching.

Tuples, Functions, and Call nodes with any number of inputs can be matched by passing `None` as the argument value, i.e.::

tuple_pattern = is_tuple(None)
func_pattern = FunctionPattern(None, wildcard() + wildcard())
call_pattern = func_pattern(None)

These patterns allow matching more generic classes patterns by constraining the use of the arguments rather than the number of arguments.

Additionally, we support matching Functions with fuzzy bodies, i.e., a function body that is under constrained by the pattern. The pattern `FunctionPattern([is_var(), is_var()], wildcard() + wildcard()])` will match `relay.Function([x, y], x + y)`, but it will also match `relay.Function([x, y], x * x + y)`. In the second case, the pattern doesn't perfectly constrain the body of the function, so the resulting match is fuzzy.


Pattern Language Design
=======================

Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class DFPattern(Node):
"""Base class of all Patterns."""

def __call__(self, *args):
return CallPattern(self, list(args))
args = list(args)
if len(args) == 1 and args[0] is None:
args = None
return CallPattern(self, args)

def __or__(self, other):
return AltPattern(self, other)
Expand Down
107 changes: 73 additions & 34 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
}
return false;
};

// logic
auto watermark = matched_nodes_.size();
if (const auto* call_node = expr.as<CallNode>()) {
Expand All @@ -253,13 +254,15 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
const Array<Expr> expr_args) {
bool matches = true;
size_t i = 0;
if (pattern_args.size() == expr_args.size()) {
while (matches && i < pattern_args.size()) {
matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
++i;
if (pattern_args.defined()) {
if (pattern_args.size() == expr_args.size()) {
while (matches && i < pattern_args.size()) {
matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
++i;
}
} else {
matches = false;
}
} else {
matches = false;
}
if (!matches) {
ClearMap(watermark2);
Expand Down Expand Up @@ -381,14 +384,16 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr
bool matches = false;
if (const auto* func = expr.as<FunctionNode>()) {
matches = true;
size_t i = 0;
if (op->params.size() == func->params.size()) {
while (matches && i < op->params.size()) {
matches &= VisitDFPattern(op->params[i], func->params[i]);
++i;
if (op->params.defined()) {
size_t i = 0;
if (op->params.size() == func->params.size()) {
while (matches && i < op->params.size()) {
matches &= VisitDFPattern(op->params[i], func->params[i]);
++i;
}
} else {
matches = false;
}
} else {
matches = false;
}
if (matches) {
matches &= VisitDFPattern(op->body, func->body);
Expand All @@ -409,12 +414,16 @@ bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const
bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* tuple_node = expr.as<TupleNode>()) {
if (op->fields.size() == tuple_node->fields.size()) {
matches = true;
size_t i = 0;
while (matches && i < op->fields.size()) {
matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
++i;
matches = true;
if (op->fields.defined()) {
if (op->fields.size() == tuple_node->fields.size()) {
size_t i = 0;
while (matches && i < op->fields.size()) {
matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
++i;
}
} else {
matches = false;
}
}
}
Expand Down Expand Up @@ -657,7 +666,6 @@ class PatternGrouper {
int var_number = 0;

auto node_map = matcher_->GetMemo();

// Get fuzzy patterns
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
for (auto node : pattern_graph_.topological_order_) {
Expand All @@ -669,11 +677,13 @@ class PatternGrouper {
}
}
}
// Don't treat Function params as input variables for partition
if (auto op = node->ref_.as<FunctionPatternNode>()) {
for (auto fuzzy_op : op->params) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
// Don't treat Function params or body as input variables for partition
if (node->ref_.as<FunctionPatternNode>()) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
for (auto node : graph.topological_order_) {
fuzzy_matches.insert(node->ref_);
}
}
}
Expand All @@ -686,22 +696,46 @@ class PatternGrouper {

std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
Array<Var> params;

for (auto node : pattern_graph_.topological_order_) {
if (node->inputs_.size() == 0) {
auto make_input = [&](const Expr& input) {
if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref_)) {
inputs[input] =
Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
group.args.push_back(input);
params.push_back(inputs[input]);
var_number++;
}
};
auto tuple = node->ref_.as<TuplePatternNode>();
auto call = node->ref_.as<CallPatternNode>();
if (tuple && !tuple->fields.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
match.as<FunctionNode>() == nullptr && !EmbedConst(match, node->ref_)) {
inputs[match] = Var(
"FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
group.args.push_back(match);
params.push_back(inputs[match]);
var_number++;
for (auto input : match.as<TupleNode>()->fields) {
make_input(input);
}
}
}
} else if (call && !call->args.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
for (auto input : match.as<CallNode>()->args) {
make_input(input);
}
}
}
} else if (node->inputs_.size() == 0) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
make_input(match);
}
}
}
}

Expand Down Expand Up @@ -898,6 +932,11 @@ class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre, const Map<String, ObjectRef>& attrs,
PackedFunc check) {
if (pattern.as<FunctionPatternNode>()) {
LOG(WARNING) << "Partioning a Function that isn't called doesn't make sense, skipping"
<< pattern;
return pre;
}
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
Expand Down
18 changes: 12 additions & 6 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPatte

void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) {
VisitDFPattern(op->op);
for (auto arg : op->args) {
VisitDFPattern(arg);
if (op->args.defined()) {
for (auto arg : op->args) {
VisitDFPattern(arg);
}
}
}

Expand All @@ -63,8 +65,10 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {
void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) {
for (auto param : op->params) {
VisitDFPattern(param);
if (op->params.defined()) {
for (auto param : op->params) {
VisitDFPattern(param);
}
}
VisitDFPattern(op->body);
}
Expand All @@ -76,8 +80,10 @@ void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
}

void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) {
for (auto field : op->fields) {
VisitDFPattern(field);
if (op->fields.defined()) {
for (auto field : op->fields) {
VisitDFPattern(field);
}
}
}

Expand Down
18 changes: 12 additions & 6 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {

void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->op, graph_.node_map_[GetRef<DFPattern>(op)]);
for (auto arg : op->args) {
VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
if (op->args.defined()) {
for (auto arg : op->args) {
VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
}
}
}

Expand All @@ -262,8 +264,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}

void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override {
for (auto param : op->params) {
VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
if (op->params.defined()) {
for (auto param : op->params) {
VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
}
}
VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand All @@ -277,8 +281,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
}

void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override {
for (auto field : op->fields) {
VisitDFPattern(field, graph_.node_map_[GetRef<DFPattern>(op)]);
if (op->fields.defined()) {
for (auto field : op->fields) {
VisitDFPattern(field, graph_.node_map_[GetRef<DFPattern>(op)]);
}
}
}

Expand Down
Loading

0 comments on commit 599a640

Please sign in to comment.