Skip to content

Commit 20740af

Browse files
committed
Rework test and imports
1 parent c08918b commit 20740af

File tree

13 files changed

+100
-103
lines changed

13 files changed

+100
-103
lines changed

array_api_compat/cupy/__init__.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
1+
from typing import Final
12
from cupy import * # noqa: F403
23

34
# from cupy import * doesn't overwrite these builtin names
45
from cupy import abs, max, min, round # noqa: F401
56

67
# These imports may overwrite names from the import * above.
78
from ._aliases import * # noqa: F403
9+
from ._info import __array_namespace_info__ # noqa: F401
810

911
# See the comment in the numpy __init__.py
1012
__import__(__package__ + '.linalg')
1113
__import__(__package__ + '.fft')
1214

13-
__array_api_version__ = '2024.12'
15+
__array_api_version__: Final = '2024.12'
16+
17+
__all__ = sorted(
18+
{name for name in globals() if not name.startswith("__")}
19+
- {"Final", "_aliases", "_info", "_typing"}
20+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
21+
)
22+
23+
def __dir__() -> list[str]:
24+
return __all__

array_api_compat/cupy/_aliases.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from ..common import _aliases, _helpers
88
from ..common._typing import NestedSequence, SupportsBufferProtocol
99
from .._internal import get_xp
10-
from ._info import __array_namespace_info__
1110
from ._typing import Array, Device, DType
1211

