1
+ from __future__ import annotations
2
+
1
3
from functools import partial
2
4
3
5
import jax .numpy as jnp
4
6
import numpy as np
5
7
from jax import jit
6
8
9
+ from s2fft import recursions
10
+ from s2fft .utils import quadrature , quadrature_jax
11
+
7
12
8
13
def inverse_transform (
9
14
flmn : np .ndarray ,
10
- DW : np .ndarray ,
11
15
L : int ,
12
16
N : int ,
17
+ precomps : tuple [np .ndarray , np .ndarray ] | None = None ,
13
18
reality : bool = False ,
14
19
sampling : str = "mw" ,
15
20
) -> np .ndarray :
@@ -18,10 +23,11 @@ def inverse_transform(
18
23
19
24
Args:
20
25
flmn (np.ndarray): Wigner coefficients.
21
- DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
22
- Wigner d-functions and the corresponding upsampled quadrature weights.
23
26
L (int): Harmonic band-limit.
24
27
N (int): Azimuthal band-limit.
28
+ precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
29
+ reduced Wigner d-functions and the corresponding upsampled quadrature
30
+ weights. Defaults to None.
25
31
reality (bool, optional): Whether the signal on the sphere is real. If so,
26
32
conjugate symmetry is exploited to reduce computational costs.
27
33
Defaults to False.
@@ -37,9 +43,6 @@ def inverse_transform(
37
43
f"Fourier-Wigner algorithm does not support { sampling } sampling."
38
44
)
39
45
40
- # EXTRACT VARIOUS PRECOMPUTES
41
- Delta , _ = DW
42
-
43
46
# INDEX VALUES
44
47
n_start_ind = N - 1 if reality else 0
45
48
n_dim = N if reality else 2 * N - 1
@@ -52,15 +55,29 @@ def inverse_transform(
52
55
m = np .arange (- L + 1 - m_offset , L )
53
56
n = np .arange (n_start_ind - N + 1 , N )
54
57
55
- # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
58
+ # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
56
59
x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
57
- x [m_offset :, m_offset :] = np .einsum (
58
- "nlm,lam,lan,l->amn" ,
59
- flmn [n_start_ind :],
60
- Delta ,
61
- Delta [:, :, L - 1 + n ],
62
- (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ),
63
- )
60
+ flmn = np .einsum ("nlm,l->nlm" , flmn , (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ))
61
+
62
+ # PRECOMPUTE TRANSFORM
63
+ if precomps is not None :
64
+ delta , _ = precomps
65
+ x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
66
+ x [m_offset :, m_offset :] = np .einsum (
67
+ "nlm,lam,lan->amn" , flmn [n_start_ind :], delta , delta [:, :, L - 1 + n ]
68
+ )
69
+
70
+ # OTF TRANSFORM
71
+ else :
72
+ delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
73
+ for el in range (L ):
74
+ delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
75
+ x [m_offset :, m_offset :] += np .einsum (
76
+ "nm,am,an->amn" ,
77
+ flmn [n_start_ind :, el ],
78
+ delta_el ,
79
+ delta_el [:, L - 1 + n ],
80
+ )
64
81
65
82
# APPLY SIGN FUNCTION AND PHASE SHIFT
66
83
x = np .einsum ("amn,m,n,a->nam" , x , 1j ** (- m ), 1j ** (n ), np .exp (1j * m * theta0 ))
@@ -77,12 +94,12 @@ def inverse_transform(
77
94
return np .fft .ifft2 (x , axes = (0 , 2 ), norm = "forward" )
78
95
79
96
80
- @partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
97
+ @partial (jit , static_argnums = (1 , 2 , 4 , 5 ))
81
98
def inverse_transform_jax (
82
99
flmn : jnp .ndarray ,
83
- DW : jnp .ndarray ,
84
100
L : int ,
85
101
N : int ,
102
+ precomps : tuple [jnp .ndarray , jnp .ndarray ] | None = None ,
86
103
reality : bool = False ,
87
104
sampling : str = "mw" ,
88
105
) -> jnp .ndarray :
@@ -91,10 +108,11 @@ def inverse_transform_jax(
91
108
92
109
Args:
93
110
flmn (jnp.ndarray): Wigner coefficients.
94
- DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
95
- Wigner d-functions and the corresponding upsampled quadrature weights.
96
111
L (int): Harmonic band-limit.
97
112
N (int): Azimuthal band-limit.
113
+ precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
114
+ reduced Wigner d-functions and the corresponding upsampled quadrature
115
+ weights. Defaults to None.
98
116
reality (bool, optional): Whether the signal on the sphere is real. If so,
99
117
conjugate symmetry is exploited to reduce computational costs.
100
118
Defaults to False.
@@ -110,9 +128,6 @@ def inverse_transform_jax(
110
128
f"Fourier-Wigner algorithm does not support { sampling } sampling."
111
129
)
112
130
113
- # EXTRACT VARIOUS PRECOMPUTES
114
- Delta , _ = DW
115
-
116
131
# INDEX VALUES
117
132
n_start_ind = N - 1 if reality else 0
118
133
n_dim = N if reality else 2 * N - 1
@@ -125,14 +140,32 @@ def inverse_transform_jax(
125
140
m = jnp .arange (- L + 1 - m_offset , L )
126
141
n = jnp .arange (n_start_ind - N + 1 , N )
127
142
128
- # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
143
+ # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
129
144
x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = jnp .complex128 )
130
145
flmn = jnp .einsum ("nlm,l->nlm" , flmn , (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ))
131
- x = x .at [m_offset :, m_offset :].set (
132
- jnp .einsum (
133
- "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
146
+
147
+ # PRECOMPUTE TRANSFORM
148
+ if precomps is not None :
149
+ delta , _ = precomps
150
+ x = x .at [m_offset :, m_offset :].set (
151
+ jnp .einsum (
152
+ "nlm,lam,lan->amn" , flmn [n_start_ind :], delta , delta [:, :, L - 1 + n ]
153
+ )
134
154
)
135
- )
155
+
156
+ # OTF TRANSFORM
157
+ else :
158
+ delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
159
+ for el in range (L ):
160
+ delta_el = recursions .risbo_jax .compute_full (delta_el , jnp .pi / 2 , L , el )
161
+ x = x .at [m_offset :, m_offset :].add (
162
+ jnp .einsum (
163
+ "nm,am,an->amn" ,
164
+ flmn [n_start_ind :, el ],
165
+ delta_el ,
166
+ delta_el [:, L - 1 + n ],
167
+ )
168
+ )
136
169
137
170
# APPLY SIGN FUNCTION AND PHASE SHIFT
138
171
x = jnp .einsum ("amn,m,n,a->nam" , x , 1j ** (- m ), 1j ** (n ), jnp .exp (1j * m * theta0 ))
@@ -151,9 +184,9 @@ def inverse_transform_jax(
151
184
152
185
def forward_transform (
153
186
f : np .ndarray ,
154
- DW : np .ndarray ,
155
187
L : int ,
156
188
N : int ,
189
+ precomps : tuple [np .ndarray , np .ndarray ] | None = None ,
157
190
reality : bool = False ,
158
191
sampling : str = "mw" ,
159
192
) -> np .ndarray :
@@ -162,10 +195,11 @@ def forward_transform(
162
195
163
196
Args:
164
197
f (np.ndarray): Function sampled on the rotation group.
165
- DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
166
- Wigner d-functions and the corresponding upsampled quadrature weights.
167
198
L (int): Harmonic band-limit.
168
199
N (int): Azimuthal band-limit.
200
+ precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
201
+ reduced Wigner d-functions and the corresponding upsampled quadrature
202
+ weights. Defaults to None.
169
203
reality (bool, optional): Whether the signal on the sphere is real. If so,
170
204
conjugate symmetry is exploited to reduce computational costs.
171
205
Defaults to False.
@@ -181,9 +215,6 @@ def forward_transform(
181
215
f"Fourier-Wigner algorithm does not support { sampling } sampling."
182
216
)
183
217
184
- # EXTRACT VARIOUS PRECOMPUTES
185
- Delta , Quads = DW
186
-
187
218
# INDEX VALUES
188
219
n_start_ind = N - 1 if reality else 0
189
220
m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -223,14 +254,39 @@ def forward_transform(
223
254
# NB: Our convention here is conjugate to that of SSHT, in which
224
255
# the weights are conjugate but applied flipped and therefore are
225
256
# equivalent. To avoid flipping here we simply conjugate the weights.
226
- x = np .einsum ("nbm,b->nbm" , x , Quads )
257
+
258
+ if precomps is not None :
259
+ # PRECOMPUTE TRANSFORM
260
+ delta , quads = precomps
261
+ else :
262
+ # OTF TRANSFORM
263
+ delta = None
264
+ # COMPUTE QUADRATURE WEIGHTS
265
+ quads = np .zeros (4 * L - 3 , dtype = np .complex128 )
266
+ for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
267
+ quads [mm + 2 * (L - 1 )] = quadrature .mw_weights (- mm )
268
+ quads = np .fft .ifft (np .fft .ifftshift (quads ), norm = "forward" )
269
+
270
+ # APPLY QUADRATURE
271
+ x = np .einsum ("nbm,b->nbm" , x , quads )
227
272
228
273
# COMPUTE GMM BY FFT
229
274
x = np .fft .fft (x , axis = 1 , norm = "forward" )
230
275
x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
231
276
232
- # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
233
- x = np .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
277
+ # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
278
+ if delta is not None :
279
+ # PRECOMPUTE TRANSFORM
280
+ x = np .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
281
+ else :
282
+ # OTF TRANSFORM
283
+ delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
284
+ xx = np .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
285
+ for el in range (L ):
286
+ delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
287
+ xx [:, el ] = np .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
288
+ x = xx
289
+
234
290
x = np .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
235
291
236
292
# SYMMETRY REFLECT FOR N < 0
@@ -246,12 +302,12 @@ def forward_transform(
246
302
return x * (2.0 * np .pi ) ** 2
247
303
248
304
249
- @partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
305
+ @partial (jit , static_argnums = (1 , 2 , 4 , 5 ))
250
306
def forward_transform_jax (
251
307
f : jnp .ndarray ,
252
- DW : jnp .ndarray ,
253
308
L : int ,
254
309
N : int ,
310
+ precomps : tuple [jnp .ndarray , jnp .ndarray ] | None = None ,
255
311
reality : bool = False ,
256
312
sampling : str = "mw" ,
257
313
) -> jnp .ndarray :
@@ -260,10 +316,11 @@ def forward_transform_jax(
260
316
261
317
Args:
262
318
f (jnp.ndarray): Function sampled on the rotation group.
263
- DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
264
- Wigner d-functions and the corresponding upsampled quadrature weights.
265
319
L (int): Harmonic band-limit.
266
320
N (int): Azimuthal band-limit.
321
+ precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
322
+ reduced Wigner d-functions and the corresponding upsampled quadrature
323
+ weights. Defaults to None.
267
324
reality (bool, optional): Whether the signal on the sphere is real. If so,
268
325
conjugate symmetry is exploited to reduce computational costs.
269
326
Defaults to False.
@@ -279,9 +336,6 @@ def forward_transform_jax(
279
336
f"Fourier-Wigner algorithm does not support { sampling } sampling."
280
337
)
281
338
282
- # EXTRACT VARIOUS PRECOMPUTES
283
- Delta , Quads = DW
284
-
285
339
# INDEX VALUES
286
340
n_start_ind = N - 1 if reality else 0
287
341
m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -321,14 +375,41 @@ def forward_transform_jax(
321
375
# NB: Our convention here is conjugate to that of SSHT, in which
322
376
# the weights are conjugate but applied flipped and therefore are
323
377
# equivalent. To avoid flipping here we simply conjugate the weights.
324
- x = jnp .einsum ("nbm,b->nbm" , x , Quads )
378
+
379
+ if precomps is not None :
380
+ # PRECOMPUTE TRANSFORM
381
+ delta , quads = precomps
382
+ else :
383
+ # OTF TRANSFORM
384
+ delta = None
385
+ # COMPUTE QUADRATURE WEIGHTS
386
+ quads = jnp .zeros (4 * L - 3 , dtype = jnp .complex128 )
387
+ for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
388
+ quads = quads .at [mm + 2 * (L - 1 )].set (quadrature_jax .mw_weights (- mm ))
389
+ quads = jnp .fft .ifft (jnp .fft .ifftshift (quads ), norm = "forward" )
390
+
391
+ # APPLY QUADRATURE
392
+ x = jnp .einsum ("nbm,b->nbm" , x , quads )
325
393
326
394
# COMPUTE GMM BY FFT
327
395
x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
328
396
x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
329
397
330
- # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
331
- x = jnp .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
398
+ # Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
399
+ if delta is not None :
400
+ # PRECOMPUTE TRANSFORM
401
+ x = jnp .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
402
+ else :
403
+ # OTF TRANSFORM
404
+ delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
405
+ xx = jnp .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
406
+ for el in range (L ):
407
+ delta_el = recursions .risbo_jax .compute_full (delta_el , jnp .pi / 2 , L , el )
408
+ xx = xx .at [:, el ].set (
409
+ jnp .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
410
+ )
411
+ x = xx
412
+
332
413
x = jnp .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
333
414
334
415
# SYMMETRY REFLECT FOR N < 0
0 commit comments