Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Add coat for float8 optimizer #1231

Draft
wants to merge 75 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
c62fcd3
added dynamic range expansion
MirMustafaAli Nov 6, 2024
b11b4f6
created optimstate with DRE class
MirMustafaAli Nov 6, 2024
b887e69
implement copy_.default for OptimStateFp8WithDynamicRangeExpansion cl…
MirMustafaAli Nov 6, 2024
ab02605
implements _to_copy
MirMustafaAli Nov 6, 2024
43f5c08
removed implemented classes
MirMustafaAli Nov 6, 2024
7d98b15
dynamic_range_expansion -> apply_dynamic_range_expansion
MirMustafaAli Nov 6, 2024
e03c79c
add DRE flags to class
MirMustafaAli Nov 6, 2024
5faa7de
implementing contraction for dequantize
MirMustafaAli Nov 6, 2024
b43c88f
copy k values as well for copy method
MirMustafaAli Nov 6, 2024
ac41627
added dynamic range expansion
MirMustafaAli Nov 6, 2024
79c9461
created optimstate with DRE class
MirMustafaAli Nov 6, 2024
47a7bb0
implement copy_.default for OptimStateFp8WithDynamicRangeExpansion cl…
MirMustafaAli Nov 6, 2024
a162f94
implements _to_copy
MirMustafaAli Nov 6, 2024
5c1a3f4
removed implemented classes
MirMustafaAli Nov 6, 2024
1458d65
dynamic_range_expansion -> apply_dynamic_range_expansion
MirMustafaAli Nov 6, 2024
42fbb09
add DRE flags to class
MirMustafaAli Nov 6, 2024
9d1c00c
implementing contraction for dequantize
MirMustafaAli Nov 6, 2024
7bc6ea4
copy k values as well for copy method
MirMustafaAli Nov 6, 2024
7be5a6b
Merge branch 'add_coat_optimizer' of https://github.com/MirMustafaAli…
MirMustafaAli Nov 8, 2024
70937c8
combine range_expansion into quantize_fp8 function
MirMustafaAli Nov 8, 2024
3583de7
passing apply_range_expansion to quantize_fp8
MirMustafaAli Nov 8, 2024
c754893
remove apply_dynamic_range_expansion method
MirMustafaAli Nov 8, 2024
c47b987
pass destination's dynamic range expasnsion variable to quantize fp8
MirMustafaAli Nov 8, 2024
7a754ce
change type annotation to optional
MirMustafaAli Nov 8, 2024
4d37d86
k is none when dynamic range expansion is False
MirMustafaAli Nov 9, 2024
2d1834a
referencing paper for calculation of dynamic range expansion
MirMustafaAli Nov 9, 2024
3d0d5d6
replaced condition check using variable k
MirMustafaAli Nov 9, 2024
c413ac4
added parameter dynamic_range_expansion
MirMustafaAli Nov 9, 2024
c3f5d29
pass bool condition for quantizing src tensor
MirMustafaAli Nov 9, 2024
1ec9335
readded the torchversion safe_global exports
MirMustafaAli Nov 9, 2024
122530e
initialize k to none and later assign value if dynamic range expansio…
MirMustafaAli Nov 9, 2024
77e1371
conditional statement by checking if k is None instead of directly ap…
MirMustafaAli Nov 9, 2024
366743c
checking if k is available in dst to copy it
MirMustafaAli Nov 9, 2024
38951ae
matching parameters counts with constructor of optimStateFp8
MirMustafaAli Nov 12, 2024
4b3fb6b
copy to k tensor only if k is not None
MirMustafaAli Nov 12, 2024
7185b00
passing k tensor if values are available
MirMustafaAli Nov 12, 2024
0d7edae
providing dynamic range expansion to the adamfloat8 class
MirMustafaAli Nov 12, 2024
58ff635
change of _subclass_zeros from static method to normal class method
MirMustafaAli Nov 12, 2024
6c536a9
added dynamic range expansion to adamwfp8
MirMustafaAli Nov 12, 2024
767ccab
adding smoke test for additional parameters for float8 optimizers
MirMustafaAli Nov 12, 2024
8fa5e3d
added new line
MirMustafaAli Nov 13, 2024
f34bfdd
remove newline
MirMustafaAli Nov 13, 2024
41598a0
removed optim_addon parameter
MirMustafaAli Nov 13, 2024
c189dc7
rename test_optim_addon to test_optim_fp8_coat_smoke
MirMustafaAli Nov 13, 2024
6bb49ea
code formatting
MirMustafaAli Nov 13, 2024
6707425
Merge branch 'main' into add_coat_optimizer
MirMustafaAli Nov 13, 2024
b1aea26
Moved device compatibility check for FP8 optimizer tests from pytest …
MirMustafaAli Nov 13, 2024
92ca7b2
formatting for `ruff check F,I`
MirMustafaAli Nov 13, 2024
861423d
removing duplicate
MirMustafaAli Nov 13, 2024
7661b61
checking if device is cuda before calling device capability
MirMustafaAli Nov 13, 2024
e1fa683
Updating Readme with dynamic range Expansion and Reference to Paper
MirMustafaAli Nov 13, 2024
62eac8b
Merge branch 'main' into add_coat_optimizer
MirMustafaAli Nov 15, 2024
1f8f153
Merge branch 'main' into add_coat_optimizer
MirMustafaAli Nov 26, 2024
f9d0aa1
Merge branch 'main' into add_coat_optimizer
MirMustafaAli Dec 15, 2024
6eba1d1
Merge branch 'pytorch:main' into add_coat_optimizer
MirMustafaAli Dec 21, 2024
e1ce12d
removal of block_size parameter
MirMustafaAli Dec 23, 2024
8b8126a
added geometric mean to expand function
MirMustafaAli Dec 23, 2024
3004a53
geometeric mean to subclass_fp8
MirMustafaAli Dec 23, 2024
0bbba59
merged un staged commit
MirMustafaAli Dec 23, 2024
a18f58f
Revert "merged un staged commit"
MirMustafaAli Dec 23, 2024
6a401b4
adding github qoptim library to test expand optim function
MirMustafaAli Dec 24, 2024
fa40e16
testcase for numerical accuracy with coat library
MirMustafaAli Dec 25, 2024
529ff66
remove print statement
MirMustafaAli Dec 25, 2024
bb6b707
removed libraries which failed cpu runner
MirMustafaAli Dec 25, 2024
c9af6ce
adding k and sqrt_minmax_exp when dynamic range expansion has been en…
MirMustafaAli Dec 25, 2024
f530c42
ruff formating
MirMustafaAli Dec 25, 2024
024b465
Merge branch 'pytorch:main' into add_coat_optimizer
MirMustafaAli Dec 25, 2024
8de5220
tensor unflatten when dynamic range expansion is true
MirMustafaAli Dec 26, 2024
dc8cc60
Merge branch 'pytorch:main' into add_coat_optimizer
MirMustafaAli Dec 30, 2024
8588055
added test case to check numerical comparison between eager, aot_eage…
MirMustafaAli Jan 5, 2025
bcfa131
ruff formatting
MirMustafaAli Jan 5, 2025
a6b9618
Merge branch 'pytorch:main' into add_coat_optimizer
MirMustafaAli Jan 5, 2025
65ac4ec
Merge branch 'pytorch:main' into add_coat_optimizer
MirMustafaAli Jan 25, 2025
a31b0d4
removed testing of torch.compile accuracy
MirMustafaAli Jan 25, 2025
ec7b2db
Merge branch 'pytorch:main' into add_coat_optimizer
MirMustafaAli Jan 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into add_coat_optimizer
  • Loading branch information
MirMustafaAli committed Nov 26, 2024
commit 1f8f153bc6c7f5ede9eadf6be4f444e6ec8f0685
15 changes: 5 additions & 10 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,11 @@ def _subclass_zeros(self, p: Tensor, signed: bool):
raise NotImplementedError

def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed)
local_p = p.to_local() if isinstance(p, DTensor) else p

# follow bitsandbytes, only quantize tensors >= 4096 values
if local_p.numel() >= 4096 and local_p.numel() % self.block_size == 0:
out = self._subclass_zeros(local_p, signed, self.block_size)
else:
out = torch.zeros_like(local_p)

Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.