Skip to content

Commit

Permalink
Merge pull request #6 from PINTO0309/preserve_meta
Browse files Browse the repository at this point in the history
Retention of metadata_props
  • Loading branch information
PINTO0309 authored May 28, 2024
2 parents 6e72b47 + b72e0fc commit 84cd2fe
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 35 deletions.
2 changes: 1 addition & 1 deletion sam4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sam4onnx.onnx_attr_const_modify import modify, main

__version__ = '1.0.15'
__version__ = '1.0.16'
41 changes: 7 additions & 34 deletions sam4onnx/onnx_attr_const_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ class Color:
'complex128': np.complex128,
}

ONNX_STANDARD_DOMAINS = [
'ai.onnx',
'ai.onnx.ml',
'',
]


def __subgraph_node_search(
search_graph: gs.Graph,
op_name: str,
Expand Down Expand Up @@ -292,30 +285,16 @@ def modify(
if not onnx_graph:
onnx_graph = onnx.load(input_onnx_file_path)

# Acquisition of Node with custom domain
custom_domain_check_onnx_nodes = []
custom_domain_check_onnx_nodes = \
custom_domain_check_onnx_nodes + \
[
node for node in onnx_graph.graph.node \
if node.domain not in ONNX_STANDARD_DOMAINS
]

# domain, ir_version
domain: str = onnx_graph.domain
ir_version: int = onnx_graph.ir_version
meta_data = {'domain': domain, 'ir_version': ir_version}
metadata_props = None
if hasattr(onnx_graph, 'metadata_props'):
metadata_props = onnx_graph.metadata_props

graph = gs.import_onnx(onnx_graph)

# Check if Graph contains a custom domain (custom module)
contains_custom_domain = len(
[
domain \
for domain in graph.import_domains \
if domain.domain not in ONNX_STANDARD_DOMAINS
]
) > 0

# Search for OPs matching op_name
node_subject_to_change = None
if op_name:
Expand Down Expand Up @@ -440,7 +419,9 @@ def modify(

# Cleanup
graph.cleanup().toposort()
modified_graph = gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version})
modified_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
if metadata_props is not None:
modified_graph.metadata_props.extend(metadata_props)

# Optimize
new_model = None
Expand All @@ -457,14 +438,6 @@ def modify(
tracetxt = traceback.format_exc().splitlines()[-1]
print(f'{Color.YELLOW}WARNING:{Color.RESET} {tracetxt}')

## Restore a node's custom domain
if contains_custom_domain:
new_model_nodes = new_model.graph.node
for new_model_node in new_model_nodes:
for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes:
if new_model_node.name == custom_domain_check_onnx_node.name:
new_model_node.domain = custom_domain_check_onnx_node.domain

# Save
if output_onnx_file_path:
onnx.save(new_model, f'{output_onnx_file_path}')
Expand Down

0 comments on commit 84cd2fe

Please sign in to comment.