From bbd08be1306192895c914ebcee38927117815aa8 Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Mon, 26 Aug 2024 03:30:35 +0000 Subject: [PATCH] formatted foak Signed-off-by: 1000850000 user --- .../framework_plugin_fast_quantized_peft.py | 19 ++++++++++++++++--- .../src/fms_acceleration_foak/models/llama.py | 13 +++++++------ .../fms_acceleration_foak/models/mistral.py | 14 +++++++------- .../fms_acceleration_foak/models/mixtral.py | 12 ++++++------ .../src/fms_acceleration_foak/models/utils.py | 2 +- .../tests/test_fused_ops.py | 4 +--- 6 files changed, 38 insertions(+), 26 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index d2abd5b1..ff67229c 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist + # consider moving this somewhere else later def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin): """ @@ -58,9 +59,20 @@ def _all_reduce_hook(grad): if not B.weight.is_cuda: set_module_tensor_to_device(B, "weight", "cuda") + def register_foak_model_patch_rules(base_type): - from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel - from .models import llama, mistral, mixtral # pylint: disable=import-outside-toplevel + # Third Party + from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel + ModelPatcher, + ) + + # Local + from .models import ( # pylint: disable=import-outside-toplevel + llama, + mistral, + mixtral, + ) + rules = [ *llama.get_mp_rules(base_type), *mistral.get_mp_rules(base_type), @@ -69,6 +81,7 @@ def register_foak_model_patch_rules(base_type): for _rule in rules: ModelPatcher.register(_rule) + class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin): # NOTE: may remove this when we have generic model rules @@ -122,7 +135,7 @@ def augmentation( ), "need to run in fp16 mixed precision or load model in fp16" # wrapper function to register foak patches - register_foak_model_patch_rules(base_type = self._base_layer) + register_foak_model_patch_rules(base_type=self._base_layer) return model, modifiable_args def get_callbacks_and_ready_for_train( diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index a934fc1e..58bb456f 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -16,17 +16,17 @@ from functools import partial # Third Party -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaMLP, - LlamaRMSNorm, -) from fms_acceleration.model_patcher import ( ModelPatcherRule, ModelPatcherTrigger, combine_functions, combine_triggers, ) +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaMLP, + LlamaRMSNorm, +) # Local from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss @@ -34,6 +34,7 @@ from ..kernels.unsloth.rope_embedding import fast_rope_embedding from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops + def get_mp_rules(base_type: str): """ Function to access all patch rules in this module. @@ -125,5 +126,5 @@ def get_mp_rules(base_type: str): fast_rope_embedding, None, ), - ) + ), ] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index d090da5f..8e773a24 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -16,18 +16,17 @@ from functools import partial # Third Party -from transformers.models.mistral.modeling_mistral import ( - MistralAttention, - MistralMLP, - MistralRMSNorm, -) from fms_acceleration.model_patcher import ( ModelPatcherRule, ModelPatcherTrigger, combine_functions, combine_triggers, ) - +from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralMLP, + MistralRMSNorm, +) # Local from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss @@ -35,6 +34,7 @@ from ..kernels.unsloth.rope_embedding import fast_rope_embedding from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops + def get_mp_rules(base_type): """ Function to access all patch rules in this module. @@ -117,5 +117,5 @@ def get_mp_rules(base_type): fast_rope_embedding, None, ), - ) + ), ] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py index 7c0c58ab..67eada1c 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py @@ -16,24 +16,24 @@ from functools import partial # Third Party -from transformers.models.mixtral.modeling_mixtral import ( - MixtralAttention, - MixtralRMSNorm, -) from fms_acceleration.model_patcher import ( ModelPatcherRule, ModelPatcherTrigger, combine_functions, combine_triggers, ) +from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralRMSNorm, +) # Local from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding - from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops + def get_mp_rules(base_type): """ Function to access all patch rules in this module. @@ -100,5 +100,5 @@ def get_mp_rules(base_type): fast_rope_embedding, None, ), - ) + ), ] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py index 9d624277..3653dc06 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py @@ -4,6 +4,7 @@ import os # Third Party +from fms_acceleration.model_patcher import ModelPatcherTrigger import torch # Local @@ -16,7 +17,6 @@ from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_mlp as fused_op_mlp_gptq from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_o_v2 as fused_op_o_gptq from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq -from fms_acceleration.model_patcher import ModelPatcherTrigger KEY_QKV = "qkv" KEY_O = "o" diff --git a/plugins/fused-ops-and-kernels/tests/test_fused_ops.py b/plugins/fused-ops-and-kernels/tests/test_fused_ops.py index 237a3a6f..356c00b3 100644 --- a/plugins/fused-ops-and-kernels/tests/test_fused_ops.py +++ b/plugins/fused-ops-and-kernels/tests/test_fused_ops.py @@ -3,6 +3,7 @@ from itertools import product # Third Party +from fms_acceleration.model_patcher import patch_model from peft import LoraConfig from transformers import AutoConfig from transformers.models.llama.modeling_llama import LlamaAttention @@ -10,9 +11,6 @@ import pytest # pylint: disable=import-error import torch -# First Party -from fms_acceleration.model_patcher import patch_model - BNB = "bitsandbytes" GPTQ = "auto_gptq"