4
4
from typing import Annotated as An
5
5
from typing import final
6
6
7
+ from datasets import Dataset as HFDataset
7
8
from lightning .pytorch import LightningDataModule
8
9
from torch import Tensor
9
10
from torch .utils .data import DataLoader , Dataset
13
14
14
15
@dataclass
15
16
class Datasets :
16
- """Holds stage-specific :class:`~torch.utils.data.Dataset` objects.
17
+ """Holds phase-specific :class:`~torch.utils.data.Dataset` objects.
18
+
19
+ Using the word ``phase`` to not overload :mod:`lightning` ``stage``
20
+ terminology used for ``fit``, ``validate`` and ``test``.
17
21
18
22
Args:
19
23
train: Training dataset.
@@ -22,10 +26,10 @@ class Datasets:
22
26
predict: Prediction dataset.
23
27
"""
24
28
25
- train : Dataset [Tensor ] | None = None
26
- val : Dataset [Tensor ] | None = None
27
- test : Dataset [Tensor ] | None = None
28
- predict : Dataset [Tensor ] | None = None
29
+ train : Dataset [Tensor ] | HFDataset | None = None
30
+ val : Dataset [Tensor ] | HFDataset | None = None
31
+ test : Dataset [Tensor ] | HFDataset | None = None
32
+ predict : Dataset [Tensor ] | HFDataset | None = None
29
33
30
34
31
35
@dataclass
@@ -44,16 +48,18 @@ class BaseDataModuleConfig:
44
48
class BaseDataModule (LightningDataModule , metaclass = ABCMeta ):
45
49
"""Base :mod:`lightning` ``DataModule``.
46
50
47
- With ``<stage >`` being any of ``train``, ``val``, ``test`` or
51
+ With ``<phase >`` being any of ``train``, ``val``, ``test`` or
48
52
``predict``, subclasses need to properly define the
49
- ``datasets.<stage >`` attribute(s) for each desired stage .
53
+ ``datasets.<phase >`` attribute(s) for each desired phase .
50
54
51
55
Args:
52
56
config: See :class:`BaseDataModuleConfig`.
53
57
54
58
Attributes:
55
59
config (:class:`BaseDataModuleConfig`)
56
60
datasets (:class:`Datasets`)
61
+ collate_fn (``callable``): See \
62
+ :paramref:`torch.utils.data.DataLoader.collate_fn`.
57
63
pin_memory (``bool``): Whether to copy tensors into device\
58
64
pinned memory before returning them (is set to ``True`` by\
59
65
default if :paramref:`~BaseDataModuleConfig.device` is\
@@ -72,6 +78,7 @@ def __init__(self: "BaseDataModule", config: BaseDataModuleConfig) -> None:
72
78
super ().__init__ ()
73
79
self .config = config
74
80
self .datasets = Datasets ()
81
+ self .collate_fn = None
75
82
self .pin_memory = self .config .device == "gpu"
76
83
self .per_device_batch_size = 1
77
84
self .per_device_num_workers = 0
@@ -108,7 +115,7 @@ def state_dict(self: "BaseDataModule") -> dict[str, int]:
108
115
@final
109
116
def x_dataloader (
110
117
self : "BaseDataModule" ,
111
- dataset : Dataset [Tensor ] | None ,
118
+ dataset : Dataset [Tensor ] | HFDataset | None ,
112
119
* ,
113
120
shuffle : bool = True ,
114
121
) -> DataLoader [Tensor ]:
@@ -134,6 +141,7 @@ def x_dataloader(
134
141
batch_size = self .per_device_batch_size ,
135
142
shuffle = shuffle ,
136
143
num_workers = self .per_device_num_workers ,
144
+ collate_fn = self .collate_fn ,
137
145
pin_memory = self .pin_memory ,
138
146
)
139
147
0 commit comments