Skip to content

Commit

Permalink
add torchao (#894)
Browse files Browse the repository at this point in the history
* add `torchao`

* test rm

* debug `trainer.validate`

* debug p2

* debug p3

* make `drop_last` a config arg
  • Loading branch information
MaximilienLC authored Mar 4, 2025
1 parent 0b86d8e commit 7ff6fda
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion cneuromax/fitting/deeplearning/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BaseDataModuleConfig:
not recommended for resource efficiency.
shuffle_train_dataset
shuffle_val_dataset
drop_last: See
:paramref:`~torch.utils.data.DataLoader.drop_last`.
"""

data_dir: An[str, not_empty()] = "${config.data_dir}"
Expand All @@ -65,6 +67,7 @@ class BaseDataModuleConfig:
fixed_per_device_num_workers: An[int, ge(0)] | None = None
shuffle_train_dataset: bool = True
shuffle_val_dataset: bool = True
drop_last: bool = False


class BaseDataModule(LightningDataModule, ABC):
Expand Down Expand Up @@ -166,7 +169,7 @@ def x_dataloader(
num_workers=self.per_device_num_workers,
collate_fn=self.collate_fn,
pin_memory=self.pin_memory,
drop_last=True,
drop_last=self.config.drop_last,
)

@final
Expand Down

0 comments on commit 7ff6fda

Please sign in to comment.