-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadam_reset.py
61 lines (49 loc) · 1.7 KB
/
adam_reset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from nqgl.sae.hsae.hsae import HierarchicalAutoEncoder, HierarchicalAutoEncoderConfig
class AdamResetter:
def __init__(self, param):
self.param = param
def __getitem__(self, indices):
return AdamResetterCallable(self.param, indices)
class AdamResetterCallable:
def __init__(self, param, indices):
self.param = param
self.indices = indices
def __call__(self, adam: torch.optim.Adam):
state = adam.state[self.param]
state["exp_avg"][self.indices] = 0 # zero momentum
eps = 1e-5
ratio = 0.99
state["exp_avg_sq"][self.indices] = (
(
eps
+ torch.sum(state["exp_avg_sq"])
- torch.sum(state["exp_avg_sq"][self.indices] * ratio)
) / (eps + state["exp_avg_sq"].numel() - self.indices.numel() * ratio)
)
# leave step as is
def reset_adam(adam: torch.optim.Adam, param, indices):
state = adam.state[p]
state["exp_avg"]
exp_avg_sq
step
# amsgrad?
# if so, then max_exp_avg_sq
def main():
cfg = HierarchicalAutoEncoderConfig(d_data=4)
hsae = HierarchicalAutoEncoder(cfg)
adam = torch.optim.Adam(hsae.parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2))
groups = adam.param_groups
# print(groups[0].keys())
x = torch.randn(10, 4, device="cuda")
x_reconstruct = hsae(x)
loss = hsae.get_loss()
loss.backward()
adam.step()
for key in adam.state[hsae.layers[0].b_enc].keys():
print(key)
print(adam.state[hsae.layers[0].b_enc][key].shape)
# for p in groups[0]["params"]:
# print(p.shape)
if __name__ == "__main__":
main()