Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: galore optimizer #1370

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,6 @@ def parse_requirements():
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
"galore": ["galore_torch @ git+https://github.com/jiaweizzhao/GaLore"],
},
)
71 changes: 67 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import transformers
from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
Expand Down Expand Up @@ -633,7 +634,7 @@ def peft_config(self, peft_config):
self._peft_config = peft_config

@abstractmethod
def build(self, total_num_steps):
def build(self, total_num_steps, model):
pass

def get_callbacks(self) -> List[TrainerCallback]:
Expand Down Expand Up @@ -740,7 +741,7 @@ def _get_trainer_cls(self):
return AxolotlMambaTrainer
return AxolotlTrainer

def build(self, total_num_steps):
def build(self, total_num_steps, model):
warmup_steps = None
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
Expand Down Expand Up @@ -1037,6 +1038,68 @@ def build(self, total_num_steps):
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
elif self.cfg.optimizer in [
"galore_adamw",
"galore_adamw8bit",
"galore_ada_factor",
]:
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit

galore_params = []
for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue

if not any(
target_key in module_name
for target_key in self.cfg.galore_target_modules
):
continue

galore_params.append(module.weight)

id_galore_params = [id(p) for p in galore_params]
regular_params = [
p for p in model.parameters() if id(p) not in id_galore_params
]

param_groups = [
{"params": regular_params},
{
"params": galore_params,
"rank": self.cfg.galore_rank,
"update_proj_gap": self.cfg.galore_update_proj_gap,
"scale": self.cfg.galore_scale,
"proj_type": self.cfg.galore_proj_type,
},
]
if self.cfg.optimizer == "galore_ada_factor":
optimizer = GaLoreAdafactor(
param_groups,
lr=training_arguments_kwargs["learning_rate"],
weight_decay=training_arguments_kwargs.get("weight_decay", 0.0),
)
else:
galore_cls = GaLoreAdamW
if self.cfg.optimizer == "galore_adamw8bit":
galore_cls = GaLoreAdamW8bit
optimizer = galore_cls(
param_groups,
lr=training_arguments_kwargs["learning_rate"],
betas=(
training_arguments_kwargs.get("adam_beta1", 0.9),
training_arguments_kwargs.get("adam_beta2", 0.999),
),
eps=training_arguments_kwargs.get("adam_epsilon", 0.00000001),
weight_decay=training_arguments_kwargs.get("weight_decay", 0.0),
)

trainer_kwargs["optimizers"] = (
optimizer,
None,
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"

if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
Expand Down Expand Up @@ -1246,7 +1309,7 @@ def build_training_arguments(self, total_num_steps):

return training_args

def build(self, total_num_steps):
def build(self, total_num_steps, model):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
Expand Down Expand Up @@ -1297,6 +1360,6 @@ def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
return callbacks

def build(self, total_num_steps):
def build(self, total_num_steps, model):
# build PPOConfig
pass
22 changes: 21 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Module for pydantic models for configuration
"""

# pylint: disable=too-many-lines

import logging
import os
from enum import Enum
Expand Down Expand Up @@ -246,6 +248,16 @@ def validate_quantized_dora(cls, data):
return data


class GaloreConfig(BaseModel):
"""Galore optimizer configuration"""

galore_rank: Optional[int] = None
galore_update_proj_gap: Optional[int] = 50
galore_scale: Optional[float] = 1.0
galore_proj_type: Optional[str] = "std"
galore_target_modules: Optional[List[str]] = ["attn", "mlp"]


class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

Expand Down Expand Up @@ -305,7 +317,14 @@ class HyperparametersConfig(BaseModel):

learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
optimizer: Optional[
Union[
OptimizerNames,
Literal[
"lion_pytorch", "galore_adamw8bit", "galore_adamw", "galore_ada_factor"
],
]
] = None
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = None
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -388,6 +407,7 @@ class AxolotlInputConfig(
ModelOutputConfig,
LoraConfig,
ReLoRAConfig,
GaloreConfig,
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,4 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

return trainer_builder.build(total_num_steps)
return trainer_builder.build(total_num_steps, model[0])
Loading