Skip to content

Commit aee9e17

Browse files
committed
Implement using generator
1 parent 3523343 commit aee9e17

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

adaptive/learner/balancing_learner.py

+35-23
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
import sys
55
from collections import defaultdict
6-
from collections.abc import Iterable, Sequence
6+
from collections.abc import Generator, Iterable, Sequence
77
from contextlib import suppress
88
from functools import partial
99
from operator import itemgetter
@@ -126,11 +126,10 @@ def __init__(
126126
self._cdims_default = cdims
127127

128128
if len({learner.__class__ for learner in self.learners}) > 1:
129-
raise TypeError(
130-
"A BalacingLearner can handle only one type" " of learners."
131-
)
129+
raise TypeError("A BalacingLearner can handle only one type of learners.")
132130

133131
self.strategy: STRATEGY_TYPE = strategy
132+
self._gen: Generator | None = None
134133

135134
def new(self) -> BalancingLearner:
136135
"""Create a new `BalancingLearner` with the same parameters."""
@@ -288,27 +287,16 @@ def _ask_and_tell_based_on_cycle(
288287
def _ask_and_tell_based_on_sequential(
289288
self, n: int
290289
) -> tuple[list[tuple[Int, Any]], list[float]]:
290+
if self._gen is None:
291+
self._gen = _sequential_generator(self.learners)
291292
points: list[tuple[Int, Any]] = []
292293
loss_improvements: list[float] = []
293-
learner_index = 0
294-
295-
while len(points) < n:
296-
learner = self.learners[learner_index]
297-
if learner.done(): # type: ignore[attr-defined]
298-
if learner_index == len(self.learners) - 1:
299-
break
300-
learner_index += 1
301-
continue
302-
303-
point, loss_improvement = learner.ask(n=1)
304-
if not point: # if learner is exhausted, we don't get points
305-
if learner_index == len(self.learners) - 1:
306-
break
307-
learner_index += 1
308-
continue
309-
points.append((learner_index, point[0]))
310-
loss_improvements.append(loss_improvement[0])
311-
self.tell_pending((learner_index, point[0]))
294+
for learner_index, point, loss_improvement in self._gen:
295+
points.append((learner_index, point))
296+
loss_improvements.append(loss_improvement)
297+
self.tell_pending((learner_index, point))
298+
if len(points) >= n:
299+
break
312300

313301
return points, loss_improvements
314302

@@ -629,3 +617,27 @@ def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
629617
def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
630618
learners, cdims, strategy = state
631619
self.__init__(learners, cdims=cdims, strategy=strategy) # type: ignore[misc]
620+
621+
622+
def _sequential_generator(
623+
learners: list[BaseLearner],
624+
) -> Generator[tuple[int, Any, float], None, None]:
625+
learner_index = 0
626+
if not hasattr(learners[0], "done"):
627+
msg = "All learners must have a `done` method to use the 'sequential' strategy."
628+
raise ValueError(msg)
629+
while True:
630+
learner = learners[learner_index]
631+
if learner.done(): # type: ignore[attr-defined]
632+
if learner_index == len(learners) - 1:
633+
return
634+
learner_index += 1
635+
continue
636+
637+
point, loss_improvement = learner.ask(n=1)
638+
if not point: # if learner is exhausted, we don't get points
639+
if learner_index == len(learners) - 1:
640+
return
641+
learner_index += 1
642+
continue
643+
yield learner_index, point[0], loss_improvement[0]

0 commit comments

Comments
 (0)