-
Notifications
You must be signed in to change notification settings - Fork 152
/
Copy pathadamw8bit.py
84 lines (62 loc) · 3.35 KB
/
adamw8bit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from bitsandbytes.optim.optimizer import Optimizer2State
import torch
from .galore_projector import GaLoreProjector
from .galore_projector_tensor import GaLoreProjectorTensor
class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
overflows = []
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
#if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
if "step" not in state:
state["step"] = 0
if 'dim' not in group:
group['dim'] = 2
# GaLore Projection
if "rank" in group:
if "projector" not in state:
if group['dim'] <= 2:
state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
else:
state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
grad = state["projector"].project(p.grad, state["step"])
lor_update = torch.zeros_like(
grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad
)
p.grad = grad
if 'state1' not in state:
self.init_state(group, p, gindex, pindex)
self.prefetch_state(p)
if "rank" in group:
self.update_step(group, p, gindex, pindex, return_updates=lor_update)
# GaLore Projection Back
p.data.add_(state["projector"].project_back(lor_update))
if "weight_decay" in group and group["weight_decay"] > 0:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
else:
self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss