Skip to content

Commit

Permalink
Added option to auto downcast
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Apr 6, 2022
1 parent e252555 commit 1a1e2cc
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ A very simple tool that compresses the overall size of the ONNX model by aggrega
- [ ] ~Finally, create a Fork of **[onnx-simplifier](https://github.com/daquexian/onnx-simplifier)** and merge this process just before the onnx file output process~ -> Temporarily abandoned because it turned out that the onnx-simplifier specification needed to be changed in a major way.
- [x] Implementation of a specification for separating the weight of a specified OP name to an external file.
- [ ] Implementation of a specification for separating the weight of a specified Constant name to an external file.
- [ ] Added option to downcast from Float64 to Float32 and INT64 to INT32 to attempt size compression.
- [x] Added option to downcast from Float64 to Float32 and INT64 to INT32 to attempt size compression.
- [ ] Final work-around idea for breaking the 2GB limit, since the internal logic of onnx has a Protocol Buffers limit of 2GB checked. Recombine after optimization. Splitting and merging seems like it would be easy. For each partitioned onnx component, optimization is performed in the order of onnx-simplifier → scs4onnx to optimize the structure while keeping the buffer size to a minimum, and then the optimized components are recombined to reconstruct the whole graph. Finally, run scs4onnx again on the reconstructed, optimized overall graph to further reduce the model-wide constant.


Expand Down
2 changes: 1 addition & 1 deletion scs4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from scs4onnx.onnx_shrink_constant import shrinking, main

__version__ = '1.0.8'
__version__ = '1.0.9'
22 changes: 21 additions & 1 deletion scs4onnx/onnx_shrink_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import sys
import traceback
from pprint import pprint
from argparse import ArgumentParser
import numpy as np
Expand Down Expand Up @@ -112,6 +113,22 @@ def shrinking(
continue
if np.isscalar(graph_node_input.values):
continue

# Try downcast
### INT64 -> INT32
if graph_node_input.values.dtype == np.int64:
orig = graph_node_input.values
dist = graph_node_input.values.astype(np.int32)
if (orig == dist).all():
graph_node_input.values = dist

### Float64 -> Float32
if graph_node_input.values.dtype == np.float64:
orig = graph_node_input.values
dist = graph_node_input.values.astype(np.float32)
if (orig == dist).all():
graph_node_input.values = dist

constants[graph_node_input.name] = graph_node_input
if not non_verbose:
print(
Expand Down Expand Up @@ -251,14 +268,16 @@ def shrinking(
new_model = None
try:
new_model = onnx.shape_inference.infer_shapes(shrunken_graph)
except:
except Exception as e:
new_model = shrunken_graph
if not non_verbose:
print(
f'{Color.YELLOW}WARNING:{Color.RESET} '+
'The input shape of the next OP does not match the output shape. '+
'Be sure to open the .onnx file to verify the certainty of the geometry.'
)
tracetxt = traceback.format_exc().splitlines()[-1]
print(f'{Color.YELLOW}WARNING:{Color.RESET} {tracetxt}')

# Save
if output_onnx_file_path:
Expand Down Expand Up @@ -299,6 +318,7 @@ def main():
parser.add_argument(
'--forced_extraction_op_names',
type=str,
default='',
help="\
Extracts the constant value of the specified OP name to .npy \
regardless of the mode specified. \
Expand Down

0 comments on commit 1a1e2cc

Please sign in to comment.