Skip to content

Commit 8c31248

Browse files
committed
ENH: Simplify CuPy asarray and to_device
1 parent 205c967 commit 8c31248

File tree

3 files changed

+25
-56
lines changed

3 files changed

+25
-56
lines changed

Diff for: array_api_compat/common/_helpers.py

+17-31
Original file line numberDiff line numberDiff line change
@@ -781,42 +781,28 @@ def _cupy_to_device(
781781
/,
782782
stream: int | Any | None = None,
783783
) -> _CupyArray:
784-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
785-
from cupy.cuda import Device as _Device # pyright: ignore
786-
from cupy.cuda import stream as stream_module # pyright: ignore
787-
from cupy_backends.cuda.api import runtime # pyright: ignore
784+
import cupy as cp
788785

789-
if device == x.device:
790-
return x
791-
elif device == "cpu":
786+
if device == "cpu":
792787
# allowing us to use `to_device(x, "cpu")`
793788
# is useful for portable test swapping between
794789
# host and device backends
795790
return x.get()
796-
elif not isinstance(device, _Device):
797-
raise ValueError(f"Unsupported device {device!r}")
798-
else:
799-
# see cupy/cupy#5985 for the reason how we handle device/stream here
800-
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
801-
prev_stream = None
802-
if stream is not None:
803-
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore
804-
# stream can be an int as specified in __dlpack__, or a CuPy stream
805-
if isinstance(stream, int):
806-
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
807-
elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType]
808-
pass
809-
else:
810-
raise ValueError("the input stream is not recognized")
811-
stream.use() # pyright: ignore[reportUnknownMemberType]
812-
try:
813-
runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType]
814-
arr = x.copy()
815-
finally:
816-
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
817-
if stream is not None:
818-
prev_stream.use()
819-
return arr
791+
if not isinstance(device, cp.cuda.Device):
792+
raise TypeError(f"Unsupported device type {device!r}")
793+
794+
if stream is None:
795+
with device:
796+
return cp.asarray(x)
797+
798+
# stream can be an int as specified in __dlpack__, or a CuPy stream
799+
if isinstance(stream, int):
800+
stream = cp.cuda.ExternalStream(stream)
801+
elif not isinstance(stream, cp.cuda.Stream):
802+
raise TypeError(f"Unsupported stream type {stream!r}")
803+
804+
with device, stream:
805+
return cp.asarray(x)
820806

821807

822808
def _torch_to_device(

Diff for: array_api_compat/cupy/_aliases.py

+8-22
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464
finfo = get_xp(cp)(_aliases.finfo)
6565
iinfo = get_xp(cp)(_aliases.iinfo)
6666

67-
_copy_default = object()
68-
6967

7068
# asarray also adds the copy keyword, which is not present in numpy 1.0.
7169
def asarray(
@@ -79,7 +77,7 @@ def asarray(
7977
*,
8078
dtype: Optional[DType] = None,
8179
device: Optional[Device] = None,
82-
copy: Optional[bool] = _copy_default,
80+
copy: Optional[bool] = None,
8381
**kwargs,
8482
) -> Array:
8583
"""
@@ -89,25 +87,13 @@ def asarray(
8987
specification for more details.
9088
"""
9189
with cp.cuda.Device(device):
92-
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
93-
# in asarray in numpy/_aliases.py.
94-
if copy is not _copy_default:
95-
# A future version of CuPy will change the meaning of copy=False
96-
# to mean no-copy. We don't know for certain what version it will
97-
# be yet, so to avoid breaking that version, we use a different
98-
# default value for copy so asarray(obj) with no copy kwarg will
99-
# always do the copy-if-needed behavior.
100-
101-
# This will still need to be updated to remove the
102-
# NotImplementedError for copy=False, but at least this won't
103-
# break the default or existing behavior.
104-
if copy is None:
105-
copy = False
106-
elif copy is False:
107-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
108-
kwargs['copy'] = copy
109-
110-
return cp.array(obj, dtype=dtype, **kwargs)
90+
if copy is None:
91+
return cp.asarray(obj, dtype=dtype, **kwargs)
92+
else:
93+
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
94+
if not copy and res is not obj:
95+
raise ValueError("Unable to avoid copy while creating an array as requested")
96+
return res
11197

11298

11399
def astype(

Diff for: cupy-xfails.txt

-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)]
1111
# testsuite bug (https://github.com/data-apis/array-api-tests/issues/172)
1212
array_api_tests/test_array_object.py::test_getitem
1313

14-
# copy=False is not yet implemented
15-
array_api_tests/test_creation_functions.py::test_asarray_arrays
16-
1714
# attributes are np.float32 instead of float
1815
# (see also https://github.com/data-apis/array-api/issues/405)
1916
array_api_tests/test_data_type_functions.py::test_finfo[float32]

0 commit comments

Comments
 (0)