Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Add coat for float8 optimizer #1231

Draft
wants to merge 74 commits into
base: main
Choose a base branch
from

Conversation

MirMustafaAli
Copy link

@MirMustafaAli MirMustafaAli commented Nov 6, 2024

This is a Work in Progress PR for #1190.

As a draft PR, I have followed the first piece of advice by @gau-nernst of "extending OptimStateFp8". Have created a separate Dynamic Range Function Instead of creating a different quantize_fp8 method as it will be applied before quantization to achieve larger representation range of float8 datatypes and the class will be storing value k to inverse the it after dequantization.

Requirements:
TBA
Additional Code/logic Added:
TBA
Logic/Code changes to existing codebase:
TBA
Outcome:
TBA
Scope of Usage:
TBA
Example
TBA

Changes:

  • Dynamic Range Expansion Function: implementation of formula from the paper
  • Created OptimStateFp8WithDynamicRangeExpansion class by extending OptimStateFp8: by referencing the implementation of the OptimStatefp8. I have only overridden the dequantize method
  • Implemented aten.copy.default and aten.to_copy.default for OptimStateFp8WithDynamicRangeExpansion:

Benchmarks

Parameters

Parameter Value
Learning Rate (lr) 0.0001
Automatic Mixed Precision (amp) bf16
Seed 42
Model timm/vit_base_patch16_224.augreg_in21k
Optimizer (optim) AdamWFp8Ao_coat
Compile False
Profile False
Project COAT-benchmarking
Number of Epochs 10
Run Name AdamWFp8Ao_coat
Full BF16 False
Number of Workers 4
Batch Size 1024
Weight Decay 0
Channels Last False
Optimizer CPU Offload None
Cosine LR Scheduler False
Checkpoint Activations False

Results

W B Chart Nov 15 2024 (1)

W B Chart Nov 15 2024

W B Chart Nov 14 2024

Copy link

pytorch-bot bot commented Nov 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1231

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 6, 2024
@MirMustafaAli MirMustafaAli marked this pull request as draft November 6, 2024 11:35
@gau-nernst
Copy link
Collaborator

I was thinking you can just add a flag to the current OptimStateFp8, something like dynamic_range_expansion: bool, instead of subclass-ing it.

@MirMustafaAli
Copy link
Author

I was thinking you can just add a flag to the current OptimStateFp8, something like dynamic_range_expansion: bool, instead of subclass-ing it.

i have added the flag for optimstatefp8. could you verify its right?

@gau-nernst
Copy link
Collaborator

I think this requires a bit more work. You need to verify that you can create an optimizer with this (add test to https://github.com/pytorch/ao/blob/main/test/prototype/test_low_bit_optim.py) as well do some short training runs for sanity checks (using https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_low_bit_adam.py).

I think for merging the PR, we should wait for the official code release to check numeric against them.

If you don't mind, we can discuss more details in GPU-MODE discord group https://discord.gg/gpumode. Just create a thread under torchao and tag me in (@gau.nernst)

@MirMustafaAli
Copy link
Author

MirMustafaAli commented Nov 6, 2024

I understand the situation for merging the PR. Will be glad to work on working on this issue. creating thread in gpumode

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Show a8wxdq load error only when the quant is used

* Update Error check
self.block_size = codes.numel() // scale.numel()
self.sqrt_minmax_exp = sqrt_minmax_exp

def __tensor_flatten__(self):
return self.tensor_attrs, []
Copy link
Collaborator

@gau-nernst gau-nernst Dec 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When k and sqrt_minmax_exp is not None, you need to return them here (in __tensor_flatten__()) also.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should i pass them instead of empty array?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first returned value (currently self.tensor_attrs) is a list of strings containing the names of tensor attributes. In this case, when there is no dynamic range extension, it's just "codes", "scale". However, when there is dynamic range extension, you need to also add "k", "sqrt_minmax_exp". IIRC, when they are None, you are not supposed to include them.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining it. i have added them. looking forward to your approval

@@ -21,8 +21,7 @@ lm_eval
diskcache
pycocotools
tqdm

# Custom CUDA Extensions
git+https://github.com/NVlabs/COAT.git#subdirectory=coat/optimizer/kernels # Custom CUDA Extensions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't add this. CPU runner will fail to build CUDA extension. We will just test this locally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants