Skip to content

Commit

Permalink
reset scaler (#1999)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Feb 24, 2023
1 parent 850e34d commit 0926cbf
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import re
import time
import warnings
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Sequence, TextIO, Tuple, Union, cast
Expand All @@ -24,7 +25,7 @@
import torch.distributed
import torch.nn as nn
import torch.utils.data
from torch.cuda.amp.grad_scaler import GradScaler
from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torchmetrics import Metric
Expand Down Expand Up @@ -257,6 +258,8 @@ def _adjust_grad_accum(state: State, device_batch_size: int):
del state.loss
for optimizer in state.optimizers:
optimizer.zero_grad(set_to_none=True)
if state.scaler is not None:
state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
torch.cuda.empty_cache()


Expand Down Expand Up @@ -285,6 +288,8 @@ def _adjust_device_train_microbatch_size(state: State):
del state.loss
for optimizer in state.optimizers:
optimizer.zero_grad(set_to_none=True)
if state.scaler is not None:
state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
torch.cuda.empty_cache()


Expand Down

0 comments on commit 0926cbf

Please sign in to comment.