-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Best practice for mix precision training #57
Comments
You can use fastai, which already provides Also, I've already implemented training code in fastai, you can also take a look at it. You just need to add lines for loading the open source CLIP weights for finetuning, my code is currently for training from scratch but it should be just a couple of lines of change. Code is here, an example sample script. |
Thanks @KeremTurgutlu. The basic idea is to convert parameters into half and then use torch autocast? |
You can read materials from https://pytorch.org/docs/stable/amp.html. |
We did the memory-intensive tensor operations (matmul, convolutions) in fp16, while doing aggregation (batchnorms/layernorms) in fp32. We also kept the model weights and gradients in fp32 (which only takes constant memory w.r.t. the batch size) so that Adam only sees fp32 numbers. For a general introduction to mixed-precision training, I found this video tutorial and document by NVIDIA and particularly useful. |
@jongwook Thanks. fp32 for optimization and fp16 for forward/backward. Switching back and forth works! |
hello, I met the same problem with you. |
@cupcakefromchina Convert parameters and grads to fp 32 before applying Adam, then convert it back
---- your train function ------
|
Thank you for your solution. But there is still a question. Which object is the "clip" represent? I don't understand this function and find no api in the original code. |
@cupcakefromchina here |
@qingerVT I was trying out training on my own dataset too and was facing the exact same issue. Thanks for pointing out the stuff about mixed precision training, will definitely try it out, thanks! |
Yes, I have successfully trained CLIP on my own data and performance is close to open source model with just 5.5 million image-text pairs. I am working on wrapping up the code for both training from scratch and fine-tuning. In the meantime you can watch this branch, should be complete in 1-2 days. |
yes, loss is decreasing. Function convert_weights() is in clip/model.py. Try |
can I know what's the differences between
vs |
Ah, no difference :) |
@jongwook. @qingerVT . Hi. I have a simple question. I understand the solution below about the mix precision training. I know the converting to fp16 is to save cuda memory, with negligible loss of Acc. However, what are the advantages of the below referenced method over the usual methods, such as Apex and torch.cuda.amp work well?
|
@nbl97 Not much, just that we wanted to have a granular control over which operation is in which dtype. Now that torch AMP is mature, I'd use it if I were to start a project from scratch. |
@jongwook As we all know, converting fp32 to fp16 may drop Acc, however, what is the performance if we convert fp16 to fp32 ? When useing CLIP pretrained model on my own task, what is the performance if I always use fp32 and never convert the model? Do you have experience about that? Thanks in advance. |
I am trying to fine-tune CLIP models on new datasets. What's the best practice for mix precision training?
Using Adam, I got errors either nan or inf since attribute eps is hard to specify for Half and float32. My walk around is to divide parameters into two groups and specify different eps. Any better solutions?
The text was updated successfully, but these errors were encountered: