-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathload_model.py
149 lines (127 loc) · 5.21 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import torch
from torch import nn
import yaml
import sys
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(SCRIPT_DIR)
from model import ft_net, ft_net_dense, ft_net_hr, ft_net_swin, ft_net_efficient, ft_net_NAS, PCB
sys.path.remove(SCRIPT_DIR)
def load_weights(model, ckpt_path):
"""Loads weights of the model from a checkpoint file
Paremeters
----------
model: torch.nn.Module
Model to load weights of (needs to have a model.classifier head).
ckpt_path: str
Path to the checkpoint file to load (e.g net_X.pth).
Returns
-------
model: torch.nn.Module
The model object with the loaded weights.
"""
state = torch.load(ckpt_path, map_location="cpu")
if model.classifier.classifier[0].weight.shape != state["classifier.classifier.0.weight"].shape:
state["classifier.classifier.0.weight"] = model.classifier.classifier[0].weight
state["classifier.classifier.0.bias"] = model.classifier.classifier[0].bias
model.load_state_dict(state)
return model
def create_model(n_classes, kind="resnet", **kwargs):
"""Creates a model of a given kind and number of classes"""
if kind == "resnet":
return ft_net(n_classes, **kwargs)
elif kind == "densenet":
return ft_net_dense(n_classes, **kwargs)
elif kind == "hr":
return ft_net_hr(n_classes, **kwargs)
elif kind == "efficientnet":
return ft_net_efficient(n_classes, **kwargs)
elif kind == "NAS":
return ft_net_NAS(n_classes, **kwargs)
elif kind == "swin":
return ft_net_swin(n_classes, **kwargs)
elif kind == "PCB":
return PCB(n_classes)
else:
raise ValueError("Model type cannot be created: {}".format(kind))
def load_model(n_classes, kind="resnet", ckpt=None, remove_classifier=False, **kwargs):
"""Loads a model of a given type and number of classes.
Parameters
----------
n_classes: int
Number of classes at the head.
kind: str
Type of the model ('resnet', 'efficientnet', 'densenet', 'hr', 'swin', 'NAS', 'PCB').
ckpt: Union[str, None]
Path to the checkpoint to load or None.
remove_classifier: bool
Whether or not to remove the classifier head.
**kwargs: params to pass to the model
Returns
-------
model: torch.nn.Module
"""
model = create_model(n_classes, kind, **kwargs)
if ckpt:
model = load_weights(model, ckpt)
if remove_classifier:
model.classifier.classifier = nn.Sequential()
model.eval()
return model
def load_model_from_opts(opts_file, ckpt=None, return_feature=False, remove_classifier=False):
"""Loads a saved model by reading its opts.yaml file.
Parameters
----------
opts_file: str
Path to the saved opts.yaml file of the model
ckpt: str
Path to the saved checkpoint of the model (net_X.pth)
return_feature: bool
Shows whether the model has to return the feature along with the result in the forward
function. This is needed for certain loss functions (circle loss).
remove_classifier: bool
Whether we have to remove the classifier block from the model, which is needed for
training but not for evaluation
Returns
-------
model: torch.nn.Module
The model requested to be loaded.
"""
with open(opts_file, "r") as stream:
opts = yaml.load(stream, Loader=yaml.FullLoader)
n_classes = opts["nclasses"]
droprate = opts["droprate"]
stride = opts["stride"]
linear_num = opts["linear_num"]
model_subtype = opts.get("model_subtype", "default")
model_type = opts.get("model", "resnet_ibn")
mixstyle = opts.get("mixstyle", False)
if model_type in ("resnet", "resnet_ibn"):
model = create_model(n_classes, "resnet", droprate=droprate, ibn=(model_type == "resnet_ibn"),
stride=stride, circle=return_feature, linear_num=linear_num,
model_subtype=model_subtype, mixstyle=mixstyle)
elif model_type == "densenet":
model = create_model(n_classes, "densenet", droprate=droprate, circle=return_feature,
linear_num=linear_num)
elif model_type == "efficientnet":
model = create_model(n_classes, "efficientnet", droprate=droprate,
circle=return_feature, linear_num=linear_num, model_subtype=model_subtype)
elif model_type == "NAS":
model = create_model(n_classes, "NAS", droprate=droprate,
linear_num=linear_num)
elif model_type == "PCB":
model = create_model(n_classes, "PCB")
elif model_type == "hr":
model = create_model(n_classes, "hr", droprate=droprate, circle=return_feature,
linear_num=linear_num)
elif model_type == "swin":
model = create_model(n_classes, "swin", droprate=droprate, stride=stride,
circle=return_feature, linear_num=linear_num)
else:
raise ValueError("Unsupported model type: {}".format(model_type))
if ckpt:
load_weights(model, ckpt)
if remove_classifier:
model.classifier.classifier = nn.Sequential()
model.eval()
return model