From 818297aac40ef6ee7a57152dd2405e3a2a274616 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kun=C3=A1k?= <38215643+Adamusen@users.noreply.github.com> Date: Tue, 7 Jan 2025 14:52:42 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=A9=B9=20[Fix]=20load=20new=20checkpoints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Modified save_load_weights function to properly load weights from both old .pt and new .ckpt files. --- yolo/model/yolo.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index cc9ce20..8d463c1 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -127,29 +127,35 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): """ if isinstance(weights, Path): weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False) - if "model_state_dict" in weights: + if "model_state_dict" in weights: #.pt weights = weights["model_state_dict"] + elif "state_dict" in weights: #.ckpt + weights = weights["state_dict"] - model_state_dict = self.model.state_dict() + model_state_dict = self.state_dict() - # TODO1: autoload old version weight # TODO2: weight transform if num_class difference error_dict = {"Mismatch": set(), "Not Found": set()} for model_key, model_weight in model_state_dict.items(): - if model_key not in weights: + weights_key = model_key + if weights_key not in weights: #.ckpt + weights_key = "model." + model_key + if weights_key not in weights: #.pt old + weights_key = model_key.removeprefix("model.") + if weights_key not in weights: error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) continue - if model_weight.shape != weights[model_key].shape: + if model_weight.shape != weights[weights_key].shape: error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2])) continue - model_state_dict[model_key] = weights[model_key] + model_state_dict[model_key] = weights[weights_key] for error_name, error_set in error_dict.items(): for weight_name in error_set: logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}") - self.model.load_state_dict(model_state_dict) + self.load_state_dict(model_state_dict) def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: