Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 integration #1086

Merged
merged 34 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
68b98fc
Draft of FP8 support
sgugger Dec 7, 2022
f10a4b0
Missing import
sgugger Dec 8, 2022
23b0e04
Fix names
sgugger Dec 9, 2022
12ac9bf
Conversion is inplace
sgugger Dec 9, 2022
ab878ac
Enable fp8 in examples
sgugger Dec 9, 2022
b5578ea
Customization point for Recipe
sgugger Dec 9, 2022
b818aee
Auto-enable FP8 depending on compute capability
sgugger Dec 12, 2022
9e7157e
Fix typo
sgugger Dec 12, 2022
a1a86d1
Put back mixed precision arg
sgugger Dec 12, 2022
94d7101
Add debug script
sgugger Dec 12, 2022
13deee8
Add more tests in debug
sgugger Dec 12, 2022
097d3eb
Add more stuff to debug
sgugger Dec 12, 2022
23d188f
Don't forget train
sgugger Dec 12, 2022
4656b23
Put the train in the right place
sgugger Dec 12, 2022
17d0d38
Add options for selective conversion
sgugger Dec 13, 2022
aa3c0c2
Fix typo
sgugger Dec 13, 2022
8c69cd4
Properly recurse
sgugger Dec 13, 2022
4df14d0
Add more debug utils
sgugger Dec 13, 2022
978ed82
Typo and init
sgugger Dec 14, 2022
a2efc93
Last choice
sgugger Dec 14, 2022
583c5c5
More fixes
sgugger Dec 14, 2022
cb2b2f1
More options in example
sgugger Dec 14, 2022
7deee02
Remove debug scripts
sgugger Jan 26, 2023
e3e2ee9
Clean up debug and new names
sgugger Jan 26, 2023
659e1dc
Add torch.no_grad for conversion
sgugger Jan 27, 2023
9c715b0
Optimizer is deconnected from model?
sgugger Jan 27, 2023
b3f997f
Re-attach model parameters to optimizer
sgugger Jan 27, 2023
2787834
Fix extract
sgugger Jan 27, 2023
e2070df
Style
sgugger Jan 27, 2023
6f21320
Cleanup post-rebase
sgugger Feb 15, 2023
063d33e
Deal with apdding
sgugger Feb 15, 2023
9f6b40f
fix examples
sgugger Feb 15, 2023
73e06ac
Update src/accelerate/accelerator.py
sgugger Mar 2, 2023
10f879a
Address comments
sgugger Mar 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def main():
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
choices=["no", "fp16", "bf16", "fp8"],
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
Expand Down
16 changes: 12 additions & 4 deletions examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,17 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
if accelerator.distributed_type == DistributedType.TPU:
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
elif accelerator.mixed_precision != "no":
pad_to_multiple_of = 8
else:
pad_to_multiple_of = None

return tokenizer.pad(
examples, padding="longest", max_length=128, pad_to_multiple_of=pad_to_multiple_of, return_tensors="pt"
)

