Skip to content

Commit ce34183

Browse files
authored
generalize to hf (#203)
1 parent 9423229 commit ce34183

File tree

3 files changed

+34
-22
lines changed

3 files changed

+34
-22
lines changed

cneuromax/fitting/deeplearning/datamodule/base.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Annotated as An
55
from typing import final
66

7+
from datasets import Dataset as HFDataset
78
from lightning.pytorch import LightningDataModule
89
from torch import Tensor
910
from torch.utils.data import DataLoader, Dataset
@@ -13,7 +14,10 @@
1314

1415
@dataclass
1516
class Datasets:
16-
"""Holds stage-specific :class:`~torch.utils.data.Dataset` objects.
17+
"""Holds phase-specific :class:`~torch.utils.data.Dataset` objects.
18+
19+
Using the word ``phase`` to not overload :mod:`lightning` ``stage``
20+
terminology used for ``fit``, ``validate`` and ``test``.
1721
1822
Args:
1923
train: Training dataset.
@@ -22,10 +26,10 @@ class Datasets:
2226
predict: Prediction dataset.
2327
"""
2428

25-
train: Dataset[Tensor] | None = None
26-
val: Dataset[Tensor] | None = None
27-
test: Dataset[Tensor] | None = None
28-
predict: Dataset[Tensor] | None = None
29+
train: Dataset[Tensor] | HFDataset | None = None
30+
val: Dataset[Tensor] | HFDataset | None = None
31+
test: Dataset[Tensor] | HFDataset | None = None
32+
predict: Dataset[Tensor] | HFDataset | None = None
2933

3034

3135
@dataclass
@@ -44,16 +48,18 @@ class BaseDataModuleConfig:
4448
class BaseDataModule(LightningDataModule, metaclass=ABCMeta):
4549
"""Base :mod:`lightning` ``DataModule``.
4650
47-
With ``<stage>`` being any of ``train``, ``val``, ``test`` or
51+
With ``<phase>`` being any of ``train``, ``val``, ``test`` or
4852
``predict``, subclasses need to properly define the
49-
``datasets.<stage>`` attribute(s) for each desired stage.
53+
``datasets.<phase>`` attribute(s) for each desired phase.
5054
5155
Args:
5256
config: See :class:`BaseDataModuleConfig`.
5357
5458
Attributes:
5559
config (:class:`BaseDataModuleConfig`)
5660
datasets (:class:`Datasets`)
61+
collate_fn (``callable``): See \
62+
:paramref:`torch.utils.data.DataLoader.collate_fn`.
5763
pin_memory (``bool``): Whether to copy tensors into device\
5864
pinned memory before returning them (is set to ``True`` by\
5965
default if :paramref:`~BaseDataModuleConfig.device` is\
@@ -72,6 +78,7 @@ def __init__(self: "BaseDataModule", config: BaseDataModuleConfig) -> None:
7278
super().__init__()
7379
self.config = config
7480
self.datasets = Datasets()
81+
self.collate_fn = None
7582
self.pin_memory = self.config.device == "gpu"
7683
self.per_device_batch_size = 1
7784
self.per_device_num_workers = 0
@@ -108,7 +115,7 @@ def state_dict(self: "BaseDataModule") -> dict[str, int]:
108115
@final
109116
def x_dataloader(
110117
self: "BaseDataModule",
111-
dataset: Dataset[Tensor] | None,
118+
dataset: Dataset[Tensor] | HFDataset | None,
112119
*,
113120
shuffle: bool = True,
114121
) -> DataLoader[Tensor]:
@@ -134,6 +141,7 @@ def x_dataloader(
134141
batch_size=self.per_device_batch_size,
135142
shuffle=shuffle,
136143
num_workers=self.per_device_num_workers,
144+
collate_fn=self.collate_fn,
137145
pin_memory=self.pin_memory,
138146
)
139147

cneuromax/fitting/deeplearning/litmodule/base.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.optim import Optimizer
1111
from torch.optim.lr_scheduler import LRScheduler
1212

13+
from cneuromax.fitting.deeplearning.utils.type import Batch_type
1314
from cneuromax.utils.beartype import one_of
1415

1516

@@ -86,9 +87,7 @@ def __init__(
8687
@final
8788
def stage_step(
8889
self: "BaseLitModule",
89-
batch: Num[Tensor, " ..."]
90-
| tuple[Num[Tensor, " ..."], ...]
91-
| list[Num[Tensor, " ..."]],
90+
batch: Batch_type,
9291
stage: An[str, one_of("train", "val", "test", "predict")],
9392
) -> Num[Tensor, " ..."]:
9493
"""Generic stage wrapper around the :meth:`step` method.
@@ -105,17 +104,15 @@ def stage_step(
105104
The loss value(s).
106105
"""
107106
if isinstance(batch, list):
108-
tupled_batch: tuple[Num[Tensor, " ..."], ...] = tuple(batch)
109-
loss: Num[Tensor, " ..."] = self.step(tupled_batch, stage)
107+
batch = tuple(batch)
108+
loss: Num[Tensor, " ..."] = self.step(batch, stage)
110109
self.log(name=f"{stage}/loss", value=loss)
111110
return loss
112111

113112
@final
114113
def training_step(
115114
self: "BaseLitModule",
116-
batch: Num[Tensor, " ..."]
117-
| tuple[Num[Tensor, " ..."], ...]
118-
| list[Num[Tensor, " ..."]],
115+
batch: Batch_type,
119116
) -> Num[Tensor, " ..."]:
120117
"""Calls :meth:`stage_step` with argument ``stage="train"``.
121118
@@ -130,9 +127,7 @@ def training_step(
130127
@final
131128
def validation_step(
132129
self: "BaseLitModule",
133-
batch: Num[Tensor, " ..."]
134-
| tuple[Num[Tensor, " ..."], ...]
135-
| list[Num[Tensor, " ..."]],
130+
batch: Batch_type,
136131
# :paramref:`*args` & :paramref:`**kwargs` type annotations
137132
# cannot be more specific because of
138133
# :meth:`LightningModule.validation_step`\'s signature.
@@ -154,9 +149,7 @@ def validation_step(
154149
@final
155150
def test_step(
156151
self: "BaseLitModule",
157-
batch: Num[Tensor, " ..."]
158-
| tuple[Num[Tensor, " ..."], ...]
159-
| list[Num[Tensor, " ..."]],
152+
batch: Batch_type,
160153
) -> Num[Tensor, " ..."]:
161154
"""Calls :meth:`stage_step` with argument ``stage="test"``.
162155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Typing utilities."""
2+
3+
from jaxtyping import Num
4+
from torch import Tensor
5+
6+
Batch_type = (
7+
Num[Tensor, " ..."]
8+
| tuple[Num[Tensor, " ..."], ...]
9+
| list[Num[Tensor, " ..."]]
10+
| dict[str, Num[Tensor, " ..."]]
11+
)

0 commit comments

Comments
 (0)