Skip to content

Commit 3f1a2d0

Browse files
On-the-fly support for Fourier Wigner transforms with Risbo recursions (#260)
* add on-the-fly support for Fourier Wigner transforms * update custom ops test for new fourier wigner variable ordering * Some minor updates to #260 (#298) * Variable name and type hint cleanup * Refactor to remove repeated code --------- Co-authored-by: Matt Graham <[email protected]>
1 parent 83296b9 commit 3f1a2d0

File tree

3 files changed

+174
-83
lines changed

3 files changed

+174
-83
lines changed

s2fft/precompute_transforms/fourier_wigner.py

+126-45
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
from __future__ import annotations
2+
13
from functools import partial
24

35
import jax.numpy as jnp
46
import numpy as np
57
from jax import jit
68

9+
from s2fft import recursions
10+
from s2fft.utils import quadrature, quadrature_jax
11+
712

813
def inverse_transform(
914
flmn: np.ndarray,
10-
DW: np.ndarray,
1115
L: int,
1216
N: int,
17+
precomps: tuple[np.ndarray, np.ndarray] | None = None,
1318
reality: bool = False,
1419
sampling: str = "mw",
1520
) -> np.ndarray:
@@ -18,10 +23,11 @@ def inverse_transform(
1823
1924
Args:
2025
flmn (np.ndarray): Wigner coefficients.
21-
DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
22-
Wigner d-functions and the corresponding upsampled quadrature weights.
2326
L (int): Harmonic band-limit.
2427
N (int): Azimuthal band-limit.
28+
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
29+
reduced Wigner d-functions and the corresponding upsampled quadrature
30+
weights. Defaults to None.
2531
reality (bool, optional): Whether the signal on the sphere is real. If so,
2632
conjugate symmetry is exploited to reduce computational costs.
2733
Defaults to False.
@@ -37,9 +43,6 @@ def inverse_transform(
3743
f"Fourier-Wigner algorithm does not support {sampling} sampling."
3844
)
3945

40-
# EXTRACT VARIOUS PRECOMPUTES
41-
Delta, _ = DW
42-
4346
# INDEX VALUES
4447
n_start_ind = N - 1 if reality else 0
4548
n_dim = N if reality else 2 * N - 1
@@ -52,15 +55,29 @@ def inverse_transform(
5255
m = np.arange(-L + 1 - m_offset, L)
5356
n = np.arange(n_start_ind - N + 1, N)
5457

55-
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
58+
# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
5659
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
57-
x[m_offset:, m_offset:] = np.einsum(
58-
"nlm,lam,lan,l->amn",
59-
flmn[n_start_ind:],
60-
Delta,
61-
Delta[:, :, L - 1 + n],
62-
(2 * np.arange(L) + 1) / (8 * np.pi**2),
63-
)
60+
flmn = np.einsum("nlm,l->nlm", flmn, (2 * np.arange(L) + 1) / (8 * np.pi**2))
61+
62+
# PRECOMPUTE TRANSFORM
63+
if precomps is not None:
64+
delta, _ = precomps
65+
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
66+
x[m_offset:, m_offset:] = np.einsum(
67+
"nlm,lam,lan->amn", flmn[n_start_ind:], delta, delta[:, :, L - 1 + n]
68+
)
69+
70+
# OTF TRANSFORM
71+
else:
72+
delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
73+
for el in range(L):
74+
delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el)
75+
x[m_offset:, m_offset:] += np.einsum(
76+
"nm,am,an->amn",
77+
flmn[n_start_ind:, el],
78+
delta_el,
79+
delta_el[:, L - 1 + n],
80+
)
6481

6582
# APPLY SIGN FUNCTION AND PHASE SHIFT
6683
x = np.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), np.exp(1j * m * theta0))
@@ -77,12 +94,12 @@ def inverse_transform(
7794
return np.fft.ifft2(x, axes=(0, 2), norm="forward")
7895

7996

80-
@partial(jit, static_argnums=(2, 3, 4, 5))
97+
@partial(jit, static_argnums=(1, 2, 4, 5))
8198
def inverse_transform_jax(
8299
flmn: jnp.ndarray,
83-
DW: jnp.ndarray,
84100
L: int,
85101
N: int,
102+
precomps: tuple[jnp.ndarray, jnp.ndarray] | None = None,
86103
reality: bool = False,
87104
sampling: str = "mw",
88105
) -> jnp.ndarray:
@@ -91,10 +108,11 @@ def inverse_transform_jax(
91108
92109
Args:
93110
flmn (jnp.ndarray): Wigner coefficients.
94-
DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
95-
Wigner d-functions and the corresponding upsampled quadrature weights.
96111
L (int): Harmonic band-limit.
97112
N (int): Azimuthal band-limit.
113+
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
114+
reduced Wigner d-functions and the corresponding upsampled quadrature
115+
weights. Defaults to None.
98116
reality (bool, optional): Whether the signal on the sphere is real. If so,
99117
conjugate symmetry is exploited to reduce computational costs.
100118
Defaults to False.
@@ -110,9 +128,6 @@ def inverse_transform_jax(
110128
f"Fourier-Wigner algorithm does not support {sampling} sampling."
111129
)
112130

113-
# EXTRACT VARIOUS PRECOMPUTES
114-
Delta, _ = DW
115-
116131
# INDEX VALUES
117132
n_start_ind = N - 1 if reality else 0
118133
n_dim = N if reality else 2 * N - 1
@@ -125,14 +140,32 @@ def inverse_transform_jax(
125140
m = jnp.arange(-L + 1 - m_offset, L)
126141
n = jnp.arange(n_start_ind - N + 1, N)
127142

128-
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
143+
# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
129144
x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128)
130145
flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2))
131-
x = x.at[m_offset:, m_offset:].set(
132-
jnp.einsum(
133-
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]
146+
147+
# PRECOMPUTE TRANSFORM
148+
if precomps is not None:
149+
delta, _ = precomps
150+
x = x.at[m_offset:, m_offset:].set(
151+
jnp.einsum(
152+
"nlm,lam,lan->amn", flmn[n_start_ind:], delta, delta[:, :, L - 1 + n]
153+
)
134154
)
135-
)
155+
156+
# OTF TRANSFORM
157+
else:
158+
delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
159+
for el in range(L):
160+
delta_el = recursions.risbo_jax.compute_full(delta_el, jnp.pi / 2, L, el)
161+
x = x.at[m_offset:, m_offset:].add(
162+
jnp.einsum(
163+
"nm,am,an->amn",
164+
flmn[n_start_ind:, el],
165+
delta_el,
166+
delta_el[:, L - 1 + n],
167+
)
168+
)
136169

