diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
index 89fdee46..ca8e48f5 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
@@ -230,8 +230,9 @@ def model_loader(self, model_name: str, **kwargs):
# - in particular "is_loaded_in_4bit" will be checked in prepare_model_for_kbit_training
# and there is a section of code that will be skipped if not set.
setattr(model, "is_loaded_in_4bit", True)
- setattr(model, "quantization_method", "gptq")
-
+ # Need to set this on model.model instead of model, applies to all versionsof gptq
+ # as this attribute is accessed later on from model.model
+ setattr(model.model, "quantization_method", "gptq")
return model
@property
diff --git a/plugins/framework/pyproject.toml b/plugins/framework/pyproject.toml
index aefa9fe1..f60ca1a3 100644
--- a/plugins/framework/pyproject.toml
+++ b/plugins/framework/pyproject.toml
@@ -24,7 +24,6 @@ classifiers=[
dependencies = [
"numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3
"torch>2.2",
- "git+https://github.com/huggingface/transformers.git@9230d78e76611cfa38c845213021aeb185362d10",
"peft",
"accelerate",
"pandas",
diff --git a/plugins/fused-ops-and-kernels/README.md b/plugins/fused-ops-and-kernels/README.md
index 2f17ddfc..9add3339 100644
--- a/plugins/fused-ops-and-kernels/README.md
+++ b/plugins/fused-ops-and-kernels/README.md
@@ -39,10 +39,6 @@ Path | Description | Extracted From | Modifications | Date
[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`
`triton/layers.py` | 6 Feb 2024
[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`
`rms_layernorm.py` | 28 Jan 2024
-
-## Code Extracted from Liger
-- TODO
-
## Known Issues
- MixedPrecision `--fp16` should be used `fast_lora`. Also consider loading the model in `torch.float16`.
diff --git a/plugins/fused-ops-and-kernels/configs/fast_full.yaml b/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml
similarity index 100%
rename from plugins/fused-ops-and-kernels/configs/fast_full.yaml
rename to plugins/fused-ops-and-kernels/configs/fast_kernels.yaml
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py
index 17305bec..f11687cc 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py
@@ -22,6 +22,8 @@
from transformers import TrainingArguments
import torch
+from .framework_plugin_fast_quantized_peft import lora_adapters_switch_ddp_from_fsdp
+
# consider rewriting register_foak_model_patch_rules into something
# like this also
def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = None):
@@ -59,7 +61,6 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
# maybe this we should define envvars
FILTER_MAP = {
- "base_layer": set(),
"fused_lora": {"qkvo", "mlp"},
"fast_loss": "cross-ent",
"fast_rsm_layernorm": "rms",
@@ -91,7 +92,7 @@ def __init__(self, configurations: Dict[str, Dict]):
key="peft.quantization.fused_ops_and_kernels",
)
- self._check_config_and_maybe_check_values(
+ self.configurations["base_layer"] = self._check_config_and_maybe_check_values(
key="base_layer",
values=["auto_gptq", "bitsandbytes"],
default="auto_gptq"
@@ -124,8 +125,8 @@ def augmentation(
# will still be installed but never triggered
# if no peft layer is detected at the point of patching
terms = set()
- for k, v in self.configurations.items():
- if v:
+ for k in self.configurations:
+ if k in FILTER_MAP:
ts = FILTER_MAP[k]
if isinstance(ts, str):
ts = {ts}
@@ -134,13 +135,37 @@ def augmentation(
# wrapper function to register foak patches
# NOTE: we never take the lora modules so just set arbitrarily
# to "auto_gptq"
- _base_layer = self.configurations['base_layer'] if 'base_layer' \
- in self.configurations else 'auto_gptq'
+ _base_layer = self.configurations.get('base_layer', None)
register_foak_model_patch_rules2(
base_type=_base_layer, filter_endswith=terms
)
return model, modifiable_args
+ def get_callbacks_and_ready_for_train(
+ self, model: torch.nn.Module = None, accelerator=None
+ ):
+ # This callback applies only for qpeft
+ # should not install this for full FT and standard peft
+ is_quantized = getattr(model, "quantization_method", None)
+ callbacks = []
+ if (
+ accelerator is not None
+ and getattr(accelerator.state, "fsdp_plugin", None) is not None
+ and is_quantized is not None
+ ):
+ # This function installs grad reduction hooks on adapters if
+ # FSDP is detected. Because of incompatibility between FSDP and
+ # fused modules, adapters are not sharded - instead
+ # accumulated gradients from adapters in each device are reduced
+ # in these grad reduce hooks
+ # This function might be removed in future if the incompatiblity
+ # is resolved
+ lora_adapters_switch_ddp_from_fsdp(
+ [mod for mod in model.modules() if isinstance(mod, LoraLayer)],
+ accelerator.state.fsdp_plugin,
+ )
+ return callbacks
+
# register
AccelerationPlugin.register_plugin(
FastKernelsAccelerationPlugin,
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py
index 778b6211..e5307791 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py
@@ -22,11 +22,6 @@
combine_functions,
combine_triggers,
)
-from transformers.models.granite.modeling_granite import (
- GraniteAttention,
- GraniteMLP,
- GraniteRMSNorm,
-)
# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
@@ -42,6 +37,15 @@ def get_mp_rules(base_type: str):
its forward builder argument, wrap the forward_builder
function as a partial function with the base_type argument
"""
+ try:
+ from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel
+ GraniteAttention,
+ GraniteMLP,
+ GraniteRMSNorm,
+ )
+ except ImportError:
+ return []
+
return [
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml
index 89502227..13c5356d 100644
--- a/scripts/benchmarks/scenarios.yaml
+++ b/scripts/benchmarks/scenarios.yaml
@@ -43,11 +43,11 @@ scenarios:
arguments:
learning_rate: 2e-5
model_name_or_path:
+ # - 'ibm/PowerLM-3b'
- 'bigcode/gpt_bigcode-santacoder'
- - 'ibm/PowerLM-3b'
- 'mistralai/Mistral-7B-v0.1'
- - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- - 'NousResearch/Llama-2-70b-hf'
+ # - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
+ # - 'NousResearch/Llama-2-70b-hf'
torch_dtype: bfloat16
- name: standard-peft
@@ -64,14 +64,14 @@ scenarios:
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
- - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- - 'NousResearch/Llama-2-70b-hf'
+ # - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
+ # - 'NousResearch/Llama-2-70b-hf'
- name: baseline-peft-bnb
framework_config:
- baseline-peft-bnb
arguments:
- fp16: True
+ bf16: True
learning_rate: 2e-4
torch_dtype: bfloat16
peft_method: lora
@@ -81,28 +81,28 @@ scenarios:
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
- - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- - 'NousResearch/Llama-2-70b-hf'
+ # - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
+ # - 'NousResearch/Llama-2-70b-hf'
- name: accelerated-peft-bnb
framework_config:
- accelerated-peft-bnb
- accelerated-peft-bnb-foak
arguments:
- fp16: True
+ bf16: True
learning_rate: 2e-4
torch_dtype: bfloat16
peft_method: lora
r: 16
lora_alpha: 16
lora_dropout: 0.1
- target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn"]
model_name_or_path:
- - 'ibm/PowerLM-3b'
+ # - 'ibm/PowerLM-3b'
- 'bigcode/gpt_bigcode-santacoder'
- 'mistralai/Mistral-7B-v0.1'
- - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- - 'NousResearch/Llama-2-70b-hf'
+ # - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
+ # - 'NousResearch/Llama-2-70b-hf'
- name: accelerated-peft-gptq
framework_config:
@@ -110,7 +110,7 @@ scenarios:
- accelerated-peft-autogptq-foak
arguments:
learning_rate: 2e-4
- fp16: True
+ bf16: True
torch_dtype: bfloat16
peft_method: lora
r: 16
@@ -119,5 +119,5 @@ scenarios:
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
- - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
- - 'TheBloke/Llama-2-70B-GPTQ'
+ # - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
+ # - 'TheBloke/Llama-2-70B-GPTQ'
diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py
index 27cd3df2..11619106 100644
--- a/scripts/generate_sample_configurations.py
+++ b/scripts/generate_sample_configurations.py
@@ -172,7 +172,7 @@ def read_configuration(path: str) -> Dict:
),
KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/padding_free.yaml",
KEY_AADP_MULTIPACK: "plugins/attention-and-distributed-packing/configs/multipack.yaml",
- KEY_FAST_KERNELS: "plugins/fused-ops-and-kernels/configs/fast_full.yaml",
+ KEY_FAST_KERNELS: "plugins/fused-ops-and-kernels/configs/fast_kernels.yaml",
}
# list of (tag, combi) tuples