1
+ from __future__ import annotations
2
+
1
3
from functools import partial
2
4
3
5
import jax .numpy as jnp
@@ -12,7 +14,7 @@ def inverse_transform(
12
14
flmn : np .ndarray ,
13
15
L : int ,
14
16
N : int ,
15
- DW : np .ndarray = None ,
17
+ precomps : tuple [ np .ndarray , np . ndarray ] | None = None ,
16
18
reality : bool = False ,
17
19
sampling : str = "mw" ,
18
20
) -> np .ndarray :
@@ -23,9 +25,9 @@ def inverse_transform(
23
25
flmn (np.ndarray): Wigner coefficients.
24
26
L (int): Harmonic band-limit.
25
27
N (int): Azimuthal band-limit.
26
- DW (Tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
27
- Wigner d-functions and the corresponding upsampled quadrature weights.
28
- Defaults to None.
28
+ precomps (tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the
29
+ reduced Wigner d-functions and the corresponding upsampled quadrature
30
+ weights. Defaults to None.
29
31
reality (bool, optional): Whether the signal on the sphere is real. If so,
30
32
conjugate symmetry is exploited to reduce computational costs.
31
33
Defaults to False.
@@ -53,28 +55,28 @@ def inverse_transform(
53
55
m = np .arange (- L + 1 - m_offset , L )
54
56
n = np .arange (n_start_ind - N + 1 , N )
55
57
56
- # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
58
+ # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
57
59
x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
58
60
flmn = np .einsum ("nlm,l->nlm" , flmn , (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ))
59
61
60
62
# PRECOMPUTE TRANSFORM
61
- if DW is not None :
62
- Delta , _ = DW
63
+ if precomps is not None :
64
+ delta , _ = precomps
63
65
x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
64
66
x [m_offset :, m_offset :] = np .einsum (
65
- "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
67
+ "nlm,lam,lan->amn" , flmn [n_start_ind :], delta , delta [:, :, L - 1 + n ]
66
68
)
67
69
68
70
# OTF TRANSFORM
69
71
else :
70
- Delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
72
+ delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
71
73
for el in range (L ):
72
- Delta_el = recursions .risbo .compute_full (Delta_el , np .pi / 2 , L , el )
74
+ delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
73
75
x [m_offset :, m_offset :] += np .einsum (
74
76
"nm,am,an->amn" ,
75
77
flmn [n_start_ind :, el ],
76
- Delta_el ,
77
- Delta_el [:, L - 1 + n ],
78
+ delta_el ,
79
+ delta_el [:, L - 1 + n ],
78
80
)
79
81
80
82
# APPLY SIGN FUNCTION AND PHASE SHIFT
@@ -97,7 +99,7 @@ def inverse_transform_jax(
97
99
flmn : jnp .ndarray ,
98
100
L : int ,
99
101
N : int ,
100
- DW : jnp .ndarray = None ,
102
+ precomps : tuple [ jnp .ndarray , jnp . ndarray ] | None = None ,
101
103
reality : bool = False ,
102
104
sampling : str = "mw" ,
103
105
) -> jnp .ndarray :
@@ -108,9 +110,9 @@ def inverse_transform_jax(
108
110
flmn (jnp.ndarray): Wigner coefficients.
109
111
L (int): Harmonic band-limit.
110
112
N (int): Azimuthal band-limit.
111
- DW (Tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
112
- Wigner d-functions and the corresponding upsampled quadrature weights.
113
- Defaults to None.
113
+ precomps (tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the
114
+ reduced Wigner d-functions and the corresponding upsampled quadrature
115
+ weights. Defaults to None.
114
116
reality (bool, optional): Whether the signal on the sphere is real. If so,
115
117
conjugate symmetry is exploited to reduce computational costs.
116
118
Defaults to False.
@@ -138,30 +140,30 @@ def inverse_transform_jax(
138
140
m = jnp .arange (- L + 1 - m_offset , L )
139
141
n = jnp .arange (n_start_ind - N + 1 , N )
140
142
141
- # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
143
+ # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
142
144
x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = jnp .complex128 )
143
145
flmn = jnp .einsum ("nlm,l->nlm" , flmn , (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ))
144
146
145
147
# PRECOMPUTE TRANSFORM
146
- if DW is not None :
147
- Delta , _ = DW
148
+ if precomps is not None :
149
+ delta , _ = precomps
148
150
x = x .at [m_offset :, m_offset :].set (
149
151
jnp .einsum (
150
- "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
152
+ "nlm,lam,lan->amn" , flmn [n_start_ind :], delta , delta [:, :, L - 1 + n ]
151
153
)
152
154
)
153
155
154
156
# OTF TRANSFORM
155
157
else :
156
- Delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
158
+ delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
157
159
for el in range (L ):
158
- Delta_el = recursions .risbo_jax .compute_full (Delta_el , jnp .pi / 2 , L , el )
160
+ delta_el = recursions .risbo_jax .compute_full (delta_el , jnp .pi / 2 , L , el )
159
161
x = x .at [m_offset :, m_offset :].add (
160
162
jnp .einsum (
161
163
"nm,am,an->amn" ,
162
164
flmn [n_start_ind :, el ],
163
- Delta_el ,
164
- Delta_el [:, L - 1 + n ],
165
+ delta_el ,
166
+ delta_el [:, L - 1 + n ],
165
167
)
166
168
)
167
169
@@ -184,7 +186,7 @@ def forward_transform(
184
186
f : np .ndarray ,
185
187
L : int ,
186
188
N : int ,
187
- DW : np .ndarray = None ,
189
+ precomps : tuple [ np .ndarray , np . ndarray ] | None = None ,
188
190
reality : bool = False ,
189
191
sampling : str = "mw" ,
190
192
) -> np .ndarray :
@@ -195,9 +197,9 @@ def forward_transform(
195
197
f (np.ndarray): Function sampled on the rotation group.
196
198
L (int): Harmonic band-limit.
197
199
N (int): Azimuthal band-limit.
198
- DW (Tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
199
- Wigner d-functions and the corresponding upsampled quadrature weights.
200
- Defaults to None.
200
+ precomps (tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the
201
+ reduced Wigner d-functions and the corresponding upsampled quadrature
202
+ weights. Defaults to None.
201
203
reality (bool, optional): Whether the signal on the sphere is real. If so,
202
204
conjugate symmetry is exploited to reduce computational costs.
203
205
Defaults to False.
@@ -253,43 +255,38 @@ def forward_transform(
253
255
# the weights are conjugate but applied flipped and therefore are
254
256
# equivalent. To avoid flipping here we simply conjugate the weights.
255
257
256
- # PRECOMPUTE TRANSFORM
257
- if DW is not None :
258
- # EXTRACT VARIOUS PRECOMPUTES
259
- Delta , Quads = DW
260
-
261
- # APPLY QUADRATURE
262
- x = np .einsum ("nbm,b->nbm" , x , Quads )
263
-
264
- # COMPUTE GMM BY FFT
265
- x = np .fft .fft (x , axis = 1 , norm = "forward" )
266
- x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
267
-
268
- # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
269
- x = np .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
270
-
271
- # OTF TRANSFORM
258
+ if precomps is not None :
259
+ # PRECOMPUTE TRANSFORM
260
+ delta , quads = precomps
272
261
else :
262
+ # OTF TRANSFORM
263
+ delta = None
273
264
# COMPUTE QUADRATURE WEIGHTS
274
- Quads = np .zeros (4 * L - 3 , dtype = np .complex128 )
265
+ quads = np .zeros (4 * L - 3 , dtype = np .complex128 )
275
266
for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
276
- Quads [mm + 2 * (L - 1 )] = quadrature .mw_weights (- mm )
277
- Quads = np .fft .ifft (np .fft .ifftshift (Quads ), norm = "forward" )
267
+ quads [mm + 2 * (L - 1 )] = quadrature .mw_weights (- mm )
268
+ quads = np .fft .ifft (np .fft .ifftshift (quads ), norm = "forward" )
278
269
279
- # APPLY QUADRATURE
280
- x = np .einsum ("nbm,b->nbm" , x , Quads )
270
+ # APPLY QUADRATURE
271
+ x = np .einsum ("nbm,b->nbm" , x , quads )
281
272
282
- # COMPUTE GMM BY FFT
283
- x = np .fft .fft (x , axis = 1 , norm = "forward" )
284
- x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
273
+ # COMPUTE GMM BY FFT
274
+ x = np .fft .fft (x , axis = 1 , norm = "forward" )
275
+ x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
285
276
286
- # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
287
- Delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
277
+ # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
278
+ if delta is not None :
279
+ # PRECOMPUTE TRANSFORM
280
+ x = np .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
281
+ else :
282
+ # OTF TRANSFORM
283
+ delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
288
284
xx = np .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
289
285
for el in range (L ):
290
- Delta_el = recursions .risbo .compute_full (Delta_el , np .pi / 2 , L , el )
291
- xx [:, el ] = np .einsum ("nam,am,an->nm" , x , Delta_el , Delta_el [:, L - 1 + n ])
286
+ delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
287
+ xx [:, el ] = np .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
292
288
x = xx
289
+
293
290
x = np .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
294
291
295
292
# SYMMETRY REFLECT FOR N < 0
@@ -310,7 +307,7 @@ def forward_transform_jax(
310
307
f : jnp .ndarray ,
311
308
L : int ,
312
309
N : int ,
313
- DW : jnp .ndarray = None ,
310
+ precomps : tuple [ jnp .ndarray , jnp . ndarray ] | None = None ,
314
311
reality : bool = False ,
315
312
sampling : str = "mw" ,
316
313
) -> jnp .ndarray :
@@ -321,9 +318,9 @@ def forward_transform_jax(
321
318
f (jnp.ndarray): Function sampled on the rotation group.
322
319
L (int): Harmonic band-limit.
323
320
N (int): Azimuthal band-limit.
324
- DW (Tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
325
- Wigner d-functions and the corresponding upsampled quadrature weights.
326
- Defaults to None.
321
+ precomps (tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the
322
+ reduced Wigner d-functions and the corresponding upsampled quadrature
323
+ weights. Defaults to None.
327
324
reality (bool, optional): Whether the signal on the sphere is real. If so,
328
325
conjugate symmetry is exploited to reduce computational costs.
329
326
Defaults to False.
@@ -379,41 +376,37 @@ def forward_transform_jax(
379
376
# the weights are conjugate but applied flipped and therefore are
380
377
# equivalent. To avoid flipping here we simply conjugate the weights.
381
378
382
- # PRECOMPUTE TRANSFORM
383
- if DW is not None :
384
- # EXTRACT VARIOUS PRECOMPUTES
385
- Delta , Quads = DW
386
-
387
- # APPLY QUADRATURE
388
- x = jnp .einsum ("nbm,b->nbm" , x , Quads )
389
-
390
- # COMPUTE GMM BY FFT
391
- x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
392
- x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
393
-
394
- # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
395
- x = jnp .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
396
-
379
+ if precomps is not None :
380
+ # PRECOMPUTE TRANSFORM
381
+ delta , quads = precomps
397
382
else :
398
- Quads = jnp .zeros (4 * L - 3 , dtype = jnp .complex128 )
383
+ # OTF TRANSFORM
384
+ delta = None
385
+ # COMPUTE QUADRATURE WEIGHTS
386
+ quads = jnp .zeros (4 * L - 3 , dtype = jnp .complex128 )
399
387
for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
400
- Quads = Quads .at [mm + 2 * (L - 1 )].set (quadrature_jax .mw_weights (- mm ))
401
- Quads = jnp .fft .ifft (jnp .fft .ifftshift (Quads ), norm = "forward" )
388
+ quads = quads .at [mm + 2 * (L - 1 )].set (quadrature_jax .mw_weights (- mm ))
389
+ quads = jnp .fft .ifft (jnp .fft .ifftshift (quads ), norm = "forward" )
402
390
403
- # APPLY QUADRATURE
404
- x = jnp .einsum ("nbm,b->nbm" , x , Quads )
391
+ # APPLY QUADRATURE
392
+ x = jnp .einsum ("nbm,b->nbm" , x , quads )
405
393
406
- # COMPUTE GMM BY FFT
407
- x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
408
- x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
394
+ # COMPUTE GMM BY FFT
395
+ x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
396
+ x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
409
397
410
- # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
411
- Delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
398
+ # Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
399
+ if delta is not None :
400
+ # PRECOMPUTE TRANSFORM
401
+ x = jnp .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
402
+ else :
403
+ # OTF TRANSFORM
404
+ delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
412
405
xx = jnp .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
413
406
for el in range (L ):
414
- Delta_el = recursions .risbo_jax .compute_full (Delta_el , jnp .pi / 2 , L , el )
407
+ delta_el = recursions .risbo_jax .compute_full (delta_el , jnp .pi / 2 , L , el )
415
408
xx = xx .at [:, el ].set (
416
- jnp .einsum ("nam,am,an->nm" , x , Delta_el , Delta_el [:, L - 1 + n ])
409
+ jnp .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
417
410
)
418
411
x = xx
419
412
0 commit comments