From d40fffec26c44a35685f8396d84ca2dcbc5430c3 Mon Sep 17 00:00:00 2001 From: Maximilien Le Clei Date: Thu, 20 Feb 2025 23:57:53 +0000 Subject: [PATCH] add param to not shuffle train dataloader --- cneuromax/fitting/deeplearning/datamodule/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cneuromax/fitting/deeplearning/datamodule/base.py b/cneuromax/fitting/deeplearning/datamodule/base.py index fd8387bf..bb5d2b74 100644 --- a/cneuromax/fitting/deeplearning/datamodule/base.py +++ b/cneuromax/fitting/deeplearning/datamodule/base.py @@ -54,6 +54,7 @@ class BaseDataModuleConfig: this value skips the num workers search in :func:`.find_good_per_device_num_workers` which is not recommended for resource efficiency. + shuffle_train_dataset shuffle_val_dataset """ @@ -62,6 +63,7 @@ class BaseDataModuleConfig: max_per_device_batch_size: An[int, ge(1)] | None = None fixed_per_device_batch_size: An[int, ge(1)] | None = None fixed_per_device_num_workers: An[int, ge(0)] | None = None + shuffle_train_dataset: bool = True shuffle_val_dataset: bool = True @@ -174,7 +176,10 @@ def train_dataloader(self: "BaseDataModule") -> DataLoader[Tensor]: A new training :class:`torch.utils.data.DataLoader` instance. """ - return self.x_dataloader(dataset=self.datasets.train) + return self.x_dataloader( + dataset=self.datasets.train, + shuffle=self.config.shuffle_train_dataset, + ) def val_dataloader(self: "BaseDataModule") -> DataLoader[Tensor]: """Calls :meth:`x_dataloader` w/ :attr:`datasets` ``.val``.