# Instantiate dataloaders.
train_dataloader = DataLoader(
Expand Down Expand Up @@ -251,7 +259,7 @@ def main():
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
choices=["no", "fp16", "bf16", "fp8"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
Expand Down
2 changes: 1 addition & 1 deletion examples/cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def main():
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
choices=["no", "fp16", "bf16", "fp8"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
Expand Down
18 changes: 13 additions & 5 deletions examples/nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,17 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
if accelerator.distributed_type == DistributedType.TPU:
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
elif accelerator.mixed_precision != "no":
pad_to_multiple_of = 8
else:
pad_to_multiple_of = None

return tokenizer.pad(
examples, padding="longest", max_length=128, pad_to_multiple_of=pad_to_multiple_of, return_tensors="pt"
)

# Instantiate dataloaders.
train_dataloader = DataLoader(
Expand Down Expand Up @@ -120,7 +128,6 @@ def training_function(config, args):
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
model = model.to(accelerator.device)

# Instantiate optimizer
optimizer = AdamW(params=model.parameters(), lr=lr)

Expand All @@ -134,6 +141,7 @@ def training_function(config, args):
# Prepare everything
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
# prepare method.

model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
Expand Down Expand Up @@ -177,7 +185,7 @@ def main():
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
choices=["no", "fp16", "bf16", "fp8"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
Expand Down
47 changes: 41 additions & 6 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
DistributedDataParallelKwargs,
DistributedType,
DynamoBackend,
FP8RecipeKwargs,
FullyShardedDataParallelPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
Expand All @@ -49,12 +50,15 @@
ProjectConfiguration,
RNGType,
compare_versions,
convert_model,
convert_outputs_to_fp32,
extract_model_from_parallel,
gather,
get_pretty_name,
has_transformer_engine_layers,
is_bf16_available,
is_deepspeed_available,
is_fp8_available,
is_megatron_lm_available,
is_torch_version,
is_tpu_available,
Expand All @@ -79,6 +83,11 @@
DummyScheduler,
)

if is_fp8_available():
import transformer_engine.common.recipe as te_recipe
from transformer_engine.pytorch import fp8_autocast


if is_megatron_lm_available():
from .utils import (
MegatronEngine,
Expand Down Expand Up @@ -123,10 +132,11 @@ class Accelerator:
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
in your script multiplied by the number of processes.
mixed_precision (`str`, *optional*):
Whether or not to use mixed precision training (fp16 or bfloat16). Choose from 'no','fp16','bf16'. Will
default to the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default
value in the accelerate config of the current system or the flag passed with the `accelerate.launch`
command. 'fp16' requires pytorch 1.6 or higher. 'bf16' requires pytorch 1.10 or higher.
Whether or not to use mixed precision training (fp16 or bfloat16). Choose from 'no','fp16','bf16 or 'fp8'.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Whether or not to use mixed precision training (fp16 or bfloat16). Choose from 'no','fp16','bf16 or 'fp8'.
Whether or not to use mixed precision training (fp8, fp16, or bfloat16). Choose from 'no','fp16','bf16 or 'fp8'.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the earlier comment, this could be "fp16, bfloat16, or fp8", or we remove the () and just have "Choose from"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the parenthesis entirely.

Will default to the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the
default value in the accelerate config of the current system or the flag passed with the
`accelerate.launch` command. 'fp16' requires pytorch 1.6 or higher. 'bf16' requires pytorch 1.10 or higher.
'fp8' requires the installation of transformers-engine.
gradient_accumulation_steps (`int`, *optional*, default to 1):
The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with
`Accelerator.accumulate`. If not passed, will default to the value in the environment variable
Expand Down Expand Up @@ -298,6 +308,7 @@ def __init__(
self.ddp_handler = None
self.scaler_handler = None
self.init_handler = None
self.fp8_recipe_handler = None
if kwargs_handlers is not None:
for handler in kwargs_handlers:
assert isinstance(
Expand All @@ -318,6 +329,11 @@ def __init__(
raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
else:
self.init_handler = handler
elif isinstance(handler, FP8RecipeKwargs):
if self.fp8_recipe_handler is not None:
raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.")
else:
self.fp8_recipe_handler = handler

kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
self.state = AcceleratorState(
Expand Down Expand Up @@ -1046,7 +1062,7 @@ def prepare(self, *args, device_placement=None):

# If we're dealing with device placement, this deals with that by...
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.TPU
if tpu_should_fix_optimizer:
if tpu_should_fix_optimizer or self.mixed_precision == "fp8":
# 1. grabbing old model parameters
old_named_params = self._get_named_parameters(*args)

Expand All @@ -1060,7 +1076,7 @@ def prepare(self, *args, device_placement=None):
)
result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))

if tpu_should_fix_optimizer:
if tpu_should_fix_optimizer or self.mixed_precision == "fp8":
# 2. grabbing new model parameters
new_named_params = self._get_named_parameters(*result)
# 3. building a map from the first to the second
Expand Down Expand Up @@ -1144,6 +1160,25 @@ def prepare_model(self, model: torch.nn.Module, device_placement=None):
else:
model.forward = torch.cuda.amp.autocast()(model.forward)
model.forward = convert_outputs_to_fp32(model.forward)
elif self.mixed_precision == "fp8":
if not has_transformer_engine_layers(model):
with torch.no_grad():
convert_model(model)
model._converted_to_transformer_engine = True
model._original_forward = model.forward

kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
fp8_enabled = torch.cuda.get_device_capability()[0] >= 9
if not fp8_enabled:
logger.warn(
f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
"insufficient for FP8 mixed precision training (requires a GPU Hopper or higher, compute "
"capability of 9 or higher). Will using FP16 instead."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this only warn? Or should it not explicitly raise an error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still uses transformer engine instead of the regular model, so useful for testing on A100s

sgugger marked this conversation as resolved.
Show resolved Hide resolved
)
model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
if self.distributed_type == DistributedType.TPU and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
return model
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def get_cluster_input():
else:
mixed_precision = _ask_options(
"Do you wish to use FP16 or BF16 (mixed precision)?",
["no", "fp16", "bf16"],
["no", "fp16", "bf16", "fp8"],
_convert_mixed_precision,
)

Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/config/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _convert_dynamo_backend(value):

def _convert_mixed_precision(value):
value = int(value)
return PrecisionType(["no", "fp16", "bf16"][value])
return PrecisionType(["no", "fp16", "bf16", "fp8"][value])


def _convert_sagemaker_distributed_mode(value):
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DistributedDataParallelKwargs,
DistributedType,
DynamoBackend,
FP8RecipeKwargs,
FullyShardedDataParallelPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
Expand All @@ -28,6 +29,7 @@
is_comet_ml_available,
is_datasets_available,
is_deepspeed_available,
is_fp8_available,
is_megatron_lm_available,
is_mlflow_available,
is_mps_available,
Expand Down Expand Up @@ -130,3 +132,4 @@
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
from .torch_xla import install_xla
from .tqdm import tqdm
from .transformer_engine import convert_model, has_transformer_engine_layers
35 changes: 34 additions & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from dataclasses import dataclass, field
from datetime import timedelta
from distutils.util import strtobool
from typing import Any, Callable, Dict, Iterable, List, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -141,6 +141,38 @@ class InitProcessGroupKwargs(KwargsHandler):
timeout: timedelta = timedelta(seconds=1800)


@dataclass
class FP8RecipeKwargs(KwargsHandler):
"""
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
training. Please refer to the documentation of this
[class](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transformer_engine.common.recipe.DelayedScaling)
for more information on each argument.

```python
from accelerate import Accelerator
from accelerate.utils import FP8RecipeKwargs

kwargs = FP8RecipeKwargs(fp8_format="HYBRID")
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs])
```
"""

margin: int = 0
interval: int = 1
fp8_format: str = "E4M3"
amax_history_len: int = 1
amax_compute_algo: str = "most_recent"
override_linear_precision: Tuple[bool, bool, bool] = (False, False, False)

def __post_init__(self):
self.fp8_format = self.fp8_format.upper()
if self.fp8_format not in ["E4M3", "HYBRID"]:
raise ValueError("`fp8_format` must be 'E4M3' or 'HYBRID'.")
if self.amax_compute_algo not in ["max", "most_recent"]:
raise ValueError("`amax_compute_algo` must be 'max' or 'most_recent'")


class DistributedType(str, enum.Enum):
"""
Represents a type of distributed environment.
Expand Down Expand Up @@ -294,6 +326,7 @@ class PrecisionType(BaseEnum):
"""

NO = "no"
FP8 = "fp8"
FP16 = "fp16"
BF16 = "bf16"

Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None


def is_fp8_available():
return importlib.util.find_spec("transformer_engine") is not None


@lru_cache()
def is_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..state import PartialState
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_tpu_available
from .transformer_engine import convert_model


if is_deepspeed_available():
Expand Down Expand Up @@ -59,6 +60,8 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True):
if forward == original_forward:
break
model.forward = forward
if getattr(model, "_converted_to_transformer_engine", False):
convert_model(model, to_transformer_engine=False)
return model


Expand Down
Loading