Skip to content

Commit 91c1ebd

Browse files
committed
improve triangulation and LearnerND typing
1 parent f21f19d commit 91c1ebd

File tree

2 files changed

+80
-69
lines changed

2 files changed

+80
-69
lines changed

Diff for: adaptive/learner/learnerND.py

+30-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import random
44
from collections import OrderedDict
55
from collections.abc import Iterable
6-
from typing import Any, Callable, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
77

88
import numpy as np
99
import scipy.spatial
@@ -13,6 +13,8 @@
1313

1414
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
1515
from adaptive.learner.triangulation import (
16+
Point,
17+
Simplex,
1618
Triangulation,
1719
circumsphere,
1820
fast_det,
@@ -40,7 +42,7 @@ def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float:
4042
return vol
4143

4244

43-
def orientation(simplex):
45+
def orientation(simplex: np.ndarray):
4446
matrix = np.subtract(simplex[:-1], simplex[-1])
4547
# See https://www.jstor.org/stable/2315353
4648
sign, _logdet = np.linalg.slogdet(matrix)
@@ -339,12 +341,14 @@ def __init__(
339341

340342
self.function = func
341343
self._tri = None
342-
self._losses = dict()
344+
self._losses: Dict[Simplex, float] = dict()
343345

344-
self._pending_to_simplex = dict() # vertex → simplex
346+
self._pending_to_simplex: Dict[Point, Simplex] = dict() # vertex → simplex
345347

346348
# triangulation of the pending points inside a specific simplex
347-
self._subtriangulations = dict() # simplex → triangulation
349+
self._subtriangulations: Dict[
350+
Simplex, Triangulation
351+
] = dict() # simplex → triangulation
348352

349353
# scale to unit hypercube
350354
# for the input
@@ -456,7 +460,7 @@ def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray],) -> No
456460
to_delete, to_add = tri.add_point(point, simplex, transform=self._transform)
457461
self._update_losses(to_delete, to_add)
458462

459-
def _simplex_exists(self, simplex: Any) -> bool: # XXX: specify simplex: Any
463+
def _simplex_exists(self, simplex: Simplex) -> bool:
460464
simplex = tuple(sorted(simplex))
461465
return simplex in self.tri.simplices
462466

@@ -498,7 +502,7 @@ def tell_pending(self, point: Tuple[float, ...], *, simplex=None,) -> None:
498502
self._update_subsimplex_losses(simpl, to_add)
499503

500504
def _try_adding_pending_point_to_simplex(
501-
self, point: Tuple[float, ...], simplex: Any, # XXX: specify simplex: Any
505+
self, point: Point, simplex: Simplex,
502506
) -> Any:
503507
# try to insert it
504508
if not self.tri.point_in_simplex(point, simplex):
@@ -512,8 +516,8 @@ def _try_adding_pending_point_to_simplex(
512516
return self._subtriangulations[simplex].add_point(point)
513517

514518
def _update_subsimplex_losses(
515-
self, simplex: Any, new_subsimplices: Any
516-
) -> None: # XXX: specify simplex: Any
519+
self, simplex: Simplex, new_subsimplices: Set[Simplex]
520+
) -> None:
517521
loss = self._losses[simplex]
518522

519523
loss_density = loss / self.tri.volume(simplex)
@@ -534,7 +538,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
534538
else:
535539
return self._ask_and_tell_pending(n)
536540

537-
def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]:
541+
def _ask_bound_point(self,) -> Tuple[Point, float]:
538542
# get the next bound point that is still available
539543
new_point = next(
540544
p
@@ -544,7 +548,7 @@ def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]:
544548
self.tell_pending(new_point)
545549
return new_point, np.inf
546550

547-
def _ask_point_without_known_simplices(self,) -> Tuple[Tuple[float, ...], float]:
551+
def _ask_point_without_known_simplices(self,) -> Tuple[Point, float]:
548552
assert not self._bounds_available
549553
# pick a random point inside the bounds
550554
# XXX: change this into picking a point based on volume loss
@@ -585,7 +589,7 @@ def _pop_highest_existing_simplex(self) -> Any:
585589
" be a simplex available if LearnerND.tri() is not None."
586590
)
587591

588-
def _ask_best_point(self,) -> Tuple[Tuple[float, ...], float]:
592+
def _ask_best_point(self,) -> Tuple[Point, float]:
589593
assert self.tri is not None
590594

591595
loss, simplex, subsimplex = self._pop_highest_existing_simplex()
@@ -612,7 +616,7 @@ def _bounds_available(self) -> bool:
612616
for p in self._bounds_points
613617
)
614618

615-
def _ask(self,) -> Tuple[Tuple[float, ...], float]:
619+
def _ask(self,) -> Tuple[Point, float]:
616620
if self._bounds_available:
617621
return self._ask_bound_point() # O(1)
618622

