Skip to content

Commit 5de9f26

Browse files
committed
Merge branch 'main' into cupy_device
2 parents 96d8f5e + 3e5fdc0 commit 5de9f26

20 files changed

+127
-82
lines changed

Diff for: array_api_compat/common/_aliases.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
import inspect
8-
from typing import NamedTuple, Optional, Sequence, Tuple, Union
8+
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
99

1010
from ._typing import Array, Device, DType, Namespace
1111
from ._helpers import (
@@ -609,13 +609,30 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
609609
out[xp.isnan(x)] = xp.nan
610610
return out[()]
611611

612+
613+
def finfo(type_: DType | Array, /, xp: Namespace) -> Any:
614+
# It is surprisingly difficult to recognize a dtype apart from an array.
615+
# np.int64 is not the same as np.asarray(1).dtype!
616+
try:
617+
return xp.finfo(type_)
618+
except (ValueError, TypeError):
619+
return xp.finfo(type_.dtype)
620+
621+
622+
def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
623+
try:
624+
return xp.iinfo(type_)
625+
except (ValueError, TypeError):
626+
return xp.iinfo(type_.dtype)
627+
628+
612629
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
613630
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
614631
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
615632
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
616633
'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims',
617634
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
618635
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
619-
'unstack', 'sign']
636+
'unstack', 'sign', 'finfo', 'iinfo']
620637

621638
_all_ignore = ['inspect', 'array_namespace', 'NamedTuple']

Diff for: array_api_compat/common/_helpers.py

+38-38
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,44 @@ def your_function(x, y):
598598
get_namespace = array_namespace
599599

600600

601+
def _device_ctx(
602+
bare_xp: Namespace, device: Device, like: Array | None = None
603+
) -> Generator[None]:
604+
"""Context manager which changes the current device in CuPy.
605+
606+
Used internally by array creation functions in common._aliases.
607+
"""
608+
if device is None:
609+
if like is None:
610+
return contextlib.nullcontext()
611+
device = _device(like)
612+
613+
if bare_xp is sys.modules.get('numpy'):
614+
if device != "cpu":
615+
raise ValueError(f"Unsupported device for NumPy: {device!r}")
616+
return contextlib.nullcontext()
617+
618+
if bare_xp is sys.modules.get('dask.array'):
619+
if device not in ("cpu", _DASK_DEVICE):
620+
raise ValueError(f"Unsupported device for Dask: {device!r}")
621+
return contextlib.nullcontext()
622+
623+
if bare_xp is sys.modules.get('cupy'):
624+
if not isinstance(device, bare_xp.cuda.Device):
625+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
626+
return device
627+
628+
# PyTorch doesn't have a "current device" context manager and you
629+
# can't use array creation functions from common._aliases.
630+
raise AssertionError("unreachable") # pragma: nocover
631+
632+
633+
def _check_device(bare_xp: Namespace, device: Device) -> None:
634+
"""Validate dummy device on device-less array backends."""
635+
with _device_ctx(bare_xp, device):
636+
pass
637+
638+
601639
# Placeholder object to represent the dask device
602640
# when the array backend is not the CPU.
603641
# (since it is not easy to tell which device a dask array is on)
@@ -607,7 +645,6 @@ def __repr__(self):
607645

608646
_DASK_DEVICE = _dask_device()
609647

610-
611648
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
612649
# or cupy.ndarray. They are not included in array objects of this library
613650
# because this library just reuses the respective ndarray classes without
@@ -799,43 +836,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
799836
return x.to_device(device, stream=stream)
800837

801838

