-
Notifications
You must be signed in to change notification settings - Fork 1
/
optimizer.py
126 lines (104 loc) · 5.7 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import math
import torch
from torch.optim.optimizer import Optimizer
class DenseSparseAdam(Optimizer):
"""
NOTE: This class has been copied from AllenNLP repository:
https://github.com/allenai/allennlp/blob/main/allennlp/training/optimizers.py
Implements Adam algorithm with dense & sparse gradients.
It has been proposed in Adam: A Method for Stochastic Optimization.
Registered as an `Optimizer` with name "dense_sparse_adam".
# Parameters
params : `iterable`
iterable of parameters to optimize or dicts defining parameter groups
lr : `float`, optional (default = `1e-3`)
The learning rate.
betas : `Tuple[float, float]`, optional (default = `(0.9, 0.999)`)
coefficients used for computing running averages of gradient
and its square.
eps : `float`, optional, (default = `1e-8`)
A term added to the denominator to improve numerical stability.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(DenseSparseAdam, self).__init__(params, defaults)
def step(self, closure=None):
"""
Performs a single optimization step.
Parameters
----------
closure : ``callable``, optional.
A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if 'step' not in state:
state['step'] = 0
if 'exp_avg' not in state:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
if 'exp_avg_sq' not in state:
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['step'] += 1
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
weight_decay = group['weight_decay']
if grad.is_sparse:
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()
def make_sparse(values):
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)
# Decay the first and second moment running average coefficient
# old <- b * old + (1 - b) * new
# <==> old += (1 - b) * (new - old)
old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
exp_avg.add_(make_sparse(exp_avg_update_values))
old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
# Dense addition again is intended, avoiding another sparse_mask
numer = exp_avg_update_values.add_(old_exp_avg_values)
exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
denom = exp_avg_sq_update_values.sqrt_().add_(group['eps'])
del exp_avg_update_values, exp_avg_sq_update_values
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.add_(make_sparse(-step_size * numer.div_(denom)))
if weight_decay > 0.0:
p.data.add_(-group['lr'] * weight_decay, p.data.sparse_mask(grad))
else:
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
if weight_decay > 0.0:
p.data.add_(-group['lr'] * weight_decay, p.data)
return loss