From 92189485f6ee8e063fe5fbd05ce6f0865a6e6559 Mon Sep 17 00:00:00 2001 From: michele-milesi <74559684+michele-milesi@users.noreply.github.com> Date: Tue, 19 Dec 2023 15:10:35 +0100 Subject: [PATCH] Fix/delete old ckpts (#174) * fix: distributed training * fix: typing * release: version 0.5.1 --- pyproject.toml | 2 +- sheeprl/__init__.py | 2 +- sheeprl/utils/callback.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9441d342..80c03e6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -196,7 +196,7 @@ warn_return_any = false [tool.bumpver] -current_version = "0.5.0" +current_version = "0.5.1" version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]" commit_message = "bump version {old_version} -> {new_version}" tag_message = "v{new_version}" diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index a5a2bd8a..c81a990b 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -50,7 +50,7 @@ np.int = np.int64 np.bool = bool -__version__ = "0.5.0" +__version__ = "0.5.1" # Replace `moviepy.decorators.use_clip_fps_by_default` method to work with python 3.8, 3.9, and 3.10 diff --git a/sheeprl/utils/callback.py b/sheeprl/utils/callback.py index c40bdfd9..577b1561 100644 --- a/sheeprl/utils/callback.py +++ b/sheeprl/utils/callback.py @@ -52,7 +52,7 @@ def on_checkpoint_coupled( fabric.save(ckpt_path, state) if replay_buffer is not None: self._experiment_consistent_rb(replay_buffer, rb_state) - if self.keep_last: + if fabric.is_global_zero and self.keep_last: self._delete_old_checkpoints(pathlib.Path(ckpt_path).parent) def on_checkpoint_player( @@ -71,7 +71,7 @@ def on_checkpoint_player( fabric.save(ckpt_path, state) if replay_buffer is not None: self._experiment_consistent_rb(replay_buffer, rb_state) - if self.keep_last: + if fabric.is_global_zero and self.keep_last: self._delete_old_checkpoints(pathlib.Path(ckpt_path).parent) def on_checkpoint_trainer( @@ -138,7 +138,7 @@ def _experiment_consistent_rb( # reinsert the open episodes to continue the training rb._open_episodes = state - def _delete_old_checkpoints(self, ckpt_folder: str | pathlib.Path): + def _delete_old_checkpoints(self, ckpt_folder: pathlib.Path): ckpts = list(sorted(ckpt_folder.glob("*.ckpt"), key=os.path.getmtime)) if len(ckpts) > self.keep_last: to_delete = ckpts[: -self.keep_last]