From 89a65260abc77e238f61b1e227564c1da98e84a7 Mon Sep 17 00:00:00 2001 From: henrytsui000 Date: Thu, 21 Nov 2024 15:18:47 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=91=EF=B8=8F=20[Fix]=20broadcast=20of?= =?UTF-8?q?=20EMA=20and=20sync=5Fdist=20only=20in=20val?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/tools/solver.py | 9 +++++---- yolo/utils/model_utils.py | 3 +-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 8b7e056..a867c76 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -63,9 +63,11 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): epoch_metrics = self.metric.compute() del epoch_metrics["classes"] - self.log_dict(epoch_metrics, prog_bar=True, rank_zero_only=True) + self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True) self.log_dict( - {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, rank_zero_only=True + {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, + sync_dist=True, + rank_zero_only=True, ) self.metric.reset() @@ -101,10 +103,9 @@ def training_step(self, batch, batch_idx): prog_bar=True, on_epoch=True, batch_size=batch_size, - sync_dist=True, rank_zero_only=True, ) - self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, sync_dist=True, rank_zero_only=True) + self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True) return loss * batch_size def configure_optimizers(self): diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 17ad606..3c79915 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -53,8 +53,7 @@ def setup(self, trainer, pl_module, stage): def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"): for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters): param.data.copy_(ema_param) - if dist.is_initialized(): - dist.broadcast(param, src=0) + trainer.strategy.broadcast(param) @rank_zero_only @no_grad()