802-
def _device_ctx(
803-
bare_xp: Namespace, device: Device, like: Array | None = None
804-
) -> Generator[None]:
805-
"""Context manager which changes the current device in CuPy.
806-
807-
Used internally by array creation functions in common._aliases.
808-
"""
809-
if device is None:
810-
if like is None:
811-
return contextlib.nullcontext()
812-
device = _device(like)
813-
814-
if bare_xp is sys.modules.get('numpy'):
815-
if device != "cpu":
816-
raise ValueError(f"Unsupported device for NumPy: {device!r}")
817-
return contextlib.nullcontext()
818-
819-
if bare_xp is sys.modules.get('dask.array'):
820-
if device not in ("cpu", _DASK_DEVICE):
821-
raise ValueError(f"Unsupported device for Dask: {device!r}")
822-
return contextlib.nullcontext()
823-
824-
if bare_xp is sys.modules.get('cupy'):
825-
if not isinstance(device, bare_xp.cuda.Device):
826-
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
827-
return device
828-
829-
# PyTorch doesn't have a "current device" context manager and you
830-
# can't use array creation functions from common._aliases.
831-
raise AssertionError("unreachable") # pragma: nocover
832-
833-
834-
def _check_device(bare_xp: Namespace, device: Device) -> None:
835-
with _device_ctx(bare_xp, device):
836-
pass
837-
838-
839839
def size(x: Array) -> int | None:
840840
"""
841841
Return the total number of elements of x.

Diff for: array_api_compat/common/_linalg.py

+2
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,5 @@ def trace(
174174
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
175175
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
176176
'trace']
177+
178+
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']

Diff for: array_api_compat/cupy/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88

99
# See the comment in the numpy __init__.py
1010
__import__(__package__ + '.linalg')
11-
1211
__import__(__package__ + '.fft')
1312

14-
from ..common._helpers import * # noqa: F401,F403
15-
1613
__array_api_version__ = '2024.12'

Diff for: array_api_compat/cupy/_aliases.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6262
tensordot = get_xp(cp)(_aliases.tensordot)
6363
sign = get_xp(cp)(_aliases.sign)
64+
finfo = get_xp(cp)(_aliases.finfo)
65+
iinfo = get_xp(cp)(_aliases.iinfo)
6466

6567

6668
# asarray also adds the copy keyword, which is not present in numpy 1.0.
@@ -87,7 +89,7 @@ def asarray(
8789
if copy is False:
8890
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
8991

90-
like = obj if _helpers.is_cupy_array(obj) else None
92+
like = obj if isinstance(obj, cp.ndarray) else None
9193
with _helpers._device_ctx(cp, device, like=like):
9294
if copy is None:
9395
return cp.asarray(obj, dtype=dtype, **kwargs)

Diff for: array_api_compat/dask/array/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55

66
__array_api_version__ = '2024.12'
77

8+
# See the comment in the numpy __init__.py
89
__import__(__package__ + '.linalg')
910
__import__(__package__ + '.fft')

Diff for: array_api_compat/dask/array/_aliases.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import numpy as np
66
from numpy import (
77
# dtypes
8-
iinfo,
9-
finfo,
108
bool_ as bool,
119
float32,
1210
float64,
@@ -133,6 +131,8 @@ def arange(
133131
matmul = get_xp(np)(_aliases.matmul)
134132
tensordot = get_xp(np)(_aliases.tensordot)
135133
sign = get_xp(np)(_aliases.sign)
134+
finfo = get_xp(np)(_aliases.finfo)
135+
iinfo = get_xp(np)(_aliases.iinfo)
136136

137137

138138
# asarray also adds the copy keyword, which is not present in numpy 1.0.
@@ -346,10 +346,9 @@ def count_nonzero(
346346
'__array_namespace_info__', 'asarray', 'astype', 'acos',
347347
'acosh', 'asin', 'asinh', 'atan', 'atan2',
348348
'atanh', 'bitwise_left_shift', 'bitwise_invert',
349-
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
349+
'bitwise_right_shift', 'concat', 'pow', 'can_cast',
350350
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
351-
'uint8', 'uint16', 'uint32', 'uint64',
352-
'complex64', 'complex128', 'iinfo', 'finfo',
351+
'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128',
353352
'can_cast', 'count_nonzero', 'result_type']
354353

355354
_all_ignore = ["array_namespace", "get_xp", "da", "np"]

Diff for: array_api_compat/numpy/__init__.py

-9
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,8 @@
1414
# It doesn't overwrite np.linalg from above. The import is generated
1515
# dynamically so that the library can be vendored.
1616
__import__(__package__ + '.linalg')
17-
1817
__import__(__package__ + '.fft')
1918

2019
from .linalg import matrix_transpose, vecdot # noqa: F401
2120

22-
from ..common._helpers import * # noqa: F403
23-
24-
try:
25-
# Used in asarray(). Not present in older versions.
26-
from numpy import _CopyMode # noqa: F401
27-
except ImportError:
28-
pass
29-
3021
__array_api_version__ = '2024.12'

Diff for: array_api_compat/numpy/_aliases.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6262
tensordot = get_xp(np)(_aliases.tensordot)
6363
sign = get_xp(np)(_aliases.sign)
64+
finfo = get_xp(np)(_aliases.finfo)
65+
iinfo = get_xp(np)(_aliases.iinfo)
6466

6567

6668
def _supports_buffer_protocol(obj):
@@ -86,7 +88,7 @@ def asarray(
8688
*,
8789
dtype: Optional[DType] = None,
8890
device: Optional[Device] = None,
89-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
91+
copy: Optional[Union[bool, np._CopyMode]] = None,
9092
**kwargs,
9193
) -> Array:
9294
"""

Diff for: array_api_compat/torch/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99
or 'cpu' in n
1010
or 'backward' in n):
1111
continue
12-
exec(n + ' = torch.' + n)
12+
exec(f"{n} = torch.{n}")
13+
del n
1314

1415
# These imports may overwrite names from the import * above.
1516
from ._aliases import * # noqa: F403
1617

1718
# See the comment in the numpy __init__.py
1819
__import__(__package__ + '.linalg')
19-
2020
__import__(__package__ + '.fft')
2121

22-
from ..common._helpers import * # noqa: F403
23-
2422
__array_api_version__ = '2024.12'

