Skip to content

Commit

Permalink
formatted foak
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Aug 26, 2024
1 parent d5546d7 commit bbd08be
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import torch.distributed as dist


# consider moving this somewhere else later
def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin):
"""
Expand Down Expand Up @@ -58,9 +59,20 @@ def _all_reduce_hook(grad):
if not B.weight.is_cuda:
set_module_tensor_to_device(B, "weight", "cuda")


def register_foak_model_patch_rules(base_type):
from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
from .models import llama, mistral, mixtral # pylint: disable=import-outside-toplevel
# Third Party
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
ModelPatcher,
)

# Local
from .models import ( # pylint: disable=import-outside-toplevel
llama,
mistral,
mixtral,
)

rules = [
*llama.get_mp_rules(base_type),
*mistral.get_mp_rules(base_type),
Expand All @@ -69,6 +81,7 @@ def register_foak_model_patch_rules(base_type):
for _rule in rules:
ModelPatcher.register(_rule)


class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):

# NOTE: may remove this when we have generic model rules
Expand Down Expand Up @@ -122,7 +135,7 @@ def augmentation(
), "need to run in fp16 mixed precision or load model in fp16"

# wrapper function to register foak patches
register_foak_model_patch_rules(base_type = self._base_layer)
register_foak_model_patch_rules(base_type=self._base_layer)
return model, modifiable_args

def get_callbacks_and_ready_for_train(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,25 @@
from functools import partial

# Third Party
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
LlamaRMSNorm,
)
from fms_acceleration.model_patcher import (
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
LlamaRMSNorm,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops


def get_mp_rules(base_type: str):
"""
Function to access all patch rules in this module.
Expand Down Expand Up @@ -125,5 +126,5 @@ def get_mp_rules(base_type: str):
fast_rope_embedding,
None,
),
)
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,25 @@
from functools import partial

# Third Party
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralMLP,
MistralRMSNorm,
)
from fms_acceleration.model_patcher import (
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)

from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralMLP,
MistralRMSNorm,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops


def get_mp_rules(base_type):
"""
Function to access all patch rules in this module.
Expand Down Expand Up @@ -117,5 +117,5 @@ def get_mp_rules(base_type):
fast_rope_embedding,
None,
),
)
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@
from functools import partial

# Third Party
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralRMSNorm,
)
from fms_acceleration.model_patcher import (
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralRMSNorm,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding

from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops


def get_mp_rules(base_type):
"""
Function to access all patch rules in this module.
Expand Down Expand Up @@ -100,5 +100,5 @@ def get_mp_rules(base_type):
fast_rope_embedding,
None,
),
)
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os

# Third Party
from fms_acceleration.model_patcher import ModelPatcherTrigger
import torch

# Local
Expand All @@ -16,7 +17,6 @@
from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_mlp as fused_op_mlp_gptq
from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_o_v2 as fused_op_o_gptq
from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq
from fms_acceleration.model_patcher import ModelPatcherTrigger

KEY_QKV = "qkv"
KEY_O = "o"
Expand Down
4 changes: 1 addition & 3 deletions plugins/fused-ops-and-kernels/tests/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from itertools import product

# Third Party
from fms_acceleration.model_patcher import patch_model
from peft import LoraConfig
from transformers import AutoConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.utils.import_utils import _is_package_available
import pytest # pylint: disable=import-error
import torch

# First Party
from fms_acceleration.model_patcher import patch_model

BNB = "bitsandbytes"
GPTQ = "auto_gptq"

Expand Down

0 comments on commit bbd08be

Please sign in to comment.