Skip to content

Commit 8e5cc94

Browse files
add paddle support in array-api-compat
1 parent ee25aae commit 8e5cc94

19 files changed

+2088
-97
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: Array API Tests (Paddle Latest)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-paddle:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: paddle
10+
extra-env-vars: |
11+
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64

array_api_compat/common/_helpers.py

+79
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,33 @@ def is_torch_array(x):
120120
# TODO: Should we reject ndarray subclasses?
121121
return isinstance(x, torch.Tensor)
122122

123+
def is_paddle_array(x):
124+
"""
125+
Return True if `x` is a Paddle tensor.
126+
127+
This function does not import Paddle if it has not already been imported
128+
and is therefore cheap to use.
129+
130+
See Also
131+
--------
132+
133+
array_namespace
134+
is_array_api_obj
135+
is_numpy_array
136+
is_cupy_array
137+
is_dask_array
138+
is_jax_array
139+
is_pydata_sparse_array
140+
"""
141+
# Avoid importing paddle if it isn't already
142+
if 'paddle' not in sys.modules:
143+
return False
144+
145+
import paddle
146+
147+
# TODO: Should we reject ndarray subclasses?
148+
return paddle.is_tensor(x)
149+
123150
def is_ndonnx_array(x):
124151
"""
125152
Return True if `x` is a ndonnx Array.
@@ -252,6 +279,7 @@ def is_array_api_obj(x):
252279
or is_dask_array(x) \
253280
or is_jax_array(x) \
254281
or is_pydata_sparse_array(x) \
282+
or is_paddle_array(x) \
255283
or hasattr(x, '__array_namespace__')
256284

257285
def _compat_module_name():
@@ -319,6 +347,27 @@ def is_torch_namespace(xp) -> bool:
319347
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320348

321349

350+
def is_paddle_namespace(xp) -> bool:
351+
"""
352+
Returns True if `xp` is a Paddle namespace.
353+
354+
This includes both Paddle itself and the version wrapped by array-api-compat.
355+
356+
See Also
357+
--------
358+
359+
array_namespace
360+
is_numpy_namespace
361+
is_cupy_namespace
362+
is_ndonnx_namespace
363+
is_dask_namespace
364+
is_jax_namespace
365+
is_pydata_sparse_namespace
366+
is_array_api_strict_namespace
367+
"""
368+
return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'}
369+
370+
322371
def is_ndonnx_namespace(xp):
323372
"""
324373
Returns True if `xp` is an NDONNX namespace.
@@ -543,6 +592,14 @@ def your_function(x, y):
543592
else:
544593
import jax.experimental.array_api as jnp
545594
namespaces.add(jnp)
595+
elif is_paddle_array(x):
596+
if _use_compat:
597+
_check_api_version(api_version)
598+
from .. import paddle as paddle_namespace
599+
namespaces.add(paddle_namespace)
600+
else:
601+
import paddle
602+
namespaces.add(paddle)
546603
elif is_pydata_sparse_array(x):
547604
if use_compat is True:
548605
_check_api_version(api_version)
@@ -660,6 +717,16 @@ def device(x: Array, /) -> Device:
660717
return "cpu"
661718
# Return the device of the constituent array
662719
return device(inner)
720+
elif is_paddle_array(x):
721+
raw_place_str = str(x.place)
722+
if "gpu_pinned" in raw_place_str:
723+
return "cpu"
724+
elif "cpu" in raw_place_str:
725+
return "cpu"
726+
elif "gpu" in raw_place_str:
727+
return "gpu"
728+
raise NotImplementedError(f"Unsupported device {raw_place_str}")
729+
663730
return x.device
664731

665732
# Prevent shadowing, used below
@@ -709,6 +776,14 @@ def _torch_to_device(x, device, /, stream=None):
709776
raise NotImplementedError
710777
return x.to(device)
711778

779+
def _paddle_to_device(x, device, /, stream=None):
780+
if stream is not None:
781+
raise NotImplementedError(
782+
"paddle.Tensor.to() do not support stream argument yet"
783+
)
784+
return x.to(device)
785+
786+
712787
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
713788
"""
714789
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -781,6 +856,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
781856
# In JAX v0.4.31 and older, this import adds to_device method to x.
782857
import jax.experimental.array_api # noqa: F401
783858
return x.to_device(device, stream=stream)
859+
elif is_paddle_array(x):
860+
return _paddle_to_device(x, device, stream=stream)
784861
elif is_pydata_sparse_array(x) and device == _device(x):
785862
# Perform trivial check to return the same array if
786863
# device is same instead of err-ing.
@@ -819,6 +896,8 @@ def size(x):
819896
"is_torch_namespace",
820897
"is_ndonnx_array",
821898
"is_ndonnx_namespace",
899+
"is_paddle_array",
900+
"is_paddle_namespace",
822901
"is_pydata_sparse_array",
823902
"is_pydata_sparse_namespace",
824903
"size",

array_api_compat/paddle/__init__.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from paddle import * # noqa: F403
2+
3+
# Several names are not included in the above import *
4+
import paddle
5+
6+
for n in dir(paddle):
7+
if (
8+
n.startswith("_")
9+
or n.endswith("_")
10+
or "gpu" in n
11+
or "cpu" in n
12+
or "backward" in n
13+
):
14+
continue
15+
exec(n + " = paddle." + n)
16+
exec("asarray = paddle.to_tensor")
17+
18+
# These imports may overwrite names from the import * above.
19+
from ._aliases import * # noqa: F403
20+
21+
# See the comment in the numpy __init__.py
22+
__import__(__package__ + ".linalg")
23+
24+
__import__(__package__ + ".fft")
25+
26+
from ..common._helpers import * # noqa: F403
27+
28+
__array_api_version__ = "2023.12"

0 commit comments

Comments
 (0)