Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] More graph rewrites for Faster RCNN / MaskRCNN #7346

Merged
merged 11 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 245 additions & 13 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument, invalid-name
""" Common utilities used by PyTorch frontend """
from .. import expr
from .. import op
from ..dataflow_pattern import (
wildcard,
is_constant,
is_op,
rewrite,
is_tuple,
wildcard,
is_tuple_get_item,
is_if,
DFPatternCallback,
)

Expand All @@ -36,6 +39,19 @@ def is_version_greater_than(ver):
)


def dyn_strided_slice_pattern(inp, end):
"""A pattern to detect dynamic strided slice op."""
zero = is_constant()
cast_like = is_op("cast_like")(zero, is_constant())
less = is_op("less")(is_constant(), cast_like)
shape_of = is_op("shape_of")(inp)
cast_like = is_op("cast_like")(shape_of, is_constant())
add = is_op("add")(is_constant(), cast_like)
where = is_op("where")(less, add, is_constant())

return is_op("dyn.strided_slice")(inp, where, end, is_constant())


def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
"""A pattern to detect batched_nms function in torchvision

Expand Down Expand Up @@ -73,7 +89,6 @@ def batched_nms(boxes, scores, idxs, iou_threshold):

"""
one = is_constant()
zero = is_constant()

# Equivelent PyTorch code from above snippet
# offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
Expand All @@ -84,17 +99,10 @@ def batched_nms(boxes, scores, idxs, iou_threshold):

# The following doesn't appear in the above Relay snippet. It is required for dynamic
# stride_slice handling
cast_like = is_op("cast_like")(zero, is_constant())
less = is_op("less")(is_constant(), cast_like)
shape_of = is_op("shape_of")(mul)
cast_like = is_op("cast_like")(shape_of, is_constant())
add = is_op("add")(is_constant(), cast_like)
where = is_op("where")(less, add, is_constant())
shape_of = is_op("shape_of")(mul)
cast = is_op("cast")(shape_of)

# This corresponds to offsets[:, None], where offsets is the result of multiplication
dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant())
dyn_strided_slice = dyn_strided_slice_pattern(mul, cast)

# Add offsets to the boxes
expand_dims = is_op("expand_dims")(dyn_strided_slice)
Expand All @@ -112,8 +120,49 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
)


class NMSRewrite(DFPatternCallback):
"""A callback to rewrite nms and restore batched nms"""
def topk_after_batch_nms_pattern(cond, true_branch, data, valid_count, indices, iou_threshold):
"""
Detect the following pattern used in torchvision detection models.

def batched_nms(...):
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
else:
...
return nms(boxes_for_nms, scores, iou_threshold)

keep = batched_nms(boxes, scores, lvl, self.nms_thresh)
keep = keep[:post_nms_top_k] # keep only topk scoring predictions

An equivalent Relay subgraph:

%1184 = if (%1117) {
...
} else {
...
%1172 = vision.non_max_suppression(%1167, %1168, %1171, -1, 0.7f, ...);
...
%1183 = dyn.strided_slice(%1174, %1180, %1182, ...);
cast(%1183, dtype="int64")
};
%1185 = strided_slice(%1184, begin=[0], end=[1000], strides=[1]);

"""
nms = is_op("vision.non_max_suppression")(
data, valid_count, indices, is_constant(), iou_threshold
)
indices = is_op("squeeze")(is_tuple_get_item(nms, 0))
size = is_op("squeeze")(is_tuple_get_item(nms, 1))
dyn_strided_slice = dyn_strided_slice_pattern(indices, size)
cast_i64 = is_op("cast")(dyn_strided_slice)

batched_nms_result = is_if(cond, true_branch, cast_i64)

return is_op("strided_slice")(batched_nms_result)


class MulticlassNMSRewrite(DFPatternCallback):
"""A callback to rewrite nms and restore batched nms."""

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -169,10 +218,193 @@ def callback(self, pre, post, node_map):
return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices)


class PostNMSTopKRewrite(DFPatternCallback):
"""A callback to rewrite nms to exploit max_out_size parameter."""

def __init__(self):
super().__init__()
self.cond = wildcard()
self.true_branch = wildcard()
self.data = wildcard()
self.valid_count = wildcard()
self.indices = wildcard()
self.iou_threshold = wildcard()

self.pattern = topk_after_batch_nms_pattern(
self.cond,
self.true_branch,
self.data,
self.valid_count,
self.indices,
self.iou_threshold,
)

def rewrite_batch_nms_with_max_out_size(
self, cond, true_branch, data, valid_count, indices, iou_threshold, post_nms_topk
):
"""Use the detected post NMS topk parameter in NMS op."""
nms_ret = op.vision.non_max_suppression(
data=data,
valid_count=valid_count,
indices=indices,
max_output_size=post_nms_topk,
iou_threshold=iou_threshold,
force_suppress=False,
top_k=-1,
coord_start=2,
score_index=1,
id_index=0,
return_indices=True,
invalid_to_bottom=False,
)

size = op.squeeze(nms_ret[1], axis=[1])
data_slice = op.squeeze(nms_ret[0], axis=[0])

ret = op.strided_slice(data_slice, begin=expr.const([0]), end=size, slice_mode="size")

nms_result = op.cast(ret, "int64")

return expr.If(cond, true_branch, nms_result)

def callback(self, pre, post, node_map):
post_nms_topk = post.attrs.end[0].value
return self.rewrite_batch_nms_with_max_out_size(
node_map[self.cond][0],
node_map[self.true_branch][0],
node_map[self.data][0],
node_map[self.valid_count][0],
node_map[self.indices][0],
node_map[self.iou_threshold][0],
post_nms_topk,
)


