Skip to content

Commit

Permalink
dataloader drop_last, BS finder bin_search, smarter wandb media logging
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximilienLC committed Feb 23, 2025
1 parent e99851b commit 0b86d8e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
1 change: 1 addition & 0 deletions cneuromax/fitting/deeplearning/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def x_dataloader(
num_workers=self.per_device_num_workers,
collate_fn=self.collate_fn,
pin_memory=self.pin_memory,
drop_last=True,
)

@final
Expand Down
22 changes: 10 additions & 12 deletions cneuromax/fitting/deeplearning/litmodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseLitModuleConfig:
Args:
wandb_column_names
wandb_train_log_interval
wandb_train_log_interval: `0` means no logging.
wandb_num_samples
"""

Expand Down Expand Up @@ -135,31 +135,29 @@ def on_load_checkpoint( # noqa: D102
self.curr_val_epoch = checkpoint["curr_val_epoch"]
return super().on_load_checkpoint(checkpoint)

def on_train_batch_start( # noqa: D102
self: "BaseLitModule",
*args: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
) -> None:
def on_validation_end(self: "BaseLitModule") -> None: # noqa: D102
self.wandb_train_data = []
super().on_train_batch_start(*args, **kwargs)

def on_validation_start(self: "BaseLitModule") -> None: # noqa: D102
self.wandb_val_data = []
super().on_validation_start()
super().on_validation_end()

def optimizer_step( # noqa: D102
self: "BaseLitModule",
*args: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
) -> None:
super().optimizer_step(*args, **kwargs)
if self.curr_train_step % self.config.wandb_train_log_interval == 0:
if (
self.config.wandb_train_log_interval
and self.curr_train_step % self.config.wandb_train_log_interval
== 0
):
self.log_table(self.wandb_train_data)
self.curr_train_step += 1

def on_validation_epoch_end(self: "BaseLitModule") -> None: # noqa: D102
super().on_validation_epoch_end()
self.log_table(self.wandb_val_data)
if self.config.wandb_train_log_interval:
self.log_table(self.wandb_val_data)
self.curr_val_epoch += 1

def update_wandb_data_before_log(
Expand Down
2 changes: 1 addition & 1 deletion cneuromax/fitting/deeplearning/utils/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def find_good_per_device_batch_size(
launcher_config.cpus_per_task
)
batch_size_finder = BatchSizeFinder(
mode="power",
mode="binsearch",
batch_arg_name="per_device_batch_size",
max_trials=int(math.log2(max_per_device_batch_size)),
)
Expand Down

0 comments on commit 0b86d8e

Please sign in to comment.