From 1aaaac6328476249371799b92ced3edcbaac8d18 Mon Sep 17 00:00:00 2001 From: Lordmau5 Date: Fri, 21 Apr 2023 23:25:27 +0200 Subject: [PATCH] fix(train): only save checkpoints on main device --- src/so_vits_svc_fork/train.py | 17 +++++++++++++---- src/so_vits_svc_fork/utils.py | 5 +++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index 925f1ee4..e2fe9be8 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -254,10 +254,20 @@ def stft( torch.stft = stft def on_train_end(self) -> None: - if not self.tuning: - self.save_checkpoints(adjust=0) + self.save_checkpoints(adjust=0) def save_checkpoints(self, adjust=1): + if self.tuning or self.trainer.sanity_checking: + return + + # only save checkpoints if we are on the main device + if ( + hasattr(self.device, "index") + and self.device.index != None + and self.device.index != 0 + ): + return + # `on_train_end` will be the actual epoch, not a -1, so we have to call it with `adjust = 0` current_epoch = self.current_epoch + adjust total_batch_idx = self.total_batch_idx - 1 + adjust @@ -547,5 +557,4 @@ def validation_step(self, batch, batch_idx): ) def on_validation_end(self) -> None: - if not self.trainer.sanity_checking and not self.tuning: - self.save_checkpoints() + self.save_checkpoints() diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index 36d512e8..c5e08f6d 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -324,8 +324,9 @@ def clean_checkpoints( to_delete_list = list(group_items)[:-n_ckpts_to_keep] for to_delete in to_delete_list: - LOG.info(f"Removing {to_delete}") - to_delete.unlink() + if to_delete.exists(): + LOG.info(f"Removing {to_delete}") + to_delete.unlink() def latest_checkpoint_path(dir_path: Path | str, regex: str = "G_*.pth") -> Path | None: