From 1d404e285dc85458c1e4590dc67206ca6ca1ec33 Mon Sep 17 00:00:00 2001 From: henrytsui000 Date: Sat, 23 Nov 2024 02:27:45 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20[Update]=20EMA=20with=20sync=20s?= =?UTF-8?q?tate=5Fdict?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/utils/model_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 3c79915..0c85917 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -37,31 +37,31 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1): class EMA(Callback): - def __init__(self, decay: float = 0.9999, tau: float = 500): + def __init__(self, decay: float = 0.9999, tau: float = 2000): super().__init__() logger.info(":chart_with_upwards_trend: Enable Model EMA") self.decay = decay self.tau = tau self.step = 0 + self.ema_state_dict = None def setup(self, trainer, pl_module, stage): pl_module.ema = deepcopy(pl_module.model) - self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()] + self.tau /= trainer.world_size for param in pl_module.ema.parameters(): param.requires_grad = False def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"): - for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters): - param.data.copy_(ema_param) - trainer.strategy.broadcast(param) + if self.ema_state_dict is None: + self.ema_state_dict = deepcopy(pl_module.model.state_dict()) + pl_module.ema.load_state_dict(self.ema_state_dict) - @rank_zero_only @no_grad() def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None: self.step += 1 decay_factor = self.decay * (1 - exp(-self.step / self.tau)) - for param, ema_param in zip(pl_module.parameters(), self.ema_parameters): - ema_param.data.copy_(lerp(param.detach(), ema_param, decay_factor)) + for key, param in pl_module.model.state_dict().items(): + self.ema_state_dict[key] = lerp(param.detach(), self.ema_state_dict[key], decay_factor) def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer: