Skip to content

Commit

Permalink
Mergeback 2.1.0 to develop (#3787)
Browse files Browse the repository at this point in the history
* Revert #3579 (#3753)

* remove config attributes on OTXDataModule

* remove dataformat check test

---------

Co-authored-by: Harim Kang <harim.kang@intel.com>
  • Loading branch information
yunchu and harimkang authored Aug 6, 2024
1 parent 0f08b40 commit f618a37
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 52 deletions.
13 changes: 0 additions & 13 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import torch
from datumaro import Dataset as DmDataset
from datumaro import Environment
from lightning import LightningDataModule
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, RandomSampler
Expand Down Expand Up @@ -105,18 +104,6 @@ def __init__(

VIDEO_EXTENSIONS.append(".mp4")

# Data Format Check
available_data_formats = Environment().detect_dataset(str(self.data_root))
if not available_data_formats:
msg = f"Invalid data root: {self.data_root}. Please check if the data root is valid."
raise ValueError(msg)
if self.data_format not in available_data_formats:
log.warning(
f"Invalid data format: {self.data_format}. Available formats: {available_data_formats} "
f"Replace data_format: {self.data_format} -> {available_data_formats[0]}.",
)
self.data_format = available_data_formats[0]

dataset = DmDataset.import_from(self.data_root, format=self.data_format)
if self.task != "H_LABEL_CLS":
dataset = pre_filtering(
Expand Down
40 changes: 1 addition & 39 deletions tests/unit/core/data/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock

import pytest
from datumaro.components.environment import Environment
from importlib_resources import files
from lightning.pytorch.loggers import CSVLogger
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -162,43 +161,6 @@ def test_init_input_size(
assert fxt_config.val_subset.input_size == (1200, 1200)
assert fxt_config.test_subset.input_size == (800, 800)

def test_data_format_check(
self,
mock_dm_dataset,
mock_otx_dataset_factory,
mock_data_filtering,
fxt_config,
caplog,
) -> None:
# Dataset will have "train_0", "train_1", "val_0", ..., "test_1" subsets
mock_dm_subsets = {f"{name}_{idx}": MagicMock() for name in ["train", "val", "test"] for idx in range(2)}
mock_dm_dataset.return_value.subsets.return_value = mock_dm_subsets
with patch.object(Environment, "detect_dataset", return_value=["voc", "voc_classification"]):
_ = OTXDataModule(
task=fxt_config.task,
data_format=fxt_config.data_format,
data_root=fxt_config.data_root,
train_subset=fxt_config.train_subset,
val_subset=fxt_config.val_subset,
test_subset=fxt_config.test_subset,
)

assert "Invalid data format:" in caplog.text
assert "Replace data_format:" in caplog.text

with patch.object(Environment, "detect_dataset", return_value=[]), pytest.raises(
ValueError,
match="Invalid data root:",
):
_ = OTXDataModule(
task=fxt_config.task,
data_format=fxt_config.data_format,
data_root=fxt_config.data_root,
train_subset=fxt_config.train_subset,
val_subset=fxt_config.val_subset,
test_subset=fxt_config.test_subset,
)

@pytest.fixture()
def fxt_real_tv_cls_config(self) -> DictConfig:
cfg_path = files("otx") / "recipe" / "_base_" / "data" / "torchvision_base.yaml"
Expand Down

0 comments on commit f618a37

Please sign in to comment.