Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable smart weight loading #2758

Merged
merged 11 commits into from
Jan 8, 2024
1 change: 1 addition & 0 deletions src/otx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TrainConfig:
debug: Optional[str]
train: bool
test: bool
resume: bool = False

seed: Optional[int] = None
checkpoint: Optional[str] = None
Expand Down
23 changes: 20 additions & 3 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from __future__ import annotations

import logging as log
from dataclasses import dataclass
from typing import TYPE_CHECKING

from datumaro import Dataset as DmDataset
from datumaro.components.annotation import AnnotationType
from lightning import LightningDataModule
from torch.utils.data import DataLoader

Expand All @@ -24,6 +26,21 @@
from .dataset.base import OTXDataset


@dataclass
class DataMetaInfo:
sungmanc marked this conversation as resolved.
Show resolved Hide resolved
"""Meta information of OTXDataModule.

This meta information will be used by OTXLitModule.
"""

class_names: list[str]

@property
def num_classes(self) -> int:
"""Return number of classes."""
return len(self.class_names)


class OTXDataModule(LightningDataModule):
"""LightningDataModule extension for OTX pipeline."""

Expand All @@ -46,9 +63,9 @@ def __init__(

VIDEO_EXTENSIONS.append(".mp4")

dataset = DmDataset.import_from(
self.config.data_root,
format=self.config.data_format,
dataset = DmDataset.import_from(self.config.data_root, format=self.config.data_format)
self.meta_info = DataMetaInfo(
class_names=[category.name for category in dataset.categories()[AnnotationType.label]],
)

config_mapping = {
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/engine/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test(cfg: TrainConfig) -> tuple[Trainer, dict[str, Any]]:

log.info(f"Instantiating model <{cfg.model}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
model.meta_info = datamodule.meta_info

log.info(f"Instantiating trainer <{cfg.trainer}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer)
Expand Down
13 changes: 11 additions & 2 deletions src/otx/core/engine/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def train(
Returns:
A tuple with Pytorch Lightning Trainer and Python dict of metrics
"""
import torch
from lightning import seed_everything

from otx.core.data.module import OTXDataModule
Expand All @@ -69,9 +70,9 @@ def train(

log.info(f"Instantiating datamodule <{cfg.data}>")
datamodule = OTXDataModule(task=cfg.base.task, config=cfg.data)

log.info(f"Instantiating model <{cfg.model}>")
model: OTXLitModule = hydra.utils.instantiate(cfg.model)
model.meta_info = datamodule.meta_info

if otx_model is not None:
if not isinstance(otx_model, OTXModel):
Expand Down Expand Up @@ -108,7 +109,15 @@ def train(

if cfg.train:
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.checkpoint)
if cfg.resume:
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.checkpoint)
else:
# load weight to finetune the model
if cfg.checkpoint is not None:
loaded_checkpoint = torch.load(cfg.checkpoint)
model.load_state_dict(loaded_checkpoint["state_dict"])
# train
trainer.fit(model=model, datamodule=datamodule)

train_metrics = trainer.callback_metrics

Expand Down
3 changes: 2 additions & 1 deletion src/otx/core/model/entity/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model
from otx.core.utils.build import build_mm_model, get_classification_layers
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand Down Expand Up @@ -57,6 +57,7 @@ def device(self) -> device:
else:
return buf.device

self.classification_layers = get_classification_layers(self.config, MODELS, "model.")
return build_mm_model(self.config, MODELS, self.load_from)

def _customize_inputs(self, entity: ActionClsBatchDataEntity) -> dict[str, Any]:
Expand Down
3 changes: 2 additions & 1 deletion src/otx/core/model/entity/action_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from otx.core.data.entity.action_detection import ActionDetBatchDataEntity, ActionDetBatchPredEntity
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model
from otx.core.utils.build import build_mm_model, get_classification_layers

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand Down Expand Up @@ -56,6 +56,7 @@ def device(self) -> device:
else:
return buf.device

self.classification_layers = get_classification_layers(self.config, MODELS, "model.")
return build_mm_model(self.config, MODELS, self.load_from)

def _customize_inputs(self, entity: ActionDetBatchDataEntity) -> dict[str, Any]:
Expand Down
51 changes: 50 additions & 1 deletion src/otx/core/model/entity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Generic
from typing import TYPE_CHECKING, Any, Generic

from torch import nn

Expand All @@ -16,12 +16,16 @@
T_OTXBatchPredEntity,
)

if TYPE_CHECKING:
import torch


class OTXModel(nn.Module, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the models used in OTX."""

def __init__(self) -> None:
super().__init__()
self.classification_layers: list[str] = []
self.model = self._create_model()

@abstractmethod
Expand Down Expand Up @@ -57,3 +61,48 @@ def forward(
if self._customize_outputs != OTXModel._customize_outputs
else outputs
)

def register_load_state_dict_pre_hook(self, model_classes: list[str], ckpt_classes: list[str]) -> None:
"""Register load_state_dict_pre_hook.

Args:
model_classes (list[str]): Class names from training data.
ckpt_classes (list[str]): Class names from checkpoint state dictionary.
"""
self.model_classes = model_classes
self.ckpt_classes = ckpt_classes
self._register_load_state_dict_pre_hook(self.load_state_dict_pre_hook)

def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None:
"""Modify input state_dict according to class name matching before weight loading."""
model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes)

for param_name in self.classification_layers:
model_param = self.state_dict()[param_name].clone()
ckpt_param = state_dict[prefix + param_name]
for model_t, ckpt_t in enumerate(model2ckpt):
if ckpt_t >= 0:
model_param[model_t].copy_(ckpt_param[ckpt_t])

# Replace checkpoint weight by mixed weights
state_dict[prefix + param_name] = model_param

@staticmethod
def map_class_names(src_classes: list[str], dst_classes: list[str]) -> list[int]:
"""Computes src to dst index mapping.

src2dst[src_idx] = dst_idx
# according to class name matching, -1 for non-matched ones
assert(len(src2dst) == len(src_classes))
ex)
src_classes = ['person', 'car', 'tree']
dst_classes = ['tree', 'person', 'sky', 'ball']
-> Returns src2dst = [1, -1, 0]
"""
src2dst = []
for src_class in src_classes:
if src_class in dst_classes:
src2dst.append(dst_classes.index(src_class))
else:
src2dst.append(-1)
return src2dst
13 changes: 9 additions & 4 deletions src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MultilabelClsBatchPredEntity,
)
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model
from otx.core.utils.build import build_mm_model, get_classification_layers

if TYPE_CHECKING:
from mmpretrain.models.utils import ClsDataPreprocessor
Expand Down Expand Up @@ -46,7 +46,8 @@ def device(self) -> device:
else:
return buf.device

return build_mm_model(config, MODELS, load_from)
classification_layers = get_classification_layers(config, MODELS, "model.")
return build_mm_model(config, MODELS, load_from), classification_layers
sungmanc marked this conversation as resolved.
Show resolved Hide resolved


class MMPretrainMulticlassClsModel(OTXMulticlassClsModel):
Expand All @@ -63,7 +64,9 @@ def __init__(self, config: DictConfig) -> None:
super().__init__()

def _create_model(self) -> nn.Module:
return _create_mmpretrain_model(self.config, self.load_from)
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
self.classification_layers = classification_layers
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
return model

def _customize_inputs(self, entity: MulticlassClsBatchDataEntity) -> dict[str, Any]:
from mmpretrain.structures import DataSample
Expand Down Expand Up @@ -153,7 +156,9 @@ def __init__(self, config: DictConfig) -> None:
super().__init__()

def _create_model(self) -> nn.Module:
return _create_mmpretrain_model(self.config, self.load_from)
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
self.classification_layers = classification_layers
return model

def _customize_inputs(self, entity: MultilabelClsBatchDataEntity) -> dict[str, Any]:
from mmpretrain.structures import DataSample
Expand Down
3 changes: 2 additions & 1 deletion src/otx/core/model/entity/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model
from otx.core.utils.build import build_mm_model, get_classification_layers

if TYPE_CHECKING:
from mmdet.models.data_preprocessors import DetDataPreprocessor
Expand Down Expand Up @@ -58,6 +58,7 @@ def device(self) -> device:
else:
return buf.device

self.classification_layers = get_classification_layers(self.config, MODELS, "model.")
return build_mm_model(self.config, MODELS, self.load_from)

def _customize_inputs(self, entity: DetBatchDataEntity) -> dict[str, Any]:
Expand Down
3 changes: 2 additions & 1 deletion src/otx/core/model/entity/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
InstanceSegBatchPredEntity,
)
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model
from otx.core.utils.build import build_mm_model, get_classification_layers

if TYPE_CHECKING:
from mmdet.models.data_preprocessors import DetDataPreprocessor
Expand Down Expand Up @@ -59,6 +59,7 @@ def device(self) -> device:
else:
return buf.device

self.classification_layers = get_classification_layers(self.config, MODELS, "model.")
return build_mm_model(self.config, MODELS, self.load_from)

def _customize_inputs(self, entity: InstanceSegBatchDataEntity) -> dict[str, Any]:
Expand Down
3 changes: 2 additions & 1 deletion src/otx/core/model/entity/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model
from otx.core.utils.build import build_mm_model, get_classification_layers

if TYPE_CHECKING:
from mmseg.models.data_preprocessor import SegDataPreProcessor
Expand Down Expand Up @@ -53,6 +53,7 @@ def device(self) -> device:
else:
return buf.device

self.classification_layers = get_classification_layers(self.config, MODELS, "model.")
return build_mm_model(self.config, MODELS, self.load_from)

def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]:
Expand Down
61 changes: 61 additions & 0 deletions src/otx/core/model/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
"""Class definition for base lightning module used in OTX."""
from __future__ import annotations

import logging
import warnings
from typing import Any

import torch
from lightning import LightningModule
from torch import Tensor

from otx.core.data.entity.base import OTXBatchDataEntity
from otx.core.data.module import DataMetaInfo
from otx.core.model.entity.base import OTXModel


Expand All @@ -30,6 +33,7 @@ def __init__(
self.optimizer = optimizer
self.scheduler = scheduler
self.torch_compile = torch_compile
self._meta_info: DataMetaInfo | None = None

# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
Expand Down Expand Up @@ -105,7 +109,64 @@ def configure_optimizers(self) -> dict[str, Any]:
}
return {"optimizer": optimizer}

def register_load_state_dict_pre_hook(self, model_classes: list[str], ckpt_classes: list[str]) -> None:
"""Register self.model's load_state_dict_pre_hook.

Args:
model_classes (list[str]): Class names from training data.
ckpt_classes (list[str]): Class names from checkpoint state dictionary.
"""
self.model.register_load_state_dict_pre_hook(model_classes, ckpt_classes)

def state_dict(self) -> dict[str, Any]:
"""Return state dictionary of model entity with meta information.

Returns:
A dictionary containing datamodule state.

"""
state_dict = super().state_dict()
state_dict["meta_info"] = self.meta_info
return state_dict

def load_state_dict(self, state_dict: dict[str, Any], *args, **kwargs) -> None:
"""Load state dictionary from checkpoint state dictionary.

If checkpoint's meta_info and OTXLitModule's meta_info are different,
load_state_pre_hook for smart weight loading will be registered.
"""
ckpt_meta_info = state_dict.pop("meta_info", None)
if ckpt_meta_info and self.meta_info is None:
msg = (
"`state_dict` to load has `meta_info`, but the current model has no `meta_info`. "
"It is recommended to set proper `meta_info` for the incremental learning case."
)
warnings.warn(msg, stacklevel=2)
if ckpt_meta_info and self.meta_info and ckpt_meta_info != self.meta_info:
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
logger = logging.getLogger()
logger.info(
f"Data classes from checkpoint: {ckpt_meta_info.class_names} -> "
f"Data classes from training data: {self.meta_info.class_names}",
)
self.register_load_state_dict_pre_hook(
self.meta_info.class_names,
ckpt_meta_info.class_names,
)
return super().load_state_dict(state_dict, *args, **kwargs)

@property
def lr_scheduler_monitor_key(self) -> str:
"""Metric name that the learning rate scheduler monitor."""
return "val/loss"

@property
def meta_info(self) -> DataMetaInfo:
"""Meta information of OTXLitModule."""
if self._meta_info is None:
err_msg = "meta_info is referenced before assignment"
raise ValueError(err_msg)
return self._meta_info

@meta_info.setter
def meta_info(self, meta_info: DataMetaInfo) -> None:
self._meta_info = meta_info
Loading