From aee9e17444372d63aea6b2a03d025b69dd5a32e5 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Sun, 2 Jun 2024 17:49:05 -0700 Subject: [PATCH] Implement using generator --- adaptive/learner/balancing_learner.py | 58 ++++++++++++++++----------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/adaptive/learner/balancing_learner.py b/adaptive/learner/balancing_learner.py index 07512f5b..78fe59fd 100644 --- a/adaptive/learner/balancing_learner.py +++ b/adaptive/learner/balancing_learner.py @@ -3,7 +3,7 @@ import itertools import sys from collections import defaultdict -from collections.abc import Iterable, Sequence +from collections.abc import Generator, Iterable, Sequence from contextlib import suppress from functools import partial from operator import itemgetter @@ -126,11 +126,10 @@ def __init__( self._cdims_default = cdims if len({learner.__class__ for learner in self.learners}) > 1: - raise TypeError( - "A BalacingLearner can handle only one type" " of learners." - ) + raise TypeError("A BalacingLearner can handle only one type of learners.") self.strategy: STRATEGY_TYPE = strategy + self._gen: Generator | None = None def new(self) -> BalancingLearner: """Create a new `BalancingLearner` with the same parameters.""" @@ -288,27 +287,16 @@ def _ask_and_tell_based_on_cycle( def _ask_and_tell_based_on_sequential( self, n: int ) -> tuple[list[tuple[Int, Any]], list[float]]: + if self._gen is None: + self._gen = _sequential_generator(self.learners) points: list[tuple[Int, Any]] = [] loss_improvements: list[float] = [] - learner_index = 0 - - while len(points) < n: - learner = self.learners[learner_index] - if learner.done(): # type: ignore[attr-defined] - if learner_index == len(self.learners) - 1: - break - learner_index += 1 - continue - - point, loss_improvement = learner.ask(n=1) - if not point: # if learner is exhausted, we don't get points - if learner_index == len(self.learners) - 1: - break - learner_index += 1 - continue - points.append((learner_index, point[0])) - loss_improvements.append(loss_improvement[0]) - self.tell_pending((learner_index, point[0])) + for learner_index, point, loss_improvement in self._gen: + points.append((learner_index, point)) + loss_improvements.append(loss_improvement) + self.tell_pending((learner_index, point)) + if len(points) >= n: + break return points, loss_improvements @@ -629,3 +617,27 @@ def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]: def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]): learners, cdims, strategy = state self.__init__(learners, cdims=cdims, strategy=strategy) # type: ignore[misc] + + +def _sequential_generator( + learners: list[BaseLearner], +) -> Generator[tuple[int, Any, float], None, None]: + learner_index = 0 + if not hasattr(learners[0], "done"): + msg = "All learners must have a `done` method to use the 'sequential' strategy." + raise ValueError(msg) + while True: + learner = learners[learner_index] + if learner.done(): # type: ignore[attr-defined] + if learner_index == len(learners) - 1: + return + learner_index += 1 + continue + + point, loss_improvement = learner.ask(n=1) + if not point: # if learner is exhausted, we don't get points + if learner_index == len(learners) - 1: + return + learner_index += 1 + continue + yield learner_index, point[0], loss_improvement[0]