Skip to content

Commit

Permalink
🩹 [Fix] load new checkpoints
Browse files Browse the repository at this point in the history
Modified save_load_weights function to properly load weights from both old .pt and new .ckpt files.
  • Loading branch information
Adamusen authored Jan 7, 2025
1 parent fa548df commit 818297a
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions yolo/model/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,29 +127,35 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
"""
if isinstance(weights, Path):
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
if "model_state_dict" in weights:
if "model_state_dict" in weights: #.pt
weights = weights["model_state_dict"]
elif "state_dict" in weights: #.ckpt
weights = weights["state_dict"]

model_state_dict = self.model.state_dict()
model_state_dict = self.state_dict()

# TODO1: autoload old version weight
# TODO2: weight transform if num_class difference

error_dict = {"Mismatch": set(), "Not Found": set()}
for model_key, model_weight in model_state_dict.items():
if model_key not in weights:
weights_key = model_key
if weights_key not in weights: #.ckpt
weights_key = "model." + model_key
if weights_key not in weights: #.pt old
weights_key = model_key.removeprefix("model.")
if weights_key not in weights:
error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
continue
if model_weight.shape != weights[model_key].shape:
if model_weight.shape != weights[weights_key].shape:
error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
continue
model_state_dict[model_key] = weights[model_key]
model_state_dict[model_key] = weights[weights_key]

for error_name, error_set in error_dict.items():
for weight_name in error_set:
logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")

self.model.load_state_dict(model_state_dict)
self.load_state_dict(model_state_dict)


def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
Expand Down

0 comments on commit 818297a

Please sign in to comment.