Skip to content

Commit

Permalink
add param to not shuffle train dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximilienLC committed Feb 20, 2025
1 parent 360e433 commit d40fffe
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion cneuromax/fitting/deeplearning/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class BaseDataModuleConfig:
this value skips the num workers search in
:func:`.find_good_per_device_num_workers` which is
not recommended for resource efficiency.
shuffle_train_dataset
shuffle_val_dataset
"""

Expand All @@ -62,6 +63,7 @@ class BaseDataModuleConfig:
max_per_device_batch_size: An[int, ge(1)] | None = None
fixed_per_device_batch_size: An[int, ge(1)] | None = None
fixed_per_device_num_workers: An[int, ge(0)] | None = None
shuffle_train_dataset: bool = True
shuffle_val_dataset: bool = True


Expand Down Expand Up @@ -174,7 +176,10 @@ def train_dataloader(self: "BaseDataModule") -> DataLoader[Tensor]:
A new training
:class:`torch.utils.data.DataLoader` instance.
"""
return self.x_dataloader(dataset=self.datasets.train)
return self.x_dataloader(
dataset=self.datasets.train,
shuffle=self.config.shuffle_train_dataset,
)

def val_dataloader(self: "BaseDataModule") -> DataLoader[Tensor]:
"""Calls :meth:`x_dataloader` w/ :attr:`datasets` ``.val``.
Expand Down

0 comments on commit d40fffe

Please sign in to comment.