Skip to content

Commit

Permalink
[onnx] Update the importer to create a none for missing operands (l…
Browse files Browse the repository at this point in the history
…lvm#2931)

Some operands are optional so we require a placeholder for missing
operands. We invent an `onnx.None` operation as our placeholder.
  • Loading branch information
rsuderman authored Feb 20, 2024
1 parent 4446fa0 commit 13553d4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 61 deletions.
13 changes: 3 additions & 10 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,7 @@
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseWhereScalarModule_basic",
"FlattenDynamicModule_basic",
"FlipModule_basic",
"FlipModuleStaticShape_basic",
"GluStaticModule_basic",
"MaskedFillTensorFloatValueModule_basic",
Expand All @@ -2193,17 +2194,9 @@
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
}

ONNX_CRASHING_SET = {
"FlipModule_basic",
"IndexTensorNegativeIndexModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"PermuteNegativeIndexModule_basic",
"RollModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceSizeTwoStepModule_basic",
}

ONNX_CRASHING_SET = { }
20 changes: 19 additions & 1 deletion python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def import_all(self):
# much unused crap.
for init in self._gi.initializer_map.values():
self.import_initializer(init)

self.get_none()
for node in self._gi.graph_proto.node:
self.import_node(node)

Expand All @@ -272,6 +274,20 @@ def import_all(self):
with InsertionPoint(self._b), Location.unknown():
func_dialect.ReturnOp(outputs)

def get_none(self):
if '' in self._nv_map:
return self._nv_map['']

with InsertionPoint(self._b), Location.name("onnx_importer.none"):
nne = Operation.create(
name="torch.constant.none",
results=[self._cc.get_none_type()],
operands=[],
attributes={},
).results[0]
self._nv_map[''] = nne
return nne

def import_node(self, node: onnx.NodeProto):
with InsertionPoint(self._b), Location.name(node.name):
op_type = node.op_type
Expand All @@ -283,7 +299,6 @@ def import_node(self, node: onnx.NodeProto):
was_handled = getattr(self, special_key)(node)
if was_handled:
return

# General node import.
input_values = []
for input_name in node.input:
Expand Down Expand Up @@ -449,6 +464,9 @@ def tensor_element_type(self, elem_type: int) -> IrType:
self._elem_type_map[elem_type] = t
return t

def get_none_type(self):
return IrType.parse("!torch.none", context=self._c)

def get_vtensor_type(
self, dims: tuple[Optional[int]], element_type: IrType
) -> IrType:
Expand Down
53 changes: 3 additions & 50 deletions test/python/onnx_importer/import_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,12 @@
"node_test_castlike_FLOAT_to_STRING_model",
"node_test_castlike_STRING_to_FLOAT_expanded_model",
"node_test_castlike_STRING_to_FLOAT_model",
"node_test_center_crop_pad_crop_axes_chw_expanded_model",
"node_test_center_crop_pad_crop_axes_hwc_expanded_model",
"node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model",
"node_test_clip_default_inbounds_model",
"node_test_clip_default_int8_inbounds_model",
"node_test_clip_default_int8_max_model",
"node_test_clip_default_max_model",
"node_test_constantofshape_float_ones_model",
"node_test_constantofshape_int_shape_zero_model",
"node_test_constantofshape_int_zeros_model",
"node_test_dequantizelinear_e4m3fn_model",
"node_test_dequantizelinear_e4m3fn_zero_point_model",
"node_test_dequantizelinear_e5m2_model",
"node_test_dft_axis_model",
"node_test_dft_inverse_model",
"node_test_dft_model",
"node_test_equal_string_broadcast_model",
"node_test_equal_string_model",
"node_test_gru_defaults_model",
Expand Down Expand Up @@ -175,8 +165,6 @@
"node_test_optional_get_element_optional_sequence_model",
"node_test_optional_get_element_optional_tensor_model",
"node_test_optional_get_element_sequence_model",
"node_test_optional_has_element_empty_no_input_name_optional_input_model",
"node_test_optional_has_element_empty_no_input_name_tensor_input_model",
"node_test_optional_has_element_empty_optional_input_model",
"node_test_optional_has_element_optional_input_model",
"node_test_optional_has_element_tensor_input_model",
Expand All @@ -187,43 +175,6 @@
"node_test_regex_full_match_basic_model",
"node_test_regex_full_match_email_domain_model",
"node_test_regex_full_match_empty_model",
"node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model",
"node_test_resize_downsample_scales_cubic_align_corners_model",
"node_test_resize_downsample_scales_cubic_antialias_model",
"node_test_resize_downsample_scales_cubic_model",
"node_test_resize_downsample_scales_linear_align_corners_model",
"node_test_resize_downsample_scales_linear_antialias_model",
"node_test_resize_downsample_scales_linear_half_pixel_symmetric_model",
"node_test_resize_downsample_scales_linear_model",
"node_test_resize_downsample_scales_nearest_model",
"node_test_resize_downsample_sizes_cubic_antialias_model",
"node_test_resize_downsample_sizes_cubic_model",
"node_test_resize_downsample_sizes_linear_antialias_model",
"node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model",
"node_test_resize_downsample_sizes_nearest_model",
"node_test_resize_downsample_sizes_nearest_not_larger_model",
"node_test_resize_downsample_sizes_nearest_not_smaller_model",
"node_test_resize_tf_crop_and_resize_axes_2_3_model",
"node_test_resize_tf_crop_and_resize_axes_3_2_model",
"node_test_resize_tf_crop_and_resize_model",
"node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model",
"node_test_resize_upsample_scales_cubic_align_corners_model",
"node_test_resize_upsample_scales_cubic_asymmetric_model",
"node_test_resize_upsample_scales_cubic_model",
"node_test_resize_upsample_scales_linear_align_corners_model",
"node_test_resize_upsample_scales_linear_half_pixel_symmetric_model",
"node_test_resize_upsample_scales_linear_model",
"node_test_resize_upsample_scales_nearest_axes_2_3_model",
"node_test_resize_upsample_scales_nearest_axes_3_2_model",
"node_test_resize_upsample_scales_nearest_model",
"node_test_resize_upsample_sizes_cubic_model",
"node_test_resize_upsample_sizes_nearest_axes_2_3_model",
"node_test_resize_upsample_sizes_nearest_axes_3_2_model",
"node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model",
"node_test_resize_upsample_sizes_nearest_floor_align_corners_model",
"node_test_resize_upsample_sizes_nearest_model",
"node_test_resize_upsample_sizes_nearest_not_larger_model",
"node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model",
"node_test_rnn_seq_length_model",
"node_test_scan9_sum_model",
"node_test_scan_sum_model",
Expand All @@ -246,7 +197,6 @@
"node_test_split_to_sequence_1_model",
"node_test_split_to_sequence_2_model",
"node_test_split_to_sequence_nokeepdims_model",
"node_test_stft_model",
"node_test_string_concat_broadcasting_model",
"node_test_string_concat_empty_string_model",
"node_test_string_concat_model",
Expand Down Expand Up @@ -281,6 +231,9 @@
]





class ImportSmokeTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down

0 comments on commit 13553d4

Please sign in to comment.