diff --git a/README.md b/README.md index 0e0a0a1..26e2583 100644 --- a/README.md +++ b/README.md @@ -94,9 +94,9 @@ structure_check( ## 6. Sample https://github.com/PINTO0309/ssc4onnx/releases/download/1.0.6/deqflow_b_things_opset12_192x320.onnx -https://github.com/PINTO0309/ssc4onnx/assets/33194443/5ddd242d-41e1-4186-85a7-5306cd410e1d +https://github.com/PINTO0309/ssc4onnx/assets/33194443/fd6a4aa2-9ed5-492b-82ae-1f8306af5119 -![image](https://github.com/PINTO0309/ssc4onnx/assets/33194443/0e079a4d-b227-488f-bc4e-cc1b686126ed) +![image](https://github.com/PINTO0309/ssc4onnx/assets/33194443/45343c95-dbb9-471c-8718-3d0a4d653250) ## 7. Reference 1. https://github.com/onnx/onnx/blob/main/docs/Operators.md diff --git a/ssc4onnx/__init__.py b/ssc4onnx/__init__.py index 7a4dd7d..c7108fa 100644 --- a/ssc4onnx/__init__.py +++ b/ssc4onnx/__init__.py @@ -1,3 +1,3 @@ from ssc4onnx.onnx_structure_check import structure_check, main -__version__ = '1.0.7' +__version__ = '1.0.8' diff --git a/ssc4onnx/onnx_structure_check.py b/ssc4onnx/onnx_structure_check.py index bf22dfe..6f8a312 100644 --- a/ssc4onnx/onnx_structure_check.py +++ b/ssc4onnx/onnx_structure_check.py @@ -72,6 +72,7 @@ def __init__(self, model: onnx.ModelProto): self.op_nums = defaultdict(int) self.op_bytesizes = defaultdict(int) self.model_params_size = 0 + self.model_params = 0 for graph_node in gs_graph.nodes: self.op_nums[graph_node.op] += 1 if hasattr(graph_node, 'attrs') \ @@ -80,6 +81,7 @@ def __init__(self, model: onnx.ModelProto): value: np.ndarray = graph_node.attrs['value'].values self.op_bytesizes[graph_node.op] += value.nbytes self.model_params_size += value.nbytes + self.model_params += np.prod(value.shape) if hasattr(graph_node, 'attrs') \ and len(graph_node.attrs) > 0: for key, value in graph_node.attrs.items(): @@ -91,6 +93,7 @@ def __init__(self, model: onnx.ModelProto): and isinstance(graph_node_input.values, np.ndarray): self.op_bytesizes[graph_node.op] += graph_node_input.values.nbytes self.model_params_size += graph_node_input.values.nbytes + self.model_params += np.prod(graph_node_input.values.shape) else: self.op_bytesizes[graph_node.op] += 0 self.model_size = model.ByteSize() @@ -169,13 +172,15 @@ def structure_check( table = Table() table.add_column('OP Type') table.add_column('OPs') - table.add_column('Params') + table.add_column('Sizes') sorted_list = sorted(list(set(model_info.op_nums.keys()))) sorted_bytes_list = sorted(list(set(model_info.op_bytesizes.keys()))) _ = [table.add_row(key1, f"{model_info.op_nums[key1]:,}", f"{human_readable_size(model_info.op_bytesizes[key2])}") for key1, key2 in zip(sorted_list, sorted_bytes_list)] table.add_row('----------------------', '----------', '----------') ops_count = sum([model_info.op_nums[key] for key in sorted_list]) table.add_row('Total number of OPs', f"{ops_count:,}") + table.add_row('----------------------', '----------', '----------') + table.add_row('Total params', f"{human_readable_size(model_info.model_params).replace('iB','')}") table.add_row('======================', '==========', '==========') table.add_row('Model Size', human_readable_size(model_info.model_size), human_readable_size(model_info.model_params_size)) rich_print(table)