Skip to content

Commit

Permalink
Fix/delete old ckpts (#174)
Browse files Browse the repository at this point in the history
* fix: distributed training

* fix: typing

* release: version 0.5.1
  • Loading branch information
michele-milesi authored Dec 19, 2023
1 parent 0453284 commit 9218948
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ warn_return_any = false


[tool.bumpver]
current_version = "0.5.0"
current_version = "0.5.1"
version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
commit_message = "bump version {old_version} -> {new_version}"
tag_message = "v{new_version}"
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
np.int = np.int64
np.bool = bool

__version__ = "0.5.0"
__version__ = "0.5.1"


# Replace `moviepy.decorators.use_clip_fps_by_default` method to work with python 3.8, 3.9, and 3.10
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/utils/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def on_checkpoint_coupled(
fabric.save(ckpt_path, state)
if replay_buffer is not None:
self._experiment_consistent_rb(replay_buffer, rb_state)
if self.keep_last:
if fabric.is_global_zero and self.keep_last:
self._delete_old_checkpoints(pathlib.Path(ckpt_path).parent)

def on_checkpoint_player(
Expand All @@ -71,7 +71,7 @@ def on_checkpoint_player(
fabric.save(ckpt_path, state)
if replay_buffer is not None:
self._experiment_consistent_rb(replay_buffer, rb_state)
if self.keep_last:
if fabric.is_global_zero and self.keep_last:
self._delete_old_checkpoints(pathlib.Path(ckpt_path).parent)

def on_checkpoint_trainer(
Expand Down Expand Up @@ -138,7 +138,7 @@ def _experiment_consistent_rb(
# reinsert the open episodes to continue the training
rb._open_episodes = state

def _delete_old_checkpoints(self, ckpt_folder: str | pathlib.Path):
def _delete_old_checkpoints(self, ckpt_folder: pathlib.Path):
ckpts = list(sorted(ckpt_folder.glob("*.ckpt"), key=os.path.getmtime))
if len(ckpts) > self.keep_last:
to_delete = ckpts[: -self.keep_last]
Expand Down

0 comments on commit 9218948

Please sign in to comment.