Skip to content

Commit

Permalink
🔧 [Update] EMA with sync state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytsui000 committed Nov 22, 2024
1 parent 89a6526 commit 1d404e2
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions yolo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,31 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1):


class EMA(Callback):
def __init__(self, decay: float = 0.9999, tau: float = 500):
def __init__(self, decay: float = 0.9999, tau: float = 2000):
super().__init__()
logger.info(":chart_with_upwards_trend: Enable Model EMA")
self.decay = decay
self.tau = tau
self.step = 0
self.ema_state_dict = None

def setup(self, trainer, pl_module, stage):
pl_module.ema = deepcopy(pl_module.model)
self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()]
self.tau /= trainer.world_size
for param in pl_module.ema.parameters():
param.requires_grad = False

def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
param.data.copy_(ema_param)
trainer.strategy.broadcast(param)
if self.ema_state_dict is None:
self.ema_state_dict = deepcopy(pl_module.model.state_dict())
pl_module.ema.load_state_dict(self.ema_state_dict)

@rank_zero_only
@no_grad()
def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
self.step += 1
decay_factor = self.decay * (1 - exp(-self.step / self.tau))
for param, ema_param in zip(pl_module.parameters(), self.ema_parameters):
ema_param.data.copy_(lerp(param.detach(), ema_param, decay_factor))
for key, param in pl_module.model.state_dict().items():
self.ema_state_dict[key] = lerp(param.detach(), self.ema_state_dict[key], decay_factor)


def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
Expand Down

0 comments on commit 1d404e2

Please sign in to comment.