@@ -624,7 +628,7 @@ def _ask(self,) -> Tuple[Tuple[float, ...], float]:
624628

625629
return self._ask_best_point() # O(log N)
626630

627-
def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any
631+
def _compute_loss(self, simplex: Simplex) -> float:
628632
# get the loss
629633
vertices = self.tri.get_vertices(simplex)
630634
values = [self.data[tuple(v)] for v in vertices]
@@ -663,7 +667,7 @@ def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any
663667
)
664668
)
665669

666-
def _update_losses(self, to_delete: set, to_add: set) -> None:
670+
def _update_losses(self, to_delete: Set[Simplex], to_add: Set[Simplex]) -> None:
667671
# XXX: add the points outside the triangulation to this as well
668672
pending_points_unbound = set()
669673

@@ -733,13 +737,11 @@ def _recompute_all_losses(self) -> None:
733737
)
734738

735739
@property
736-
def _scale(self) -> Union[float, np.int64]:
740+
def _scale(self) -> float:
737741
# get the output scale
738742
return self._max_value - self._min_value
739743

740-
def _update_range(
741-
self, new_output: Union[List[int], float, float, np.ndarray]
742-
) -> bool:
744+
def _update_range(self, new_output: Union[List[int], float, np.ndarray]) -> bool:
743745
if self._min_value is None or self._max_value is None:
744746
# this is the first point, nothing to do, just set the range
745747
self._min_value = np.min(new_output)
@@ -790,7 +792,7 @@ def remove_unfinished(self) -> None:
790792
# Plotting related stuff #
791793
##########################
792794

793-
def plot(self, n=None, tri_alpha=0):
795+
def plot(self, n: Optional[int] = None, tri_alpha: float = 0):
794796
"""Plot the function we want to learn, only works in 2D.
795797
796798
Parameters
@@ -851,7 +853,7 @@ def plot(self, n=None, tri_alpha=0):
851853

852854
return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)
853855

854-
def plot_slice(self, cut_mapping, n=None):
856+
def plot_slice(self, cut_mapping: Dict[int, float], n: Optional[int] = None):
855857
"""Plot a 1D or 2D interpolated slice of a N-dimensional function.
856858
857859
Parameters
@@ -921,7 +923,7 @@ def plot_slice(self, cut_mapping, n=None):
921923
else:
922924
raise ValueError("Only 1 or 2-dimensional plots can be generated.")
923925

924-
def plot_3D(self, with_triangulation=False):
926+
def plot_3D(self, with_triangulation: bool = False):
925927
"""Plot the learner's data in 3D using plotly.
926928
927929
Does *not* work with the
@@ -1010,7 +1012,7 @@ def _set_data(self, data: OrderedDict) -> None:
10101012
if data:
10111013
self.tell_many(*zip(*data.items()))
10121014

1013-
def _get_iso(self, level=0.0, which="surface"):
1015+
def _get_iso(self, level: float = 0.0, which: str = "surface"):
10141016
if which == "surface":
10151017
if self.ndim != 3 or self.vdim != 1:
10161018
raise Exception(
@@ -1081,7 +1083,9 @@ def _get_vertex_index(a, b):
10811083

10821084
return vertices, faces_or_lines
10831085

1084-
def plot_isoline(self, level=0.0, n=None, tri_alpha=0):
1086+
def plot_isoline(
1087+
self, level: float = 0.0, n: Optional[int] = None, tri_alpha: float = 0
1088+
):
10851089
"""Plot the isoline at a specific level, only works in 2D.
10861090
10871091
Parameters
@@ -1121,7 +1125,7 @@ def plot_isoline(self, level=0.0, n=None, tri_alpha=0):
11211125
contour = contour.opts(style=contour_opts)
11221126
return plot * contour
11231127

1124-
def plot_isosurface(self, level=0.0, hull_opacity=0.2):
1128+
def plot_isosurface(self, level: float = 0.0, hull_opacity: float = 0.2):
11251129
"""Plots a linearly interpolated isosurface.
11261130
11271131
This is the 3D analog of an isoline. Does *not* work with the
@@ -1159,7 +1163,7 @@ def plot_isosurface(self, level=0.0, hull_opacity=0.2):
11591163
hull_mesh = self._get_hull_mesh(opacity=hull_opacity)
11601164
return plotly.offline.iplot([isosurface, hull_mesh])
11611165

1162-
def _get_hull_mesh(self, opacity=0.2):
1166+
def _get_hull_mesh(self, opacity: float = 0.2):
11631167
plotly = ensure_plotly()
11641168
hull = scipy.spatial.ConvexHull(self._bounds_points)
11651169

0 commit comments

Comments
 (0)