Skip to content

Commit

Permalink
Rename ask->ask_trials / _ask->ask
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiLehe committed Jun 26, 2024
1 parent 110b48d commit d1b7721
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion doc/source/examples/ps_line_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ where :math:`x_0` and :math:`x_1` have a default values of :math:`5` and

all_trials = []
while True:
trial = gen.ask(1)
trial = gen.ask_trials(1)
if trial:
all_trials.append(trial[0])
else:
Expand Down
2 changes: 1 addition & 1 deletion doc/source/examples/ps_random_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ distribution such as:

all_trials = []
while len(all_trials) <= 100:
trial = gen.ask(1)
trial = gen.ask_trials(1)
if trial:
all_trials.append(trial[0])
else:
Expand Down
4 changes: 2 additions & 2 deletions optimas/gen_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def persistent_generator(H, persis_info, gen_specs, libE_info):
# Store this information in the format expected by libE
H_o = np.zeros(number_of_gen_points, dtype=gen_specs["out"])
for i in range(number_of_gen_points):
generated_trials = generator.ask(1)
generated_trials = generator.ask_trials(1)
if generated_trials:
trial = generated_trials[0]
for var, val in zip(
Expand Down Expand Up @@ -110,7 +110,7 @@ def persistent_generator(H, persis_info, gen_specs, libE_info):
ev = Evaluation(parameter=par, value=y)
trial.complete_evaluation(ev)
# Register trial with unknown SEM
generator.tell([trial])
generator.tell_trials([trial])
# Set the number of points to generate to that number:
number_of_gen_points = min(n + n_failed_gens, max_evals - n_gens)
n_failed_gens = 0
Expand Down
4 changes: 2 additions & 2 deletions optimas/generators/ax/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def model(self) -> AxModelManager:
"""Get access to the underlying model using an `AxModelManager`."""
return self._model

def _ask(self, trials: List[Trial]) -> List[Trial]:
def ask(self, trials: List[Trial]) -> List[Trial]:
"""Fill in the parameter values of the requested trials."""
for trial in trials:
parameters, trial_id = self._ax_client.get_next_trial(
Expand All @@ -154,7 +154,7 @@ def _ask(self, trials: List[Trial]) -> List[Trial]:
trial.ax_trial_id = trial_id
return trials

def _tell(self, trials: List[Trial]) -> None:
def tell(self, trials: List[Trial]) -> None:
"""Incorporate evaluated trials into Ax client."""
for trial in trials:
try:
Expand Down
12 changes: 6 additions & 6 deletions optimas/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def n_evaluated_trials(self) -> int:
n_evaluated += 1
return n_evaluated

def ask(self, n_trials: int) -> List[Trial]:
def ask_trials(self, n_trials: int) -> List[Trial]:
"""Ask the generator to suggest the next ``n_trials`` to evaluate.
Parameters
Expand Down Expand Up @@ -215,7 +215,7 @@ def ask(self, n_trials: int) -> List[Trial]:
)
)
# Ask the generator to fill them.
gen_trials = self._ask(gen_trials)
gen_trials = self.ask(gen_trials)
# Keep only trials that have been given data.
for trial in gen_trials:
if len(trial.parameter_values) > 0:
Expand All @@ -236,7 +236,7 @@ def ask(self, n_trials: int) -> List[Trial]:
trials.append(trial)
return trials

def tell(
def tell_trials(
self, trials: List[Trial], allow_saving_model: Optional[bool] = True
) -> None:
"""Give trials back to generator once they have been evaluated.
Expand All @@ -250,7 +250,7 @@ def tell(
incorporating the evaluated trials. By default ``True``.
"""
self._tell(trials)
self.tell(trials)
for trial in trials:
if trial not in self._given_trials:
self._add_external_evaluated_trial(trial)
Expand Down Expand Up @@ -290,7 +290,7 @@ def incorporate_history(self, history: np.ndarray) -> None:
trials = self._create_trials_from_external_data(
history_ended, ignore_unrecognized_parameters=True
)
self.tell(trials, allow_saving_model=False)
self.tell_trials(trials, allow_saving_model=False)
# Communicate to history array whether the trial has been ignored.
for trial in trials:
i = np.where(history["trial_index"] == trial.index)[0][0]
Expand Down Expand Up @@ -578,7 +578,7 @@ def get_libe_specs(self) -> Dict:
libE_specs = {}
return libE_specs

def _ask(self, trials: List[Trial]) -> List[Trial]:
def ask(self, trials: List[Trial]) -> List[Trial]:
"""Ask method to be implemented by the Generator subclasses.
Parameters
Expand Down

0 comments on commit d1b7721

Please sign in to comment.