Skip to content

Commit

Permalink
Add Acceleration Patcher and MultiPack Plugin (#67)
Browse files Browse the repository at this point in the history
* drafting accelerator patcher

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* add comments and cleanup

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* shift dataloader to framework and add multipack

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* fmt + lint

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* more linting and readme updates

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* minor update

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* modifications to multipack

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* add unit tests

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* Apply suggestions from code review

Co-authored-by: Yu Chin Fabian Lim <fabianlim@users.noreply.github.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* additional changes from code review

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* framework lint and fmt

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* aadp lint and fmt

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* update sample configs config generator for consistency

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* Apply suggestions from code review

Co-authored-by: Yu Chin Fabian Lim <fabianlim@users.noreply.github.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* further changes from code review

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* more fixes

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* Update plugins/attention-and-distributed-packing/README.md

Co-authored-by: Yu Chin Fabian Lim <fabianlim@users.noreply.github.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* fixed tox default benchmark command

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* file renaming

Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* minor formatting

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* fix test

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

---------

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Co-authored-by: 1000850000 user <aaron.chew1@ibm.com>
Co-authored-by: achew010 <165894159+achew010@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 23, 2024
1 parent f4710e7 commit 4224c66
Show file tree
Hide file tree
Showing 42 changed files with 1,645 additions and 242 deletions.
2 changes: 1 addition & 1 deletion plugins/attention-and-distributed-packing/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ ignore=CVS,protobufs
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
# NOTE: do not lint code imported from unsloth
ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth*
ignore-paths=.*multipack_sampler.py

# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
20 changes: 19 additions & 1 deletion plugins/attention-and-distributed-packing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,31 @@
This library contains plugins to accelerate finetuning with the following optimizations:

1. Padding-Free Flash Attention Computation
2. Multipack Distributed Sampling


## Plugins

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[padding_free](./src/fms_acceleration_ilab/framework_plugin_padding_free.py) | Padding-Free Flash Attention Computation | flash_attn | | ✅ | ✅
[padding_free](./src/fms_acceleration_aadp/framework_plugin_padding_free.py) | Padding-Free Flash Attention Computation | flash_attn | | ✅ |
[multipack sampler](./src/fms_acceleration_aadp/framework_plugin_multipack.py) | Multipack Distributed Sampling | numba | | ✅ |


## Native Transformers Support from v4.44.0
Transformers natively supports padding-free from v4.44.0 [see here](https://github.com/huggingface/transformers/pull/31629). The padding-free plugin will use the transformers library if compatible,
otherwise if `transformers < v4.44.0` the plugin will use an internal implementation instead.

## Running Benchmarks

To reproduce the benchmarks, simply run the following commands,

Reproduce [Padding Free on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_pf.csv)
`bash scripts/run_benchmarks.sh "1 2" "4 8" benchmark_outputs scenarios-orca.yaml "none"`

Reproduce [MultiPack on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_mp.csv)
`bash scripts/run_benchmarks.sh "2 4 8" "16 32 64" benchmark_outputs scenarios-orca.yaml "padding-free"`

## Known Issues

### Currently Only Supports Pre-Tokenized Dataset
Expand All @@ -32,3 +44,9 @@ In the meantime, the plugin expects the user to provide a pretokenized dataset t
- is tokenized
- has template labels that are masked to exclude from loss computation
- has eos token appended

### Currenly Only Supports Multipack with Padding-Free

The multipack plugin currently also requires the padding-free plugin to work.
This may change in the future if there is demand for multipack to work standalone without padding free.

11 changes: 11 additions & 0 deletions plugins/attention-and-distributed-packing/configs/multipack.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Configurations to accelerate data packing/padding in training
training:

# dataloader configurations
dataloader:

# multipack dataloader
multipack:

# number of processes used to calculate dataset lengths
num_processes: 16
4 changes: 3 additions & 1 deletion plugins/attention-and-distributed-packing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ authors = [
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = "~=3.9"
keywords = ['fms-hf-tuning', 'acceleration', 'padding-free']
keywords = ['fms-hf-tuning', 'acceleration', 'padding-free', 'multipack']
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
Expand All @@ -23,6 +23,8 @@ classifiers=[
"Programming Language :: Python :: 3.11",
]

dependencies = ["numba"]

[tool.hatch.build.targets.wheel]
only-include = ["src/fms_acceleration_aadp"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

# Local
from .framework_plugin_multipack import MultipackDataloaderAccelerationPlugin
from .framework_plugin_padding_free import PaddingFreeAccelerationPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from dataclasses import dataclass
import warnings

# Third Party
from transformers import DefaultDataCollator, default_data_collator
import numpy as np


@dataclass
Expand Down Expand Up @@ -52,3 +56,13 @@ def __call__(self, features, return_tensors=None):
else:
ret["labels"] += [-100] + feature["input_ids"][1:]
return default_data_collator([ret], return_tensors)


def calculate_token_lengths(dataset, num_processes):
return np.array(
dataset.map(
lambda x: {"len": len(x["input_ids"])},
num_proc=num_processes,
load_from_cache_file=True,
)["len"]
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import inspect
# Standard
from functools import partial
import torch
from typing import Optional
import inspect
import os

# Third Party
# pylint: disable=no-name-in-module
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
from typing import Optional
import torch

if is_flash_attn_2_available():
# pylint: disable=import-error
from flash_attn import (
flash_attn_func,
flash_attn_varlen_func,
)
# Third Party
from flash_attn import flash_attn_func, flash_attn_varlen_func

_flash_supports_window_size = "window_size" in list(
inspect.signature(flash_attn_func).parameters
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from types import MethodType
from typing import Dict, Tuple
import warnings

# Third Party
from accelerate import Accelerator
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from torch.utils.data import DataLoader
from transformers import TrainingArguments

# from accelerate.data_loader import DataLoaderShard
import torch


class MultipackDataloaderAccelerationPlugin(AccelerationPlugin):

require_packages = {"numba"}

def __init__(
self,
configurations: Dict[str, Dict],
seed: int = 42,
):
super().__init__(configurations)

self.num_processes = self._check_config_and_maybe_check_values(
key="training.dataloader.multipack.num_processes",
)

# see about the collator
attention = self._check_config_and_maybe_check_values(
key="training.attention",
)

# internal flags
self._seed = seed
self._padding_free = False
self._pad_token_id = None

if "padding_free" in attention:
# for padding free the multipack preparation will ignore the padding tokens
self._padding_free = True
else:
# NOTE: need to get this from somewhere
assert self._pad_token_id is not None, "need to get pad token id"

@property
def requires_agumentation(self):
return True

def augmentation(
self,
model,
train_args: TrainingArguments,
modifiable_args: Tuple[LoraConfig],
):

# guarded because multipack has numba dependencies
# Third Party
# pylint: disable=import-outside-toplevel
from fms_acceleration.accelerator_patcher import (
AcceleratorPatcher,
AcceleratorPatcherComponent,
)

# Local
from .aadp_utils import ( # pylint: disable=import-outside-toplevel
calculate_token_lengths,
)
from .multipack_sampler import ( # pylint: disable=import-outside-toplevel
MultipackDistributedBatchSampler,
)

rank, num_bins = 0, 1
if torch.distributed.is_initialized():
num_bins = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
else:
# NOTE: or should we do a silent fallback
raise AssertionError(
"Multipack dataloader only works for distributed training."
)

# some checks
def _prereq(dataloader: DataLoader):
return hasattr(dataloader, "dataset")

def _build_multipack_dataloader(
dataloader: DataLoader, accelerator: Accelerator
):

# NOTE: for now we disable support for deepspeed, but can be added in
# future if needed
assert (
not accelerator.state.deepspeed_plugin
), "Currently, multipack not supported for deepspeed"

# get the dataset
dataset = dataloader.dataset
if torch.distributed.get_rank() > 0:
warnings.warn(
"Waiting for main process to perform the mapping."
"If the dataset is large, some processes might time out,"
"You may need to increase the timeout limit or the number "
f"of workers processing the dataset > {self.num_processes}."
)
torch.distributed.barrier()

lengths = calculate_token_lengths(dataset, num_processes=self.num_processes)

if torch.distributed.get_rank() == 0:
torch.distributed.barrier()

self._max_number_tokens = (
train_args.per_device_train_batch_size * lengths.mean()
)

# prepare the multipack distributed batch sampler
sampler = MultipackDistributedBatchSampler(
batch_max_length=self._max_number_tokens,
lengths=lengths,
num_replicas=num_bins,
rank=rank,
seed=self._seed,
padding=not self._padding_free,
)

# wanted to use this but its abit annoying,
# from accelerate.data_loader import DataLoaderShard
# - so will just patch for now, but lets have a better
# solution later
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
num_workers=dataloader.num_workers,
collate_fn=dataloader.collate_fn,
)

# patch a set epoch function to delegate the call to the
# batch_sampler
def _set_epoch(self, epoch: int):
self.batch_sampler.set_epoch(epoch)

dataloader.set_epoch = MethodType(_set_epoch, dataloader)
return dataloader

AcceleratorPatcher.replace(
"multipack",
AcceleratorPatcherComponent.data_loader,
replacement_builder=_build_multipack_dataloader,
pre_requisite_check=_prereq,
skip_prepare=True,
)

# take a pointer to train args
self._train_args = train_args
return model, modifiable_args


# register
AccelerationPlugin.register_plugin(
MultipackDataloaderAccelerationPlugin,
configuration_and_paths=[
"training.dataloader.multipack", # activate if multipack config
"training.attention", # currently we require multipack to work with padding free
],
)
Loading

0 comments on commit 4224c66

Please sign in to comment.