1
1
from __future__ import annotations
2
2
3
3
from copy import copy
4
+ from typing import Any , Tuple
4
5
5
6
import cloudpickle
6
7
from sortedcontainers import SortedDict , SortedSet
7
8
8
9
from adaptive .learner .base_learner import BaseLearner
10
+ from adaptive .types import Int
9
11
from adaptive .utils import assign_defaults , partial_function_from_dataframe
10
12
11
13
try :
16
18
except ModuleNotFoundError :
17
19
with_pandas = False
18
20
21
+ try :
22
+ from typing import TypeAlias
23
+ except ImportError :
24
+ from typing_extensions import TypeAlias
25
+
26
+
27
+ PointType : TypeAlias = Tuple [Int , Any ]
28
+
19
29
20
30
class _IgnoreFirstArgument :
21
31
"""Remove the first argument from the call signature.
@@ -30,7 +40,7 @@ class _IgnoreFirstArgument:
30
40
def __init__ (self , function ):
31
41
self .function = function
32
42
33
- def __call__ (self , index_point , * args , ** kwargs ):
43
+ def __call__ (self , index_point : PointType , * args , ** kwargs ):
34
44
index , point = index_point
35
45
return self .function (point , * args , ** kwargs )
36
46
@@ -81,7 +91,9 @@ def new(self) -> SequenceLearner:
81
91
"""Return a new `~adaptive.SequenceLearner` without the data."""
82
92
return SequenceLearner (self ._original_function , self .sequence )
83
93
84
- def ask (self , n , tell_pending = True ):
94
+ def ask (
95
+ self , n : int , tell_pending : bool = True
96
+ ) -> tuple [list [PointType ], list [float ]]:
85
97
indices = []
86
98
points = []
87
99
loss_improvements = []
@@ -99,40 +111,40 @@ def ask(self, n, tell_pending=True):
99
111
100
112
return points , loss_improvements
101
113
102
- def loss (self , real = True ):
114
+ def loss (self , real : bool = True ) -> float :
103
115
if not (self ._to_do_indices or self .pending_points ):
104
- return 0
116
+ return 0.0
105
117
else :
106
118
npoints = self .npoints + (0 if real else len (self .pending_points ))
107
119
return (self ._ntotal - npoints ) / self ._ntotal
108
120
109
- def remove_unfinished (self ):
121
+ def remove_unfinished (self ) -> None :
110
122
for i in self .pending_points :
111
123
self ._to_do_indices .add (i )
112
124
self .pending_points = set ()
113
125
114
- def tell (self , point , value ) :
126
+ def tell (self , point : PointType , value : Any ) -> None :
115
127
index , point = point
116
128
self .data [index ] = value
117
129
self .pending_points .discard (index )
118
130
self ._to_do_indices .discard (index )
119
131
120
- def tell_pending (self , point ) :
132
+ def tell_pending (self , point : PointType ) -> None :
121
133
index , point = point
122
134
self .pending_points .add (index )
123
135
self ._to_do_indices .discard (index )
124
136
125
- def done (self ):
137
+ def done (self ) -> bool :
126
138
return not self ._to_do_indices and not self .pending_points
127
139
128
- def result (self ):
140
+ def result (self ) -> list [ Any ] :
129
141
"""Get the function values in the same order as ``sequence``."""
130
142
if not self .done ():
131
143
raise Exception ("Learner is not yet complete." )
132
144
return list (self .data .values ())
133
145
134
146
@property
135
- def npoints (self ):
147
+ def npoints (self ) -> int :
136
148
return len (self .data )
137
149
138
150
def to_dataframe (
@@ -213,16 +225,18 @@ def load_dataframe(
213
225
y_name : str, optional
214
226
The ``y_name`` used in ``to_dataframe``, by default "y"
215
227
"""
216
- self .tell_many (df [[index_name , x_name ]].values , df [y_name ].values )
228
+ indices = df [index_name ].values
229
+ xs = df [x_name ].values
230
+ self .tell_many (zip (indices , xs ), df [y_name ].values )
217
231
if with_default_function_args :
218
232
self .function = partial_function_from_dataframe (
219
233
self ._original_function , df , function_prefix
220
234
)
221
235
222
- def _get_data (self ):
236
+ def _get_data (self ) -> dict [ int , Any ] :
223
237
return self .data
224
238
225
- def _set_data (self , data ) :
239
+ def _set_data (self , data : dict [ int , Any ]) -> None :
226
240
if data :
227
241
indices , values = zip (* data .items ())
228
242
# the points aren't used by tell, so we can safely pass None
0 commit comments