Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added AdamW fused. Made fused logic more generic. #1

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions library/adafactor_fused.py → library/optimizer_fused.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import torch
from transformers import Adafactor
from transformers import Adafactor, AdamW

@torch.no_grad()
def adafactor_step_param(self, p, group):
Expand Down Expand Up @@ -81,9 +81,57 @@ def adafactor_step_param(self, p, group):
if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)

@torch.no_grad()
def adamw_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

state = self.state[p]

# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]

state["step"] += 1

# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])

step_size = group["lr"]
# if group["correct_bias"]: # No bias correction for Bert
# bias_correction1 = 1.0 - beta1 ** state["step"]
# bias_correction2 = 1.0 - beta2 ** state["step"]
# step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

p.addcdiv_(exp_avg, denom, value=-step_size)

# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))


@torch.no_grad()
def adafactor_step(self, closure=None):
def optimizer_step(self, optimizer_step_param, closure=None):
"""
Performs a single optimization step

Expand All @@ -97,10 +145,18 @@ def adafactor_step(self, closure=None):

for group in self.param_groups:
for p in group["params"]:
adafactor_step_param(self, p, group)
optimizer_step_param(self, p, group)

return loss

def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)
def patch_optimizer_fused(optimizer, optimizer_type):
print(type(optimizer))
if optimizer_type.lower()=='adamw':
print("Using AdamW Fused")
optimizer.step_param = adamw_step_param.__get__(optimizer)
optimizer.step = optimizer_step.__get__(optimizer, adamw_step_param)
if optimizer_type.lower()=='adafactor':
print("Using Adafactor Fused")
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = optimizer_step.__get__(optimizer, adafactor_step_param)

5 changes: 3 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3852,9 +3852,10 @@ def get_optimizer(args, trainable_params):
optimizer_type = optimizer_type.lower()

if args.fused_backward_pass:
accepted_optimizers=["Adafactor","AdamW"]
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
optimizer_type in [optimizer.lower() for optimizer in accepted_optimizers]
), f"fused_backward_pass currently only works with optimizer_type in {accepted_optimizers} / fused_backward_passは現在optimizer_type {accepted_optimizers}でのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

if args.fused_backward_pass:
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
import library.optimizer_fused
library.optimizer_fused.patch_optimizer_fused(optimizer, args.optimizer_type)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
Expand Down