-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 integration #1086
Fp8 integration #1086
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome job! I left a few notes and general questions, but great work :)
src/accelerate/accelerator.py
Outdated
default to the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default | ||
value in the accelerate config of the current system or the flag passed with the `accelerate.launch` | ||
command. 'fp16' requires pytorch 1.6 or higher. 'bf16' requires pytorch 1.10 or higher. | ||
Whether or not to use mixed precision training (fp16 or bfloat16). Choose from 'no','fp16','bf16 or 'fp8'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether or not to use mixed precision training (fp16 or bfloat16). Choose from 'no','fp16','bf16 or 'fp8'. | |
Whether or not to use mixed precision training (fp8, fp16, or bfloat16). Choose from 'no','fp16','bf16 or 'fp8'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the earlier comment, this could be "fp16, bfloat16, or fp8", or we remove the () and just have "Choose from"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the parenthesis entirely.
src/accelerate/accelerator.py
Outdated
f"The current device has compute capability of {torch.cuda.get_device_capability()} which is " | ||
"insufficient for FP8 mixed precision training (requires a GPU Hopper or higher, compute " | ||
"capability of 9 or higher). Will using FP16 instead." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this only warn? Or should it not explicitly raise an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It still uses transformer engine instead of the regular model, so useful for testing on A100s
import transformer_engine.pytorch as te | ||
|
||
|
||
def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this run an explicit try/catch for is_fp8_available
and raise an error if not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be called if is_fp8_available
is False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! This is going to be a game-changer for LLM training. LGTM 🚀
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
This PR brings FP8 mixed precision training to Accelerate. Using this requires a GPU Hopper or higher (hard to find at the moment!) and the
transformers_engine
library.