Skip to content

Commit

Permalink
fix(train): only save checkpoints on main device
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordmau5 authored Apr 21, 2023
1 parent 30a08d5 commit 1aaaac6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
17 changes: 13 additions & 4 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,20 @@ def stft(
torch.stft = stft

def on_train_end(self) -> None:
if not self.tuning:
self.save_checkpoints(adjust=0)
self.save_checkpoints(adjust=0)

def save_checkpoints(self, adjust=1):
if self.tuning or self.trainer.sanity_checking:
return

# only save checkpoints if we are on the main device
if (
hasattr(self.device, "index")
and self.device.index != None
and self.device.index != 0
):
return

# `on_train_end` will be the actual epoch, not a -1, so we have to call it with `adjust = 0`
current_epoch = self.current_epoch + adjust
total_batch_idx = self.total_batch_idx - 1 + adjust
Expand Down Expand Up @@ -547,5 +557,4 @@ def validation_step(self, batch, batch_idx):
)

def on_validation_end(self) -> None:
if not self.trainer.sanity_checking and not self.tuning:
self.save_checkpoints()
self.save_checkpoints()
5 changes: 3 additions & 2 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,9 @@ def clean_checkpoints(
to_delete_list = list(group_items)[:-n_ckpts_to_keep]

for to_delete in to_delete_list:
LOG.info(f"Removing {to_delete}")
to_delete.unlink()
if to_delete.exists():
LOG.info(f"Removing {to_delete}")
to_delete.unlink()


def latest_checkpoint_path(dir_path: Path | str, regex: str = "G_*.pth") -> Path | None:
Expand Down

0 comments on commit 1aaaac6

Please sign in to comment.