-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinitialisation.py
82 lines (75 loc) · 3.81 KB
/
initialisation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import collections
import model as arch
def processor_init(model, model_state):
if isinstance(model, arch.NeuralExecutor2Transfer) or isinstance(model, arch.NeuralExecutor3Transfer):
model.reset_parameters()
proc_model_state = collections.OrderedDict()
term_model_state = collections.OrderedDict()
term_mpnn_model_state = collections.OrderedDict()
for k, v in model_state.items():
if k.startswith('processor'):
proc_model_state[k[10:]] = v
if k.startswith('termination'):
term_model_state[k[12:]] = v
if k.startswith('termination_mpnn'):
term_model_state[k[17:]] = v
model.processor.processors[0].load_state_dict(proc_model_state, strict=False)
model.processor.processors[0].requires_grad_(False)
model.termination.load_state_dict(term_model_state, strict=False)
model.termination_mpnn.load_state_dict(term_mpnn_model_state, strict=False)
else:
new_model_state = collections.OrderedDict()
for k, v in model_state.items():
if k.startswith('processor') or k.startswith('termination'):
new_model_state[k] = v
model.load_state_dict(new_model_state, strict=False)
if isinstance(model, arch.NeuralExecutor2Freeze) or isinstance(model, arch.NeuralExecutor3Freeze):
model.processor.requires_grad_(False)
return model
def merge_processor_init(init_model, target_model, joint_model):
if isinstance(joint_model, arch.NeuralExecutor2Transfer):
joint_model.reset_parameters()
## target model enc/dec + proc
joint_model.load_state_dict(target_model.state_dict(),strict=False)
proc_model_state = collections.OrderedDict()
for k, v in target_model.state_dict().items():
if k.startswith('processor'):
proc_model_state[k[10:]] = v
joint_model.processor.processors[1].load_state_dict(proc_model_state, strict=False)
joint_model.processor.processors[1].requires_grad_(True)
## init model proc + term (freeze)
proc_model_state = collections.OrderedDict()
term_model_state = collections.OrderedDict()
term_mpnn_model_state = collections.OrderedDict()
for k, v in init_model.state_dict().items():
if k.startswith('processor'):
proc_model_state[k[10:]] = v
if k.startswith('termination'):
term_model_state[k[12:]] = v
if k.startswith('termination_mpnn'):
term_model_state[k[17:]] = v
joint_model.processor.processors[0].load_state_dict(proc_model_state, strict=False)
joint_model.processor.processors[0].requires_grad_(False)
joint_model.termination.load_state_dict(term_model_state, strict=False)
joint_model.termination_mpnn.load_state_dict(term_mpnn_model_state, strict=False)
else:
raise NotImplementedError
return joint_model
def distill_processor(source_model, target_model):
if isinstance(target_model, arch.NeuralExecutor2):
target_model.reset_parameters()
encdec_model_state = collections.OrderedDict()
for k, v in source_model.state_dict().items():
if 'encoder' in k or 'decoder' in k or 'termination' in k or 'predecessor' in k:
encdec_model_state[k] = v
target_model.load_state_dict(encdec_model_state, strict=False)
target_model.node_encoder.requires_grad_(False)
target_model.edge_encoder.requires_grad_(False)
target_model.decoder.requires_grad_(False)
target_model.predecessor[0].requires_grad_(False)
target_model.termination.requires_grad_(False)
target_model.termination_mpnn.requires_grad_(False)
else:
raise NotImplementedError
return target_model