7
7
import inspect
8
8
from typing import TYPE_CHECKING , Any , NamedTuple , Optional , Sequence , cast
9
9
10
- from ._helpers import _check_device , array_namespace
10
+ from ._helpers import _device_ctx , array_namespace
11
11
from ._helpers import device as _get_device
12
12
from ._helpers import is_cupy_namespace as _is_cupy_namespace
13
13
from ._typing import Array , Device , DType , Namespace
@@ -32,8 +32,8 @@ def arange(
32
32
device : Device | None = None ,
33
33
** kwargs : object ,
34
34
) -> Array :
35
- _check_device (xp , device )
36
- return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
35
+ with _device_ctx (xp , device ):
36
+ return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
37
37
38
38
39
39
def empty (
@@ -44,8 +44,8 @@ def empty(
44
44
device : Device | None = None ,
45
45
** kwargs : object ,
46
46
) -> Array :
47
- _check_device (xp , device )
48
- return xp .empty (shape , dtype = dtype , ** kwargs )
47
+ with _device_ctx (xp , device ):
48
+ return xp .empty (shape , dtype = dtype , ** kwargs )
49
49
50
50
51
51
def empty_like (
@@ -57,8 +57,8 @@ def empty_like(
57
57
device : Device | None = None ,
58
58
** kwargs : object ,
59
59
) -> Array :
60
- _check_device (xp , device )
61
- return xp .empty_like (x , dtype = dtype , ** kwargs )
60
+ with _device_ctx (xp , device , like = x ):
61
+ return xp .empty_like (x , dtype = dtype , ** kwargs )
62
62
63
63
64
64
def eye (
@@ -72,8 +72,8 @@ def eye(
72
72
device : Device | None = None ,
73
73
** kwargs : object ,
74
74
) -> Array :
75
- _check_device (xp , device )
76
- return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
75
+ with _device_ctx (xp , device ):
76
+ return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
77
77
78
78
79
79
def full (
@@ -85,8 +85,8 @@ def full(
85
85
device : Device | None = None ,
86
86
** kwargs : object ,
87
87
) -> Array :
88
- _check_device (xp , device )
89
- return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
88
+ with _device_ctx (xp , device ):
89
+ return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
90
90
91
91
92
92
def full_like (
@@ -99,8 +99,8 @@ def full_like(
99
99
device : Device | None = None ,
100
100
** kwargs : object ,
101
101
) -> Array :
102
- _check_device (xp , device )
103
- return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
102
+ with _device_ctx (xp , device , like = x ):
103
+ return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
104
104
105
105
106
106
def linspace (
@@ -115,8 +115,8 @@ def linspace(
115
115
endpoint : bool = True ,
116
116
** kwargs : object ,
117
117
) -> Array :
118
- _check_device (xp , device )
119
- return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
118
+ with _device_ctx (xp , device ):
119
+ return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
120
120
121
121
122
122
def ones (
@@ -127,8 +127,8 @@ def ones(
127
127
device : Device | None = None ,
128
128
** kwargs : object ,
129
129
) -> Array :
130
- _check_device (xp , device )
131
- return xp .ones (shape , dtype = dtype , ** kwargs )
130
+ with _device_ctx (xp , device ):
131
+ return xp .ones (shape , dtype = dtype , ** kwargs )
132
132
133
133
134
134
def ones_like (
@@ -140,8 +140,8 @@ def ones_like(
140
140
device : Device | None = None ,
141
141
** kwargs : object ,
142
142
) -> Array :
143
- _check_device (xp , device )
144
- return xp .ones_like (x , dtype = dtype , ** kwargs )
143
+ with _device_ctx (xp , device , like = x ):
144
+ return xp .ones_like (x , dtype = dtype , ** kwargs )
145
145
146
146
147
147
def zeros (
@@ -152,8 +152,8 @@ def zeros(
152
152
device : Device | None = None ,
153
153
** kwargs : object ,
154
154
) -> Array :
155
- _check_device (xp , device )
156
- return xp .zeros (shape , dtype = dtype , ** kwargs )
155
+ with _device_ctx (xp , device ):
156
+ return xp .zeros (shape , dtype = dtype , ** kwargs )
157
157
158
158
159
159
def zeros_like (
@@ -165,8 +165,8 @@ def zeros_like(
165
165
device : Device | None = None ,
166
166
** kwargs : object ,
167
167
) -> Array :
168
- _check_device (xp , device )
169
- return xp .zeros_like (x , dtype = dtype , ** kwargs )
168
+ with _device_ctx (xp , device , like = x ):
169
+ return xp .zeros_like (x , dtype = dtype , ** kwargs )
170
170
171
171
172
172
# np.unique() is split into four functions in the array API:
0 commit comments