From 13553d49c9488e09fec6ba790fb095eea66c48ea Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 20 Feb 2024 09:30:30 -0800 Subject: [PATCH] [onnx] Update the importer to create a `none` for missing operands (#2931) Some operands are optional so we require a placeholder for missing operands. We invent an `onnx.None` operation as our placeholder. --- projects/pt1/e2e_testing/xfail_sets.py | 13 ++--- python/torch_mlir/extras/onnx_importer.py | 20 ++++++- .../python/onnx_importer/import_smoke_test.py | 53 ++----------------- 3 files changed, 25 insertions(+), 61 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 52e1ea3321b8..632b15e85c74 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2184,6 +2184,7 @@ "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseWhereScalarModule_basic", "FlattenDynamicModule_basic", + "FlipModule_basic", "FlipModuleStaticShape_basic", "GluStaticModule_basic", "MaskedFillTensorFloatValueModule_basic", @@ -2193,17 +2194,9 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", -} -ONNX_CRASHING_SET = { - "FlipModule_basic", - "IndexTensorNegativeIndexModule_basic", "MoveDimIntNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic", - "RollModule_basic", - "SliceModule_basic", - "SliceNegIdxModule_basic", - "SliceOutOfLowerBoundEndIndexModule_basic", - "SliceOutOfLowerBoundStartIndexModule_basic", - "SliceSizeTwoStepModule_basic", } + +ONNX_CRASHING_SET = { } diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 24520e9ce970..c62324832520 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -258,6 +258,8 @@ def import_all(self): # much unused crap. for init in self._gi.initializer_map.values(): self.import_initializer(init) + + self.get_none() for node in self._gi.graph_proto.node: self.import_node(node) @@ -272,6 +274,20 @@ def import_all(self): with InsertionPoint(self._b), Location.unknown(): func_dialect.ReturnOp(outputs) + def get_none(self): + if '' in self._nv_map: + return self._nv_map[''] + + with InsertionPoint(self._b), Location.name("onnx_importer.none"): + nne = Operation.create( + name="torch.constant.none", + results=[self._cc.get_none_type()], + operands=[], + attributes={}, + ).results[0] + self._nv_map[''] = nne + return nne + def import_node(self, node: onnx.NodeProto): with InsertionPoint(self._b), Location.name(node.name): op_type = node.op_type @@ -283,7 +299,6 @@ def import_node(self, node: onnx.NodeProto): was_handled = getattr(self, special_key)(node) if was_handled: return - # General node import. input_values = [] for input_name in node.input: @@ -449,6 +464,9 @@ def tensor_element_type(self, elem_type: int) -> IrType: self._elem_type_map[elem_type] = t return t + def get_none_type(self): + return IrType.parse("!torch.none", context=self._c) + def get_vtensor_type( self, dims: tuple[Optional[int]], element_type: IrType ) -> IrType: diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py index f27cc9caf5bd..708324e72db6 100644 --- a/test/python/onnx_importer/import_smoke_test.py +++ b/test/python/onnx_importer/import_smoke_test.py @@ -102,22 +102,12 @@ "node_test_castlike_FLOAT_to_STRING_model", "node_test_castlike_STRING_to_FLOAT_expanded_model", "node_test_castlike_STRING_to_FLOAT_model", - "node_test_center_crop_pad_crop_axes_chw_expanded_model", - "node_test_center_crop_pad_crop_axes_hwc_expanded_model", - "node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model", - "node_test_clip_default_inbounds_model", - "node_test_clip_default_int8_inbounds_model", - "node_test_clip_default_int8_max_model", - "node_test_clip_default_max_model", "node_test_constantofshape_float_ones_model", "node_test_constantofshape_int_shape_zero_model", "node_test_constantofshape_int_zeros_model", "node_test_dequantizelinear_e4m3fn_model", "node_test_dequantizelinear_e4m3fn_zero_point_model", "node_test_dequantizelinear_e5m2_model", - "node_test_dft_axis_model", - "node_test_dft_inverse_model", - "node_test_dft_model", "node_test_equal_string_broadcast_model", "node_test_equal_string_model", "node_test_gru_defaults_model", @@ -175,8 +165,6 @@ "node_test_optional_get_element_optional_sequence_model", "node_test_optional_get_element_optional_tensor_model", "node_test_optional_get_element_sequence_model", - "node_test_optional_has_element_empty_no_input_name_optional_input_model", - "node_test_optional_has_element_empty_no_input_name_tensor_input_model", "node_test_optional_has_element_empty_optional_input_model", "node_test_optional_has_element_optional_input_model", "node_test_optional_has_element_tensor_input_model", @@ -187,43 +175,6 @@ "node_test_regex_full_match_basic_model", "node_test_regex_full_match_email_domain_model", "node_test_regex_full_match_empty_model", - "node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model", - "node_test_resize_downsample_scales_cubic_align_corners_model", - "node_test_resize_downsample_scales_cubic_antialias_model", - "node_test_resize_downsample_scales_cubic_model", - "node_test_resize_downsample_scales_linear_align_corners_model", - "node_test_resize_downsample_scales_linear_antialias_model", - "node_test_resize_downsample_scales_linear_half_pixel_symmetric_model", - "node_test_resize_downsample_scales_linear_model", - "node_test_resize_downsample_scales_nearest_model", - "node_test_resize_downsample_sizes_cubic_antialias_model", - "node_test_resize_downsample_sizes_cubic_model", - "node_test_resize_downsample_sizes_linear_antialias_model", - "node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model", - "node_test_resize_downsample_sizes_nearest_model", - "node_test_resize_downsample_sizes_nearest_not_larger_model", - "node_test_resize_downsample_sizes_nearest_not_smaller_model", - "node_test_resize_tf_crop_and_resize_axes_2_3_model", - "node_test_resize_tf_crop_and_resize_axes_3_2_model", - "node_test_resize_tf_crop_and_resize_model", - "node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model", - "node_test_resize_upsample_scales_cubic_align_corners_model", - "node_test_resize_upsample_scales_cubic_asymmetric_model", - "node_test_resize_upsample_scales_cubic_model", - "node_test_resize_upsample_scales_linear_align_corners_model", - "node_test_resize_upsample_scales_linear_half_pixel_symmetric_model", - "node_test_resize_upsample_scales_linear_model", - "node_test_resize_upsample_scales_nearest_axes_2_3_model", - "node_test_resize_upsample_scales_nearest_axes_3_2_model", - "node_test_resize_upsample_scales_nearest_model", - "node_test_resize_upsample_sizes_cubic_model", - "node_test_resize_upsample_sizes_nearest_axes_2_3_model", - "node_test_resize_upsample_sizes_nearest_axes_3_2_model", - "node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model", - "node_test_resize_upsample_sizes_nearest_floor_align_corners_model", - "node_test_resize_upsample_sizes_nearest_model", - "node_test_resize_upsample_sizes_nearest_not_larger_model", - "node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model", "node_test_rnn_seq_length_model", "node_test_scan9_sum_model", "node_test_scan_sum_model", @@ -246,7 +197,6 @@ "node_test_split_to_sequence_1_model", "node_test_split_to_sequence_2_model", "node_test_split_to_sequence_nokeepdims_model", - "node_test_stft_model", "node_test_string_concat_broadcasting_model", "node_test_string_concat_empty_string_model", "node_test_string_concat_model", @@ -281,6 +231,9 @@ ] + + + class ImportSmokeTest(unittest.TestCase): @classmethod def setUpClass(cls):