diff --git a/doc/source/tune-searchalg.rst b/doc/source/tune-searchalg.rst index 0a2bf491a6769..2dae8eaf4abed 100644 --- a/doc/source/tune-searchalg.rst +++ b/doc/source/tune-searchalg.rst @@ -171,7 +171,9 @@ This algorithm requires specifying a search space and objective. You can use `Ax .. code-block:: python - tune.run(... , search_alg=AxSearch(parameter_dicts, ... )) + client = AxClient(enforce_sequential_optimization=False) + client.create_experiment( ... ) + tune.run(... , search_alg=AxSearch(client)) An example of this can be found in `ax_example.py `__. diff --git a/python/ray/tune/examples/ax_example.py b/python/ray/tune/examples/ax_example.py index 07bb7f79a1f32..8620986a26ea9 100644 --- a/python/ray/tune/examples/ax_example.py +++ b/python/ray/tune/examples/ax_example.py @@ -51,11 +51,13 @@ def easy_objective(config, reporter): if __name__ == "__main__": import argparse + from ax.service.ax_client import AxClient parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() + ray.init() config = { @@ -101,13 +103,14 @@ def easy_objective(config, reporter): "bounds": [0.0, 1.0], }, ] - algo = AxSearch( + client = AxClient(enforce_sequential_optimization=False) + client.create_experiment( parameters=parameters, objective_name="hartmann6", - max_concurrent=4, minimize=True, # Optional, defaults to False. parameter_constraints=["x1 + x2 <= 2.0"], # Optional. outcome_constraints=["l2norm <= 1.25"], # Optional. ) + algo = AxSearch(client, max_concurrent=4) scheduler = AsyncHyperBandScheduler(reward_attr="hartmann6") run(easy_objective, name="ax", search_alg=algo, **config) diff --git a/python/ray/tune/suggest/ax.py b/python/ray/tune/suggest/ax.py index a48852e848643..75b982d670871 100644 --- a/python/ray/tune/suggest/ax.py +++ b/python/ray/tune/suggest/ax.py @@ -6,16 +6,19 @@ import ax except ImportError: ax = None +import logging from ray.tune.suggest.suggestion import SuggestionAlgorithm +logger = logging.getLogger(__name__) + class AxSearch(SuggestionAlgorithm): """A wrapper around Ax to provide trial suggestions. - Requires Ax to be installed. - Ax is an open source tool from Facebook for configuring and - optimizing experiments. More information can be found in https://ax.dev/. + Requires Ax to be installed. Ax is an open source tool from + Facebook for configuring and optimizing experiments. More information + can be found in https://ax.dev/. Parameters: parameters (list[dict]): Parameters in the experiment search space. @@ -48,40 +51,27 @@ class AxSearch(SuggestionAlgorithm): >>> objective_name="hartmann6", max_concurrent=4) """ - def __init__(self, - parameters, - objective_name, - max_concurrent=10, - minimize=False, - parameter_constraints=None, - outcome_constraints=None, - **kwargs): + def __init__(self, ax_client, max_concurrent=10, **kwargs): assert ax is not None, "Ax must be installed!" - from ax.service import ax_client assert type(max_concurrent) is int and max_concurrent > 0 - self._ax = ax_client.AxClient(enforce_sequential_optimization=False) - self._ax.create_experiment( - name="ax", - parameters=parameters, - objective_name=objective_name, - minimize=minimize, - parameter_constraints=parameter_constraints or [], - outcome_constraints=outcome_constraints or [], - ) + self._ax = ax_client + exp = self._ax.experiment + self._objective_name = exp.optimization_config.objective.metric.name + if self._ax._enforce_sequential_optimization: + logger.warning("Detected sequential enforcement. Setting max " + "concurrency to 1.") + max_concurrent = 1 self._max_concurrent = max_concurrent - self._parameters = [d["name"] for d in parameters] - self._objective_name = objective_name + self._parameters = list(exp.parameters) self._live_index_mapping = {} - super(AxSearch, self).__init__(**kwargs) def _suggest(self, trial_id): if self._num_live_trials() >= self._max_concurrent: return None parameters, trial_index = self._ax.get_next_trial() - suggested_config = list(parameters.values()) self._live_index_mapping[trial_id] = trial_index - return dict(zip(self._parameters, suggested_config)) + return parameters def on_trial_result(self, trial_id, result): pass