-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathinner_loop_optimizers.py
114 lines (100 loc) · 5.22 KB
/
inner_loop_optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import logging
import os
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class GradientDescentLearningRule(nn.Module):
"""Simple (stochastic) gradient descent learning rule.
For a scalar error function `E(p[0], p_[1] ... )` of some set of
potentially multidimensional parameters this attempts to find a local
minimum of the loss function by applying updates to each parameter of the
form
p[i] := p[i] - learning_rate * dE/dp[i]
With `learning_rate` a positive scaling parameter.
The error function used in successive applications of these updates may be
a stochastic estimator of the true error function (e.g. when the error with
respect to only a subset of data-points is calculated) in which case this
will correspond to a stochastic gradient descent learning rule.
"""
def __init__(self, device, learning_rate=1e-3):
"""Creates a new learning rule object.
Args:
learning_rate: A postive scalar to scale gradient updates to the
parameters by. This needs to be carefully set - if too large
the learning dynamic will be unstable and may diverge, while
if set too small learning will proceed very slowly.
"""
super(GradientDescentLearningRule, self).__init__()
assert learning_rate > 0., 'learning_rate should be positive.'
self.learning_rate = torch.ones(1) * learning_rate
self.learning_rate.to(device)
def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.9):
"""Applies a single gradient descent update to all parameters.
All parameter updates are performed using in-place operations and so
nothing is returned.
Args:
grads_wrt_params: A list of gradients of the scalar loss function
with respect to each of the parameters passed to `initialise`
previously, with this list expected to be in the same order.
"""
return {
key: names_weights_dict[key]
- self.learning_rate * names_grads_wrt_params_dict[key]
for key in names_weights_dict.keys()
}
class LSLRGradientDescentLearningRule(nn.Module):
"""Simple (stochastic) gradient descent learning rule.
For a scalar error function `E(p[0], p_[1] ... )` of some set of
potentially multidimensional parameters this attempts to find a local
minimum of the loss function by applying updates to each parameter of the
form
p[i] := p[i] - learning_rate * dE/dp[i]
With `learning_rate` a positive scaling parameter.
The error function used in successive applications of these updates may be
a stochastic estimator of the true error function (e.g. when the error with
respect to only a subset of data-points is calculated) in which case this
will correspond to a stochastic gradient descent learning rule.
"""
def __init__(self, device, total_num_inner_loop_steps, use_learnable_learning_rates, init_learning_rate=1e-3):
"""Creates a new learning rule object.
Args:
init_learning_rate: A postive scalar to scale gradient updates to the
parameters by. This needs to be carefully set - if too large
the learning dynamic will be unstable and may diverge, while
if set too small learning will proceed very slowly.
"""
super(LSLRGradientDescentLearningRule, self).__init__()
print(init_learning_rate)
assert init_learning_rate > 0., 'learning_rate should be positive.'
self.init_learning_rate = torch.ones(1) * init_learning_rate
self.init_learning_rate.to(device)
self.total_num_inner_loop_steps = total_num_inner_loop_steps
self.use_learnable_learning_rates = use_learnable_learning_rates
def initialise(self, names_weights_dict):
self.names_learning_rates_dict = nn.ParameterDict()
for idx, (key, param) in enumerate(names_weights_dict.items()):
self.names_learning_rates_dict[key.replace(".", "-")] = nn.Parameter(
data=torch.ones(self.total_num_inner_loop_steps + 1) * self.init_learning_rate,
requires_grad=self.use_learnable_learning_rates)
def reset(self):
# for key, param in self.names_learning_rates_dict.items():
# param.fill_(self.init_learning_rate)
pass
def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.1):
"""Applies a single gradient descent update to all parameters.
All parameter updates are performed using in-place operations and so
nothing is returned.
Args:
grads_wrt_params: A list of gradients of the scalar loss function
with respect to each of the parameters passed to `initialise`
previously, with this list expected to be in the same order.
"""
return {
key: names_weights_dict[key]
- self.names_learning_rates_dict[key.replace(".", "-")][num_step]
* names_grads_wrt_params_dict[key]
for key in names_grads_wrt_params_dict.keys()
}