Skip to content

Commit

Permalink
Fix a bug that dino_v2 model can't be run w/ HPO (#3518)
Browse files Browse the repository at this point in the history
* add reduce function to dino backbone

* add unit test

* update integration test

* change name
  • Loading branch information
eunwoosh authored May 20, 2024
1 parent c6715f7 commit f4040de
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 21 deletions.
5 changes: 5 additions & 0 deletions src/otx/algo/classification/dino_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from otx.core.model.classification import OTXMulticlassClsModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import LabelInfoTypes
from otx.utils.utils import get_class_initial_arguments

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand All @@ -46,6 +47,7 @@ def __init__(
num_classes: int,
):
super().__init__()
self._init_args = get_class_initial_arguments()
self.backbone = torch.hub.load(
repo_or_dir="facebookresearch/dinov2",
model=backbone,
Expand Down Expand Up @@ -75,6 +77,9 @@ def forward(self, imgs: torch.Tensor, labels: torch.Tensor | None = None, **kwar
return self.loss(logits, labels)
return self.softmax(logits)

def __reduce__(self):
return (DINOv2, self._init_args)


class DINOv2RegisterClassifier(OTXMulticlassClsModel):
"""DINO-v2 Classification Model with register."""
Expand Down
5 changes: 5 additions & 0 deletions src/otx/algo/segmentation/backbones/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from otx.algo.modules.base_module import BaseModule
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http
from otx.utils.utils import get_class_initial_arguments


class DinoVisionTransformer(BaseModule):
Expand All @@ -27,6 +28,7 @@ def __init__(
pretrained_weights: str | None = None,
):
super().__init__(init_cfg)
self._init_args = get_class_initial_arguments()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # noqa: SLF001, ARG005
self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=name)
if freeze_backbone:
Expand Down Expand Up @@ -70,3 +72,6 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "
print(f"init weight - {pretrained}")
if checkpoint is not None:
load_checkpoint_to_model(self, checkpoint, prefix=prefix)

def __reduce__(self):
return (DinoVisionTransformer, self._init_args)
10 changes: 10 additions & 0 deletions src/otx/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,13 @@ def can_pass_tile_config(model_cls: type[OTXModel]) -> bool:
"""
tile_config_param = inspect.signature(model_cls).parameters.get("tile_config")
return tile_config_param is not None


def get_class_initial_arguments() -> tuple:
"""Return arguments of class initilization. This function should be called in '__init__' function.
Returns:
tuple: class arguments.
"""
keywords, _, _, values = inspect.getargvalues(inspect.stack()[1].frame)
return tuple(values[key] for key in keywords[1:])
34 changes: 13 additions & 21 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ def test_otx_ov_test(
assert len(metric_result) > 0


@pytest.mark.parametrize("task", pytest.TASK_LIST)
@pytest.mark.parametrize("recipe", pytest.RECIPE_LIST, ids=lambda x: "/".join(Path(x).parts[-2:]))
def test_otx_hpo_e2e(
task: OTXTaskType,
recipe: str,
tmp_path: Path,
fxt_accelerator: str,
fxt_target_dataset_per_task: dict,
Expand All @@ -447,30 +447,22 @@ def test_otx_hpo_e2e(
Returns:
None
"""
if task not in DEFAULT_CONFIG_PER_TASK:
pytest.skip(f"Task {task} is not supported in the auto-configuration.")
if task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
pytest.skip("ZERO_SHOT_VISUAL_PROMPTING doesn't support HPO.")
task = recipe.split("/")[-2]
model_name = recipe.split("/")[-1].split(".")[0]

# Need to change model to stfpm because default anomaly model is 'padim' which doesn't support HPO
model_cfg = []
if task in {
OTXTaskType.ANOMALY_CLASSIFICATION,
OTXTaskType.ANOMALY_DETECTION,
OTXTaskType.ANOMALY_SEGMENTATION,
}:
model_cfg = ["--config", str(DEFAULT_CONFIG_PER_TASK[task].parent / "stfpm.yaml")]
if task.upper() == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
pytest.skip("ZERO_SHOT_VISUAL_PROMPTING doesn't support HPO.")
if "padim" in recipe:
pytest.skip("padim model doesn't support HPO.")

task = task.lower()
tmp_path_hpo = tmp_path / f"otx_hpo_{task}"
tmp_path_hpo = tmp_path / f"otx_hpo_{model_name}"
tmp_path_hpo.mkdir(parents=True)

command_cfg = [
"otx",
"train",
*model_cfg,
"--task",
task.upper(),
"--config",
recipe,
"--data_root",
fxt_target_dataset_per_task[task],
"--work_dir",
Expand All @@ -482,7 +474,7 @@ def test_otx_hpo_e2e(
"--run_hpo",
"true",
"--hpo_config.expected_time_ratio",
"2",
"1",
"--hpo_config.num_workers",
"1",
*fxt_cli_override_command_per_task[task],
Expand All @@ -500,7 +492,7 @@ def test_otx_hpo_e2e(
if task.startswith("anomaly"):
return

assert len([val for val in hpo_work_dor.rglob("*.json") if str(val.stem).isdigit()]) == 2
assert len([val for val in hpo_work_dor.rglob("*.json") if str(val.stem).isdigit()]) == 1


@pytest.mark.parametrize("task", pytest.TASK_LIST)
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from otx.utils.utils import (
find_file_recursively,
get_class_initial_arguments,
get_decimal_point,
get_using_dot_delimited_key,
remove_matched_files,
Expand Down Expand Up @@ -99,3 +100,13 @@ def test_remove_matched_files_no_file_to_remove(temporary_dir_w_some_txt):
remove_matched_files(temporary_dir_w_some_txt, "*.log")

assert len(list(temporary_dir_w_some_txt.rglob("*.txt"))) == 5


def test_get_class_initial_arguments():
class FakeCls:
def __init__(self, a, b):
self.init_args = get_class_initial_arguments()

fake_cls = FakeCls(4, 5)

assert fake_cls.init_args == (4, 5)

0 comments on commit f4040de

Please sign in to comment.