Skip to content

Commit d5673a1

Browse files
committed
ENH: CuPy creation functions to respect device= parameter
1 parent 205c967 commit d5673a1

File tree

3 files changed

+81
-93
lines changed

3 files changed

+81
-93
lines changed

Diff for: array_api_compat/common/_aliases.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import inspect
88
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
99

10-
from ._helpers import _check_device, array_namespace
10+
from ._helpers import _device_ctx, array_namespace
1111
from ._helpers import device as _get_device
1212
from ._helpers import is_cupy_namespace as _is_cupy_namespace
1313
from ._typing import Array, Device, DType, Namespace
@@ -32,8 +32,8 @@ def arange(
3232
device: Device | None = None,
3333
**kwargs: object,
3434
) -> Array:
35-
_check_device(xp, device)
36-
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
35+
with _device_ctx(xp, device):
36+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
3737

3838

3939
def empty(
@@ -44,8 +44,8 @@ def empty(
4444
device: Device | None = None,
4545
**kwargs: object,
4646
) -> Array:
47-
_check_device(xp, device)
48-
return xp.empty(shape, dtype=dtype, **kwargs)
47+
with _device_ctx(xp, device):
48+
return xp.empty(shape, dtype=dtype, **kwargs)
4949

5050

5151
def empty_like(
@@ -57,8 +57,8 @@ def empty_like(
5757
device: Device | None = None,
5858
**kwargs: object,
5959
) -> Array:
60-
_check_device(xp, device)
61-
return xp.empty_like(x, dtype=dtype, **kwargs)
60+
with _device_ctx(xp, device, like=x):
61+
return xp.empty_like(x, dtype=dtype, **kwargs)
6262

6363

6464
def eye(
@@ -72,8 +72,8 @@ def eye(
7272
device: Device | None = None,
7373
**kwargs: object,
7474
) -> Array:
75-
_check_device(xp, device)
76-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
75+
with _device_ctx(xp, device):
76+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
7777

7878

7979
def full(
@@ -85,8 +85,8 @@ def full(
8585
device: Device | None = None,
8686
**kwargs: object,
8787
) -> Array:
88-
_check_device(xp, device)
89-
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
88+
with _device_ctx(xp, device):
89+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
9090

9191

9292
def full_like(
@@ -99,8 +99,8 @@ def full_like(
9999
device: Device | None = None,
100100
**kwargs: object,
101101
) -> Array:
102-
_check_device(xp, device)
103-
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
102+
with _device_ctx(xp, device, like=x):
103+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
104104

105105

106106
def linspace(
@@ -115,8 +115,8 @@ def linspace(
115115
endpoint: bool = True,
116116
**kwargs: object,
117117
) -> Array:
118-
_check_device(xp, device)
119-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
118+
with _device_ctx(xp, device):
119+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
120120

121121

122122
def ones(
@@ -127,8 +127,8 @@ def ones(
127127
device: Device | None = None,
128128
**kwargs: object,
129129
) -> Array:
130-
_check_device(xp, device)
131-
return xp.ones(shape, dtype=dtype, **kwargs)
130+
with _device_ctx(xp, device):
131+
return xp.ones(shape, dtype=dtype, **kwargs)
132132

133133

134134
def ones_like(
@@ -140,8 +140,8 @@ def ones_like(
140140
device: Device | None = None,
141141
**kwargs: object,
142142
) -> Array:
143-
_check_device(xp, device)
144-
return xp.ones_like(x, dtype=dtype, **kwargs)
143+
with _device_ctx(xp, device, like=x):
144+
return xp.ones_like(x, dtype=dtype, **kwargs)
145145

146146

147147
def zeros(
@@ -152,8 +152,8 @@ def zeros(
152152
device: Device | None = None,
153153
**kwargs: object,
154154
) -> Array:
155-
_check_device(xp, device)
156-
return xp.zeros(shape, dtype=dtype, **kwargs)
155+
with _device_ctx(xp, device):
156+
return xp.zeros(shape, dtype=dtype, **kwargs)
157157

158158

159159
def zeros_like(
@@ -165,8 +165,8 @@ def zeros_like(
165165
device: Device | None = None,
166166
**kwargs: object,
167167
) -> Array:
168-
_check_device(xp, device)
169-
return xp.zeros_like(x, dtype=dtype, **kwargs)
168+
with _device_ctx(xp, device, like=x):
169+
return xp.zeros_like(x, dtype=dtype, **kwargs)
170170

171171

172172
# np.unique() is split into four functions in the array API:

Diff for: array_api_compat/common/_helpers.py

+48-47
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
from __future__ import annotations
1010

11+
import contextlib
1112
import inspect
1213
import math
1314
import sys
1415
import warnings
15-
from collections.abc import Collection
16+
from collections.abc import Collection, Generator
1617
from typing import (
1718
TYPE_CHECKING,
1819
Any,
@@ -663,26 +664,42 @@ def your_function(x, y):
663664
get_namespace = array_namespace
664665

665666

666-
def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction]
667-
"""
668-
Validate dummy device on device-less array backends.
669-
670-
Notes
671-
-----
672-
This function is also invoked by CuPy, which does have multiple devices
673-
if there are multiple GPUs available.
674-
However, CuPy multi-device support is currently impossible
675-
without using the global device or a context manager:
667+
def _device_ctx(
668+
bare_xp: Namespace, device: Device, like: Array | None = None
669+
) -> Generator[None]:
670+
"""Context manager which changes the current device in CuPy.
676671
677-
https://github.com/data-apis/array-api-compat/pull/293
672+
Used internally by array creation functions in common._aliases.
678673
"""
679-
if bare_xp is sys.modules.get("numpy"):
680-
if device not in ("cpu", None):
674+
if device is None:
675+
if like is None:
676+
return contextlib.nullcontext()
677+
device = _device(like)
678+
679+
if bare_xp is sys.modules.get('numpy'):
680+
if device != "cpu":
681681
raise ValueError(f"Unsupported device for NumPy: {device!r}")
682+
return contextlib.nullcontext()
682683

683-
elif bare_xp is sys.modules.get("dask.array"):
684-
if device not in ("cpu", _DASK_DEVICE, None):
684+
if bare_xp is sys.modules.get('dask.array'):
685+
if device not in ("cpu", _DASK_DEVICE):
685686
raise ValueError(f"Unsupported device for Dask: {device!r}")
687+
return contextlib.nullcontext()
688+
689+
if bare_xp is sys.modules.get('cupy'):
690+
if not isinstance(device, bare_xp.cuda.Device):
691+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
692+
return device
693+
694+
# PyTorch doesn't have a "current device" context manager and you
695+
# can't use array creation functions from common._aliases.
696+
raise AssertionError("unreachable") # pragma: nocover
697+
698+
699+
def _check_device(bare_xp: Namespace, device: Device) -> None:
700+
"""Validate dummy device on device-less array backends."""
701+
with _device_ctx(bare_xp, device):
702+
pass
686703

687704

688705
# Placeholder object to represent the dask device
@@ -781,42 +798,26 @@ def _cupy_to_device(
781798
/,
782799
stream: int | Any | None = None,
783800
) -> _CupyArray:
784-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
785-
from cupy.cuda import Device as _Device # pyright: ignore
786-
from cupy.cuda import stream as stream_module # pyright: ignore
787-
from cupy_backends.cuda.api import runtime # pyright: ignore
801+
import cupy as cp
788802

789-
if device == x.device:
790-
return x
791-
elif device == "cpu":
803+
if device == "cpu":
792804
# allowing us to use `to_device(x, "cpu")`
793805
# is useful for portable test swapping between
794806
# host and device backends
795807
return x.get()
796-
elif not isinstance(device, _Device):
797-
raise ValueError(f"Unsupported device {device!r}")
798-
else:
799-
# see cupy/cupy#5985 for the reason how we handle device/stream here
800-
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
801-
prev_stream = None
802-
if stream is not None:
803-
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore
804-
# stream can be an int as specified in __dlpack__, or a CuPy stream
805-
if isinstance(stream, int):
806-
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
807-
elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType]
808-
pass
809-
else:
810-
raise ValueError("the input stream is not recognized")
811-
stream.use() # pyright: ignore[reportUnknownMemberType]
812-
try:
813-
runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType]
814-
arr = x.copy()
815-
finally:
816-
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
817-
if stream is not None:
818-
prev_stream.use()
819-
return arr
808+
if not isinstance(device, cp.cuda.Device):
809+
raise TypeError(f"Unsupported device {device!r}")
810+
811+
# stream can be an int as specified in __dlpack__, or a CuPy stream
812+
if isinstance(stream, int):
813+
stream = cp.cuda.ExternalStream(stream)
814+
elif stream is None:
815+
stream = contextlib.nullcontext()
816+
elif not isinstance(stream, cp.cuda.Stream):
817+
raise TypeError('the input stream is not recognized')
818+
819+
with device, stream:
820+
return cp.asarray(x)
820821

821822

822823
def _torch_to_device(

Diff for: array_api_compat/cupy/_aliases.py

+10-23
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464
finfo = get_xp(cp)(_aliases.finfo)
6565
iinfo = get_xp(cp)(_aliases.iinfo)
6666

67-
_copy_default = object()
68-
6967

7068
# asarray also adds the copy keyword, which is not present in numpy 1.0.
7169
def asarray(
@@ -79,7 +77,7 @@ def asarray(
7977
*,
8078
dtype: Optional[DType] = None,
8179
device: Optional[Device] = None,
82-
copy: Optional[bool] = _copy_default,
80+
copy: Optional[bool] = None,
8381
**kwargs,
8482
) -> Array:
8583
"""
@@ -88,26 +86,15 @@ def asarray(
8886
See the corresponding documentation in the array library and/or the array API
8987
specification for more details.
9088
"""
91-
with cp.cuda.Device(device):
92-
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
93-
# in asarray in numpy/_aliases.py.
94-
if copy is not _copy_default:
95-
# A future version of CuPy will change the meaning of copy=False
96-
# to mean no-copy. We don't know for certain what version it will
97-
# be yet, so to avoid breaking that version, we use a different
98-
# default value for copy so asarray(obj) with no copy kwarg will
99-
# always do the copy-if-needed behavior.
100-
101-
# This will still need to be updated to remove the
102-
# NotImplementedError for copy=False, but at least this won't
103-
# break the default or existing behavior.
104-
if copy is None:
105-
copy = False
106-
elif copy is False:
107-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
108-
kwargs['copy'] = copy
109-
110-
return cp.array(obj, dtype=dtype, **kwargs)
89+
if copy is False:
90+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
91+
92+
like = obj if isinstance(obj, cp.ndarray) else None
93+
with _helpers._device_ctx(cp, device, like=like):
94+
if copy is None:
95+
return cp.asarray(obj, dtype=dtype, **kwargs)
96+
else:
97+
return cp.array(obj, dtype=dtype, copy=True, **kwargs)
11198

11299

113100
def astype(

0 commit comments

Comments
 (0)