137170
# APPLY SIGN FUNCTION AND PHASE SHIFT
138171
x = jnp.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0))
@@ -151,9 +184,9 @@ def inverse_transform_jax(
151184

152185
def forward_transform(
153186
f: np.ndarray,
154-
DW: np.ndarray,
155187
L: int,
156188
N: int,
189+
precomps: tuple[np.ndarray, np.ndarray] | None = None,
157190
reality: bool = False,
158191
sampling: str = "mw",
159192
) -> np.ndarray:
@@ -162,10 +195,11 @@ def forward_transform(
162195
163196
Args:
164197
f (np.ndarray): Function sampled on the rotation group.
165-
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
166-
Wigner d-functions and the corresponding upsampled quadrature weights.
167198
L (int): Harmonic band-limit.
168199
N (int): Azimuthal band-limit.
200+
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
201+
reduced Wigner d-functions and the corresponding upsampled quadrature
202+
weights. Defaults to None.
169203
reality (bool, optional): Whether the signal on the sphere is real. If so,
170204
conjugate symmetry is exploited to reduce computational costs.
171205
Defaults to False.
@@ -181,9 +215,6 @@ def forward_transform(
181215
f"Fourier-Wigner algorithm does not support {sampling} sampling."
182216
)
183217

184-
# EXTRACT VARIOUS PRECOMPUTES
185-
Delta, Quads = DW
186-
187218
# INDEX VALUES
188219
n_start_ind = N - 1 if reality else 0
189220
m_offset = 1 if sampling.lower() == "mwss" else 0
@@ -223,14 +254,39 @@ def forward_transform(
223254
# NB: Our convention here is conjugate to that of SSHT, in which
224255
# the weights are conjugate but applied flipped and therefore are
225256
# equivalent. To avoid flipping here we simply conjugate the weights.
226-
x = np.einsum("nbm,b->nbm", x, Quads)
257+
258+
if precomps is not None:
259+
# PRECOMPUTE TRANSFORM
260+
delta, quads = precomps
261+
else:
262+
# OTF TRANSFORM
263+
delta = None
264+
# COMPUTE QUADRATURE WEIGHTS
265+
quads = np.zeros(4 * L - 3, dtype=np.complex128)
266+
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
267+
quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
268+
quads = np.fft.ifft(np.fft.ifftshift(quads), norm="forward")
269+
270+
# APPLY QUADRATURE
271+
x = np.einsum("nbm,b->nbm", x, quads)
227272

228273
# COMPUTE GMM BY FFT
229274
x = np.fft.fft(x, axis=1, norm="forward")
230275
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
231276

232-
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
233-
x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
277+
# CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
278+
if delta is not None:
279+
# PRECOMPUTE TRANSFORM
280+
x = np.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
281+
else:
282+
# OTF TRANSFORM
283+
delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
284+
xx = np.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
285+
for el in range(L):
286+
delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el)
287+
xx[:, el] = np.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n])
288+
x = xx
289+
234290
x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))
235291

