Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] [PBT] [Doc] Add example PBT notebook #28519

Merged
merged 23 commits into from
Oct 4, 2022
Merged
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d4ffcde
Add better PBT logging for exploit, explore
justinvyu Sep 3, 2022
d227adc
Simplify perturb logic
justinvyu Sep 6, 2022
dfef665
Add operation tracking and logging for PBT perturbs
justinvyu Sep 7, 2022
de064ab
[Debug] Temporary fix for PBT checkpoint setting and loading
justinvyu Sep 3, 2022
b07f4aa
Add forced checkpoint logic for PBT
justinvyu Sep 7, 2022
aff2bb4
Add example notebook walking through paper toy example
justinvyu Sep 8, 2022
ea54f45
Add animation, 4 trial expeirment, and more explanations to notebok
justinvyu Sep 9, 2022
421dd03
Add 4 trial gif, separate make_animation function
justinvyu Sep 9, 2022
6077337
Add open in colab button
justinvyu Sep 9, 2022
8521074
Add __init__.py to pbt_visualization doc module
justinvyu Sep 13, 2022
f08ccfe
Merge branch 'master' of https://github.com/ray-project/ray into pbt_…
justinvyu Sep 14, 2022
21eac42
Rerun 2 trial PBT to generate a better visual
justinvyu Sep 14, 2022
0bdd8c8
Clean-up tune examples TOC into sub-sections + add PBT notebook into TOC
justinvyu Sep 14, 2022
9726837
Add test for PBT mutations logging + fix for empty `hyperparam_mutati…
justinvyu Sep 27, 2022
550fd44
Fix some wording in the notebook
justinvyu Sep 27, 2022
1aa30c3
Merge branch 'master' of https://github.com/ray-project/ray into pbt_…
justinvyu Sep 27, 2022
d5f6357
Revert commits related to forced checkpoint
justinvyu Sep 27, 2022
b9763aa
Fix failing `_exploit` tests
justinvyu Sep 28, 2022
03e6d17
Improve pbt example notebook explanations (mention async behavior), a…
justinvyu Sep 29, 2022
f965de7
Clean up references to the pbt example notebook
justinvyu Sep 29, 2022
df6940b
Remove missing reference to utils file
justinvyu Sep 29, 2022
65ecf6d
Improve documentation and typing hints + fix shift noop case
justinvyu Sep 30, 2022
755bebf
Add assertion for matching keys in summarize_hyperparam_changes
justinvyu Sep 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improve documentation and typing hints + fix shift noop case
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
justinvyu committed Sep 30, 2022
commit 65ecf6d24da764750de154f42f48995b73ae5887
80 changes: 62 additions & 18 deletions python/ray/tune/schedulers/pbt.py
Original file line number Diff line number Diff line change
@@ -52,7 +52,8 @@ def _explore(
perturbation_factors: Tuple[float],
custom_explore_fn: Optional[Callable],
) -> Tuple[Dict, Dict]:
"""Return a config perturbed as specified.
"""Return a perturbed config and string descriptors of the operations performed
on the original config to produce the new config.

Args:
config: Original hyperparameter configuration.
@@ -73,6 +74,7 @@ def _explore(
new_config = copy.deepcopy(config)
for key, distribution in mutations.items():
if isinstance(distribution, dict):
# Handle nested hyperparameter configs by recursively perturbing them
nested_new_config, nested_ops = _explore(
config[key],
mutations[key],
@@ -82,34 +84,55 @@ def _explore(
)
new_config.update({key: nested_new_config})
operations.update({key: nested_ops})
elif isinstance(distribution, list):
elif isinstance(distribution, (list, tuple)):
# Case 1: Hyperparameter resample distribution is a list/tuple
if (
random.random() < resample_probability
or config[key] not in distribution
):
# Resample a value from the list with `resample_probability`
new_config[key] = random.choice(distribution)
operations[key] = "resample"
else:
# Otherwise, perturb by shifting to the left or right of the list
shift = random.choice([-1, 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here that explains what we're doing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, let's add a few comments for this whole exploration block

new_idx = distribution.index(config[key]) + shift
old_idx = distribution.index(config[key])
new_idx = old_idx + shift
new_idx = min(max(new_idx, 0), len(distribution) - 1)
new_config[key] = distribution[new_idx]
operations[key] = f"shift {'left' if shift == -1 else 'right'}"
else:
operations[key] = (
f"shift {'left' if shift == -1 else 'right'}"
f"{' (noop)' if old_idx == new_idx else ''}"
)
elif isinstance(distribution, (Domain, Callable)):
# Case 2: Hyperparameter resample distribution is:
# 1. a function (ex: lambda: np.random.uniform(0, 1))
# 2. tune search Domain (ex: tune.uniform(0, 1))
if random.random() < resample_probability:
# Resample a value from the function/domain with `resample_probability`
new_config[key] = (
distribution.sample(None)
if isinstance(distribution, Domain)
else distribution()
)
operations[key] = "resample"
else:
# Otherwise, perturb by multiplying the hyperparameter by one
# of the `perturbation_factors`
perturbation_factor = random.choice(perturbation_factors)
new_config[key] = config[key] * perturbation_factor
operations[key] = f"* {perturbation_factor}"
if isinstance(config[key], int):
# If this hyperparameter started out as an integer (ex: `batch_size`),
# convert the new value back
new_config[key] = int(new_config[key])
else:
raise ValueError(
f"Unsupported hyperparameter distribution type: {type(distribution)}"
)
if custom_explore_fn:
# The user can perform any additional hyperparameter exploration
# via `custom_explore_fn`
new_config = custom_explore_fn(new_config)
assert new_config is not None, "Custom explore fn failed to return new config"
return new_config, operations
@@ -125,14 +148,18 @@ def _make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str:


def _fill_config(
config: Dict, attr: str, search_space: Union[Callable, Domain, list, dict]
config: Dict, attr: str, search_space: Union[dict, list, tuple, Callable, Domain]
):
"""Add attr to config by sampling from search_space."""
if callable(search_space):
"""Add attr to config by sampling from search_space.

This is a helper used to set initial hyperparameter values if the user doesn't
specify them in the Tuner `param_space`.
"""
if isinstance(search_space, Callable):
config[attr] = search_space()
elif isinstance(search_space, Domain):
config[attr] = search_space.sample(None)
elif isinstance(search_space, list):
elif isinstance(search_space, (list, tuple)):
config[attr] = random.choice(search_space)
elif isinstance(search_space, dict):
config[attr] = {}
@@ -192,11 +219,12 @@ class PopulationBasedTraining(FIFOScheduler):
specifies the distribution of a continuous parameter. You must
use tune.choice, tune.uniform, tune.loguniform, etc.. Arbitrary
tune.sample_from objects are not supported.
A key can also hold a dict for nested hyperparameters.
You must specify at least one of `hyperparam_mutations` or
`custom_explore_fn`.
Tune will use the search space provided by
`hyperparam_mutations` for the initial samples if the
corresponding attributes are not present in `config`.
Tune will sample the search space provided by
`hyperparam_mutations` for the initial hyperparameter values if the
corresponding hyperparameters are not present in a trial's initial `config`.
quantile_fraction: Parameters are transferred from the top
`quantile_fraction` fraction of trials to the bottom
`quantile_fraction` fraction. Needs to be between 0 and 0.5.
@@ -271,7 +299,9 @@ def __init__(
mode: Optional[str] = None,
perturbation_interval: float = 60.0,
burn_in_period: float = 0.0,
hyperparam_mutations: Dict = None,
hyperparam_mutations: Dict[
str, Union[dict, list, tuple, Callable, Domain]
] = None,
quantile_fraction: float = 0.25,
resample_probability: float = 0.25,
perturbation_factors: Tuple[float, float] = (1.2, 0.8),
@@ -282,10 +312,10 @@ def __init__(
):
hyperparam_mutations = hyperparam_mutations or {}
for value in hyperparam_mutations.values():
if not (isinstance(value, (list, dict, Domain)) or callable(value)):
if not isinstance(value, (dict, list, tuple, Domain, Callable)):
raise TypeError(
"`hyperparam_mutation` values must be either "
"a List, Dict, a tune search space object, or "
"a List, Tuple, Dict, a tune search space object, or "
"a callable."
)
if isinstance(value, Function):
@@ -614,7 +644,7 @@ def _log_config_on_step(
with open(trial_path, "a+") as f:
f.write(json.dumps(policy, cls=SafeFallbackEncoder) + "\n")

def _get_new_config(self, trial, trial_to_clone) -> Tuple[Dict, Dict]:
def _get_new_config(self, trial: Trial, trial_to_clone: Trial) -> Tuple[Dict, Dict]:
"""Gets new config for trial by exploring trial_to_clone's config."""
return _explore(
trial_to_clone.config,
@@ -625,8 +655,21 @@ def _get_new_config(self, trial, trial_to_clone) -> Tuple[Dict, Dict]:
)

def _summarize_hyperparam_changes(
self, old_params, new_params, operations, prefix=""
):
self, old_params: Dict, new_params: Dict, operations: Dict, prefix: str = ""
) -> str:
"""Generates a summary of hyperparameter changes from a PBT "explore" step.

Args:
old_params: Old values of hyperparameters that are perturbed to generate
the new config
new_params: The newly generated hyperparameter config from PBT exploration
operations: Map of hyperparam -> string describing mutation the operation
performed on it to generate the value in `new_params`
prefix: Helper argument to format nested dict hyperparam configs

Returns:
summary_str: The hyperparameter change summary to print/log.
"""
summary_str = ""
if not old_params:
return summary_str
@@ -635,6 +678,7 @@ def _summarize_hyperparam_changes(
new_val = new_params[param_name]
summary_str += f"{prefix}{param_name} : "
if isinstance(old_val, Dict):
# Handle nested hyperparameters by recursively summarizing
summary_str += "\n"
summary_str += self._summarize_hyperparam_changes(
old_val, new_val, operations[param_name], prefix=prefix + " " * 4
10 changes: 10 additions & 0 deletions python/ray/tune/tests/test_trial_scheduler.py
Original file line number Diff line number Diff line change
@@ -1205,6 +1205,16 @@ def explore_fn(
{3, 4, 8, 10},
)

# Check that tuple also works
assertProduces(lambda: explore_fn({"v": 4}, {"v": (3, 4, 8, 10)}, 0.0), {3, 8})
assertProduces(lambda: explore_fn({"v": 3}, {"v": (3, 4, 8, 10)}, 0.0), {3, 4})

# Passing in an invalid types should raise an error
with self.assertRaises(ValueError):
explore_fn({"v": 4}, {"v": {3, 4, 8, 10}}, 0.0)
with self.assertRaises(ValueError):
explore_fn({"v": 4}, {"v": "invalid"}, 0.0)

# Continuous case
assertProduces(
lambda: explore_fn(
14 changes: 13 additions & 1 deletion python/ray/tune/tests/test_trial_scheduler_pbt.py
Original file line number Diff line number Diff line change
@@ -632,10 +632,16 @@ class DummyTrial:
def __init__(self, config):
self.config = config

def test_config(hyperparam_mutations, old_config, print_summary=False):
def test_config(
hyperparam_mutations,
old_config,
resample_probability=0.25,
print_summary=False,
):
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
hyperparam_mutations=hyperparam_mutations,
resample_probability=resample_probability,
)
new_config, operations = scheduler._get_new_config(
None, DummyTrial(old_config)
@@ -691,6 +697,12 @@ def test_config(hyperparam_mutations, old_config, print_summary=False):
] + ["resample"]
assert operations["c"]["e"]["f"] in ["shift left", "shift right", "resample"]

# 4. Test shift that results in noop
hyperparam_mutations = {"a": [1]}
scheduler, new_config, operations = test_config(
hyperparam_mutations, {"a": 1}, resample_probability=0
)
assert operations["a"] in ["shift left (noop)", "shift right (noop)"]

if __name__ == "__main__":
import pytest