@@ -120,6 +120,33 @@ def is_torch_array(x):
120
120
# TODO: Should we reject ndarray subclasses?
121
121
return isinstance (x , torch .Tensor )
122
122
123
+ def is_paddle_array (x ):
124
+ """
125
+ Return True if `x` is a Paddle tensor.
126
+
127
+ This function does not import Paddle if it has not already been imported
128
+ and is therefore cheap to use.
129
+
130
+ See Also
131
+ --------
132
+
133
+ array_namespace
134
+ is_array_api_obj
135
+ is_numpy_array
136
+ is_cupy_array
137
+ is_dask_array
138
+ is_jax_array
139
+ is_pydata_sparse_array
140
+ """
141
+ # Avoid importing paddle if it isn't already
142
+ if 'paddle' not in sys .modules :
143
+ return False
144
+
145
+ import paddle
146
+
147
+ # TODO: Should we reject ndarray subclasses?
148
+ return paddle .is_tensor (x )
149
+
123
150
def is_ndonnx_array (x ):
124
151
"""
125
152
Return True if `x` is a ndonnx Array.
@@ -252,6 +279,7 @@ def is_array_api_obj(x):
252
279
or is_dask_array (x ) \
253
280
or is_jax_array (x ) \
254
281
or is_pydata_sparse_array (x ) \
282
+ or is_paddle_array (x ) \
255
283
or hasattr (x , '__array_namespace__' )
256
284
257
285
def _compat_module_name ():
@@ -319,6 +347,27 @@ def is_torch_namespace(xp) -> bool:
319
347
return xp .__name__ in {'torch' , _compat_module_name () + '.torch' }
320
348
321
349
350
+ def is_paddle_namespace (xp ) -> bool :
351
+ """
352
+ Returns True if `xp` is a Paddle namespace.
353
+
354
+ This includes both Paddle itself and the version wrapped by array-api-compat.
355
+
356
+ See Also
357
+ --------
358
+
359
+ array_namespace
360
+ is_numpy_namespace
361
+ is_cupy_namespace
362
+ is_ndonnx_namespace
363
+ is_dask_namespace
364
+ is_jax_namespace
365
+ is_pydata_sparse_namespace
366
+ is_array_api_strict_namespace
367
+ """
368
+ return xp .__name__ in {'paddle' , _compat_module_name () + '.paddle' }
369
+
370
+
322
371
def is_ndonnx_namespace (xp ):
323
372
"""
324
373
Returns True if `xp` is an NDONNX namespace.
@@ -543,6 +592,14 @@ def your_function(x, y):
543
592
else :
544
593
import jax .experimental .array_api as jnp
545
594
namespaces .add (jnp )
595
+ elif is_paddle_array (x ):
596
+ if _use_compat :
597
+ _check_api_version (api_version )
598
+ from .. import paddle as paddle_namespace
599
+ namespaces .add (paddle_namespace )
600
+ else :
601
+ import paddle
602
+ namespaces .add (paddle )
546
603
elif is_pydata_sparse_array (x ):
547
604
if use_compat is True :
548
605
_check_api_version (api_version )
@@ -660,6 +717,16 @@ def device(x: Array, /) -> Device:
660
717
return "cpu"
661
718
# Return the device of the constituent array
662
719
return device (inner )
720
+ elif is_paddle_array (x ):
721
+ raw_place_str = str (x .place )
722
+ if "gpu_pinned" in raw_place_str :
723
+ return "cpu"
724
+ elif "cpu" in raw_place_str :
725
+ return "cpu"
726
+ elif "gpu" in raw_place_str :
727
+ return "gpu"
728
+ raise NotImplementedError (f"Unsupported device { raw_place_str } " )
729
+
663
730
return x .device
664
731
665
732
# Prevent shadowing, used below
@@ -709,6 +776,14 @@ def _torch_to_device(x, device, /, stream=None):
709
776
raise NotImplementedError
710
777
return x .to (device )
711
778
779
+ def _paddle_to_device (x , device , / , stream = None ):
780
+ if stream is not None :
781
+ raise NotImplementedError (
782
+ "paddle.Tensor.to() do not support stream argument yet"
783
+ )
784
+ return x .to (device )
785
+
786
+
712
787
def to_device (x : Array , device : Device , / , * , stream : Optional [Union [int , Any ]] = None ) -> Array :
713
788
"""
714
789
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -781,6 +856,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
781
856
# In JAX v0.4.31 and older, this import adds to_device method to x.
782
857
import jax .experimental .array_api # noqa: F401
783
858
return x .to_device (device , stream = stream )
859
+ elif is_paddle_array (x ):
860
+ return _paddle_to_device (x , device , stream = stream )
784
861
elif is_pydata_sparse_array (x ) and device == _device (x ):
785
862
# Perform trivial check to return the same array if
786
863
# device is same instead of err-ing.
@@ -819,6 +896,8 @@ def size(x):
819
896
"is_torch_namespace" ,
820
897
"is_ndonnx_array" ,
821
898
"is_ndonnx_namespace" ,
899
+ "is_paddle_array" ,
900
+ "is_paddle_namespace" ,
822
901
"is_pydata_sparse_array" ,
823
902
"is_pydata_sparse_namespace" ,
824
903
"size" ,
0 commit comments