Skip to content

Commit

Permalink
Update return type
Browse files Browse the repository at this point in the history
  • Loading branch information
jaegukhyun committed Jan 8, 2024
1 parent 486a2e3 commit b095a26
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OTXMulticlassClsModel(
"""Base class for the classification models used in OTX."""


def _create_mmpretrain_model(config: DictConfig, load_from: str) -> nn.Module:
def _create_mmpretrain_model(config: DictConfig, load_from: str) -> tuple[nn.Module, list[str]]:
from mmpretrain.models.utils import ClsDataPreprocessor as _ClsDataPreprocessor
from mmpretrain.registry import MODELS

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/utils/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def build_mm_model(config: DictConfig, model_registry: Registry, load_from: str
return model


def get_classification_layers(config: DictConfig, model_registry: Registry, prefix: str = "") -> nn.Module:
def get_classification_layers(config: DictConfig, model_registry: Registry, prefix: str = "") -> list[str]:
"""Return classification layer names by comparing two different number of classes models."""
sample_config = deepcopy(config)
modify_num_classes(sample_config, 5)
Expand Down

0 comments on commit b095a26

Please sign in to comment.