Skip to content

Commit 9860573

Browse files
authored
Typehint SequenceLearner (#366)
1 parent 50fae43 commit 9860573

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

adaptive/learner/sequence_learner.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
from copy import copy
4+
from typing import Any, Tuple
45

56
import cloudpickle
67
from sortedcontainers import SortedDict, SortedSet
78

89
from adaptive.learner.base_learner import BaseLearner
10+
from adaptive.types import Int
911
from adaptive.utils import assign_defaults, partial_function_from_dataframe
1012

1113
try:
@@ -16,6 +18,14 @@
1618
except ModuleNotFoundError:
1719
with_pandas = False
1820

21+
try:
22+
from typing import TypeAlias
23+
except ImportError:
24+
from typing_extensions import TypeAlias
25+
26+
27+
PointType: TypeAlias = Tuple[Int, Any]
28+
1929

2030
class _IgnoreFirstArgument:
2131
"""Remove the first argument from the call signature.
@@ -30,7 +40,7 @@ class _IgnoreFirstArgument:
3040
def __init__(self, function):
3141
self.function = function
3242

33-
def __call__(self, index_point, *args, **kwargs):
43+
def __call__(self, index_point: PointType, *args, **kwargs):
3444
index, point = index_point
3545
return self.function(point, *args, **kwargs)
3646

@@ -81,7 +91,9 @@ def new(self) -> SequenceLearner:
8191
"""Return a new `~adaptive.SequenceLearner` without the data."""
8292
return SequenceLearner(self._original_function, self.sequence)
8393

84-
def ask(self, n, tell_pending=True):
94+
def ask(
95+
self, n: int, tell_pending: bool = True
96+
) -> tuple[list[PointType], list[float]]:
8597
indices = []
8698
points = []
8799
loss_improvements = []
@@ -99,40 +111,40 @@ def ask(self, n, tell_pending=True):
99111

100112
return points, loss_improvements
101113

102-
def loss(self, real=True):
114+
def loss(self, real: bool = True) -> float:
103115
if not (self._to_do_indices or self.pending_points):
104-
return 0
116+
return 0.0
105117
else:
106118
npoints = self.npoints + (0 if real else len(self.pending_points))
107119
return (self._ntotal - npoints) / self._ntotal
108120

109-
def remove_unfinished(self):
121+
def remove_unfinished(self) -> None:
110122
for i in self.pending_points:
111123
self._to_do_indices.add(i)
112124
self.pending_points = set()
113125

114-
def tell(self, point, value):
126+
def tell(self, point: PointType, value: Any) -> None:
115127
index, point = point
116128
self.data[index] = value
117129
self.pending_points.discard(index)
118130
self._to_do_indices.discard(index)
119131

120-
def tell_pending(self, point):
132+
def tell_pending(self, point: PointType) -> None:
121133
index, point = point
122134
self.pending_points.add(index)
123135
self._to_do_indices.discard(index)
124136

125-
def done(self):
137+
def done(self) -> bool:
126138
return not self._to_do_indices and not self.pending_points
127139

128-
def result(self):
140+
def result(self) -> list[Any]:
129141
"""Get the function values in the same order as ``sequence``."""
130142
if not self.done():
131143
raise Exception("Learner is not yet complete.")
132144
return list(self.data.values())
133145

134146
@property
135-
def npoints(self):
147+
def npoints(self) -> int:
136148
return len(self.data)
137149

138150
def to_dataframe(
@@ -213,16 +225,18 @@ def load_dataframe(
213225
y_name : str, optional
214226
The ``y_name`` used in ``to_dataframe``, by default "y"
215227
"""
216-
self.tell_many(df[[index_name, x_name]].values, df[y_name].values)
228+
indices = df[index_name].values
229+
xs = df[x_name].values
230+
self.tell_many(zip(indices, xs), df[y_name].values)
217231
if with_default_function_args:
218232
self.function = partial_function_from_dataframe(
219233
self._original_function, df, function_prefix
220234
)
221235

222-
def _get_data(self):
236+
def _get_data(self) -> dict[int, Any]:
223237
return self.data
224238

225-
def _set_data(self, data):
239+
def _set_data(self, data: dict[int, Any]) -> None:
226240
if data:
227241
indices, values = zip(*data.items())
228242
# the points aren't used by tell, so we can safely pass None

0 commit comments

Comments
 (0)