From 160dfeff3f43017e61f5bc158c421c9ed771c359 Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Fri, 13 Sep 2024 12:12:52 +0000 Subject: [PATCH] format and lint Signed-off-by: 1000850000 user --- .../framework_plugin_autogptq.py | 5 ++- .../framework_plugin_padding_free.py | 2 +- .../src/fms_acceleration/model_patcher.py | 2 +- .../src/fms_acceleration_foak/__init__.py | 2 +- .../framework_plugin_fast_kernels.py | 40 +++++++++++-------- .../framework_plugin_fast_quantized_peft.py | 1 + .../models/gpt_bigcode.py | 6 +-- .../fms_acceleration_foak/models/granite.py | 3 +- 8 files changed, 35 insertions(+), 26 deletions(-) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index e939244f..cde3465c 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -326,8 +326,9 @@ def augmentation( train_mode=True, # install adapaters for training ) - # We do not set `is_loaded_in_4bit`` at this point because otherwise - # `accelerate.prepare_model` will think the device placement is finalized for the quantized model, and will raise + # We do not set `is_loaded_in_4bit`` at this point because otherwise + # `accelerate.prepare_model` will think the device placement is finalized + # for the quantized model, and will raise # Reassign `quantization_method` after PEFT installation replaces the top-level class setattr(model, "quantization_method", "gptq") diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py index 111336c4..0e4e5ef9 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py @@ -179,7 +179,7 @@ def _is_backbone(module: torch.nn.Module): try: # if it is peft _module_path = model.get_base_model().__module__ - except AttributeError: + except AttributeError: _module_path = model.__module__ ModelPatcher.register( diff --git a/plugins/framework/src/fms_acceleration/model_patcher.py b/plugins/framework/src/fms_acceleration/model_patcher.py index 4907858d..56c0771f 100644 --- a/plugins/framework/src/fms_acceleration/model_patcher.py +++ b/plugins/framework/src/fms_acceleration/model_patcher.py @@ -350,7 +350,7 @@ def _import_and_reload(model: torch.nn.Module): ) for i_s, rule_s in enumerate(_with_reload[:-1]): - for rule_l in _with_reload[i_s+1:]: + for rule_l in _with_reload[i_s + 1 :]: # if target paths in rule s is a prefix of rule l, raise an error _name_s, _obj_s, _path_s = rule_s.import_and_maybe_reload _, _, _path_l = rule_l.import_and_maybe_reload diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py index be1c3e9c..361bac23 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. # Local -from .framework_plugin_fast_quantized_peft import FastQuantizedPeftAccelerationPlugin from .framework_plugin_fast_kernels import FastKernelsAccelerationPlugin +from .framework_plugin_fast_quantized_peft import FastQuantizedPeftAccelerationPlugin diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 66f1ad21..cb39d4e6 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Dict, Tuple, Set +from typing import Dict, Set, Tuple # Third Party from fms_acceleration import AccelerationPlugin, AccelerationPluginConfigError @@ -22,9 +22,11 @@ from transformers import TrainingArguments import torch +# Local from .framework_plugin_fast_quantized_peft import lora_adapters_switch_ddp_from_fsdp -# consider rewriting register_foak_model_patch_rules into something + +# consider rewriting register_foak_model_patch_rules into something # like this also def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = None): @@ -36,11 +38,12 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = # Local from .models import ( # pylint: disable=import-outside-toplevel gpt_bigcode, + granite, llama, mistral, mixtral, - granite, ) + rules = [ *gpt_bigcode.get_mp_rules(base_type), *granite.get_mp_rules(base_type), @@ -52,13 +55,13 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = if filter_endswith is not None: # filter rules rules = [ - r for r in rules if - any(r.rule_id.endswith(x) for x in filter_endswith) + r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith) ] for _rule in rules: ModelPatcher.register(_rule) + # maybe this we should define envvars FILTER_MAP = { "fused_lora": {"qkvo", "mlp"}, @@ -67,6 +70,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = "fast_rope_embeddings": "rope", } + class FastKernelsAccelerationPlugin(AccelerationPlugin): # NOTE: may remove this when we have generic model rules @@ -93,21 +97,23 @@ def __init__(self, configurations: Dict[str, Dict]): ) self.configurations["base_layer"] = self._check_config_and_maybe_check_values( - key="base_layer", - values=["auto_gptq", "bitsandbytes"], - default="auto_gptq" + key="base_layer", values=["auto_gptq", "bitsandbytes"], default="auto_gptq" ) self.configurations["fused_lora"] = self._check_config_and_maybe_check_values( - key="fused_lora", values=[False,True], default=True + key="fused_lora", values=[False, True], default=True ) self.configurations["fast_loss"] = self._check_config_and_maybe_check_values( - key="fast_loss", values=[False,True], default=True + key="fast_loss", values=[False, True], default=True ) - self.configurations["fast_rms_layernorm"] = self._check_config_and_maybe_check_values( - key="fast_rms_layernorm", values=[False,True], default=True + self.configurations["fast_rms_layernorm"] = ( + self._check_config_and_maybe_check_values( + key="fast_rms_layernorm", values=[False, True], default=True + ) ) - self.configurations["fast_rope_embeddings"] = self._check_config_and_maybe_check_values( - key="fast_rope_embeddings", values=[False,True], default=True + self.configurations["fast_rope_embeddings"] = ( + self._check_config_and_maybe_check_values( + key="fast_rope_embeddings", values=[False, True], default=True + ) ) @property @@ -135,7 +141,7 @@ def augmentation( if getattr(model, "quantization_method", None) is None: # - fused_lora only required for quant-peft omitted.add("fused_lora") - + terms = set() for k, v in self.configurations.items(): if k in FILTER_MAP and k not in omitted: @@ -149,8 +155,7 @@ def augmentation( # wrapper function to register foak patches # - the base layer setting below will be ignored in non quantized-lora settings register_foak_model_patch_rules2( - base_type=self.configurations['base_layer'], - filter_endswith=terms + base_type=self.configurations["base_layer"], filter_endswith=terms ) return model, modifiable_args @@ -179,6 +184,7 @@ def get_callbacks_and_ready_for_train( ) return callbacks + # register AccelerationPlugin.register_plugin( FastKernelsAccelerationPlugin, diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index 35817071..6ec4cd99 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -160,6 +160,7 @@ def get_callbacks_and_ready_for_train( ) return callbacks + # This plugin is currently deregistered in favour of framework_plugin_fast_kernels.py # to additionally support both full-FT and standard PEFT # AccelerationPlugin.register_plugin( diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/gpt_bigcode.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/gpt_bigcode.py index 9eb2cf64..1f09d913 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/gpt_bigcode.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/gpt_bigcode.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fms_acceleration.model_patcher import ( - ModelPatcherRule, -) +# Third Party +from fms_acceleration.model_patcher import ModelPatcherRule # Local from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss + def get_mp_rules(base_type: str): """ Function to access all patch rules in this module. diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py index e5307791..a2be13ab 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py @@ -38,7 +38,8 @@ def get_mp_rules(base_type: str): function as a partial function with the base_type argument """ try: - from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel + # Third Party + from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel GraniteAttention, GraniteMLP, GraniteRMSNorm,