Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable smart weight loading #2758

Merged
merged 11 commits into from
Jan 8, 2024
1 change: 1 addition & 0 deletions src/otx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TrainConfig:
debug: Optional[str]
train: bool
test: bool
resume: bool = False

seed: Optional[int] = None
checkpoint: Optional[str] = None
Expand Down
23 changes: 20 additions & 3 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from __future__ import annotations

import logging as log
from dataclasses import dataclass
from typing import TYPE_CHECKING

from datumaro import Dataset as DmDataset
from datumaro.components.annotation import AnnotationType
from lightning import LightningDataModule
from torch.utils.data import DataLoader

Expand All @@ -24,6 +26,21 @@
from .dataset.base import OTXDataset


@dataclass
class DataMetaInfo:
sungmanc marked this conversation as resolved.
Show resolved Hide resolved
"""Meta information of OTXDataModule.

This meta information will be used by OTXLitModule.
"""

class_names: list[str]

@property
def num_classes(self) -> int:
"""Return number of classes."""
return len(self.class_names)


class OTXDataModule(LightningDataModule):
"""LightningDataModule extension for OTX pipeline."""

Expand All @@ -46,9 +63,9 @@ def __init__(

VIDEO_EXTENSIONS.append(".mp4")

dataset = DmDataset.import_from(
self.config.data_root,
format=self.config.data_format,
dataset = DmDataset.import_from(self.config.data_root, format=self.config.data_format)
self.meta_info = DataMetaInfo(
class_names=[category.name for category in dataset.categories()[AnnotationType.label]],
)

config_mapping = {
Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/engine/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ def test(cfg: TrainConfig) -> tuple[Trainer, dict[str, Any]]:

log.info(f"Instantiating datamodule <{cfg.data}>")
datamodule = OTXDataModule(task=cfg.base.task, config=cfg.data)
data_meta_info = datamodule.meta_info

log.info(f"Instantiating model <{cfg.model}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
model.meta_info = data_meta_info

log.info(f"Instantiating trainer <{cfg.trainer}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer)
Expand Down
13 changes: 12 additions & 1 deletion src/otx/core/engine/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def train(
Returns:
A tuple with Pytorch Lightning Trainer and Python dict of metrics
"""
import torch
from lightning import seed_everything

from otx.core.data.module import OTXDataModule
Expand All @@ -69,9 +70,11 @@ def train(

log.info(f"Instantiating datamodule <{cfg.data}>")
datamodule = OTXDataModule(task=cfg.base.task, config=cfg.data)
data_meta_info = datamodule.meta_info

log.info(f"Instantiating model <{cfg.model}>")
model: OTXLitModule = hydra.utils.instantiate(cfg.model)
model.meta_info = data_meta_info
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved

if otx_model is not None:
if not isinstance(otx_model, OTXModel):
Expand Down Expand Up @@ -108,7 +111,15 @@ def train(

if cfg.train:
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.checkpoint)
if cfg.resume:
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.checkpoint)
else:
# load weight
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
if cfg.checkpoint is not None:
loaded_checkpoint = torch.load(cfg.checkpoint)
model.load_state_dict(loaded_checkpoint["state_dict"])
# train
trainer.fit(model=model, datamodule=datamodule)
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved

train_metrics = trainer.callback_metrics

Expand Down
51 changes: 50 additions & 1 deletion src/otx/core/model/entity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Generic
from typing import TYPE_CHECKING, Any, Generic

from torch import nn

Expand All @@ -16,6 +16,9 @@
T_OTXBatchPredEntity,
)

if TYPE_CHECKING:
import torch


class OTXModel(nn.Module, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the models used in OTX."""
Expand Down Expand Up @@ -57,3 +60,49 @@ def forward(
if self._customize_outputs != OTXModel._customize_outputs
else outputs
)

def register_load_state_dict_pre_hook(self, model_classes: list[str], ckpt_classes: list[str]) -> None:
"""Register load_state_dict_pre_hook.

Args:
model_classes (list[str]): Class names from training data.
ckpt_classes (list[str]): Class names from checkpoint state dictionary.
"""
self.model_classes = model_classes
self.ckpt_classes = ckpt_classes
self._register_load_state_dict_pre_hook(self.load_state_dict_pre_hook)

def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None:
"""Modify input state_dict according to class name matching before weight loading."""
model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes)

for param_name in self.state_dict():
model_param = self.state_dict()[param_name].clone()
ckpt_param = state_dict[prefix + param_name]
if model_param.shape != ckpt_param.shape:
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
for model_t, ckpt_t in enumerate(model2ckpt):
if ckpt_t >= 0:
model_param[model_t].copy_(ckpt_param[ckpt_t])

# Replace checkpoint weight by mixed weights
state_dict[prefix + param_name] = model_param

@staticmethod
def map_class_names(src_classes: list[str], dst_classes: list[str]) -> list[int]:
"""Computes src to dst index mapping.

src2dst[src_idx] = dst_idx
# according to class name matching, -1 for non-matched ones
assert(len(src2dst) == len(src_classes))
ex)
src_classes = ['person', 'car', 'tree']
dst_classes = ['tree', 'person', 'sky', 'ball']
-> Returns src2dst = [1, -1, 0]
"""
src2dst = []
for src_class in src_classes:
if src_class in dst_classes:
src2dst.append(dst_classes.index(src_class))
else:
src2dst.append(-1)
return src2dst
54 changes: 54 additions & 0 deletions src/otx/core/model/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
"""Class definition for base lightning module used in OTX."""
from __future__ import annotations

import logging
from typing import Any

import torch
from lightning import LightningModule
from torch import Tensor

from otx.core.data.entity.base import OTXBatchDataEntity
from otx.core.data.module import DataMetaInfo
from otx.core.model.entity.base import OTXModel


Expand All @@ -30,6 +32,7 @@ def __init__(
self.optimizer = optimizer
self.scheduler = scheduler
self.torch_compile = torch_compile
self._meta_info: DataMetaInfo | None = None

# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
Expand Down Expand Up @@ -105,7 +108,58 @@ def configure_optimizers(self) -> dict[str, Any]:
}
return {"optimizer": optimizer}

def register_load_state_dict_pre_hook(self, model_classes: list[str], ckpt_classes: list[str]) -> None:
"""Register self.model's load_state_dict_pre_hook.

Args:
model_classes (list[str]): Class names from training data.
ckpt_classes (list[str]): Class names from checkpoint state dictionary.
"""
self.model.register_load_state_dict_pre_hook(model_classes, ckpt_classes)

def state_dict(self) -> dict[str, Any]:
"""Return state dictionary of model entity with meta information.

Returns:
A dictionary containing datamodule state.

"""
state_dict = super().state_dict()
state_dict["meta_info"] = self.meta_info
return state_dict

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load state dictionary from checkpoint state dictionary.

If checkpoint's meta_info and OTXLitModule's meta_info are different,
load_state_pre_hook for smart weight loading will be registered.
"""
ckpt_meta_info = state_dict.pop("meta_info", None)
if ckpt_meta_info and self.meta_info and ckpt_meta_info != self.meta_info:
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
logger = logging.getLogger()
logger.info(
f"Data classes from checkpoint: {ckpt_meta_info.class_names} -> "
f"Data classes from training data: {self.meta_info.class_names}",
)
self.register_load_state_dict_pre_hook(
self.meta_info.class_names,
ckpt_meta_info.class_names,
)
super().load_state_dict(state_dict)

@property
def lr_scheduler_monitor_key(self) -> str:
"""Metric name that the learning rate scheduler monitor."""
return "val/loss"

@property
def meta_info(self) -> DataMetaInfo:
"""Meta information of OTXLitModule."""
if self._meta_info is None:
err_msg = "meta_info is referenced before assignment"
raise ValueError(err_msg)
return self._meta_info

@meta_info.setter
def meta_info(self, meta_info: DataMetaInfo) -> None:
self._meta_info = meta_info
40 changes: 40 additions & 0 deletions tests/unit/core/model/entity/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from otx.core.model.entity.base import OTXModel


class MockNNModule(torch.nn.Module):
def __init__(self, num_classes):
super().__init__()
self.backbone = torch.nn.Linear(3, 1024)
self.head = torch.nn.Linear(3, num_classes)


class TestOTXModel:
def test_smart_weight_loading(self, mocker) -> None:
mocker.patch.object(OTXModel, "_create_model", return_value=MockNNModule(2))
prev_model = OTXModel()

mocker.patch.object(OTXModel, "_create_model", return_value=MockNNModule(3))
current_model = OTXModel()

prev_classes = ["car", "truck"]
current_classes = ["car", "bus", "truck"]
indices = torch.Tensor([0, 2]).to(torch.int32)

current_model.register_load_state_dict_pre_hook(current_classes, prev_classes)
current_model.load_state_dict(prev_model.state_dict())

assert torch.all(
current_model.state_dict()["model.backbone.weight"] == prev_model.state_dict()["model.backbone.weight"],
)
assert torch.all(
current_model.state_dict()["model.backbone.bias"] == prev_model.state_dict()["model.backbone.bias"],
)
assert torch.all(
current_model.state_dict()["model.head.weight"].index_select(0, indices)
== prev_model.state_dict()["model.head.weight"],
)
assert torch.all(
current_model.state_dict()["model.head.bias"].index_select(0, indices)
== prev_model.state_dict()["model.head.bias"],
)