From 16903540520cc7a404b5a24f3ccf7d7e144878af Mon Sep 17 00:00:00 2001 From: henrytsui000 Date: Thu, 21 Nov 2024 14:47:47 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=B8=20[Add]=20try-except=20in=20loadin?= =?UTF-8?q?g=20cache=20files?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/config/task/validation.yaml | 2 +- yolo/tools/data_loader.py | 10 +++++++++- yolo/tools/solver.py | 4 ++-- yolo/utils/dataset_utils.py | 9 ++++++++- yolo/utils/model_utils.py | 2 ++ 5 files changed, 22 insertions(+), 5 deletions(-) diff --git a/yolo/config/task/validation.yaml b/yolo/config/task/validation.yaml index bb30d3d..80f50dd 100644 --- a/yolo/config/task/validation.yaml +++ b/yolo/config/task/validation.yaml @@ -7,7 +7,7 @@ data: shuffle: False pin_memory: True data_augment: {} - dynamic_shape: True + dynamic_shape: False nms: min_confidence: 0.0001 min_iou: 0.7 diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index cf0782b..568f50a 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -56,7 +56,15 @@ def load_data(self, dataset_path: Path, phase_name: str): data = self.filter_data(dataset_path, phase_name, self.dynamic_shape) torch.save(data, cache_path) else: - data = torch.load(cache_path, weights_only=False) + try: + data = torch.load(cache_path, weights_only=False) + except Exception as e: + logger.error( + f":rotating_light: Failed to load the cache at '{cache_path}'.\n" + ":rotating_light: This may be caused by using cache from different other YOLO.\n" + ":rotating_light: Please clean the cache and try running again." + ) + raise e logger.info(f":package: Loaded {phase_name} cache") return data diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index b8c777f..8b7e056 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -56,7 +56,6 @@ def validation_step(self, batch, batch_idx): "map": batch_metrics["map"], "map_50": batch_metrics["map_50"], }, - on_step=True, batch_size=batch_size, ) return predicts @@ -102,9 +101,10 @@ def training_step(self, batch, batch_idx): prog_bar=True, on_epoch=True, batch_size=batch_size, + sync_dist=True, rank_zero_only=True, ) - self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True) + self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, sync_dist=True, rank_zero_only=True) return loss * batch_size def configure_optimizers(self): diff --git a/yolo/utils/dataset_utils.py b/yolo/utils/dataset_utils.py index f659090..dad7a78 100644 --- a/yolo/utils/dataset_utils.py +++ b/yolo/utils/dataset_utils.py @@ -115,7 +115,14 @@ def scale_segmentation( def tensorlize(data): - img_paths, bboxes, img_ratios = zip(*data) + try: + img_paths, bboxes, img_ratios = zip(*data) + except ValueError as e: + logger.error( + ":rotating_light: This may be caused by using old cache or another version of YOLO's cache.\n" + ":rotating_light: Please clean the cache and try running again." + ) + raise e max_box = max(bbox.size(0) for bbox in bboxes) padded_bbox_list = [] for bbox in bboxes: diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 1b4b4c8..17ad606 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -47,6 +47,8 @@ def __init__(self, decay: float = 0.9999, tau: float = 500): def setup(self, trainer, pl_module, stage): pl_module.ema = deepcopy(pl_module.model) self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()] + for param in pl_module.ema.parameters(): + param.requires_grad = False def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"): for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):