Skip to content

Commit fd1e8b1

Browse files
authored
Some minor updates to #260 (#298)
* Variable name and type hint cleanup * Refactor to remove repeated code
1 parent 1d41923 commit fd1e8b1

File tree

3 files changed

+113
-118
lines changed

3 files changed

+113
-118
lines changed

s2fft/precompute_transforms/fourier_wigner.py

+80-87
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from functools import partial
24

35
import jax.numpy as jnp
@@ -12,7 +14,7 @@ def inverse_transform(
1214
flmn: np.ndarray,
1315
L: int,
1416
N: int,
15-
DW: np.ndarray = None,
17+
precomps: tuple[np.ndarray, np.ndarray] | None = None,
1618
reality: bool = False,
1719
sampling: str = "mw",
1820
) -> np.ndarray:
@@ -23,9 +25,9 @@ def inverse_transform(
2325
flmn (np.ndarray): Wigner coefficients.
2426
L (int): Harmonic band-limit.
2527
N (int): Azimuthal band-limit.
26-
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
27-
Wigner d-functions and the corresponding upsampled quadrature weights.
28-
Defaults to None.
28+
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
29+
reduced Wigner d-functions and the corresponding upsampled quadrature
30+
weights. Defaults to None.
2931
reality (bool, optional): Whether the signal on the sphere is real. If so,
3032
conjugate symmetry is exploited to reduce computational costs.
3133
Defaults to False.
@@ -53,28 +55,28 @@ def inverse_transform(
5355
m = np.arange(-L + 1 - m_offset, L)
5456
n = np.arange(n_start_ind - N + 1, N)
5557

56-
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
58+
# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
5759
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
5860
flmn = np.einsum("nlm,l->nlm", flmn, (2 * np.arange(L) + 1) / (8 * np.pi**2))
5961

6062
# PRECOMPUTE TRANSFORM
61-
if DW is not None:
62-
Delta, _ = DW
63+
if precomps is not None:
64+
delta, _ = precomps
6365
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
6466
x[m_offset:, m_offset:] = np.einsum(
65-
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]
67+
"nlm,lam,lan->amn", flmn[n_start_ind:], delta, delta[:, :, L - 1 + n]
6668
)
6769

6870
# OTF TRANSFORM
6971
else:
70-
Delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
72+
delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
7173
for el in range(L):
72-
Delta_el = recursions.risbo.compute_full(Delta_el, np.pi / 2, L, el)
74+
delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el)
7375
x[m_offset:, m_offset:] += np.einsum(
7476
"nm,am,an->amn",
7577
flmn[n_start_ind:, el],
76-
Delta_el,
77-
Delta_el[:, L - 1 + n],
78+
delta_el,
79+
delta_el[:, L - 1 + n],
7880
)
7981

8082
# APPLY SIGN FUNCTION AND PHASE SHIFT
@@ -97,7 +99,7 @@ def inverse_transform_jax(
9799
flmn: jnp.ndarray,
98100
L: int,
99101
N: int,
100-
DW: jnp.ndarray = None,
102+
precomps: tuple[jnp.ndarray, jnp.ndarray] | None = None,
101103
reality: bool = False,
102104
sampling: str = "mw",
103105
) -> jnp.ndarray:
@@ -108,9 +110,9 @@ def inverse_transform_jax(
108110
flmn (jnp.ndarray): Wigner coefficients.
109111
L (int): Harmonic band-limit.
110112
N (int): Azimuthal band-limit.
111-
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
112-
Wigner d-functions and the corresponding upsampled quadrature weights.
113-
Defaults to None.
113+
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
114+
reduced Wigner d-functions and the corresponding upsampled quadrature
115+
weights. Defaults to None.
114116
reality (bool, optional): Whether the signal on the sphere is real. If so,
115117
conjugate symmetry is exploited to reduce computational costs.
116118
Defaults to False.
@@ -138,30 +140,30 @@ def inverse_transform_jax(
138140
m = jnp.arange(-L + 1 - m_offset, L)
139141
n = jnp.arange(n_start_ind - N + 1, N)
140142

141-
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
143+
# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
142144
x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128)
143145
flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2))
144146

145147
# PRECOMPUTE TRANSFORM
146-
if DW is not None:
147-
Delta, _ = DW
148+
if precomps is not None:
149+
delta, _ = precomps
148150
x = x.at[m_offset:, m_offset:].set(
149151
jnp.einsum(
150-
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]
152+
"nlm,lam,lan->amn", flmn[n_start_ind:], delta, delta[:, :, L - 1 + n]
151153
)
152154
)
153155

154156
# OTF TRANSFORM
155157
else:
156-
Delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
158+
delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
157159
for el in range(L):
158-
Delta_el = recursions.risbo_jax.compute_full(Delta_el, jnp.pi / 2, L, el)
160+
delta_el = recursions.risbo_jax.compute_full(delta_el, jnp.pi / 2, L, el)
159161
x = x.at[m_offset:, m_offset:].add(
160162
jnp.einsum(
161163
"nm,am,an->amn",
162164
flmn[n_start_ind:, el],
163-
Delta_el,
164-
Delta_el[:, L - 1 + n],
165+
delta_el,
166+
delta_el[:, L - 1 + n],
165167
)
166168
)
167169

