Skip to content

Commit

Permalink
addressed code review changes
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Sep 10, 2024
1 parent 741e58f commit 870ea03
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def model_loader(self, model_name: str, **kwargs):
# - in particular "is_loaded_in_4bit" will be checked in prepare_model_for_kbit_training
# and there is a section of code that will be skipped if not set.
setattr(model, "is_loaded_in_4bit", True)
setattr(model, "quantization_method", "gptq")

# Need to set this on model.model instead of model, applies to all versionsof gptq
# as this attribute is accessed later on from model.model
setattr(model.model, "quantization_method", "gptq")
return model

@property
Expand Down
1 change: 0 additions & 1 deletion plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ classifiers=[
dependencies = [
"numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3
"torch>2.2",
"git+https://github.com/huggingface/transformers.git@9230d78e76611cfa38c845213021aeb185362d10",
"peft",
"accelerate",
"pandas",
Expand Down
4 changes: 0 additions & 4 deletions plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ Path | Description | Extracted From | Modifications | Date
[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`<br>`triton/layers.py` | 6 Feb 2024
[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`<br>`rms_layernorm.py` | 28 Jan 2024
## Code Extracted from Liger
- TODO
## Known Issues
- MixedPrecision `--fp16` should be used `fast_lora`. Also consider loading the model in `torch.float16`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from transformers import TrainingArguments
import torch

from .framework_plugin_fast_quantized_peft import lora_adapters_switch_ddp_from_fsdp

# consider rewriting register_foak_model_patch_rules into something
# like this also
def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = None):
Expand Down Expand Up @@ -59,7 +61,6 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =

# maybe this we should define envvars
FILTER_MAP = {
"base_layer": set(),
"fused_lora": {"qkvo", "mlp"},
"fast_loss": "cross-ent",
"fast_rsm_layernorm": "rms",
Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(self, configurations: Dict[str, Dict]):
key="peft.quantization.fused_ops_and_kernels",
)

self._check_config_and_maybe_check_values(
self.configurations["base_layer"] = self._check_config_and_maybe_check_values(
key="base_layer",
values=["auto_gptq", "bitsandbytes"],
default="auto_gptq"
Expand Down Expand Up @@ -124,8 +125,8 @@ def augmentation(
# will still be installed but never triggered
# if no peft layer is detected at the point of patching
terms = set()
for k, v in self.configurations.items():
if v:
for k in self.configurations:
if k in FILTER_MAP:
ts = FILTER_MAP[k]
if isinstance(ts, str):
ts = {ts}
Expand All @@ -134,13 +135,37 @@ def augmentation(
# wrapper function to register foak patches
# NOTE: we never take the lora modules so just set arbitrarily
# to "auto_gptq"
_base_layer = self.configurations['base_layer'] if 'base_layer' \
in self.configurations else 'auto_gptq'
_base_layer = self.configurations.get('base_layer', None)
register_foak_model_patch_rules2(
base_type=_base_layer, filter_endswith=terms
)
return model, modifiable_args

def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator=None
):
# This callback applies only for qpeft
# should not install this for full FT and standard peft
is_quantized = getattr(model, "quantization_method", None)
callbacks = []
if (
accelerator is not None
and getattr(accelerator.state, "fsdp_plugin", None) is not None
and is_quantized is not None
):
# This function installs grad reduction hooks on adapters if
# FSDP is detected. Because of incompatibility between FSDP and
# fused modules, adapters are not sharded - instead
# accumulated gradients from adapters in each device are reduced
# in these grad reduce hooks
# This function might be removed in future if the incompatiblity
# is resolved
lora_adapters_switch_ddp_from_fsdp(
[mod for mod in model.modules() if isinstance(mod, LoraLayer)],
accelerator.state.fsdp_plugin,
)
return callbacks

# register
AccelerationPlugin.register_plugin(
FastKernelsAccelerationPlugin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
combine_functions,
combine_triggers,
)
from transformers.models.granite.modeling_granite import (
GraniteAttention,
GraniteMLP,
GraniteRMSNorm,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
Expand All @@ -42,6 +37,15 @@ def get_mp_rules(base_type: str):
its forward builder argument, wrap the forward_builder
function as a partial function with the base_type argument
"""
try:
from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel
GraniteAttention,
GraniteMLP,
GraniteRMSNorm,
)
except ImportError:
return []

return [
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
Expand Down
32 changes: 16 additions & 16 deletions scripts/benchmarks/scenarios.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ scenarios:
arguments:
learning_rate: 2e-5
model_name_or_path:
# - 'ibm/PowerLM-3b'
- 'bigcode/gpt_bigcode-santacoder'
- 'ibm/PowerLM-3b'
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
# - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
# - 'NousResearch/Llama-2-70b-hf'
torch_dtype: bfloat16

- name: standard-peft
Expand All @@ -64,14 +64,14 @@ scenarios:
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
# - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
# - 'NousResearch/Llama-2-70b-hf'

- name: baseline-peft-bnb
framework_config:
- baseline-peft-bnb
arguments:
fp16: True
bf16: True
learning_rate: 2e-4
torch_dtype: bfloat16
peft_method: lora
Expand All @@ -81,36 +81,36 @@ scenarios:
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
# - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
# - 'NousResearch/Llama-2-70b-hf'

- name: accelerated-peft-bnb
framework_config:
- accelerated-peft-bnb
- accelerated-peft-bnb-foak
arguments:
fp16: True
bf16: True
learning_rate: 2e-4
torch_dtype: bfloat16
peft_method: lora
r: 16
lora_alpha: 16
lora_dropout: 0.1
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn"]
model_name_or_path:
- 'ibm/PowerLM-3b'
# - 'ibm/PowerLM-3b'
- 'bigcode/gpt_bigcode-santacoder'
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
# - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
# - 'NousResearch/Llama-2-70b-hf'

- name: accelerated-peft-gptq
framework_config:
- accelerated-peft-autogptq
- accelerated-peft-autogptq-foak
arguments:
learning_rate: 2e-4
fp16: True
bf16: True
torch_dtype: bfloat16
peft_method: lora
r: 16
Expand All @@ -119,5 +119,5 @@ scenarios:
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
- 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
- 'TheBloke/Llama-2-70B-GPTQ'
# - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
# - 'TheBloke/Llama-2-70B-GPTQ'
2 changes: 1 addition & 1 deletion scripts/generate_sample_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def read_configuration(path: str) -> Dict:
),
KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/padding_free.yaml",
KEY_AADP_MULTIPACK: "plugins/attention-and-distributed-packing/configs/multipack.yaml",
KEY_FAST_KERNELS: "plugins/fused-ops-and-kernels/configs/fast_full.yaml",
KEY_FAST_KERNELS: "plugins/fused-ops-and-kernels/configs/fast_kernels.yaml",
}

# list of (tag, combi) tuples
Expand Down

0 comments on commit 870ea03

Please sign in to comment.