From 2acf9fade2ba4a2336924bd1811ae7e2db6e8898 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Jan 2021 22:06:21 +0900 Subject: [PATCH 01/11] add post nms topk to max_out_size rewrite --- python/tvm/relay/frontend/pytorch_utils.py | 145 ++++++++++++++++-- .../frontend/pytorch/test_object_detection.py | 12 +- 2 files changed, 142 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 6fc5a6af4a36..3a08345e29c4 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -16,13 +16,16 @@ # under the License. # pylint: disable=import-outside-toplevel, unused-argument, invalid-name """ Common utilities used by PyTorch frontend """ +from .. import expr from .. import op from ..dataflow_pattern import ( + wildcard, is_constant, is_op, rewrite, is_tuple, - wildcard, + is_tuple_get_item, + is_if, DFPatternCallback, ) @@ -36,6 +39,19 @@ def is_version_greater_than(ver): ) +def dyn_strided_slice_pattern(inp, end): + """A pattern to detect dynamic strided slice op.""" + zero = is_constant() + cast_like = is_op("cast_like")(zero, is_constant()) + less = is_op("less")(is_constant(), cast_like) + shape_of = is_op("shape_of")(inp) + cast_like = is_op("cast_like")(shape_of, is_constant()) + add = is_op("add")(is_constant(), cast_like) + where = is_op("where")(less, add, is_constant()) + + return is_op("dyn.strided_slice")(inp, where, end, is_constant()) + + def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices): """A pattern to detect batched_nms function in torchvision @@ -73,7 +89,6 @@ def batched_nms(boxes, scores, idxs, iou_threshold): """ one = is_constant() - zero = is_constant() # Equivelent PyTorch code from above snippet # offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) @@ -84,17 +99,10 @@ def batched_nms(boxes, scores, idxs, iou_threshold): # The following doesn't appear in the above Relay snippet. It is required for dynamic # stride_slice handling - cast_like = is_op("cast_like")(zero, is_constant()) - less = is_op("less")(is_constant(), cast_like) - shape_of = is_op("shape_of")(mul) - cast_like = is_op("cast_like")(shape_of, is_constant()) - add = is_op("add")(is_constant(), cast_like) - where = is_op("where")(less, add, is_constant()) shape_of = is_op("shape_of")(mul) cast = is_op("cast")(shape_of) - # This corresponds to offsets[:, None], where offsets is the result of multiplication - dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant()) + dyn_strided_slice = dyn_strided_slice_pattern(mul, cast) # Add offsets to the boxes expand_dims = is_op("expand_dims")(dyn_strided_slice) @@ -112,8 +120,49 @@ def batched_nms(boxes, scores, idxs, iou_threshold): ) -class NMSRewrite(DFPatternCallback): - """A callback to rewrite nms and restore batched nms""" +def topk_after_batch_nms_pattern(cond, true_branch, data, valid_count, indices, iou_threshold): + """ + Detect the following pattern used in torchvision detection models. + + def batched_nms(...): + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + else: + ... + return nms(boxes_for_nms, scores, iou_threshold) + + keep = batched_nms(boxes, scores, lvl, self.nms_thresh) + keep = keep[:post_nms_top_k] # keep only topk scoring predictions + + An equivalent Relay subgraph: + + %1184 = if (%1117) { + ... + } else { + ... + %1172 = vision.non_max_suppression(%1167, %1168, %1171, -1, 0.7f, ...); + ... + %1183 = dyn.strided_slice(%1174, %1180, %1182, ...); + cast(%1183, dtype="int64") + }; + %1185 = strided_slice(%1184, begin=[0], end=[1000], strides=[1]); + + """ + nms = is_op("vision.non_max_suppression")( + data, valid_count, indices, is_constant(), iou_threshold + ) + indices = is_op("squeeze")(is_tuple_get_item(nms, 0)) + size = is_op("squeeze")(is_tuple_get_item(nms, 1)) + dyn_strided_slice = dyn_strided_slice_pattern(indices, size) + cast_i64 = is_op("cast")(dyn_strided_slice) + + batched_nms_result = is_if(cond, true_branch, cast_i64) + + return is_op("strided_slice")(batched_nms_result) + + +class MulticlassNMSRewrite(DFPatternCallback): + """A callback to rewrite nms and restore batched nms.""" def __init__(self): super().__init__() @@ -169,10 +218,80 @@ def callback(self, pre, post, node_map): return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices) +class PostNMSTopKRewrite(DFPatternCallback): + """A callback to rewrite nms to exploit max_out_size parameter.""" + + def __init__(self): + super().__init__() + self.cond = wildcard() + self.true_branch = wildcard() + self.data = wildcard() + self.valid_count = wildcard() + self.indices = wildcard() + self.iou_threshold = wildcard() + + self.pattern = topk_after_batch_nms_pattern( + self.cond, + self.true_branch, + self.data, + self.valid_count, + self.indices, + self.iou_threshold, + ) + + def rewrite_batch_nms_with_max_out_size( + self, cond, true_branch, data, valid_count, indices, iou_threshold, post_nms_topk + ): + """Use the detected post NMS topk parameter in NMS op.""" + nms_ret = op.vision.non_max_suppression( + data=data, + valid_count=valid_count, + indices=indices, + max_output_size=post_nms_topk, + iou_threshold=iou_threshold, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + + size = op.squeeze(nms_ret[1], axis=[1]) + data_slice = op.squeeze(nms_ret[0], axis=[0]) + + ret = op.strided_slice(data_slice, begin=expr.const([0]), end=size, slice_mode="size") + + nms_result = op.cast(ret, "int64") + + return expr.If(cond, true_branch, nms_result) + + def callback(self, pre, post, node_map): + post_nms_topk = post.attrs.end[0].value + return self.rewrite_batch_nms_with_max_out_size( + node_map[self.cond][0], + node_map[self.true_branch][0], + node_map[self.data][0], + node_map[self.valid_count][0], + node_map[self.indices][0], + node_map[self.iou_threshold][0], + post_nms_topk, + ) + + def rewrite_nms_to_batched_nms(mod): """Rewrite the input graph to replace non maximum surpression in torchvision that does not take class id into account with the one that avoids IOU tests between different classes. """ - mod["main"] = rewrite(NMSRewrite(), mod["main"]) + mod["main"] = rewrite(MulticlassNMSRewrite(), mod["main"]) + return mod + + +def rewrite_batched_nms_with_max_out_size(mod): + """Rewrite the input graph to detect slicing after batched nms and + use the slicing size as the parameter max_out_size in NMS. + """ + mod["main"] = rewrite(PostNMSTopKRewrite(), mod["main"]) return mod diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 2c323776f087..6273fbbd60b2 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -26,7 +26,10 @@ import tvm.testing from tvm import relay from tvm.runtime.vm import VirtualMachine -from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms +from tvm.relay.frontend.pytorch_utils import ( + rewrite_nms_to_batched_nms, + rewrite_batched_nms_with_max_out_size, +) from tvm.contrib.download import download @@ -72,7 +75,7 @@ def generate_jit_model(index): ] model_func = model_funcs[index] - model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200)) + model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=1000)) model.eval() inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size))) @@ -141,6 +144,11 @@ def compile_and_run_vm(mod, params, data_np, target): after = mod["main"] assert not tvm.ir.structural_equal(after, before) + before = mod["main"] + mod = rewrite_batched_nms_with_max_out_size(mod) + after = mod["main"] + assert not tvm.ir.structural_equal(after, before) + tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm") # Results should be equivalent after rewriting From 54a067c0969bc02e3211821b45481d8b6493a9f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Jan 2021 22:07:47 +0900 Subject: [PATCH 02/11] add argsort conversion --- python/tvm/relay/frontend/pytorch.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 991e3a8a0032..3c835fe3e07a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2081,6 +2081,12 @@ def is_floating_point(self, inputs, input_types): is_float = input_type in ["float32", "float64", "float16", "bfloat16"] return _expr.const(is_float) + def argsort(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + is_descending = inputs[2] + return _op.argsort(data, dim, not is_descending) + # Operator mappings def create_convert_map(self): self.convert_map = { From 3778b89cf1840f6ff4c4a4e72e0f85c58dc1411c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 23 Jan 2021 14:57:17 +0900 Subject: [PATCH 03/11] scatter pattern first cut --- python/tvm/relay/frontend/pytorch_utils.py | 59 ++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 3a08345e29c4..35378a4aa545 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -280,6 +280,59 @@ def callback(self, pre, post, node_map): ) +def scatter_roi_align_result_pattern(levels, rois, per_level_features, scatter_res, num_scales): + def do_where(levels, i): + idx_in_level = is_op("argwhere")(is_op("equal")(levels, i)) + idx_in_level = is_op("split")(idx_in_level) + idx_in_level = is_tuple_get_item(idx_in_level, 0) + idx_in_level = is_op("squeeze")(idx_in_level) + idx_in_level = is_tuple_get_item(is_tuple([idx_in_level]), 0) + return idx_in_level + + for i in range(num_scales): + # index = torch.where(levels == level)[0].view(-1, 1, 1, 1) + scatter_indices = do_where(levels, i) + scatter_indices = is_op("reshape")(scatter_indices) + + # index = index.expand(index.size(0), + # unmerged_results[level].size(1), + # unmerged_results[level].size(2), + # unmerged_results[level].size(3)) + scatter_indices = is_op("repeat")(scatter_indices) + scatter_indices = is_op("repeat")(scatter_indices) + scatter_indices = is_op("repeat")(scatter_indices) + + # idx_in_level = torch.where(levels == level)[0] + idx_in_level = is_op("cast")(do_where(levels, i)) + + # rois_per_level = rois[idx_in_level] + rois_per_level = is_op("adv_index")(is_tuple([rois, idx_in_level])) + # result_idx_in_level = roi_align(rois_per_level, ...) + result_idx_in_level = is_op("vision.roi_align")(per_level_features[i], rois_per_level) + + # res = res.scatter(0, index, unmerged_results[level]) + scatter_res = is_op("scatter")(scatter_res, scatter_indices, result_idx_in_level) + + return scatter_res + + +class ScatterRewrite(DFPatternCallback): + def __init__(self, num_scales): + super().__init__() + self.levels = wildcard() + self.rois = wildcard() + self.scatter_res = wildcard() + self.per_level_features = [] + for _ in range(num_scales): + self.per_level_features.append(wildcard()) + + self.pattern = scatter_roi_align_result_pattern(self.levels, self.rois, self.per_level_features, self.scatter_res, num_scales) + + def callback(self, pre, post, node_map): + print("matched") + return pre + + def rewrite_nms_to_batched_nms(mod): """Rewrite the input graph to replace non maximum surpression in torchvision that does not take class id into account with the one @@ -295,3 +348,9 @@ def rewrite_batched_nms_with_max_out_size(mod): """ mod["main"] = rewrite(PostNMSTopKRewrite(), mod["main"]) return mod + + +def rewrite_scatter(mod, num_scales): + """TODO""" + mod["main"] = rewrite(ScatterRewrite(num_scales), mod["main"]) + return mod From 63659e5e519b317b4293700a70490f7d72068785 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 23 Jan 2021 15:12:39 +0900 Subject: [PATCH 04/11] matching seems to working --- python/tvm/relay/frontend/pytorch_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 35378a4aa545..a9908fc2e5dc 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -281,8 +281,8 @@ def callback(self, pre, post, node_map): def scatter_roi_align_result_pattern(levels, rois, per_level_features, scatter_res, num_scales): - def do_where(levels, i): - idx_in_level = is_op("argwhere")(is_op("equal")(levels, i)) + def do_where(levels, _): + idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant())) idx_in_level = is_op("split")(idx_in_level) idx_in_level = is_tuple_get_item(idx_in_level, 0) idx_in_level = is_op("squeeze")(idx_in_level) @@ -329,7 +329,11 @@ def __init__(self, num_scales): self.pattern = scatter_roi_align_result_pattern(self.levels, self.rois, self.per_level_features, self.scatter_res, num_scales) def callback(self, pre, post, node_map): - print("matched") + levels = node_map[self.levels][0] + rois = node_map[self.rois][0] + scatter_res = node_map[self.scatter_res][0] + per_level_features = [node_map[feat] for feat in self.per_level_features] + return pre From 60f2f4aa65d1c63cef72a4094848c553bd4f5172 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 23 Jan 2021 15:28:17 +0900 Subject: [PATCH 05/11] dup matching fixed --- python/tvm/relay/frontend/pytorch_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index a9908fc2e5dc..9731526674a0 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -313,7 +313,7 @@ def do_where(levels, _): # res = res.scatter(0, index, unmerged_results[level]) scatter_res = is_op("scatter")(scatter_res, scatter_indices, result_idx_in_level) - return scatter_res + return is_op("reshape")(scatter_res) class ScatterRewrite(DFPatternCallback): @@ -326,14 +326,15 @@ def __init__(self, num_scales): for _ in range(num_scales): self.per_level_features.append(wildcard()) - self.pattern = scatter_roi_align_result_pattern(self.levels, self.rois, self.per_level_features, self.scatter_res, num_scales) + self.pattern = scatter_roi_align_result_pattern( + self.levels, self.rois, self.per_level_features, self.scatter_res, num_scales + ) def callback(self, pre, post, node_map): levels = node_map[self.levels][0] rois = node_map[self.rois][0] scatter_res = node_map[self.scatter_res][0] per_level_features = [node_map[feat] for feat in self.per_level_features] - return pre From 2e07a2fa19c3e3f84d0c7de19485b6417a7eb6c8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 23 Jan 2021 15:57:30 +0900 Subject: [PATCH 06/11] add converter --- python/tvm/relay/frontend/pytorch_utils.py | 40 +++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 9731526674a0..7a3f0c4e7c7d 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -281,6 +281,7 @@ def callback(self, pre, post, node_map): def scatter_roi_align_result_pattern(levels, rois, per_level_features, scatter_res, num_scales): + """TODO""" def do_where(levels, _): idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant())) idx_in_level = is_op("split")(idx_in_level) @@ -319,10 +320,12 @@ def do_where(levels, _): class ScatterRewrite(DFPatternCallback): def __init__(self, num_scales): super().__init__() + self.num_scales = num_scales self.levels = wildcard() self.rois = wildcard() self.scatter_res = wildcard() self.per_level_features = [] + self.roi_align_attrs = [] for _ in range(num_scales): self.per_level_features.append(wildcard()) @@ -330,11 +333,46 @@ def __init__(self, num_scales): self.levels, self.rois, self.per_level_features, self.scatter_res, num_scales ) + def convert_scatter_to_gather(self, levels, rois, per_level_features): + indices_per_level = [] + roi_align_results_per_level = [] + for i in range(self.num_scales): + """ + %1269 = equal(%1268, i); + %1270 = argwhere(%1269); + %1271 = split(%1270, indices_or_sections=1, axis=1); + %1272 = %1271.0; + %1273 = squeeze(%1272, axis=[1]); + %1274 = (%1273,); + %1275 = %1274.0; + %1276 = cast(%1275, dtype="int64"); + %1277 = (%1214, %1276); + %1278 = adv_index(%1277); + %1279 = vision.roi_align(%763, %1278, ...) + """ + equal = op.equal(levels, i) + argwhere = op.argwhere(equal) + split = op.split(argwhere, indices_or_sections=1, axis=1) + squeeze = op.squeeze(split[0], axis=[1]); + indices = op.cast(squeeze, dtype="int64") + rois = op.adv_index([rois, indices]) + roi_align = op.vision.roi_align(per_level_features[i], rois) + + indices_per_level.append(indices) + roi_align_results_per_level.append(roi_align) + + indices_concat = op.concatenate(indices_per_level) + roi_align_results_concat = op.concatenate(roi_align_results_per_level) + argsort = op.cast(op.argsort(indices_concat), dtype="int64") + return op.adv_index([roi_align_results_concat, argsort]) + + def callback(self, pre, post, node_map): levels = node_map[self.levels][0] rois = node_map[self.rois][0] - scatter_res = node_map[self.scatter_res][0] per_level_features = [node_map[feat] for feat in self.per_level_features] + print("matched") + # return self.convert_scatter_to_gather(levels, rois, per_level_features) return pre From 32f9a5e6846a989b13ac40c85e67a14e6b259259 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 23 Jan 2021 16:30:19 +0900 Subject: [PATCH 07/11] conversion seems working --- python/tvm/relay/frontend/pytorch_utils.py | 71 +++++++--------------- 1 file changed, 23 insertions(+), 48 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 7a3f0c4e7c7d..7ce31b08f777 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -280,8 +280,9 @@ def callback(self, pre, post, node_map): ) -def scatter_roi_align_result_pattern(levels, rois, per_level_features, scatter_res, num_scales): +def scatter_roi_align_result_pattern(levels, roi_align_results, num_scales): """TODO""" + def do_where(levels, _): idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant())) idx_in_level = is_op("split")(idx_in_level) @@ -290,6 +291,8 @@ def do_where(levels, _): idx_in_level = is_tuple_get_item(is_tuple([idx_in_level]), 0) return idx_in_level + scatter_res = wildcard() + for i in range(num_scales): # index = torch.where(levels == level)[0].view(-1, 1, 1, 1) scatter_indices = do_where(levels, i) @@ -303,16 +306,7 @@ def do_where(levels, _): scatter_indices = is_op("repeat")(scatter_indices) scatter_indices = is_op("repeat")(scatter_indices) - # idx_in_level = torch.where(levels == level)[0] - idx_in_level = is_op("cast")(do_where(levels, i)) - - # rois_per_level = rois[idx_in_level] - rois_per_level = is_op("adv_index")(is_tuple([rois, idx_in_level])) - # result_idx_in_level = roi_align(rois_per_level, ...) - result_idx_in_level = is_op("vision.roi_align")(per_level_features[i], rois_per_level) - - # res = res.scatter(0, index, unmerged_results[level]) - scatter_res = is_op("scatter")(scatter_res, scatter_indices, result_idx_in_level) + scatter_res = is_op("scatter")(scatter_res, scatter_indices, roi_align_results[i]) return is_op("reshape")(scatter_res) @@ -322,58 +316,39 @@ def __init__(self, num_scales): super().__init__() self.num_scales = num_scales self.levels = wildcard() - self.rois = wildcard() - self.scatter_res = wildcard() - self.per_level_features = [] - self.roi_align_attrs = [] + self.roi_align_results = [] for _ in range(num_scales): - self.per_level_features.append(wildcard()) + self.roi_align_results.append(wildcard()) self.pattern = scatter_roi_align_result_pattern( - self.levels, self.rois, self.per_level_features, self.scatter_res, num_scales + self.levels, self.roi_align_results, num_scales ) - def convert_scatter_to_gather(self, levels, rois, per_level_features): + def convert_scatter_to_gather(self, levels, roi_align_results): + # Collect inidices and concat them indices_per_level = [] - roi_align_results_per_level = [] for i in range(self.num_scales): - """ - %1269 = equal(%1268, i); - %1270 = argwhere(%1269); - %1271 = split(%1270, indices_or_sections=1, axis=1); - %1272 = %1271.0; - %1273 = squeeze(%1272, axis=[1]); - %1274 = (%1273,); - %1275 = %1274.0; - %1276 = cast(%1275, dtype="int64"); - %1277 = (%1214, %1276); - %1278 = adv_index(%1277); - %1279 = vision.roi_align(%763, %1278, ...) - """ - equal = op.equal(levels, i) + equal = op.equal(levels, expr.const(i, dtype="int64")) argwhere = op.argwhere(equal) split = op.split(argwhere, indices_or_sections=1, axis=1) - squeeze = op.squeeze(split[0], axis=[1]); + squeeze = op.squeeze(split[0], axis=[1]) indices = op.cast(squeeze, dtype="int64") - rois = op.adv_index([rois, indices]) - roi_align = op.vision.roi_align(per_level_features[i], rois) - indices_per_level.append(indices) - roi_align_results_per_level.append(roi_align) - indices_concat = op.concatenate(indices_per_level) - roi_align_results_concat = op.concatenate(roi_align_results_per_level) - argsort = op.cast(op.argsort(indices_concat), dtype="int64") - return op.adv_index([roi_align_results_concat, argsort]) + indices_concat = op.concatenate(indices_per_level, 0) + + # Concat roi align results per level, and argsort indices + # To prepare for a batched gather + roi_align_results_concat = op.concatenate(roi_align_results, 0) + argsort_indices = op.cast(op.argsort(indices_concat), dtype="int64") + # Permute rows by argsorted indices + return op.adv_index([roi_align_results_concat, argsort_indices]) def callback(self, pre, post, node_map): levels = node_map[self.levels][0] - rois = node_map[self.rois][0] - per_level_features = [node_map[feat] for feat in self.per_level_features] - print("matched") - # return self.convert_scatter_to_gather(levels, rois, per_level_features) - return pre + roi_align_results = [node_map[feat][0] for feat in self.roi_align_results] + return self.convert_scatter_to_gather(levels, roi_align_results) def rewrite_nms_to_batched_nms(mod): @@ -393,7 +368,7 @@ def rewrite_batched_nms_with_max_out_size(mod): return mod -def rewrite_scatter(mod, num_scales): +def rewrite_scatter_to_gather(mod, num_scales): """TODO""" mod["main"] = rewrite(ScatterRewrite(num_scales), mod["main"]) return mod From 1ebbcde532951516f64d3786ab10ac6a64b6fae0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 23 Jan 2021 16:46:57 +0900 Subject: [PATCH 08/11] add reshape, use take --- python/tvm/relay/frontend/pytorch_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 7ce31b08f777..21a51971d695 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -343,7 +343,8 @@ def convert_scatter_to_gather(self, levels, roi_align_results): argsort_indices = op.cast(op.argsort(indices_concat), dtype="int64") # Permute rows by argsorted indices - return op.adv_index([roi_align_results_concat, argsort_indices]) + permuted = op.take(roi_align_results_concat, argsort_indices, axis=0) + return op.reshape(permuted, [0, -1, 1, 1]) def callback(self, pre, post, node_map): levels = node_map[self.levels][0] From 2e69e5a41b99384c10601d47feed778e6f20f291 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 23 Jan 2021 16:50:19 +0900 Subject: [PATCH 09/11] remove pytorch argsort converter --- python/tvm/relay/frontend/pytorch.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3c835fe3e07a..991e3a8a0032 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2081,12 +2081,6 @@ def is_floating_point(self, inputs, input_types): is_float = input_type in ["float32", "float64", "float16", "bfloat16"] return _expr.const(is_float) - def argsort(self, inputs, input_types): - data = inputs[0] - dim = inputs[1] - is_descending = inputs[2] - return _op.argsort(data, dim, not is_descending) - # Operator mappings def create_convert_map(self): self.convert_map = { From 21e6c24f6dc6b0fcfa375cfefec51d0b2090de13 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 23 Jan 2021 16:53:24 +0900 Subject: [PATCH 10/11] update test --- tests/python/frontend/pytorch/test_object_detection.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 6273fbbd60b2..fd33dd1da8b1 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -29,6 +29,7 @@ from tvm.relay.frontend.pytorch_utils import ( rewrite_nms_to_batched_nms, rewrite_batched_nms_with_max_out_size, + rewrite_scatter_to_gather, ) from tvm.contrib.download import download @@ -149,6 +150,11 @@ def compile_and_run_vm(mod, params, data_np, target): after = mod["main"] assert not tvm.ir.structural_equal(after, before) + before = mod["main"] + mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn + after = mod["main"] + assert not tvm.ir.structural_equal(after, before) + tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm") # Results should be equivalent after rewriting From 5a13030c736763a0fbe82f6aa6083d5bcd93bf76 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 25 Jan 2021 15:55:30 +0900 Subject: [PATCH 11/11] add doc --- python/tvm/relay/frontend/pytorch_utils.py | 39 ++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 21a51971d695..248f5354cfbb 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -281,7 +281,22 @@ def callback(self, pre, post, node_map): def scatter_roi_align_result_pattern(levels, roi_align_results, num_scales): - """TODO""" + """Detect the Relay subgraph corresponding to the following PyTorch code + + first_result = roi_align_results[0] + dtype, device = first_result.dtype, first_result.device + res = torch.zeros((levels.size(0), first_result.size(1), + first_result.size(2), first_result.size(3)), + dtype=dtype, device=device) + for level in range(len(roi_align_results)): + index = torch.where(levels == level)[0].view(-1, 1, 1, 1) + index = index.expand(index.size(0), + roi_align_results[level].size(1), + roi_align_results[level].size(2), + roi_align_results[level].size(3)) + res = res.scatter(0, index, roi_align_results[level]) + return res + """ def do_where(levels, _): idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant())) @@ -312,6 +327,8 @@ def do_where(levels, _): class ScatterRewrite(DFPatternCallback): + """A callback to rewrite repeated scatters with a batched gather.""" + def __init__(self, num_scales): super().__init__() self.num_scales = num_scales @@ -325,6 +342,19 @@ def __init__(self, num_scales): ) def convert_scatter_to_gather(self, levels, roi_align_results): + """Replace the detected scatter loop with the following PyTorch code + + indices_per_level = [] + for level in range(num_scales): + idx_in_level = torch.where(levels == level)[0] + indices_per_leve.append(idx_in_level) + + stacked_features = torch.cat(roi_align_results, dim=0) + stacked_indices = torch.cat(indices_per_level, dim=0) + argsort_indices = torch.argort(stacked_indices) + return stacked_features[argsort_indices, :] + """ + # Collect inidices and concat them indices_per_level = [] for i in range(self.num_scales): @@ -344,6 +374,7 @@ def convert_scatter_to_gather(self, levels, roi_align_results): # Permute rows by argsorted indices permuted = op.take(roi_align_results_concat, argsort_indices, axis=0) + return op.reshape(permuted, [0, -1, 1, 1]) def callback(self, pre, post, node_map): @@ -370,6 +401,10 @@ def rewrite_batched_nms_with_max_out_size(mod): def rewrite_scatter_to_gather(mod, num_scales): - """TODO""" + """Rewrite the input graph to replace a repeated scatter loop with + a batched gather. The scatter loop is used in torchvision MultiScaleRoIAlign + to merge roi_align results for all scales. The scatter is used to emulate + inplace updates. + """ mod["main"] = rewrite(ScatterRewrite(num_scales), mod["main"]) return mod