diff --git a/python/ray/tune/examples/bohb_example.py b/python/ray/tune/examples/bohb_example.py index 4d3e14572c80b..ecefb8329e588 100644 --- a/python/ray/tune/examples/bohb_example.py +++ b/python/ray/tune/examples/bohb_example.py @@ -70,10 +70,11 @@ def load_checkpoint(self, checkpoint_path): # CS.CategoricalHyperparameter( # "activation", choices=["relu", "tanh"])) + max_iterations = 10 bohb_hyperband = HyperBandForBOHB( time_attr="training_iteration", - max_t=100, - reduction_factor=4, + max_t=max_iterations, + reduction_factor=2, stop_last_trials=False, ) @@ -84,13 +85,15 @@ def load_checkpoint(self, checkpoint_path): tuner = tune.Tuner( MyTrainableClass, - run_config=air.RunConfig(name="bohb_test", stop={"training_iteration": 100}), + run_config=air.RunConfig( + name="bohb_test", stop={"training_iteration": max_iterations} + ), tune_config=tune.TuneConfig( metric="episode_reward_mean", mode="max", scheduler=bohb_hyperband, search_alg=bohb_search, - num_samples=10, + num_samples=32, ), param_space=config, ) diff --git a/python/ray/tune/schedulers/hb_bohb.py b/python/ray/tune/schedulers/hb_bohb.py index 874d2eeae9a01..f1edbad9665c8 100644 --- a/python/ray/tune/schedulers/hb_bohb.py +++ b/python/ray/tune/schedulers/hb_bohb.py @@ -112,6 +112,8 @@ def on_trial_result( trial_runner._search_alg.searcher.on_pause(trial.trial_id) return TrialScheduler.PAUSE action = self._process_bracket(trial_runner, bracket) + if action == TrialScheduler.PAUSE: + trial_runner._search_alg.searcher.on_pause(trial.trial_id) return action def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner", trial: Trial): diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index a9cb803fbcb87..3e50d272df8e4 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -266,16 +266,19 @@ def _process_bracket( # ready the good trials - if trial is too far ahead, don't continue for t in good: - if t.status not in [Trial.PAUSED, Trial.RUNNING]: - raise TuneError( - f"Trial with unexpected good status encountered: {t.status}" - ) if bracket.continue_trial(t): + # The scheduler should have cleaned up this trial already. + assert t.status not in (Trial.ERROR, Trial.TERMINATED), ( + f"Good trial {t.trial_id} is in an invalid state: {t.status}\n" + "Expected trial to be either PAUSED, PENDING, or RUNNING.\n" + "If you encounter this, please file an issue on the Ray Github." + ) if t.status == Trial.PAUSED: self._unpause_trial(trial_runner, t) trial_runner._set_trial_status(t, Trial.PENDING) elif t.status == Trial.RUNNING: action = TrialScheduler.CONTINUE + # else: PENDING trial (from a previous unpause) should stay as is. return action def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner", trial: Trial): diff --git a/python/ray/tune/search/bohb/bohb_search.py b/python/ray/tune/search/bohb/bohb_search.py index f2fafbc3cdb96..ee1c44868369b 100644 --- a/python/ray/tune/search/bohb/bohb_search.py +++ b/python/ray/tune/search/bohb/bohb_search.py @@ -272,10 +272,10 @@ def to_wrapper(self, trial_id: str, result: Dict) -> _BOHBJobWrapper: # TODO(team-ml): Refactor alongside HyperBandForBOHB def on_pause(self, trial_id: str): self.paused.add(trial_id) - self.running.remove(trial_id) + self.running.discard(trial_id) def on_unpause(self, trial_id: str): - self.paused.remove(trial_id) + self.paused.discard(trial_id) self.running.add(trial_id) @staticmethod