Skip to content

Commit

Permalink
Allow Kernels for Full FT and Non-Quantized PEFT (#79)
Browse files Browse the repository at this point in the history
* add or logic for plugin registration

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* add fast kernels plugin

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* prepare full-foak benchmarks

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* update benchmark logic to have empty framework_config

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* minor fixes to foak full

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* addressed code review changes

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* additional fixes from code review

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* minor fixes to standard peft

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* Apply suggestions from code review

Co-authored-by: Yu Chin Fabian Lim <fabianlim@users.noreply.github.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* changes to filtering function and modifications to allow flexibilty of activating kernels

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* additional check in fastkernels and changes to FOAK README.md

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* fix syntax error

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* fix reloads on multiple patches

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* dtype changes to scenarios.yaml and README.md

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* changes to scenarios.yaml

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* additional comments

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* format and lint

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* fixes and updates to benchmark

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>

* fixes on reload

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

---------

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Co-authored-by: 1000850000 user <aaron.chew1@ibm.com>
Co-authored-by: achew010 <165894159+achew010@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 16, 2024
1 parent a0ac97a commit 4e81c64
Show file tree
Hide file tree
Showing 25 changed files with 890 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def model_loader(self, model_name: str, **kwargs):
# and there is a section of code that will be skipped if not set.
setattr(model, "is_loaded_in_4bit", True)
setattr(model, "quantization_method", "gptq")

return model

@property
Expand Down Expand Up @@ -275,6 +274,8 @@ def augmentation(

# some assertions
assert peft_config is not None, "need peft_config to install PEFT adapters"
# running this plugin in float16 is the most performant
# https://github.com/foundation-model-stack/fms-acceleration/issues/84
assert (
model.dtype == torch.float16 or train_args.fp16
), "need to run in fp16 mixed precision or load model in fp16"
Expand Down Expand Up @@ -324,6 +325,13 @@ def augmentation(
auto_find_all_linears=requires_installation_on_all_linears(peft_config),
train_mode=True, # install adapaters for training
)

# We do not set `is_loaded_in_4bit`` at this point because otherwise
# `accelerate.prepare_model` will think the device placement is finalized
# for the quantized model, and will raise
# Reassign `quantization_method` after PEFT installation replaces the top-level class
setattr(model, "quantization_method", "gptq")

modifiable_args = (None,) # return a None for peft_config

if self.use_external_lib:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,20 @@ def _is_backbone(module: torch.nn.Module):
# Local
from .flash_attn import _flash_attention_forward_with_posids

# - we need to reload on the correct module
try:
# if it is peft
_module_path = model.get_base_model().__module__
except AttributeError:
_module_path = model.__module__

ModelPatcher.register(
ModelPatcherRule(
rule_id="flash_attn_forward",
import_and_maybe_reload=(
"transformers.modeling_flash_attention_utils._flash_attention_forward",
partial(_flash_attention_forward_with_posids, id(model)),
model.__module__,
_module_path,
),
),
)
Expand Down
1 change: 0 additions & 1 deletion plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ classifiers=[
dependencies = [
"numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3
"torch>2.2",
"transformers",
"peft",
"accelerate",
"pandas",
Expand Down
69 changes: 55 additions & 14 deletions plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class PluginRegistration:
plugin: "AccelerationPlugin"
AND: List[str] = None
# OR: List[str] = None # not implemented yet
OR: List[str] = None

# package metadata
package_name: str = None
Expand All @@ -53,28 +53,61 @@ def _trace_key_path(configuration: Dict, key: str):
def get_relevant_configuration_sections(configuration: Dict) -> Dict:
results = []

# this function updates cfg with content
# - equivalent to taking a union
def _update_config_contents(_cfg: Dict, content: Dict, key: str):
path = key.split(".")
n = len(path)
_cfg = relevant_config
while n > 1:
p = path.pop(0)
if p not in _cfg:
_cfg[p] = {}
_cfg = _cfg[p]
n -= 1

_cfg[path[0]] = content

# assume the registrations are all done with at least some default key
for registration in PLUGIN_REGISTRATIONS:
relevant_config = {}
# OR is not implemented yet

_and_keys = registration.AND
_or_keys = registration.OR
if _and_keys is None:
_and_keys = []
if _or_keys is None:
_or_keys = []

# go through AND paths then OR paths
# - if all AND paths are speciied, then return their union of all content
# - if any OR path is specified, then return the union of specified content
reject = False
for key in registration.AND:
for key in _and_keys:
content = _trace_key_path(configuration, key)
if content is None:
# if AND key, then if at least one of them not
# specified, then reject and do not descend config tree
reject = True
break

path = key.split(".")
n = len(path)
_cfg = relevant_config
while n > 1:
p = path.pop(0)
if p not in _cfg:
_cfg[p] = {}
_cfg = _cfg[p]
n -= 1
# update
_update_config_contents(relevant_config, content, key)

# if all the any keys were not satisfied, then reset the config
if reject:
relevant_config = {}

for key in _or_keys:
content = _trace_key_path(configuration, key)
if content is not None:
if reject:
# it is an OR key, and if at least one of them specified
# then do not reject
reject = False

_cfg[path[0]] = content
# update all content that is not None
_update_config_contents(relevant_config, content, key)

if reject:
continue
Expand All @@ -91,7 +124,8 @@ class AccelerationPlugin:
@staticmethod
def register_plugin(
plugin: "AccelerationPlugin",
configuration_and_paths: List[str],
configuration_and_paths: List[str] = None,
configuration_or_paths: List[str] = None,
**kwargs,
):

Expand All @@ -101,6 +135,12 @@ def register_plugin(
# is done (global-variable-not-assigned)
# global PLUGIN_REGISTRATIONS

assert (
configuration_and_paths is not None and len(configuration_and_paths) > 0
) or (
configuration_or_paths is not None and len(configuration_or_paths) > 0
), "Specify at least one AND or OR path"

# get the package metadata
pkg_name = sys.modules[plugin.__module__].__package__
try:
Expand All @@ -112,6 +152,7 @@ def register_plugin(
PluginRegistration(
plugin=plugin,
AND=configuration_and_paths,
OR=configuration_or_paths,
package_name=pkg_name,
package_version=package_version,
)
Expand Down
18 changes: 14 additions & 4 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,24 @@ def _import_and_reload(model: torch.nn.Module):
key=lambda _rule: len(_rule.import_and_maybe_reload[2]),
reverse=False,
)
for rule_s in _with_reload:
for rule_l in _with_reload[1:]:

for i_s, rule_s in enumerate(_with_reload[:-1]):
for rule_l in _with_reload[i_s + 1 :]:
# if target paths in rule s is a prefix of rule l, raise an error
_, _, _path_s = rule_s.import_and_maybe_reload
_name_s, _obj_s, _path_s = rule_s.import_and_maybe_reload
_, _, _path_l = rule_l.import_and_maybe_reload

if _path_s == _path_l:
# - in the even the target is exactly the same, we will
# only reload once
rule_s.import_and_maybe_reload = (_name_s, _obj_s, None)
continue

# - otherwise, we do not consider the cases where the target
# is a subpath since this results in unpredictablity.
assert not _path_l.startswith(
_path_s
), f"Attempting to reload same path `{_path_s}` multiple times in \
), f"Attempting to reload a subpath`{_path_s}` multiple times in \
{rule_s.rule_id} and {rule_l.rule_id}"

# handle those with reload first
Expand Down
22 changes: 18 additions & 4 deletions plugins/framework/src/fms_acceleration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Standard
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List, Set, Tuple, Type
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union

# Third Party
import torch
Expand Down Expand Up @@ -67,7 +67,14 @@ def configure_framework_from_json(
@contextmanager
def build_framework_and_maybe_instantiate(
plugins_to_be_registered: List[
Tuple[List[str], Type[AccelerationPlugin]] # and_paths, plugin_class
Union[
Tuple[List[str], Type[AccelerationPlugin]], # and_paths, plugin_class
Tuple[
List[str],
List[str], # and_or_paths
Type[AccelerationPlugin], # plugin_class
],
]
],
configuration_contents: Dict = None,
instantiate: bool = True,
Expand All @@ -89,10 +96,17 @@ def build_framework_and_maybe_instantiate(
AccelerationFramework.active_plugins = []
AccelerationFramework.plugins_require_custom_loading = []

for path, plugin in plugins_to_be_registered:
for paths_and_plugins in plugins_to_be_registered:
try:
and_paths, plugin = paths_and_plugins
or_paths = None
except ValueError:
and_paths, or_paths, plugin = paths_and_plugins

AccelerationPlugin.register_plugin(
plugin,
configuration_and_paths=path,
configuration_and_paths=and_paths,
configuration_or_paths=or_paths,
)

if instantiate:
Expand Down
73 changes: 73 additions & 0 deletions plugins/framework/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,76 @@ def _hook(
framework.augmentation(model, None, None)
for c, (n, _) in zip(plugin_activation_order, plugins_to_be_installed):
assert n in c


def test_plugin_registration_combination_logic():

plugin = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
agumentation=dummy_augmentation,
)

configuration_contents = {"existing1": {"key1": 1}, "existing2": {"key1": 1}}

# empty conditions
with pytest.raises(AssertionError, match="Specify at least one AND or OR path"):
with build_framework_and_instantiate(
plugins_to_be_registered=[
([], [], plugin),
],
configuration_contents=configuration_contents,
) as framework:
pass

# AND logic - happy
with build_framework_and_instantiate(
plugins_to_be_registered=[
(["existing1", "existing2"], plugin),
],
configuration_contents=configuration_contents,
) as framework:
# check 1.
assert len(PLUGIN_REGISTRATIONS) == 1

# check 2.
assert len(framework.active_plugins) == 1

# AND - sad path
with pytest.raises(
ValueError,
match="No plugins could be configured. Please check the acceleration",
):
with build_framework_and_instantiate(
plugins_to_be_registered=[
(["existing1", "non-existant"], plugin),
],
configuration_contents=configuration_contents,
) as framework:
pass

# OR logic
with build_framework_and_instantiate(
plugins_to_be_registered=[
([], ["existing1", "non-existant"], plugin),
],
configuration_contents=configuration_contents,
) as framework:
# check 1.
assert len(PLUGIN_REGISTRATIONS) == 1

# check 2.
assert len(framework.active_plugins) == 1

# OR - sad path
with pytest.raises(
ValueError,
match="No plugins could be configured. Please check the acceleration",
):
with build_framework_and_instantiate(
plugins_to_be_registered=[
(["non-existant", "non-existant2"], plugin),
],
configuration_contents=configuration_contents,
) as framework:
pass
Loading

0 comments on commit 4e81c64

Please sign in to comment.