Skip to content

Commit

Permalink
Support for merging models with custom domains
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Jan 2, 2023
1 parent 4140f30 commit acf637d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 17 deletions.
2 changes: 1 addition & 1 deletion snc4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from snc4onnx.onnx_network_combine import combine, main

__version__ = '1.0.9'
__version__ = '1.0.10'
116 changes: 100 additions & 16 deletions snc4onnx/onnx_network_combine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#! /usr/bin/env python

import sys
import os
import re
import sys
import traceback
import collections
import itertools
Expand All @@ -12,6 +12,7 @@
from onnxsim import simplify
from typing import Optional, List


class Color:
BLACK = '\033[30m'
RED = '\033[31m'
Expand All @@ -37,6 +38,12 @@ class Color:
BG_DEFAULT = '\033[49m'
RESET = '\033[0m'

ONNX_STANDARD_DOMAINS = [
'ai.onnx',
'ai.onnx.ml',
'',
]


def combine(
srcop_destop: List[str],
Expand Down Expand Up @@ -206,7 +213,8 @@ def has_duplicates(seq):

# MODEL_INDX print - only input_onnx_file_paths
if len(onnx_graphs) == 0:
for idx, (input_onnx_file_path, op_prefix_after_merging) in enumerate(itertools.zip_longest(input_onnx_file_paths, op_prefixes_after_merging)):
for idx, (input_onnx_file_path, op_prefix_after_merging) in \
enumerate(itertools.zip_longest(input_onnx_file_paths, op_prefixes_after_merging)):
if not non_verbose:
print(
f'{Color.GREEN}INFO:{Color.RESET} '+
Expand All @@ -216,20 +224,33 @@ def has_duplicates(seq):
# Combine
## 1. ONNX load
tmp_onnx_graphs = []
custom_domain_check_onnx_nodes = []
if len(onnx_graphs) > 0:
for onnx_graph in onnx_graphs:
gs_graph = gs.import_onnx(onnx_graph)
gs_graph.cleanup().toposort()
tmp_onnx_graphs.append(gs.export_onnx(gs_graph))
custom_domain_check_onnx_nodes = \
custom_domain_check_onnx_nodes + \
[
node for node in onnx_graph.graph.node \
if node.domain not in ONNX_STANDARD_DOMAINS
]
else:
for onnx_path in input_onnx_file_paths:
gs_graph = gs.import_onnx(onnx.load(onnx_path))
gs_graph.cleanup().toposort()
tmp_onnx_graphs.append(gs.export_onnx(gs_graph))
custom_domain_check_onnx_graph = onnx.load(onnx_path)
custom_domain_check_onnx_nodes = \
custom_domain_check_onnx_nodes + \
[
node for node in custom_domain_check_onnx_graph.graph.node \
if node.domain not in ONNX_STANDARD_DOMAINS
]

## 2. Repeat Merge
for model_idx in range(0, len(tmp_onnx_graphs) - 1):

src_prefix = ''
dest_prefix = ''

Expand Down Expand Up @@ -259,15 +280,57 @@ def has_duplicates(seq):
src_gs_model = gs.import_onnx(src_model)
dest_gs_model = gs.import_onnx(dest_model)

# Merging Domain Lists
src_gs_model_domains: List[onnx.OperatorSetIdProto] = src_gs_model.import_domains
dest_gs_model_domains: List[onnx.OperatorSetIdProto] = dest_gs_model.import_domains
merged_gs_model_domains: List[onnx.OperatorSetIdProto] = src_gs_model_domains
distinct_dest_gs_model_domains = [
domain \
for domain in dest_gs_model_domains \
if domain not in merged_gs_model_domains
]
for domain in distinct_dest_gs_model_domains:
merged_gs_model_domains.append(domain)
src_gs_model.import_domains = merged_gs_model_domains

# Check if Graph contains a custom domain (custom module)
contains_custom_domain = len(
[
domain \
for domain in src_gs_model.import_domains \
if domain.domain not in ONNX_STANDARD_DOMAINS
]
) > 0

# Duplicate OP name check
src_node_names = [graph_node.name for graph_node in src_gs_model.nodes]
src_input_names = [graph_input.name for graph_input in src_gs_model.inputs]
src_output_names = [graph_output.name for graph_output in src_gs_model.outputs if graph_output.name not in src_node_names]
src_node_names = [
graph_node.name \
for graph_node in src_gs_model.nodes
]
src_input_names = [
graph_input.name \
for graph_input in src_gs_model.inputs
]
src_output_names = [
graph_output.name \
for graph_output in src_gs_model.outputs \
if graph_output.name not in src_node_names
]
src_model_op_names = src_node_names + src_input_names + src_output_names

dest_node_names = [graph_node.name for graph_node in dest_gs_model.nodes]
dest_input_names = [graph_input.name for graph_input in dest_gs_model.inputs]
dest_output_names = [graph_output.name for graph_output in dest_gs_model.outputs if graph_output.name not in dest_node_names]
dest_node_names = [
graph_node.name \
for graph_node in dest_gs_model.nodes
]
dest_input_names = [
graph_input.name \
for graph_input in dest_gs_model.inputs
]
dest_output_names = [
graph_output.name \
for graph_output in dest_gs_model.outputs \
if graph_output.name not in dest_node_names
]
dest_model_op_names = dest_node_names + dest_input_names + dest_output_names

merged_model_op_names = src_model_op_names + dest_model_op_names
Expand Down Expand Up @@ -301,9 +364,18 @@ def has_duplicates(seq):
# If the OP specified as srcop in io_map_srcop_destop is a graph INPUT,
# use onnx_graphsurgeon to merge
# Otherwise, use onnx.compose.merge_models for simple merging
srcop_names = [f'{src_prefix}{srcop_destop_src}' for srcop_destop_src in srcop_destop[model_idx][::2]]
destop_names = [f'{dest_prefix}{srcop_destop_dest}' for srcop_destop_dest in srcop_destop[model_idx][1::2]]
src_gs_model_input_names = [src_gs_model_input.name for src_gs_model_input in src_gs_model.inputs]
srcop_names = [
f'{src_prefix}{srcop_destop_src}' \
for srcop_destop_src in srcop_destop[model_idx][::2]
]
destop_names = [
f'{dest_prefix}{srcop_destop_dest}' \
for srcop_destop_dest in srcop_destop[model_idx][1::2]
]
src_gs_model_input_names = [
src_gs_model_input.name \
for src_gs_model_input in src_gs_model.inputs
]

for srcop_name, destop_name in zip(srcop_names, destop_names):
# Split processing if srcop_name is included or not included in the graph INPUT
Expand Down Expand Up @@ -352,7 +424,9 @@ def has_duplicates(seq):
if not used_flg:
remove_input_names.append(input_name)
src_gs_model.inputs = [
src_gs_model_input for src_gs_model_input in src_gs_model.inputs if src_gs_model_input.name not in remove_input_names
src_gs_model_input \
for src_gs_model_input in src_gs_model.inputs \
if src_gs_model_input.name not in remove_input_names
]

# Cleaning
Expand Down Expand Up @@ -409,7 +483,10 @@ def has_duplicates(seq):

## 4. Optimize
try:
combined_model, check = simplify(combined_model)
# onnx-simplifier does not support optimization of ONNX files containing custom domains,
# so skip simplify if it contains custom domains
if not contains_custom_domain:
combined_model, check = simplify(combined_model)
except Exception as e:
if not non_verbose:
print(
Expand All @@ -419,14 +496,21 @@ def has_duplicates(seq):
tracetxt = traceback.format_exc().splitlines()[-1]
print(f'{Color.YELLOW}WARNING:{Color.RESET} {tracetxt}')

## 5. Final save
## 5. Restore a node's custom domain
combined_model_graph_nodes = combined_model.graph.node
for combined_model_graph_node in combined_model_graph_nodes:
for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes:
if combined_model_graph_node.name == custom_domain_check_onnx_node.name:
combined_model_graph_node.domain = custom_domain_check_onnx_node.domain

## 6. Final save
if output_onnx_file_path:
onnx.save(combined_model, output_onnx_file_path)

if not non_verbose:
print(f'{Color.GREEN}INFO:{Color.RESET} Finish!')

# 6. Return
# 7. Return
return combined_model


Expand Down

0 comments on commit acf637d

Please sign in to comment.