Skip to content

Commit

Permalink
Merge pull request #5 from PINTO0309/feat/params
Browse files Browse the repository at this point in the history
Additional indication of size and number of parameters
  • Loading branch information
PINTO0309 authored Sep 24, 2023
2 parents 6120e67 + 68bb328 commit 89abc07
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ structure_check(
## 6. Sample
https://github.com/PINTO0309/ssc4onnx/releases/download/1.0.6/deqflow_b_things_opset12_192x320.onnx

https://github.com/PINTO0309/ssc4onnx/assets/33194443/5ddd242d-41e1-4186-85a7-5306cd410e1d
https://github.com/PINTO0309/ssc4onnx/assets/33194443/fd6a4aa2-9ed5-492b-82ae-1f8306af5119

![image](https://github.com/PINTO0309/ssc4onnx/assets/33194443/0e079a4d-b227-488f-bc4e-cc1b686126ed)
![image](https://github.com/PINTO0309/ssc4onnx/assets/33194443/45343c95-dbb9-471c-8718-3d0a4d653250)

## 7. Reference
1. https://github.com/onnx/onnx/blob/main/docs/Operators.md
Expand Down
2 changes: 1 addition & 1 deletion ssc4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ssc4onnx.onnx_structure_check import structure_check, main

__version__ = '1.0.7'
__version__ = '1.0.8'
7 changes: 6 additions & 1 deletion ssc4onnx/onnx_structure_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, model: onnx.ModelProto):
self.op_nums = defaultdict(int)
self.op_bytesizes = defaultdict(int)
self.model_params_size = 0
self.model_params = 0
for graph_node in gs_graph.nodes:
self.op_nums[graph_node.op] += 1
if hasattr(graph_node, 'attrs') \
Expand All @@ -80,6 +81,7 @@ def __init__(self, model: onnx.ModelProto):
value: np.ndarray = graph_node.attrs['value'].values
self.op_bytesizes[graph_node.op] += value.nbytes
self.model_params_size += value.nbytes
self.model_params += np.prod(value.shape)
if hasattr(graph_node, 'attrs') \
and len(graph_node.attrs) > 0:
for key, value in graph_node.attrs.items():
Expand All @@ -91,6 +93,7 @@ def __init__(self, model: onnx.ModelProto):
and isinstance(graph_node_input.values, np.ndarray):
self.op_bytesizes[graph_node.op] += graph_node_input.values.nbytes
self.model_params_size += graph_node_input.values.nbytes
self.model_params += np.prod(graph_node_input.values.shape)
else:
self.op_bytesizes[graph_node.op] += 0
self.model_size = model.ByteSize()
Expand Down Expand Up @@ -169,13 +172,15 @@ def structure_check(
table = Table()
table.add_column('OP Type')
table.add_column('OPs')
table.add_column('Params')
table.add_column('Sizes')
sorted_list = sorted(list(set(model_info.op_nums.keys())))
sorted_bytes_list = sorted(list(set(model_info.op_bytesizes.keys())))
_ = [table.add_row(key1, f"{model_info.op_nums[key1]:,}", f"{human_readable_size(model_info.op_bytesizes[key2])}") for key1, key2 in zip(sorted_list, sorted_bytes_list)]
table.add_row('----------------------', '----------', '----------')
ops_count = sum([model_info.op_nums[key] for key in sorted_list])
table.add_row('Total number of OPs', f"{ops_count:,}")
table.add_row('----------------------', '----------', '----------')
table.add_row('Total params', f"{human_readable_size(model_info.model_params).replace('iB','')}")
table.add_row('======================', '==========', '==========')
table.add_row('Model Size', human_readable_size(model_info.model_size), human_readable_size(model_info.model_params_size))
rich_print(table)
Expand Down

0 comments on commit 89abc07

Please sign in to comment.