forked from openvinotoolkit/training_extensions
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodule.py
113 lines (90 loc) · 3.43 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""LightningDataModule extension for OTX."""
from __future__ import annotations
import logging as log
from typing import TYPE_CHECKING
from datumaro import Dataset as DmDataset
from lightning import LightningDataModule
from torch.utils.data import DataLoader
from otx.core.types.task import OTXTaskType
from .factory import OTXDatasetFactory
if TYPE_CHECKING:
from otx.core.config.data import (
DataModuleConfig,
SubsetConfig,
)
from .dataset.base import OTXDataset
class OTXDataModule(LightningDataModule):
"""LightningDataModule extension for OTX pipeline."""
def __init__(self, task: OTXTaskType, config: DataModuleConfig) -> None:
"""Constructor."""
super().__init__()
self.task = task
self.config = config
self.subsets: dict[str, OTXDataset] = {}
self.save_hyperparameters()
dataset = DmDataset.import_from(
self.config.data_root,
format=self.config.data_format,
)
for name, dm_subset in dataset.subsets().items():
try:
sub_config = self._get_config(name)
self.subsets[name] = OTXDatasetFactory.create(
task=self.task,
dm_subset=dm_subset,
config=sub_config,
)
log.info(f"Add name: {name}, self.subsets: {self.subsets}")
except KeyError: # noqa: PERF203
log.warning(f"{name} has no config. Skip it")
def _get_config(self, subset: str) -> SubsetConfig:
if (config := self.config.subsets.get(subset)) is None:
msg = f"Config has no '{subset}' subset configuration"
raise KeyError(msg)
return config
def _get_dataset(self, subset: str) -> OTXDataset:
if (dataset := self.subsets.get(subset)) is None:
msg = (
f"Dataset has no '{subset}'. Available subsets = {self.subsets.keys()}"
)
raise KeyError(msg)
return dataset
def train_dataloader(self) -> DataLoader:
"""Get train dataloader."""
config, dataset = self._get_config("train"), self._get_dataset("train")
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
collate_fn=dataset.collate_fn,
)
def val_dataloader(self) -> DataLoader:
"""Get val dataloader."""
config, dataset = self._get_config("val"), self._get_dataset("val")
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
collate_fn=dataset.collate_fn,
)
def test_dataloader(self) -> DataLoader:
"""Get test dataloader."""
config, dataset = self._get_config("test"), self._get_dataset("test")
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
collate_fn=dataset.collate_fn,
)
def setup(self, stage: str) -> None:
"""Setup for each stage."""
def teardown(self, stage: str) -> None:
"""Teardown for each stage."""
# clean up after fit or test
# called on every process in DDP