diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index cc9ce20..8d463c1 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -127,29 +127,35 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): """ if isinstance(weights, Path): weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False) - if "model_state_dict" in weights: + if "model_state_dict" in weights: #.pt weights = weights["model_state_dict"] + elif "state_dict" in weights: #.ckpt + weights = weights["state_dict"] - model_state_dict = self.model.state_dict() + model_state_dict = self.state_dict() - # TODO1: autoload old version weight # TODO2: weight transform if num_class difference error_dict = {"Mismatch": set(), "Not Found": set()} for model_key, model_weight in model_state_dict.items(): - if model_key not in weights: + weights_key = model_key + if weights_key not in weights: #.ckpt + weights_key = "model." + model_key + if weights_key not in weights: #.pt old + weights_key = model_key.removeprefix("model.") + if weights_key not in weights: error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) continue - if model_weight.shape != weights[model_key].shape: + if model_weight.shape != weights[weights_key].shape: error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2])) continue - model_state_dict[model_key] = weights[model_key] + model_state_dict[model_key] = weights[weights_key] for error_name, error_set in error_dict.items(): for weight_name in error_set: logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}") - self.model.load_state_dict(model_state_dict) + self.load_state_dict(model_state_dict) def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: