Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] [PBT] [Doc] Add example PBT notebook #28519

Merged
merged 23 commits into from
Oct 4, 2022
Merged
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d4ffcde
Add better PBT logging for exploit, explore
justinvyu Sep 3, 2022
d227adc
Simplify perturb logic
justinvyu Sep 6, 2022
dfef665
Add operation tracking and logging for PBT perturbs
justinvyu Sep 7, 2022
de064ab
[Debug] Temporary fix for PBT checkpoint setting and loading
justinvyu Sep 3, 2022
b07f4aa
Add forced checkpoint logic for PBT
justinvyu Sep 7, 2022
aff2bb4
Add example notebook walking through paper toy example
justinvyu Sep 8, 2022
ea54f45
Add animation, 4 trial expeirment, and more explanations to notebok
justinvyu Sep 9, 2022
421dd03
Add 4 trial gif, separate make_animation function
justinvyu Sep 9, 2022
6077337
Add open in colab button
justinvyu Sep 9, 2022
8521074
Add __init__.py to pbt_visualization doc module
justinvyu Sep 13, 2022
f08ccfe
Merge branch 'master' of https://github.com/ray-project/ray into pbt_…
justinvyu Sep 14, 2022
21eac42
Rerun 2 trial PBT to generate a better visual
justinvyu Sep 14, 2022
0bdd8c8
Clean-up tune examples TOC into sub-sections + add PBT notebook into TOC
justinvyu Sep 14, 2022
9726837
Add test for PBT mutations logging + fix for empty `hyperparam_mutati…
justinvyu Sep 27, 2022
550fd44
Fix some wording in the notebook
justinvyu Sep 27, 2022
1aa30c3
Merge branch 'master' of https://github.com/ray-project/ray into pbt_…
justinvyu Sep 27, 2022
d5f6357
Revert commits related to forced checkpoint
justinvyu Sep 27, 2022
b9763aa
Fix failing `_exploit` tests
justinvyu Sep 28, 2022
03e6d17
Improve pbt example notebook explanations (mention async behavior), a…
justinvyu Sep 29, 2022
f965de7
Clean up references to the pbt example notebook
justinvyu Sep 29, 2022
df6940b
Remove missing reference to utils file
justinvyu Sep 29, 2022
65ecf6d
Improve documentation and typing hints + fix shift noop case
justinvyu Sep 30, 2022
755bebf
Add assertion for matching keys in summarize_hyperparam_changes
justinvyu Sep 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add forced checkpoint logic for PBT
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
justinvyu committed Sep 8, 2022
commit b07f4aaa360ade6be36879255256de144cd646a3
5 changes: 5 additions & 0 deletions python/ray/air/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -72,6 +72,8 @@ def __init__(
self.metrics = flatten_dict(metrics) if metrics else {}
self.node_ip = node_ip or self.metrics.get(NODE_IP, None)

self.forced = False

if (
dir_or_data is not None
and storage_mode == CheckpointStorage.MEMORY
@@ -82,6 +84,9 @@ def __init__(
f"as their data. Got: {dir_or_data}"
)

def set_forced_bit(self, val: bool):
self.forced = val

def commit(self, path: Optional[Path] = None) -> None:
"""Commit checkpoint to disk, if needed.

22 changes: 10 additions & 12 deletions python/ray/tune/execution/checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -72,8 +72,10 @@ def handle_checkpoint(self, checkpoint: _TrackedCheckpoint):
)
self._process_persistent_checkpoint(checkpoint)

def on_checkpoint(self, checkpoint: _TrackedCheckpoint):
def on_checkpoint(self, checkpoint: _TrackedCheckpoint, force: bool = False):
"""Ray Tune's entrypoint"""
if force:
checkpoint.set_forced_bit(True)
# Todo (krfricke): Replace with handle_checkpoint.
self.handle_checkpoint(checkpoint)

@@ -100,17 +102,13 @@ def newest_persistent_checkpoint(self):
@property
def newest_checkpoint(self):
"""Returns the newest checkpoint (based on training iteration)."""
# NOTE: For PBT Debugging purposes only.
# TODO(justinvyu): Remove this.
# print("[DEBUGGING] newest_memory_checkpoint.id = ", self.newest_memory_checkpoint.id)
# print("[DEBUGGING] newest_persistent_checkpoint.id = ", self.newest_persistent_checkpoint.id)

# newest_checkpoint = max(
# [self.newest_memory_checkpoint, self.newest_persistent_checkpoint],
# key=lambda c: c.id,
# )
# return newest_checkpoint
return self.newest_memory_checkpoint
checkpoints = [self.newest_memory_checkpoint, self.newest_persistent_checkpoint]
# Always prefer forced checkpoints
# If multiple are forced, take the one with the highest checkpoint id
checkpoints.sort(key=lambda c: c.id)
checkpoints.sort(key=lambda c: c.forced)
newest_checkpoint = checkpoints[-1]
return newest_checkpoint

@property
def newest_memory_checkpoint(self):
4 changes: 2 additions & 2 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
@@ -683,13 +683,13 @@ def clear_checkpoint(self):
self.restoring_from = None
self.invalidate_json_state()

def on_checkpoint(self, checkpoint: _TrackedCheckpoint):
def on_checkpoint(self, checkpoint: _TrackedCheckpoint, force: bool = False):
"""Hook for handling checkpoints taken by the Trainable.

Args:
checkpoint: Checkpoint taken.
"""
self.checkpoint_manager.on_checkpoint(checkpoint)
self.checkpoint_manager.on_checkpoint(checkpoint, force=force)
self.invalidate_json_state()

def on_restore(self):
5 changes: 4 additions & 1 deletion python/ray/tune/schedulers/pbt.py
Original file line number Diff line number Diff line change
@@ -692,7 +692,10 @@ def _exploit(
trial_executor.set_status(trial, Trial.PAUSED)
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
trial.on_checkpoint(new_state.last_checkpoint)
# Create a shallow copy to keep forced bit modifications isolated, while
# avoiding copying the underlying `dir_or_data`
exploit_checkpoint = copy.copy(new_state.last_checkpoint)
trial.on_checkpoint(exploit_checkpoint, force=True)

self._num_perturbations += 1
# Transfer over the last perturbation time as well