236292
# SYMMETRY REFLECT FOR N < 0
@@ -246,12 +302,12 @@ def forward_transform(
246302
return x * (2.0 * np.pi) ** 2
247303

248304

249-
@partial(jit, static_argnums=(2, 3, 4, 5))
305+
@partial(jit, static_argnums=(1, 2, 4, 5))
250306
def forward_transform_jax(
251307
f: jnp.ndarray,
252-
DW: jnp.ndarray,
253308
L: int,
254309
N: int,
310+
precomps: tuple[jnp.ndarray, jnp.ndarray] | None = None,
255311
reality: bool = False,
256312
sampling: str = "mw",
257313
) -> jnp.ndarray:
@@ -260,10 +316,11 @@ def forward_transform_jax(
260316
261317
Args:
262318
f (jnp.ndarray): Function sampled on the rotation group.
263-
DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
264-
Wigner d-functions and the corresponding upsampled quadrature weights.
265319
L (int): Harmonic band-limit.
266320
N (int): Azimuthal band-limit.
321+
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
322+
reduced Wigner d-functions and the corresponding upsampled quadrature
323+
weights. Defaults to None.
267324
reality (bool, optional): Whether the signal on the sphere is real. If so,
268325
conjugate symmetry is exploited to reduce computational costs.
269326
Defaults to False.
@@ -279,9 +336,6 @@ def forward_transform_jax(
279336
f"Fourier-Wigner algorithm does not support {sampling} sampling."
280337
)
281338

282-
# EXTRACT VARIOUS PRECOMPUTES
283-
Delta, Quads = DW
284-
285339
# INDEX VALUES
286340
n_start_ind = N - 1 if reality else 0
287341
m_offset = 1 if sampling.lower() == "mwss" else 0
@@ -321,14 +375,41 @@ def forward_transform_jax(
321375
# NB: Our convention here is conjugate to that of SSHT, in which
322376
# the weights are conjugate but applied flipped and therefore are
323377
# equivalent. To avoid flipping here we simply conjugate the weights.
324-
x = jnp.einsum("nbm,b->nbm", x, Quads)
378+
379+
if precomps is not None:
380+
# PRECOMPUTE TRANSFORM
381+
delta, quads = precomps
382+
else:
383+
# OTF TRANSFORM
384+
delta = None
385+
# COMPUTE QUADRATURE WEIGHTS
386+
quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
387+
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
388+
quads = quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
389+
quads = jnp.fft.ifft(jnp.fft.ifftshift(quads), norm="forward")
390+
391+
# APPLY QUADRATURE
392+
x = jnp.einsum("nbm,b->nbm", x, quads)
325393

326394
# COMPUTE GMM BY FFT
327395
x = jnp.fft.fft(x, axis=1, norm="forward")
328396
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
329397

330-
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
331-
x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
398+
# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
399+
if delta is not None:
400+
# PRECOMPUTE TRANSFORM
401+
x = jnp.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
402+
else:
403+
# OTF TRANSFORM
404+
delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
405+
xx = jnp.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
406+
for el in range(L):
407+
delta_el = recursions.risbo_jax.compute_full(delta_el, jnp.pi / 2, L, el)
408+
xx = xx.at[:, el].set(
409+
jnp.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n])
410+
)
411+
x = xx
412+
332413
x = jnp.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))
333414

334415
# SYMMETRY REFLECT FOR N < 0

0 commit comments

Comments
 (0)