Skip to content

Commit

Permalink
feat: Init Sophia Optimizer (using Gauss-Newton-Barlett, Sophia-G)
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsNiklas committed Jun 16, 2023
1 parent ccc4cdc commit 17bbbc7
Showing 1 changed file with 71 additions and 14 deletions.
85 changes: 71 additions & 14 deletions optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,33 @@

class AdamW(Optimizer):
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
correct_bias=correct_bias,
)
super().__init__(params, defaults)

def step(self, closure: Callable = None):
Expand All @@ -40,7 +50,9 @@ def step(self, closure: Callable = None):
grad = p.grad.data

if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)

# State should be stored in this dictionary
state = self.state[p]
Expand All @@ -60,10 +72,14 @@ def step(self, closure: Callable = None):
state["t"] = torch.tensor([0]).to(device)

if "m" not in state:
state["m"] = torch.zeros(size=grad.size(), dtype=grad.dtype).to(device)
state["m"] = torch.zeros(size=grad.size(), dtype=grad.dtype).to(
device
)

if "v" not in state:
state["v"] = torch.zeros(size=grad.size(), dtype=grad.dtype).to(device)
state["v"] = torch.zeros(size=grad.size(), dtype=grad.dtype).to(
device
)

state["t"] += 1

Expand All @@ -83,7 +99,11 @@ def step(self, closure: Callable = None):
# (using the "efficient version" given in https://arxiv.org/abs/1412.6980;
# also given in the pseudo-code in the project description).
if correct_bias:
alpha = alpha * torch.sqrt(1 - beta_2 ** state["t"]) / (1 - beta_1 ** state["t"])
alpha = (
alpha
* torch.sqrt(1 - beta_2 ** state["t"])
/ (1 - beta_1 ** state["t"])
)

# 3- Update parameters (p.data).

Expand All @@ -95,3 +115,40 @@ def step(self, closure: Callable = None):
p.data = p.data - group["lr"] * p.data * weight_decay

return loss


class SophiaG(Optimizer):
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-4,
betas: Tuple[float, float] = (0.965, 0.99),
rho: float = 0.04,
weight_decay: float = 0.1,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])
)
if not 0.0 <= rho:
raise ValueError("Invalid rho value: {} - should be >= 0.0".format(rho))
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {} - should be >= 0.0".format(weight_decay)
)
defaults = dict(
lr=lr,
betas=betas,
rho=rho,
weight_decay=weight_decay,
)
super(SophiaG, self).__init__(params, defaults)

def step(self, closure: Callable = None, bs: int = 5120):
pass

0 comments on commit 17bbbc7

Please sign in to comment.