Skip to content

Commit cf24ea8

Browse files
committed
ENH: CuPy creation functions to respect device= parameter
1 parent 16978e6 commit cf24ea8

File tree

3 files changed

+84
-90
lines changed

3 files changed

+84
-90
lines changed

Diff for: array_api_compat/common/_aliases.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._typing import Array, Device, DType, Namespace
1111
from ._helpers import (
1212
array_namespace,
13-
_check_device,
13+
_device_ctx,
1414
device as _get_device,
1515
is_cupy_namespace as _is_cupy_namespace
1616
)
@@ -31,8 +31,8 @@ def arange(
3131
device: Optional[Device] = None,
3232
**kwargs,
3333
) -> Array:
34-
_check_device(xp, device)
35-
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
34+
with _device_ctx(xp, device):
35+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
3636

3737
def empty(
3838
shape: Union[int, Tuple[int, ...]],
@@ -42,8 +42,8 @@ def empty(
4242
device: Optional[Device] = None,
4343
**kwargs,
4444
) -> Array:
45-
_check_device(xp, device)
46-
return xp.empty(shape, dtype=dtype, **kwargs)
45+
with _device_ctx(xp, device):
46+
return xp.empty(shape, dtype=dtype, **kwargs)
4747

4848
def empty_like(
4949
x: Array,
@@ -54,8 +54,8 @@ def empty_like(
5454
device: Optional[Device] = None,
5555
**kwargs,
5656
) -> Array:
57-
_check_device(xp, device)
58-
return xp.empty_like(x, dtype=dtype, **kwargs)
57+
with _device_ctx(xp, device, like=x):
58+
return xp.empty_like(x, dtype=dtype, **kwargs)
5959

6060
def eye(
6161
n_rows: int,
@@ -68,8 +68,8 @@ def eye(
6868
device: Optional[Device] = None,
6969
**kwargs,
7070
) -> Array:
71-
_check_device(xp, device)
72-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
71+
with _device_ctx(xp, device):
72+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
7373

7474
def full(
7575
shape: Union[int, Tuple[int, ...]],
@@ -80,8 +80,8 @@ def full(
8080
device: Optional[Device] = None,
8181
**kwargs,
8282
) -> Array:
83-
_check_device(xp, device)
84-
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
83+
with _device_ctx(xp, device):
84+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
8585

8686
def full_like(
8787
x: Array,
@@ -93,8 +93,8 @@ def full_like(
9393
device: Optional[Device] = None,
9494
**kwargs,
9595
) -> Array:
96-
_check_device(xp, device)
97-
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
96+
with _device_ctx(xp, device, like=x):
97+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
9898

9999
def linspace(
100100
start: Union[int, float],
@@ -108,8 +108,8 @@ def linspace(
108108
endpoint: bool = True,
109109
**kwargs,
110110
) -> Array:
111-
_check_device(xp, device)
112-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
111+
with _device_ctx(xp, device):
112+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
113113

114114
def ones(
115115
shape: Union[int, Tuple[int, ...]],
@@ -119,8 +119,8 @@ def ones(
119119
device: Optional[Device] = None,
120120
**kwargs,
121121
) -> Array:
122-
_check_device(xp, device)
123-
return xp.ones(shape, dtype=dtype, **kwargs)
122+
with _device_ctx(xp, device):
123+
return xp.ones(shape, dtype=dtype, **kwargs)
124124

125125
def ones_like(
126126
x: Array,
@@ -131,8 +131,8 @@ def ones_like(
131131
device: Optional[Device] = None,
132132
**kwargs,
133133
) -> Array:
134-
_check_device(xp, device)
135-
return xp.ones_like(x, dtype=dtype, **kwargs)
134+
with _device_ctx(xp, device, like=x):
135+
return xp.ones_like(x, dtype=dtype, **kwargs)
136136

137137
def zeros(
138138
shape: Union[int, Tuple[int, ...]],
@@ -142,8 +142,8 @@ def zeros(
142142
device: Optional[Device] = None,
143143
**kwargs,
144144
) -> Array:
145-
_check_device(xp, device)
146-
return xp.zeros(shape, dtype=dtype, **kwargs)
145+
with _device_ctx(xp, device):
146+
return xp.zeros(shape, dtype=dtype, **kwargs)
147147

148148
def zeros_like(
149149
x: Array,
@@ -154,8 +154,8 @@ def zeros_like(
154154
device: Optional[Device] = None,
155155
**kwargs,
156156
) -> Array:
157-
_check_device(xp, device)
158-
return xp.zeros_like(x, dtype=dtype, **kwargs)
157+
with _device_ctx(xp, device, like=x):
158+
return xp.zeros_like(x, dtype=dtype, **kwargs)
159159

160160
# np.unique() is split into four functions in the array API:
161161
# unique_all, unique_counts, unique_inverse, and unique_values (this is done

Diff for: array_api_compat/common/_helpers.py

+51-44
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
"""
88
from __future__ import annotations
99

10+
import contextlib
1011
import sys
1112
import math
1213
import inspect
1314
import warnings
15+
from collections.abc import Generator
1416
from typing import Optional, Union, Any
1517

1618
from ._typing import Array, Device, Namespace
@@ -596,26 +598,42 @@ def your_function(x, y):
596598
get_namespace = array_namespace
597599

598600

599-
def _check_device(bare_xp, device):
600-
"""
601-
Validate dummy device on device-less array backends.
601+
def _device_ctx(
602+
bare_xp: Namespace, device: Device, like: Array | None = None
603+
) -> Generator[None]:
604+
"""Context manager which changes the current device in CuPy.
602605
603-
Notes
604-
-----
605-
This function is also invoked by CuPy, which does have multiple devices
606-
if there are multiple GPUs available.
607-
However, CuPy multi-device support is currently impossible
608-
without using the global device or a context manager:
609-
610-
https://github.com/data-apis/array-api-compat/pull/293
606+
Used internally by array creation functions in common._aliases.
611607
"""
608+
if device is None:
609+
if like is None:
610+
return contextlib.nullcontext()
611+
device = _device(like)
612+
612613
if bare_xp is sys.modules.get('numpy'):
613-
if device not in ("cpu", None):
614+
if device != "cpu":
614615
raise ValueError(f"Unsupported device for NumPy: {device!r}")
616+
return contextlib.nullcontext()
615617

616-
elif bare_xp is sys.modules.get('dask.array'):
617-
if device not in ("cpu", _DASK_DEVICE, None):
618+
if bare_xp is sys.modules.get('dask.array'):
619+
if device not in ("cpu", _DASK_DEVICE):
618620
raise ValueError(f"Unsupported device for Dask: {device!r}")
621+
return contextlib.nullcontext()
622+
623+
if bare_xp is sys.modules.get('cupy'):
624+
if not isinstance(device, bare_xp.cuda.Device):
625+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
626+
return device
627+
628+
# PyTorch doesn't have a "current device" context manager and you
629+
# can't use array creation functions from common._aliases.
630+
raise AssertionError("unreachable") # pragma: nocover
631+
632+
633+
def _check_device(bare_xp: Namespace, device: Device) -> None:
634+
"""Validate dummy device on device-less array backends."""
635+
with _device_ctx(bare_xp, device):
636+
pass
619637

620638

621639
# Placeholder object to represent the dask device
@@ -703,50 +721,39 @@ def device(x: Array, /) -> Device:
703721
# Prevent shadowing, used below
704722
_device = device
705723

724+
706725
# Based on cupy.array_api.Array.to_device
707726
def _cupy_to_device(x, device, /, stream=None):
708727
import cupy as cp
709-
from cupy.cuda import Device as _Device
710-
from cupy.cuda import stream as stream_module
711-
from cupy_backends.cuda.api import runtime
712728

713-
if device == x.device:
714-
return x
715-
elif device == "cpu":
729+
if device == "cpu":
716730
# allowing us to use `to_device(x, "cpu")`
717731
# is useful for portable test swapping between
718732
# host and device backends
719733
return x.get()
720-
elif not isinstance(device, _Device):
721-
raise ValueError(f"Unsupported device {device!r}")
722-
else:
723-
# see cupy/cupy#5985 for the reason how we handle device/stream here
724-
prev_device = runtime.getDevice()
725-
prev_stream: stream_module.Stream = None
726-
if stream is not None:
727-
prev_stream = stream_module.get_current_stream()
728-
# stream can be an int as specified in __dlpack__, or a CuPy stream
729-
if isinstance(stream, int):
730-
stream = cp.cuda.ExternalStream(stream)
731-
elif isinstance(stream, cp.cuda.Stream):
732-
pass
733-
else:
734-
raise ValueError('the input stream is not recognized')
735-
stream.use()
736-
try:
737-
runtime.setDevice(device.id)
738-
arr = x.copy()
739-
finally:
740-
runtime.setDevice(prev_device)
741-
if stream is not None:
742-
prev_stream.use()
743-
return arr
734+
if not isinstance(device, cp.cuda.Device):
735+
raise TypeError(f"Unsupported device {device!r}")
736+
737+
# see cupy/cupy#5985 for the reason how we handle device/stream here
738+
739+
# stream can be an int as specified in __dlpack__, or a CuPy stream
740+
if isinstance(stream, int):
741+
stream = cp.cuda.ExternalStream(stream)
742+
elif stream is None:
743+
stream = contextlib.nullcontext()
744+
elif not isinstance(stream, cp.cuda.Stream):
745+
raise TypeError('the input stream is not recognized')
746+
747+
with device, stream:
748+
return cp.asarray(x)
749+
744750

745751
def _torch_to_device(x, device, /, stream=None):
746752
if stream is not None:
747753
raise NotImplementedError
748754
return x.to(device)
749755

756+
750757
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
751758
"""
752759
Copy the array from the device on which it currently resides to the specified ``device``.

Diff for: array_api_compat/cupy/_aliases.py

+10-23
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464
finfo = get_xp(cp)(_aliases.finfo)
6565
iinfo = get_xp(cp)(_aliases.iinfo)
6666

67-
_copy_default = object()
68-
6967

7068
# asarray also adds the copy keyword, which is not present in numpy 1.0.
7169
def asarray(
@@ -79,7 +77,7 @@ def asarray(
7977
*,
8078
dtype: Optional[DType] = None,
8179
device: Optional[Device] = None,
82-
copy: Optional[bool] = _copy_default,
80+
copy: Optional[bool] = None,
8381
**kwargs,
8482
) -> Array:
8583
"""
@@ -88,26 +86,15 @@ def asarray(
8886
See the corresponding documentation in the array library and/or the array API
8987
specification for more details.
9088
"""
91-
with cp.cuda.Device(device):
92-
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
93-
# in asarray in numpy/_aliases.py.
94-
if copy is not _copy_default:
95-
# A future version of CuPy will change the meaning of copy=False
96-
# to mean no-copy. We don't know for certain what version it will
97-
# be yet, so to avoid breaking that version, we use a different
98-
# default value for copy so asarray(obj) with no copy kwarg will
99-
# always do the copy-if-needed behavior.
100-
101-
# This will still need to be updated to remove the
102-
# NotImplementedError for copy=False, but at least this won't
103-
# break the default or existing behavior.
104-
if copy is None:
105-
copy = False
106-
elif copy is False:
107-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
108-
kwargs['copy'] = copy
109-
110-
return cp.array(obj, dtype=dtype, **kwargs)
89+
if copy is False:
90+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
91+
92+
like = obj if isinstance(obj, cp.ndarray) else None
93+
with _helpers._device_ctx(cp, device, like=like):
94+
if copy is None:
95+
return cp.asarray(obj, dtype=dtype, **kwargs)
96+
else:
97+
return cp.array(obj, dtype=dtype, copy=True, **kwargs)
11198

11299

113100
def astype(

0 commit comments

Comments
 (0)