Skip to content

Commit

Permalink
format and lint
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Sep 16, 2024
1 parent 8080f86 commit 87663c1
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,9 @@ def augmentation(
train_mode=True, # install adapaters for training
)

# We do not set `is_loaded_in_4bit`` at this point because otherwise
# `accelerate.prepare_model` will think the device placement is finalized for the quantized model, and will raise
# We do not set `is_loaded_in_4bit`` at this point because otherwise
# `accelerate.prepare_model` will think the device placement is finalized
# for the quantized model, and will raise
# Reassign `quantization_method` after PEFT installation replaces the top-level class
setattr(model, "quantization_method", "gptq")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _is_backbone(module: torch.nn.Module):
try:
# if it is peft
_module_path = model.get_base_model().__module__
except AttributeError:
except AttributeError:
_module_path = model.__module__

ModelPatcher.register(
Expand Down
2 changes: 1 addition & 1 deletion plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _import_and_reload(model: torch.nn.Module):
)

for i_s, rule_s in enumerate(_with_reload[:-1]):
for rule_l in _with_reload[i_s+1:]:
for rule_l in _with_reload[i_s + 1 :]:
# if target paths in rule s is a prefix of rule l, raise an error
_name_s, _obj_s, _path_s = rule_s.import_and_maybe_reload
_, _, _path_l = rule_l.import_and_maybe_reload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.

# Local
from .framework_plugin_fast_quantized_peft import FastQuantizedPeftAccelerationPlugin
from .framework_plugin_fast_kernels import FastKernelsAccelerationPlugin
from .framework_plugin_fast_quantized_peft import FastQuantizedPeftAccelerationPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# Standard
from typing import Dict, Tuple, Set
from typing import Dict, Set, Tuple

# Third Party
from fms_acceleration import AccelerationPlugin, AccelerationPluginConfigError
Expand All @@ -22,9 +22,11 @@
from transformers import TrainingArguments
import torch

# Local
from .framework_plugin_fast_quantized_peft import lora_adapters_switch_ddp_from_fsdp

# consider rewriting register_foak_model_patch_rules into something

# consider rewriting register_foak_model_patch_rules into something
# like this also
def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = None):

Expand All @@ -36,11 +38,12 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
# Local
from .models import ( # pylint: disable=import-outside-toplevel
gpt_bigcode,
granite,
llama,
mistral,
mixtral,
granite,
)

rules = [
*gpt_bigcode.get_mp_rules(base_type),
*granite.get_mp_rules(base_type),
Expand All @@ -52,13 +55,13 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
if filter_endswith is not None:
# filter rules
rules = [
r for r in rules if
any(r.rule_id.endswith(x) for x in filter_endswith)
r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)
]

for _rule in rules:
ModelPatcher.register(_rule)


# maybe this we should define envvars
FILTER_MAP = {
"fused_lora": {"qkvo", "mlp"},
Expand All @@ -67,6 +70,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
"fast_rope_embeddings": "rope",
}


class FastKernelsAccelerationPlugin(AccelerationPlugin):

# NOTE: may remove this when we have generic model rules
Expand All @@ -93,21 +97,23 @@ def __init__(self, configurations: Dict[str, Dict]):
)

self.configurations["base_layer"] = self._check_config_and_maybe_check_values(
key="base_layer",
values=["auto_gptq", "bitsandbytes"],
default="auto_gptq"
key="base_layer", values=["auto_gptq", "bitsandbytes"], default="auto_gptq"
)
self.configurations["fused_lora"] = self._check_config_and_maybe_check_values(
key="fused_lora", values=[False,True], default=True
key="fused_lora", values=[False, True], default=True
)
self.configurations["fast_loss"] = self._check_config_and_maybe_check_values(
key="fast_loss", values=[False,True], default=True
key="fast_loss", values=[False, True], default=True
)
self.configurations["fast_rms_layernorm"] = self._check_config_and_maybe_check_values(
key="fast_rms_layernorm", values=[False,True], default=True
self.configurations["fast_rms_layernorm"] = (
self._check_config_and_maybe_check_values(
key="fast_rms_layernorm", values=[False, True], default=True
)
)
self.configurations["fast_rope_embeddings"] = self._check_config_and_maybe_check_values(
key="fast_rope_embeddings", values=[False,True], default=True
self.configurations["fast_rope_embeddings"] = (
self._check_config_and_maybe_check_values(
key="fast_rope_embeddings", values=[False, True], default=True
)
)

@property
Expand Down Expand Up @@ -135,7 +141,7 @@ def augmentation(
if getattr(model, "quantization_method", None) is None:
# - fused_lora only required for quant-peft
omitted.add("fused_lora")

terms = set()
for k, v in self.configurations.items():
if k in FILTER_MAP and k not in omitted:
Expand All @@ -149,8 +155,7 @@ def augmentation(
# wrapper function to register foak patches
# - the base layer setting below will be ignored in non quantized-lora settings
register_foak_model_patch_rules2(
base_type=self.configurations['base_layer'],
filter_endswith=terms
base_type=self.configurations["base_layer"], filter_endswith=terms
)
return model, modifiable_args

Expand Down Expand Up @@ -179,6 +184,7 @@ def get_callbacks_and_ready_for_train(
)
return callbacks


# register
AccelerationPlugin.register_plugin(
FastKernelsAccelerationPlugin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def get_callbacks_and_ready_for_train(
)
return callbacks


# This plugin is currently deregistered in favour of framework_plugin_fast_kernels.py
# to additionally support both full-FT and standard PEFT
# AccelerationPlugin.register_plugin(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from fms_acceleration.model_patcher import (
ModelPatcherRule,
)
# Third Party
from fms_acceleration.model_patcher import ModelPatcherRule

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss


def get_mp_rules(base_type: str):
"""
Function to access all patch rules in this module.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def get_mp_rules(base_type: str):
function as a partial function with the base_type argument
"""
try:
from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel
# Third Party
from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel
GraniteAttention,
GraniteMLP,
GraniteRMSNorm,
Expand Down

0 comments on commit 87663c1

Please sign in to comment.