Diff for: array_api_compat/torch/_aliases.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
5-
from typing import List, Optional, Sequence, Tuple, Union
5+
from typing import Any, List, Optional, Sequence, Tuple, Union
66

77
import torch
88

99
from .._internal import get_xp
1010
from ..common import _aliases
11+
from ..common._typing import NestedSequence, SupportsBufferProtocol
1112
from ._info import __array_namespace_info__
1213
from ._typing import Array, Device, DType
1314

@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207208
remainder = _two_arg(torch.remainder)
208209
subtract = _two_arg(torch.subtract)
209210

211+
212+
def asarray(
213+
obj: (
214+
Array
215+
| bool | int | float | complex
216+
| NestedSequence[bool | int | float | complex]
217+
| SupportsBufferProtocol
218+
),
219+
/,
220+
*,
221+
dtype: DType | None = None,
222+
device: Device | None = None,
223+
copy: bool | None = None,
224+
**kwargs: Any,
225+
) -> Array:
226+
# torch.asarray does not respect input->output device propagation
227+
# https://github.com/pytorch/pytorch/issues/150199
228+
if device is None and isinstance(obj, torch.Tensor):
229+
device = obj.device
230+
return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
231+
232+
210233
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211234
# of 'axis'.
212235

@@ -227,6 +250,9 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
227250
unstack = get_xp(torch)(_aliases.unstack)
228251
cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
229252
cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
253+
finfo = get_xp(torch)(_aliases.finfo)
254+
iinfo = get_xp(torch)(_aliases.iinfo)
255+
230256

231257
# torch.sort also returns a tuple
232258
# https://github.com/pytorch/pytorch/issues/70921
@@ -282,7 +308,6 @@ def prod(x: Array,
282308
dtype: Optional[DType] = None,
283309
keepdims: bool = False,
284310
**kwargs) -> Array:
285-
x = torch.asarray(x)
286311
ndim = x.ndim
287312

288313
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -318,7 +343,6 @@ def sum(x: Array,
318343
dtype: Optional[DType] = None,
319344
keepdims: bool = False,
320345
**kwargs) -> Array:
321-
x = torch.asarray(x)
322346
ndim = x.ndim
323347

324348
# https://github.com/pytorch/pytorch/issues/29137.
@@ -348,7 +372,6 @@ def any(x: Array,
348372
axis: Optional[Union[int, Tuple[int, ...]]] = None,
349373
keepdims: bool = False,
350374
**kwargs) -> Array:
351-
x = torch.asarray(x)
352375
ndim = x.ndim
353376
if axis == ():
354377
return x.to(torch.bool)
@@ -373,7 +396,6 @@ def all(x: Array,
373396
axis: Optional[Union[int, Tuple[int, ...]]] = None,
374397
keepdims: bool = False,
375398
**kwargs) -> Array:
376-
x = torch.asarray(x)
377399
ndim = x.ndim
378400
if axis == ():
379401
return x.to(torch.bool)
@@ -816,7 +838,7 @@ def sign(x: Array, /) -> Array:
816838
return out
817839

818840

819-
__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
841+
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
820842
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
821843
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
822844
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
@@ -832,6 +854,6 @@ def sign(x: Array, /) -> Array:
832854
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
833855
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
834856
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
835-
'take', 'take_along_axis', 'sign']
857+
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo']
836858

837859
_all_ignore = ['torch', 'get_xp']

Diff for: array_api_compat/torch/_typing.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
__all__ = ["Array", "DType", "Device"]
1+
__all__ = ["Array", "Device", "DType"]
22

3-
from torch import dtype as DType, Tensor as Array
4-
from ..common._typing import Device
3+
from torch import device as Device, dtype as DType, Tensor as Array

Diff for: cupy-xfails.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ array_api_tests/test_array_object.py::test_getitem
1414
# copy=False is not yet implemented
1515
array_api_tests/test_creation_functions.py::test_asarray_arrays
1616

17-
# finfo test is testing that the result is a float instead of float32 (see
18-
# also https://github.com/data-apis/array-api/issues/405)
17+
# attributes are np.float32 instead of float
18+
# (see also https://github.com/data-apis/array-api/issues/405)
1919
array_api_tests/test_data_type_functions.py::test_finfo[float32]
20+
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
2021

2122
# Some array attributes are missing, and we do not wrap the array object
2223
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]

Diff for: dask-xfails.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ array_api_tests/test_array_object.py::test_getitem_masking
1212
# zero division error, and typeerror: tuple indices must be integers or slices not tuple
1313
array_api_tests/test_creation_functions.py::test_eye
1414

15-
# finfo(float32).eps returns float32 but should return float
15+
# attributes are np.float32 instead of float
16+
# (see also https://github.com/data-apis/array-api/issues/405)
1617
array_api_tests/test_data_type_functions.py::test_finfo[float32]
18+
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
1719

1820
# out[-1]=dask.array<getitem ...> but should be some floating number
1921
# (I think the test is not forcing the op to be computed?)

0 commit comments

Comments
 (0)