@@ -184,7 +186,7 @@ def forward_transform(
184186
f: np.ndarray,
185187
L: int,
186188
N: int,
187-
DW: np.ndarray = None,
189+
precomps: tuple[np.ndarray, np.ndarray] | None = None,
188190
reality: bool = False,
189191
sampling: str = "mw",
190192
) -> np.ndarray:
@@ -195,9 +197,9 @@ def forward_transform(
195197
f (np.ndarray): Function sampled on the rotation group.
196198
L (int): Harmonic band-limit.
197199
N (int): Azimuthal band-limit.
198-
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
199-
Wigner d-functions and the corresponding upsampled quadrature weights.
200-
Defaults to None.
200+
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
201+
reduced Wigner d-functions and the corresponding upsampled quadrature
202+
weights. Defaults to None.
201203
reality (bool, optional): Whether the signal on the sphere is real. If so,
202204
conjugate symmetry is exploited to reduce computational costs.
203205
Defaults to False.
@@ -253,43 +255,38 @@ def forward_transform(
253255
# the weights are conjugate but applied flipped and therefore are
254256
# equivalent. To avoid flipping here we simply conjugate the weights.
255257

256-
# PRECOMPUTE TRANSFORM
257-
if DW is not None:
258-
# EXTRACT VARIOUS PRECOMPUTES
259-
Delta, Quads = DW
260-
261-
# APPLY QUADRATURE
262-
x = np.einsum("nbm,b->nbm", x, Quads)
263-
264-
# COMPUTE GMM BY FFT
265-
x = np.fft.fft(x, axis=1, norm="forward")
266-
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
267-
268-
# CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
269-
x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
270-
271-
# OTF TRANSFORM
258+
if precomps is not None:
259+
# PRECOMPUTE TRANSFORM
260+
delta, quads = precomps
272261
else:
262+
# OTF TRANSFORM
263+
delta = None
273264
# COMPUTE QUADRATURE WEIGHTS
274-
Quads = np.zeros(4 * L - 3, dtype=np.complex128)
265+
quads = np.zeros(4 * L - 3, dtype=np.complex128)
275266
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
276-
Quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
277-
Quads = np.fft.ifft(np.fft.ifftshift(Quads), norm="forward")
267+
quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
268+
quads = np.fft.ifft(np.fft.ifftshift(quads), norm="forward")
278269

279-
# APPLY QUADRATURE
280-
x = np.einsum("nbm,b->nbm", x, Quads)
270+
# APPLY QUADRATURE
271+
x = np.einsum("nbm,b->nbm", x, quads)
281272

282-
# COMPUTE GMM BY FFT
283-
x = np.fft.fft(x, axis=1, norm="forward")
284-
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
273+
# COMPUTE GMM BY FFT
274+
x = np.fft.fft(x, axis=1, norm="forward")
275+
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
285276

286-
# CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
287-
Delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
277+
# CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
278+
if delta is not None:
279+
# PRECOMPUTE TRANSFORM
280+
x = np.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
281+
else:
282+
# OTF TRANSFORM
283+
delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
288284
xx = np.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
289285
for el in range(L):
290-
Delta_el = recursions.risbo.compute_full(Delta_el, np.pi / 2, L, el)
291-
xx[:, el] = np.einsum("nam,am,an->nm", x, Delta_el, Delta_el[:, L - 1 + n])
286+
delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el)
287+
xx[:, el] = np.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n])
292288
x = xx
289+
293290
x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))
294291

