diff --git a/sam4onnx/__init__.py b/sam4onnx/__init__.py index 367383e..534591c 100644 --- a/sam4onnx/__init__.py +++ b/sam4onnx/__init__.py @@ -1,3 +1,3 @@ from sam4onnx.onnx_attr_const_modify import modify, main -__version__ = '1.0.15' +__version__ = '1.0.16' diff --git a/sam4onnx/onnx_attr_const_modify.py b/sam4onnx/onnx_attr_const_modify.py index 70c30da..a32c87f 100644 --- a/sam4onnx/onnx_attr_const_modify.py +++ b/sam4onnx/onnx_attr_const_modify.py @@ -64,13 +64,6 @@ class Color: 'complex128': np.complex128, } -ONNX_STANDARD_DOMAINS = [ - 'ai.onnx', - 'ai.onnx.ml', - '', -] - - def __subgraph_node_search( search_graph: gs.Graph, op_name: str, @@ -292,30 +285,16 @@ def modify( if not onnx_graph: onnx_graph = onnx.load(input_onnx_file_path) - # Acquisition of Node with custom domain - custom_domain_check_onnx_nodes = [] - custom_domain_check_onnx_nodes = \ - custom_domain_check_onnx_nodes + \ - [ - node for node in onnx_graph.graph.node \ - if node.domain not in ONNX_STANDARD_DOMAINS - ] - # domain, ir_version domain: str = onnx_graph.domain ir_version: int = onnx_graph.ir_version + meta_data = {'domain': domain, 'ir_version': ir_version} + metadata_props = None + if hasattr(onnx_graph, 'metadata_props'): + metadata_props = onnx_graph.metadata_props graph = gs.import_onnx(onnx_graph) - # Check if Graph contains a custom domain (custom module) - contains_custom_domain = len( - [ - domain \ - for domain in graph.import_domains \ - if domain.domain not in ONNX_STANDARD_DOMAINS - ] - ) > 0 - # Search for OPs matching op_name node_subject_to_change = None if op_name: @@ -440,7 +419,9 @@ def modify( # Cleanup graph.cleanup().toposort() - modified_graph = gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}) + modified_graph = gs.export_onnx(graph, do_type_check=False, **meta_data) + if metadata_props is not None: + modified_graph.metadata_props.extend(metadata_props) # Optimize new_model = None @@ -457,14 +438,6 @@ def modify( tracetxt = traceback.format_exc().splitlines()[-1] print(f'{Color.YELLOW}WARNING:{Color.RESET} {tracetxt}') - ## Restore a node's custom domain - if contains_custom_domain: - new_model_nodes = new_model.graph.node - for new_model_node in new_model_nodes: - for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes: - if new_model_node.name == custom_domain_check_onnx_node.name: - new_model_node.domain = custom_domain_check_onnx_node.domain - # Save if output_onnx_file_path: onnx.save(new_model, f'{output_onnx_file_path}')