-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerated training with floating point fp16 #7
Comments
Same thing is happening with me, unable to use it with AUTOMATIC MIXED PRECISION (Pytorch). |
Hey, I just implemnted AMP in this way and it seems to be working:
|
Thanks for the reply!
As for the backward, I keep 1st step to be as same as before and only use scaled_loss for the 2nd backward.
It's able to work but I'm not sure whether it's the best solution. If I use scaled loss for the 1st backward, Nan loss always happens. |
Is it also possible to initialize the full optimizer? |
@alexriedel1 I've tried to initialize amp with optimizer, but it doesn't work. |
@milliema Yes that's absolutely explainable as SAM needs two backward passes through the network instead of one with a simpel SGD, so it should take double the time to train |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
@alexriedel1 Hello. I found your comment while looking for a way to apply Why did you do backwards using the mean of the loss? Is there any problem with this? Is it okay to not use scaler? |
I didn't fully implement the amp method as proposed. I think using the scaler will be no problem. Reducing the loss to mean is just dependent on your loss function. For example, pytorchs BCE Loss is already implemented with the mean reduction by default. |
@alexriedel1 Ok, I got it. I think the problem when applying |
yes, the original solution does not unscale the gradients, which would lead to the scaling factor interfering with the learning rate if you take a look at
in theory you should be able to run something similar to the example here by doing the following during traning
however since you're not calling |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
is |
Here are my two cents on this issue. TLDR: use the following code and be ready to revert to the regular single-step optimization momentarily @torch.no_grad()
def first_step(self, zero_grad=False, mixed_precision=False):
with autocast() if mixed_precision else do_nothing_context_mgr():
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None:
continue
self.state[p]["old_p"] = p.data.clone()
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"
if zero_grad:
self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False, mixed_precision=False):
with autocast() if mixed_precision else do_nothing_context_mgr():
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
@torch.no_grad()
def step(self, closure=None):
self.base_optimizer.step(closure) Using this pytorch tutorial, the proposed solution goes as follows def train(
args, model, device, train_loader, optimizer, first_step_scaler, second_step_scaler, epoch
):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
enable_running_stats(model)
# First forward step
with autocast():
output = model(data)
loss = F.nll_loss(output, target)
first_step_scaler.scale(loss).backward()
# We unscale manually for two reasons: (1) SAM's first-step adds the gradient
# to weights directly. So gradient must be unscaled; (2) unscale_ checks if any
# gradient is inf and updates optimizer_state["found_inf_per_device"] accordingly.
# We use optimizer_state["found_inf_per_device"] to decide whether to apply
# SAM's first-step or not.
first_step_scaler.unscale_(optimizer)
optimizer_state = first_step_scaler._per_optimizer_states[id(optimizer)]
# Check if any gradients are inf/nan
inf_grad_cnt = sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
if inf_grad_cnt == 0:
# if valid graident, apply sam_first_step
optimizer.first_step(zero_grad=True, mixed_precision=True)
sam_first_step_applied = True
else:
# if invalid graident, skip sam and revert to single optimization step
optimizer.zero_grad()
sam_first_step_applied = False
# Update the scaler with no impact on the model (weights or gradient). This update step
# resets the optimizer_state["found_inf_per_device"]. So, it is applied after computing
# inf_grad_cnt. Note that zero_grad() has no impact on the update() operation,
# because update() leverage optimizer_state["found_inf_per_device"]
first_step_scaler.update()
disable_running_stats(model)
# Second forward step
with autocast():
output = model(data)
loss = F.nll_loss(output, target)
second_step_scaler.scale(loss).backward()
if sam_first_step_applied:
# If sam_first_step was applied, apply the 2nd step
optimizer.second_step(mixed_precision=True)
second_step_scaler.step(optimizer)
second_step_scaler.update() where base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
first_step_scaler = GradScaler(2 ** 8) # A small scaler_init acts as a warmup
second_step_scaler = GradScaler(2 ** 8) # A small scaler_init acts as a warmup How is this tested?
What is the main catch?
Why the network enters this state in the first place?
One thing I don't like about the proposed solution is that it is verbose. I wish someone propose a concise solution. I found the following resources helpful while investigating this issue. |
Can it work well? |
@ahmdtaha |
@alibalapour @ahmdtaha have you found out how to implement the gradient accumulation in the code? |
Thanks for the work!
I'd like to know if the the original code is also applicable to accelerated training, i.e. using automatic mixed precision like fp16. I tried to adopt SAM in my own training codes with apex fp16, but Nan loss happens and the computed grad norm is Nan. When I switch to fp32, it goes on well. Is it incompatible with fp16? What are the suggestions to make the code work with fp16? Thanks!
The text was updated successfully, but these errors were encountered: