3
3
import random
4
4
from collections import OrderedDict
5
5
from collections .abc import Iterable
6
- from typing import Any , Callable , List , Optional , Tuple , Union
6
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
7
7
8
8
import numpy as np
9
9
import scipy .spatial
13
13
14
14
from adaptive .learner .base_learner import BaseLearner , uses_nth_neighbors
15
15
from adaptive .learner .triangulation import (
16
+ Point ,
17
+ Simplex ,
16
18
Triangulation ,
17
19
circumsphere ,
18
20
fast_det ,
@@ -40,7 +42,7 @@ def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float:
40
42
return vol
41
43
42
44
43
- def orientation (simplex ):
45
+ def orientation (simplex : np . ndarray ):
44
46
matrix = np .subtract (simplex [:- 1 ], simplex [- 1 ])
45
47
# See https://www.jstor.org/stable/2315353
46
48
sign , _logdet = np .linalg .slogdet (matrix )
@@ -339,12 +341,14 @@ def __init__(
339
341
340
342
self .function = func
341
343
self ._tri = None
342
- self ._losses = dict ()
344
+ self ._losses : Dict [ Simplex , float ] = dict ()
343
345
344
- self ._pending_to_simplex = dict () # vertex → simplex
346
+ self ._pending_to_simplex : Dict [ Point , Simplex ] = dict () # vertex → simplex
345
347
346
348
# triangulation of the pending points inside a specific simplex
347
- self ._subtriangulations = dict () # simplex → triangulation
349
+ self ._subtriangulations : Dict [
350
+ Simplex , Triangulation
351
+ ] = dict () # simplex → triangulation
348
352
349
353
# scale to unit hypercube
350
354
# for the input
@@ -456,7 +460,7 @@ def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray],) -> No
456
460
to_delete , to_add = tri .add_point (point , simplex , transform = self ._transform )
457
461
self ._update_losses (to_delete , to_add )
458
462
459
- def _simplex_exists (self , simplex : Any ) -> bool : # XXX: specify simplex: Any
463
+ def _simplex_exists (self , simplex : Simplex ) -> bool :
460
464
simplex = tuple (sorted (simplex ))
461
465
return simplex in self .tri .simplices
462
466
@@ -498,7 +502,7 @@ def tell_pending(self, point: Tuple[float, ...], *, simplex=None,) -> None:
498
502
self ._update_subsimplex_losses (simpl , to_add )
499
503
500
504
def _try_adding_pending_point_to_simplex (
501
- self , point : Tuple [ float , ...], simplex : Any , # XXX: specify simplex: Any
505
+ self , point : Point , simplex : Simplex ,
502
506
) -> Any :
503
507
# try to insert it
504
508
if not self .tri .point_in_simplex (point , simplex ):
@@ -512,8 +516,8 @@ def _try_adding_pending_point_to_simplex(
512
516
return self ._subtriangulations [simplex ].add_point (point )
513
517
514
518
def _update_subsimplex_losses (
515
- self , simplex : Any , new_subsimplices : Any
516
- ) -> None : # XXX: specify simplex: Any
519
+ self , simplex : Simplex , new_subsimplices : Set [ Simplex ]
520
+ ) -> None :
517
521
loss = self ._losses [simplex ]
518
522
519
523
loss_density = loss / self .tri .volume (simplex )
@@ -534,7 +538,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
534
538
else :
535
539
return self ._ask_and_tell_pending (n )
536
540
537
- def _ask_bound_point (self ,) -> Tuple [Tuple [ float , ...] , float ]:
541
+ def _ask_bound_point (self ,) -> Tuple [Point , float ]:
538
542
# get the next bound point that is still available
539
543
new_point = next (
540
544
p
@@ -544,7 +548,7 @@ def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]:
544
548
self .tell_pending (new_point )
545
549
return new_point , np .inf
546
550
547
- def _ask_point_without_known_simplices (self ,) -> Tuple [Tuple [ float , ...] , float ]:
551
+ def _ask_point_without_known_simplices (self ,) -> Tuple [Point , float ]:
548
552
assert not self ._bounds_available
549
553
# pick a random point inside the bounds
550
554
# XXX: change this into picking a point based on volume loss
@@ -585,7 +589,7 @@ def _pop_highest_existing_simplex(self) -> Any:
585
589
" be a simplex available if LearnerND.tri() is not None."
586
590
)
587
591
588
- def _ask_best_point (self ,) -> Tuple [Tuple [ float , ...] , float ]:
592
+ def _ask_best_point (self ,) -> Tuple [Point , float ]:
589
593
assert self .tri is not None
590
594
591
595
loss , simplex , subsimplex = self ._pop_highest_existing_simplex ()
@@ -612,7 +616,7 @@ def _bounds_available(self) -> bool:
612
616
for p in self ._bounds_points
613
617
)
614
618
615
- def _ask (self ,) -> Tuple [Tuple [ float , ...] , float ]:
619
+ def _ask (self ,) -> Tuple [Point , float ]:
616
620
if self ._bounds_available :
617
621
return self ._ask_bound_point () # O(1)
618
622
@@ -624,7 +628,7 @@ def _ask(self,) -> Tuple[Tuple[float, ...], float]:
624
628
625
629
return self ._ask_best_point () # O(log N)
626
630
627
- def _compute_loss (self , simplex : Any ) -> float : # XXX: specify simplex: Any
631
+ def _compute_loss (self , simplex : Simplex ) -> float :
628
632
# get the loss
629
633
vertices = self .tri .get_vertices (simplex )
630
634
values = [self .data [tuple (v )] for v in vertices ]
@@ -663,7 +667,7 @@ def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any
663
667
)
664
668
)
665
669
666
- def _update_losses (self , to_delete : set , to_add : set ) -> None :
670
+ def _update_losses (self , to_delete : Set [ Simplex ] , to_add : Set [ Simplex ] ) -> None :
667
671
# XXX: add the points outside the triangulation to this as well
668
672
pending_points_unbound = set ()
669
673
@@ -733,13 +737,11 @@ def _recompute_all_losses(self) -> None:
733
737
)
734
738
735
739
@property
736
- def _scale (self ) -> Union [ float , np . int64 ] :
740
+ def _scale (self ) -> float :
737
741
# get the output scale
738
742
return self ._max_value - self ._min_value
739
743
740
- def _update_range (
741
- self , new_output : Union [List [int ], float , float , np .ndarray ]
742
- ) -> bool :
744
+ def _update_range (self , new_output : Union [List [int ], float , np .ndarray ]) -> bool :
743
745
if self ._min_value is None or self ._max_value is None :
744
746
# this is the first point, nothing to do, just set the range
745
747
self ._min_value = np .min (new_output )
@@ -790,7 +792,7 @@ def remove_unfinished(self) -> None:
790
792
# Plotting related stuff #
791
793
##########################
792
794
793
- def plot (self , n = None , tri_alpha = 0 ):
795
+ def plot (self , n : Optional [ int ] = None , tri_alpha : float = 0 ):
794
796
"""Plot the function we want to learn, only works in 2D.
795
797
796
798
Parameters
@@ -851,7 +853,7 @@ def plot(self, n=None, tri_alpha=0):
851
853
852
854
return im .opts (style = im_opts ) * tris .opts (style = tri_opts , ** no_hover )
853
855
854
- def plot_slice (self , cut_mapping , n = None ):
856
+ def plot_slice (self , cut_mapping : Dict [ int , float ], n : Optional [ int ] = None ):
855
857
"""Plot a 1D or 2D interpolated slice of a N-dimensional function.
856
858
857
859
Parameters
@@ -921,7 +923,7 @@ def plot_slice(self, cut_mapping, n=None):
921
923
else :
922
924
raise ValueError ("Only 1 or 2-dimensional plots can be generated." )
923
925
924
- def plot_3D (self , with_triangulation = False ):
926
+ def plot_3D (self , with_triangulation : bool = False ):
925
927
"""Plot the learner's data in 3D using plotly.
926
928
927
929
Does *not* work with the
@@ -1010,7 +1012,7 @@ def _set_data(self, data: OrderedDict) -> None:
1010
1012
if data :
1011
1013
self .tell_many (* zip (* data .items ()))
1012
1014
1013
- def _get_iso (self , level = 0.0 , which = "surface" ):
1015
+ def _get_iso (self , level : float = 0.0 , which : str = "surface" ):
1014
1016
if which == "surface" :
1015
1017
if self .ndim != 3 or self .vdim != 1 :
1016
1018
raise Exception (
@@ -1081,7 +1083,9 @@ def _get_vertex_index(a, b):
1081
1083
1082
1084
return vertices , faces_or_lines
1083
1085
1084
- def plot_isoline (self , level = 0.0 , n = None , tri_alpha = 0 ):
1086
+ def plot_isoline (
1087
+ self , level : float = 0.0 , n : Optional [int ] = None , tri_alpha : float = 0
1088
+ ):
1085
1089
"""Plot the isoline at a specific level, only works in 2D.
1086
1090
1087
1091
Parameters
@@ -1121,7 +1125,7 @@ def plot_isoline(self, level=0.0, n=None, tri_alpha=0):
1121
1125
contour = contour .opts (style = contour_opts )
1122
1126
return plot * contour
1123
1127
1124
- def plot_isosurface (self , level = 0.0 , hull_opacity = 0.2 ):
1128
+ def plot_isosurface (self , level : float = 0.0 , hull_opacity : float = 0.2 ):
1125
1129
"""Plots a linearly interpolated isosurface.
1126
1130
1127
1131
This is the 3D analog of an isoline. Does *not* work with the
@@ -1159,7 +1163,7 @@ def plot_isosurface(self, level=0.0, hull_opacity=0.2):
1159
1163
hull_mesh = self ._get_hull_mesh (opacity = hull_opacity )
1160
1164
return plotly .offline .iplot ([isosurface , hull_mesh ])
1161
1165
1162
- def _get_hull_mesh (self , opacity = 0.2 ):
1166
+ def _get_hull_mesh (self , opacity : float = 0.2 ):
1163
1167
plotly = ensure_plotly ()
1164
1168
hull = scipy .spatial .ConvexHull (self ._bounds_points )
1165
1169
0 commit comments