-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathexport_onnx.py
118 lines (95 loc) · 3.22 KB
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
from typing import Dict
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
import models
DEVICE = torch.device('cpu')
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class OnnxExportModel(nn.Module):
def __init__(self,model):
super(OnnxExportModel, self).__init__()
self.model = model
def forward(self, mel):
# mel = mel.permute(0,2,1)
# If necessary, transpose can be performed here, switching the feature dimensions with the time dimensions.
# If transposed, the following dynamic axes should also be interchanged.
feat = self.model.forward_spectrogram(mel)
return feat
@torch.no_grad()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'-m',
'--model',
type=str,
metavar=
f"Public Checkpoint [{','.join(models.list_models())}] or Experiment Path",
nargs='?',
choices=models.list_models(),
default='ced_mini')
parser.add_argument(
'--max-frames',
type=int,
default=1012,
help="Max number of frames the model can process."
)
args = parser.parse_args()
model = getattr(models, args.model)(target_length=args.max_frames,
pretrained=True)
dummy_input = torch.ones(1, model.n_mels, args.max_frames)
model = model.to(DEVICE).eval()
model = OnnxExportModel(model)
out = model(dummy_input)
print(f"Model Output is {out.shape}")
output_model = args.model + '.onnx'
torch.onnx.export(model,
dummy_input,
output_model,
do_constant_folding=True,
verbose=False,
opset_version=12,
input_names=['feats'],
output_names=['prob'],
dynamic_axes={
'feats': {
0: 'batch_size',
2: 'time_dim'
},
'prob': {
0: 'batch_size'
}
})
meta_data = {
"model_type": "CED",
"version": "1.0",
"model_author": "RicherMans",
"url": "https://github.com/RicherMans/CED",
}
add_meta_data(filename=output_model, meta_data=meta_data)
print("Generate int8 quantization models")
filename_int8 = args.model + ".int8.onnx"
quantize_dynamic(
model_input=output_model,
model_output=filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
print(f"Results is at {filename_int8}")
# ced_mini onnx-39.1mb int8onnx-9.8mb
if __name__ == "__main__":
main()