def scatter_roi_align_result_pattern(levels, roi_align_results, num_scales):
"""Detect the Relay subgraph corresponding to the following PyTorch code

first_result = roi_align_results[0]
dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1),
first_result.size(2), first_result.size(3)),
dtype=dtype, device=device)
for level in range(len(roi_align_results)):
index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
index = index.expand(index.size(0),
roi_align_results[level].size(1),
roi_align_results[level].size(2),
roi_align_results[level].size(3))
res = res.scatter(0, index, roi_align_results[level])
return res
"""

def do_where(levels, _):
idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant()))
idx_in_level = is_op("split")(idx_in_level)
idx_in_level = is_tuple_get_item(idx_in_level, 0)
idx_in_level = is_op("squeeze")(idx_in_level)
idx_in_level = is_tuple_get_item(is_tuple([idx_in_level]), 0)
return idx_in_level

scatter_res = wildcard()

for i in range(num_scales):
# index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
scatter_indices = do_where(levels, i)
scatter_indices = is_op("reshape")(scatter_indices)

# index = index.expand(index.size(0),
masahi marked this conversation as resolved.
Show resolved Hide resolved
# unmerged_results[level].size(1),
# unmerged_results[level].size(2),
# unmerged_results[level].size(3))
scatter_indices = is_op("repeat")(scatter_indices)
scatter_indices = is_op("repeat")(scatter_indices)
scatter_indices = is_op("repeat")(scatter_indices)

scatter_res = is_op("scatter")(scatter_res, scatter_indices, roi_align_results[i])

return is_op("reshape")(scatter_res)


class ScatterRewrite(DFPatternCallback):
"""A callback to rewrite repeated scatters with a batched gather."""

def __init__(self, num_scales):
super().__init__()
self.num_scales = num_scales
self.levels = wildcard()
self.roi_align_results = []
for _ in range(num_scales):
self.roi_align_results.append(wildcard())

self.pattern = scatter_roi_align_result_pattern(
self.levels, self.roi_align_results, num_scales
)

def convert_scatter_to_gather(self, levels, roi_align_results):
"""Replace the detected scatter loop with the following PyTorch code

indices_per_level = []
for level in range(num_scales):
idx_in_level = torch.where(levels == level)[0]
indices_per_leve.append(idx_in_level)

stacked_features = torch.cat(roi_align_results, dim=0)
stacked_indices = torch.cat(indices_per_level, dim=0)
argsort_indices = torch.argort(stacked_indices)
return stacked_features[argsort_indices, :]
"""

# Collect inidices and concat them
indices_per_level = []
for i in range(self.num_scales):
equal = op.equal(levels, expr.const(i, dtype="int64"))
argwhere = op.argwhere(equal)
split = op.split(argwhere, indices_or_sections=1, axis=1)
squeeze = op.squeeze(split[0], axis=[1])
indices = op.cast(squeeze, dtype="int64")
indices_per_level.append(indices)

indices_concat = op.concatenate(indices_per_level, 0)

# Concat roi align results per level, and argsort indices
# To prepare for a batched gather
roi_align_results_concat = op.concatenate(roi_align_results, 0)
argsort_indices = op.cast(op.argsort(indices_concat), dtype="int64")

# Permute rows by argsorted indices
permuted = op.take(roi_align_results_concat, argsort_indices, axis=0)

return op.reshape(permuted, [0, -1, 1, 1])

def callback(self, pre, post, node_map):
levels = node_map[self.levels][0]
roi_align_results = [node_map[feat][0] for feat in self.roi_align_results]
return self.convert_scatter_to_gather(levels, roi_align_results)


def rewrite_nms_to_batched_nms(mod):
"""Rewrite the input graph to replace non maximum surpression
in torchvision that does not take class id into account with the one
that avoids IOU tests between different classes.
"""
mod["main"] = rewrite(NMSRewrite(), mod["main"])
mod["main"] = rewrite(MulticlassNMSRewrite(), mod["main"])
return mod


def rewrite_batched_nms_with_max_out_size(mod):
"""Rewrite the input graph to detect slicing after batched nms and
use the slicing size as the parameter max_out_size in NMS.
"""
mod["main"] = rewrite(PostNMSTopKRewrite(), mod["main"])
return mod


def rewrite_scatter_to_gather(mod, num_scales):
"""Rewrite the input graph to replace a repeated scatter loop with
a batched gather. The scatter loop is used in torchvision MultiScaleRoIAlign
to merge roi_align results for all scales. The scatter is used to emulate
inplace updates.
"""
mod["main"] = rewrite(ScatterRewrite(num_scales), mod["main"])
return mod
18 changes: 16 additions & 2 deletions tests/python/frontend/pytorch/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
import tvm.testing
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms
from tvm.relay.frontend.pytorch_utils import (
rewrite_nms_to_batched_nms,
rewrite_batched_nms_with_max_out_size,
rewrite_scatter_to_gather,
)
from tvm.contrib.download import download


Expand Down Expand Up @@ -72,7 +76,7 @@ def generate_jit_model(index):
]

model_func = model_funcs[index]
model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200))
model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=1000))

model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
Expand Down Expand Up @@ -141,6 +145,16 @@ def compile_and_run_vm(mod, params, data_np, target):
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)

before = mod["main"]
mod = rewrite_batched_nms_with_max_out_size(mod)
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)

before = mod["main"]
mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)

tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm")

# Results should be equivalent after rewriting
Expand Down