-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cast bfloat16 to float32 for Numpy conversions #29755
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Could you add a test for this?
Is there any way we can keep track of the converted weights? This fixes the entering-into-numpy-hinterland issue, but if I've understood correctly will end up having far larger TF models, which won't be the same on equivalence tests
@amyeroberts thankfully, the NumPy values are never assigned directly as TF weights! The way our weight loading works, a TF model is first created with random weights, and then we loop over As a result, the exact dtype we use to load the weights from PyTorch doesn't matter, as long as it doesn't lose any precision (which is why I upcast to float32 for safety here). If the TF model is That said, our support for full-bfloat16 TF models is still a little shaky, but fixing that is probably a separate PR! |
@Rocketknight1 Great! All we need is a test then we're good to go 🚀 |
b67d54c
to
c830f36
Compare
@amyeroberts test added! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There is no direct conversion from Torch <-> TF, so passing tensors between the two requires us to go through Numpy. Unfortunately, Numpy doesn't support
bfloat16
- this patch fixes an issue when TF tries to load a PyTorch checkpoint where weights have been stores inbfloat16
, by upcasting the weights tofloat32
before the Numpy conversion - they can be downcast later when they're assigned as the TF weights.