1312
bool = cp.bool_
@@ -155,7 +154,7 @@ def count_nonzero(
155154
else:
156155
unstack = get_xp(cp)(_aliases.unstack)
157156

158-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
157+
__all__ = _aliases.__all__ + ['asarray', 'astype',
159158
'acos', 'acosh', 'asin', 'asinh', 'atan',
160159
'atan2', 'atanh', 'bitwise_left_shift',
161160
'bitwise_invert', 'bitwise_right_shift',
+15
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
11
from typing import Final
22

3+
import dask.array as da
34
from dask.array import * # noqa: F403
45

56
# These imports may overwrite names from the import * above.
7+
from . import _aliases
68
from ._aliases import * # noqa: F403
9+
from ._info import __array_namespace_info__ # noqa: F401
710

811
__array_api_version__: Final = "2024.12"
912

1013
# See the comment in the numpy __init__.py
1114
__import__(__package__ + '.linalg')
1215
__import__(__package__ + '.fft')
16+
17+
def _make_all(base):
18+
return sorted(
19+
set(base)
20+
| set(_aliases.__all__)
21+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
22+
)
23+
24+
__all__ = _make_all(da.__all__)
25+
26+
def __dir__() -> list[str]:
27+
return _make_all(dir(da))

array_api_compat/dask/array/_aliases.py

-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
NestedSequence,
4242
SupportsBufferProtocol,
4343
)
44-
from ._info import __array_namespace_info__
4544

4645
isdtype = get_xp(np)(_aliases.isdtype)
4746
unstack = get_xp(da)(_aliases.unstack)
@@ -355,7 +354,6 @@ def count_nonzero(
355354

356355

357356
__all__ = [
358-
"__array_namespace_info__",
359357
"count_nonzero",
360358
"bool",
361359
"int8", "int16", "int32", "int64",

array_api_compat/dask/array/fft.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
from dask.array.fft import * # noqa: F403
22
# dask.array.fft doesn't have __all__. If it is added, replace this with
3-
#
4-
# from dask.array.fft import __all__ as linalg_all
5-
_n = {}
3+
# from dask.array.fft import __all__ as fft_all
4+
_n: dict[str, object] = {}
65
exec('from dask.array.fft import *', _n)
7-
for k in ("__builtins__", "Sequence", "annotations", "warnings"):
8-
_n.pop(k, None)
96
fft_all = list(_n)
10-
del _n, k
7+
del _n
118

129
from ...common import _fft
1310
from ..._internal import get_xp

array_api_compat/dask/array/linalg.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@
2020
# from dask.array.linalg import __all__ as linalg_all
2121
_n = {}
2222
exec('from dask.array.linalg import *', _n)
23-
for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'):
24-
_n.pop(k, None)
2523
linalg_all = list(_n)
26-
del _n, k
24+
del _n
2725

2826
EighResult = _linalg.EighResult
2927
QRResult = _linalg.QRResult

array_api_compat/numpy/__init__.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ruff: noqa: PLC0414
22
from typing import Final
33

4+
import numpy as np
45
from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary]
56

67
# from numpy import * doesn't overwrite these builtin names
@@ -10,7 +11,9 @@
1011
from numpy import round as round
1112

1213
# These imports may overwrite names from the import * above.
14+
from . import _aliases
1315
from ._aliases import * # noqa: F403
16+
from ._info import __array_namespace_info__ # noqa: F401
1417

1518
# Don't know why, but we have to do an absolute import to import linalg. If we
1619
# instead do
@@ -23,13 +26,15 @@
2326

2427
__import__(__package__ + ".fft")
2528

26-
from ..common._helpers import * # noqa: F403
2729
from .linalg import matrix_transpose, vecdot # noqa: F401
2830

29-
try:
30-
# Used in asarray(). Not present in older versions.
31-
from numpy import _CopyMode # noqa: F401
32-
except ImportError:
33-
pass
34-
3531
__array_api_version__: Final = "2024.12"
32+
33+
__all__ = sorted(
34+
set(np.__all__)
35+
| set(_aliases.__all__)
36+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
37+
)
38+
39+
def __dir__() -> list[str]:
40+
return __all__

array_api_compat/numpy/_aliases.py

-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .._internal import get_xp
1010
from ..common import _aliases, _helpers
1111
from ..common._typing import NestedSequence, SupportsBufferProtocol
12-
from ._info import __array_namespace_info__
1312
from ._typing import Array, Device, DType
1413

1514
if TYPE_CHECKING:
@@ -158,7 +157,6 @@ def count_nonzero(
158157
unstack = get_xp(np)(_aliases.unstack)
159158

160159
__all__ = _aliases.__all__ + [
161-
"__array_namespace_info__",
162160
"asarray",
163161
"astype",
164162
"acos",

array_api_compat/numpy/fft.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from numpy.fft import fft2, ifft2, irfft2, rfft2
2+
from numpy.fft import * # noqa: F403
33

44
from .._internal import get_xp
55
from ..common import _fft
@@ -20,7 +20,7 @@
2020
ifftshift = get_xp(np)(_fft.ifftshift)
2121

2222

23-
__all__ = _fft.__all__ + ["fft2", "ifft2", "irfft2", "rfft2"]
23+
__all__ = sorted(set(np.fft.__all__) | set(_fft.__all__))
2424

2525
def __dir__() -> list[str]:
26-
return __all__
26+
return sorted(set(dir(np.fft)) | set(_fft.__all__))

array_api_compat/numpy/linalg.py

+4-18
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,7 @@
77

88
import numpy as np
99

10-
# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
11-
from numpy.linalg import (
12-
LinAlgError,
13-
cond,
14-
det,
15-
eig,
16-
eigvals,
17-
eigvalsh,
18-
inv,
19-
lstsq,
20-
matrix_power,
21-
multi_dot,
22-
norm,
23-
tensorinv,
24-
tensorsolve,
25-
)
10+
from numpy.linalg import * # noqa: F403
2611

2712
from .._internal import get_xp
2813
from ..common import _linalg
@@ -120,7 +105,7 @@ def solve(x1: Array, x2: Array, /) -> Array:
120105
vector_norm = get_xp(np)(_linalg.vector_norm)
121106

122107

123-
__all__ = _linalg.__all__ + [
108+
_all = [
124109
"LinAlgError",
125110
"cond",
126111
"det",
@@ -137,6 +122,7 @@ def solve(x1: Array, x2: Array, /) -> Array:
137122
"tensorsolve",
138123
"vector_norm",
139124
]
125+
__all__ = sorted(set(np.linalg.__all__) | set(_linalg.__all__) | set(_all))
140126

141127
def __dir__() -> list[str]:
142-
return __all__
128+
return sorted(set(dir(np.linalg)) | set(_linalg.__all__) | set(_all))

array_api_compat/torch/__init__.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from typing import Final
2+
13
from torch import * # noqa: F403
24

35
# Several names are not included in the above import *
6+
_torch_all = set()
47
import torch
58
for n in dir(torch):
69
if (n.startswith('_')
@@ -10,13 +13,25 @@
1013
or 'backward' in n):
1114
continue
1215
exec(f"{n} = torch.{n}")
16+
_torch_all.add(n)
1317
del n
1418

1519
# These imports may overwrite names from the import * above.
20+
import _aliases
1621
from ._aliases import * # noqa: F403
22+
from ._info import __array_namespace_info__ # noqa: F401
1723

1824
# See the comment in the numpy __init__.py
1925
__import__(__package__ + '.linalg')
2026
__import__(__package__ + '.fft')
2127

22-
__array_api_version__ = '2024.12'
28+
__array_api_version__: Final = '2024.12'
29+
30+
__all__ = sorted(
31+
set(_torch_all)
32+
| set(_aliases.__all__)
33+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
34+
)
35+
36+
def __dir__() -> list[str]:
37+
return __all__

array_api_compat/torch/_aliases.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .._internal import get_xp
1010
from ..common import _aliases
1111
from ..common._typing import NestedSequence, SupportsBufferProtocol
12-
from ._info import __array_namespace_info__
1312
from ._typing import Array, Device, DType
1413

1514
_int_dtypes = {
@@ -824,7 +823,7 @@ def sign(x: Array, /) -> Array:
824823
return out
825824

826825

827-
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
826+
__all__ = ['asarray', 'result_type', 'can_cast',
828827
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
829828
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
830829
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',

0 commit comments

Comments
 (0)