295292
# SYMMETRY REFLECT FOR N < 0
@@ -310,7 +307,7 @@ def forward_transform_jax(
310307
f: jnp.ndarray,
311308
L: int,
312309
N: int,
313-
DW: jnp.ndarray = None,
310+
precomps: tuple[jnp.ndarray, jnp.ndarray] | None = None,
314311
reality: bool = False,
315312
sampling: str = "mw",
316313
) -> jnp.ndarray:
@@ -321,9 +318,9 @@ def forward_transform_jax(
321318
f (jnp.ndarray): Function sampled on the rotation group.
322319
L (int): Harmonic band-limit.
323320
N (int): Azimuthal band-limit.
324-
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
325-
Wigner d-functions and the corresponding upsampled quadrature weights.
326-
Defaults to None.
321+
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
322+
reduced Wigner d-functions and the corresponding upsampled quadrature
323+
weights. Defaults to None.
327324
reality (bool, optional): Whether the signal on the sphere is real. If so,
328325
conjugate symmetry is exploited to reduce computational costs.
329326
Defaults to False.
@@ -379,41 +376,37 @@ def forward_transform_jax(
379376
# the weights are conjugate but applied flipped and therefore are
380377
# equivalent. To avoid flipping here we simply conjugate the weights.
381378

382-
# PRECOMPUTE TRANSFORM
383-
if DW is not None:
384-
# EXTRACT VARIOUS PRECOMPUTES
385-
Delta, Quads = DW
386-
387-
# APPLY QUADRATURE
388-
x = jnp.einsum("nbm,b->nbm", x, Quads)
389-
390-
# COMPUTE GMM BY FFT
391-
x = jnp.fft.fft(x, axis=1, norm="forward")
392-
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
393-
394-
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
395-
x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
396-
379+
if precomps is not None:
380+
# PRECOMPUTE TRANSFORM
381+
delta, quads = precomps
397382
else:
398-
Quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
383+
# OTF TRANSFORM
384+
delta = None
385+
# COMPUTE QUADRATURE WEIGHTS
386+
quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
399387
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
400-
Quads = Quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
401-
Quads = jnp.fft.ifft(jnp.fft.ifftshift(Quads), norm="forward")
388+
quads = quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
389+
quads = jnp.fft.ifft(jnp.fft.ifftshift(quads), norm="forward")
402390

403-
# APPLY QUADRATURE
404-
x = jnp.einsum("nbm,b->nbm", x, Quads)
391+
# APPLY QUADRATURE
392+
x = jnp.einsum("nbm,b->nbm", x, quads)
405393

406-
# COMPUTE GMM BY FFT
407-
x = jnp.fft.fft(x, axis=1, norm="forward")
408-
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
394+
# COMPUTE GMM BY FFT
395+
x = jnp.fft.fft(x, axis=1, norm="forward")
396+
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
409397

410-
# CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
411-
Delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
398+
# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
399+
if delta is not None:
400+
# PRECOMPUTE TRANSFORM
401+
x = jnp.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
402+
else:
403+
# OTF TRANSFORM
404+
delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
412405
xx = jnp.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
413406
for el in range(L):
414-
Delta_el = recursions.risbo_jax.compute_full(Delta_el, jnp.pi / 2, L, el)
407+
delta_el = recursions.risbo_jax.compute_full(delta_el, jnp.pi / 2, L, el)
415408
xx = xx.at[:, el].set(
416-
jnp.einsum("nam,am,an->nm", x, Delta_el, Delta_el[:, L - 1 + n])
409+
jnp.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n])
417410
)
418411
x = xx
419412

tests/test_fourier_wigner.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,16 @@ def test_inverse_fourier_wigner_transform(
4747
)
4848
f = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params)
4949

50-
delta = None
5150
transform = fw.inverse_transform_jax if method == "jax" else fw.inverse_transform
5251
if delta_method.lower() == "precomp":
53-
delta = (
52+
precomps = (
5453
c.fourier_wigner_kernel_jax(L)
5554
if method == "jax"
5655
else c.fourier_wigner_kernel(L)
5756
)
58-
f_check = transform(flmn, L, N, delta, reality, sampling)
57+
else:
58+
precomps = None
59+
f_check = transform(flmn, L, N, precomps, reality, sampling)
5960
np.testing.assert_allclose(f, f_check.flatten("C"), atol=atol)
6061

6162

@@ -91,15 +92,16 @@ def test_forward_fourier_wigner_transform(
9192
)
9293
flmn = samples.flmn_1d_to_3d(so3.forward(f, params), L, N)
9394

94-
delta = None
9595
transform = fw.forward_transform_jax if method == "jax" else fw.forward_transform
9696
if delta_method.lower() == "precomp":
97-
delta = (
97+
precomps = (
9898
c.fourier_wigner_kernel_jax(L)
9999
if method == "jax"
100100
else c.fourier_wigner_kernel(L)
101101
)
102-
flmn_check = transform(f_3D, L, N, delta, reality, sampling)
102+
else:
103+
precomps = None
104+
flmn_check = transform(f_3D, L, N, precomps, reality, sampling)
103105
np.testing.assert_allclose(flmn, flmn_check, atol=atol)
104106

105107

@@ -121,8 +123,8 @@ def test_inverse_fourier_wigner_transform_high_N(
121123
f = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params)
122124

123125
f = f.real if reality else f
124-
delta = c.fourier_wigner_kernel(L)
125-
f_check = fw.inverse_transform(flmn, L, N, delta, reality, sampling)
126+
precomps = c.fourier_wigner_kernel(L)
127+
f_check = fw.inverse_transform(flmn, L, N, precomps, reality, sampling)
126128

127129
np.testing.assert_allclose(f, f_check.flatten("C"), atol=atol)
128130

@@ -151,6 +153,6 @@ def test_forward_fourier_wigner_transform_high_N(
151153
)
152154
flmn_so3 = samples.flmn_1d_to_3d(so3.forward(f_1D, params), L, N)
153155

154-
delta = c.fourier_wigner_kernel_jax(L)
155-
flmn_check = fw.forward_transform_jax(f_3D, L, N, delta, reality, sampling)
156+
precomps = c.fourier_wigner_kernel_jax(L)
157+
flmn_check = fw.forward_transform_jax(f_3D, L, N, precomps, reality, sampling)
156158
np.testing.assert_allclose(flmn_so3, flmn_check, atol=atol)

0 commit comments

Comments
 (0)