From acf637ddf2d58029ab48d59d657a4fdbf6809568 Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Mon, 2 Jan 2023 14:42:02 +0900 Subject: [PATCH] Support for merging models with custom domains --- snc4onnx/__init__.py | 2 +- snc4onnx/onnx_network_combine.py | 116 ++++++++++++++++++++++++++----- 2 files changed, 101 insertions(+), 17 deletions(-) diff --git a/snc4onnx/__init__.py b/snc4onnx/__init__.py index d351872..bc83dd4 100644 --- a/snc4onnx/__init__.py +++ b/snc4onnx/__init__.py @@ -1,3 +1,3 @@ from snc4onnx.onnx_network_combine import combine, main -__version__ = '1.0.9' +__version__ = '1.0.10' diff --git a/snc4onnx/onnx_network_combine.py b/snc4onnx/onnx_network_combine.py index c0875b6..4267ec3 100644 --- a/snc4onnx/onnx_network_combine.py +++ b/snc4onnx/onnx_network_combine.py @@ -1,8 +1,8 @@ #! /usr/bin/env python -import sys import os import re +import sys import traceback import collections import itertools @@ -12,6 +12,7 @@ from onnxsim import simplify from typing import Optional, List + class Color: BLACK = '\033[30m' RED = '\033[31m' @@ -37,6 +38,12 @@ class Color: BG_DEFAULT = '\033[49m' RESET = '\033[0m' +ONNX_STANDARD_DOMAINS = [ + 'ai.onnx', + 'ai.onnx.ml', + '', +] + def combine( srcop_destop: List[str], @@ -206,7 +213,8 @@ def has_duplicates(seq): # MODEL_INDX print - only input_onnx_file_paths if len(onnx_graphs) == 0: - for idx, (input_onnx_file_path, op_prefix_after_merging) in enumerate(itertools.zip_longest(input_onnx_file_paths, op_prefixes_after_merging)): + for idx, (input_onnx_file_path, op_prefix_after_merging) in \ + enumerate(itertools.zip_longest(input_onnx_file_paths, op_prefixes_after_merging)): if not non_verbose: print( f'{Color.GREEN}INFO:{Color.RESET} '+ @@ -216,20 +224,33 @@ def has_duplicates(seq): # Combine ## 1. ONNX load tmp_onnx_graphs = [] + custom_domain_check_onnx_nodes = [] if len(onnx_graphs) > 0: for onnx_graph in onnx_graphs: gs_graph = gs.import_onnx(onnx_graph) gs_graph.cleanup().toposort() tmp_onnx_graphs.append(gs.export_onnx(gs_graph)) + custom_domain_check_onnx_nodes = \ + custom_domain_check_onnx_nodes + \ + [ + node for node in onnx_graph.graph.node \ + if node.domain not in ONNX_STANDARD_DOMAINS + ] else: for onnx_path in input_onnx_file_paths: gs_graph = gs.import_onnx(onnx.load(onnx_path)) gs_graph.cleanup().toposort() tmp_onnx_graphs.append(gs.export_onnx(gs_graph)) + custom_domain_check_onnx_graph = onnx.load(onnx_path) + custom_domain_check_onnx_nodes = \ + custom_domain_check_onnx_nodes + \ + [ + node for node in custom_domain_check_onnx_graph.graph.node \ + if node.domain not in ONNX_STANDARD_DOMAINS + ] ## 2. Repeat Merge for model_idx in range(0, len(tmp_onnx_graphs) - 1): - src_prefix = '' dest_prefix = '' @@ -259,15 +280,57 @@ def has_duplicates(seq): src_gs_model = gs.import_onnx(src_model) dest_gs_model = gs.import_onnx(dest_model) + # Merging Domain Lists + src_gs_model_domains: List[onnx.OperatorSetIdProto] = src_gs_model.import_domains + dest_gs_model_domains: List[onnx.OperatorSetIdProto] = dest_gs_model.import_domains + merged_gs_model_domains: List[onnx.OperatorSetIdProto] = src_gs_model_domains + distinct_dest_gs_model_domains = [ + domain \ + for domain in dest_gs_model_domains \ + if domain not in merged_gs_model_domains + ] + for domain in distinct_dest_gs_model_domains: + merged_gs_model_domains.append(domain) + src_gs_model.import_domains = merged_gs_model_domains + + # Check if Graph contains a custom domain (custom module) + contains_custom_domain = len( + [ + domain \ + for domain in src_gs_model.import_domains \ + if domain.domain not in ONNX_STANDARD_DOMAINS + ] + ) > 0 + # Duplicate OP name check - src_node_names = [graph_node.name for graph_node in src_gs_model.nodes] - src_input_names = [graph_input.name for graph_input in src_gs_model.inputs] - src_output_names = [graph_output.name for graph_output in src_gs_model.outputs if graph_output.name not in src_node_names] + src_node_names = [ + graph_node.name \ + for graph_node in src_gs_model.nodes + ] + src_input_names = [ + graph_input.name \ + for graph_input in src_gs_model.inputs + ] + src_output_names = [ + graph_output.name \ + for graph_output in src_gs_model.outputs \ + if graph_output.name not in src_node_names + ] src_model_op_names = src_node_names + src_input_names + src_output_names - dest_node_names = [graph_node.name for graph_node in dest_gs_model.nodes] - dest_input_names = [graph_input.name for graph_input in dest_gs_model.inputs] - dest_output_names = [graph_output.name for graph_output in dest_gs_model.outputs if graph_output.name not in dest_node_names] + dest_node_names = [ + graph_node.name \ + for graph_node in dest_gs_model.nodes + ] + dest_input_names = [ + graph_input.name \ + for graph_input in dest_gs_model.inputs + ] + dest_output_names = [ + graph_output.name \ + for graph_output in dest_gs_model.outputs \ + if graph_output.name not in dest_node_names + ] dest_model_op_names = dest_node_names + dest_input_names + dest_output_names merged_model_op_names = src_model_op_names + dest_model_op_names @@ -301,9 +364,18 @@ def has_duplicates(seq): # If the OP specified as srcop in io_map_srcop_destop is a graph INPUT, # use onnx_graphsurgeon to merge # Otherwise, use onnx.compose.merge_models for simple merging - srcop_names = [f'{src_prefix}{srcop_destop_src}' for srcop_destop_src in srcop_destop[model_idx][::2]] - destop_names = [f'{dest_prefix}{srcop_destop_dest}' for srcop_destop_dest in srcop_destop[model_idx][1::2]] - src_gs_model_input_names = [src_gs_model_input.name for src_gs_model_input in src_gs_model.inputs] + srcop_names = [ + f'{src_prefix}{srcop_destop_src}' \ + for srcop_destop_src in srcop_destop[model_idx][::2] + ] + destop_names = [ + f'{dest_prefix}{srcop_destop_dest}' \ + for srcop_destop_dest in srcop_destop[model_idx][1::2] + ] + src_gs_model_input_names = [ + src_gs_model_input.name \ + for src_gs_model_input in src_gs_model.inputs + ] for srcop_name, destop_name in zip(srcop_names, destop_names): # Split processing if srcop_name is included or not included in the graph INPUT @@ -352,7 +424,9 @@ def has_duplicates(seq): if not used_flg: remove_input_names.append(input_name) src_gs_model.inputs = [ - src_gs_model_input for src_gs_model_input in src_gs_model.inputs if src_gs_model_input.name not in remove_input_names + src_gs_model_input \ + for src_gs_model_input in src_gs_model.inputs \ + if src_gs_model_input.name not in remove_input_names ] # Cleaning @@ -409,7 +483,10 @@ def has_duplicates(seq): ## 4. Optimize try: - combined_model, check = simplify(combined_model) + # onnx-simplifier does not support optimization of ONNX files containing custom domains, + # so skip simplify if it contains custom domains + if not contains_custom_domain: + combined_model, check = simplify(combined_model) except Exception as e: if not non_verbose: print( @@ -419,14 +496,21 @@ def has_duplicates(seq): tracetxt = traceback.format_exc().splitlines()[-1] print(f'{Color.YELLOW}WARNING:{Color.RESET} {tracetxt}') - ## 5. Final save + ## 5. Restore a node's custom domain + combined_model_graph_nodes = combined_model.graph.node + for combined_model_graph_node in combined_model_graph_nodes: + for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes: + if combined_model_graph_node.name == custom_domain_check_onnx_node.name: + combined_model_graph_node.domain = custom_domain_check_onnx_node.domain + + ## 6. Final save if output_onnx_file_path: onnx.save(combined_model, output_onnx_file_path) if not non_verbose: print(f'{Color.GREEN}INFO:{Color.RESET} Finish!') - # 6. Return + # 7. Return return combined_model