Skip to content

Commit 57d3c5a

Browse files
authored
Merge pull request #110 from lithomas1/last-dask-fixes
Last dask fixes
2 parents 6cc8008 + 826cde4 commit 57d3c5a

File tree

4 files changed

+54
-23
lines changed

4 files changed

+54
-23
lines changed

Diff for: array_api_compat/common/_aliases.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -325,31 +325,30 @@ def _asarray(
325325
else:
326326
COPY_FALSE = (False,)
327327
COPY_TRUE = (True,)
328-
if copy in COPY_FALSE:
328+
if copy in COPY_FALSE and namespace != "dask.array":
329329
# copy=False is not yet implemented in xp.asarray
330330
raise NotImplementedError("copy=False is not yet implemented")
331-
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
332-
#print('hit me')
331+
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
333332
if dtype is not None and obj.dtype != dtype:
334333
copy = True
335-
#print(copy)
336334
if copy in COPY_TRUE:
337-
copy_kwargs = {}
338-
if namespace != "dask.array":
339-
copy_kwargs["copy"] = True
340-
else:
341-
# No copy kw in dask.asarray so we go thorugh np.asarray first
342-
# (like dask also does) but copy after
343-
if dtype is None:
344-
# Same dtype copy is no-op in dask
345-
#print("in here?")
346-
return obj.copy()
347-
import numpy as np
348-
#print(obj)
349-
obj = np.asarray(obj).copy()
350-
#print(obj)
351-
return xp.array(obj, dtype=dtype, **copy_kwargs)
335+
return xp.array(obj, copy=True, dtype=dtype)
352336
return obj
337+
elif namespace == "dask.array":
338+
if copy in COPY_TRUE:
339+
if dtype is None:
340+
return obj.copy()
341+
# Go through numpy, since dask copy is no-op by default
342+
import numpy as np
343+
obj = np.array(obj, dtype=dtype, copy=True)
344+
return xp.array(obj, dtype=dtype)
345+
else:
346+
import dask.array as da
347+
import numpy as np
348+
if not isinstance(obj, da.Array):
349+
obj = np.asarray(obj, dtype=dtype)
350+
return da.from_array(obj)
351+
return obj
353352

354353
return xp.asarray(obj, dtype=dtype, **kwargs)
355354

Diff for: array_api_compat/dask/array/linalg.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from typing import TYPE_CHECKING
1717
if TYPE_CHECKING:
1818
from ...common._typing import Array
19+
from typing import Literal
1920

20-
# cupy.linalg doesn't have __all__. If it is added, replace this with
21+
# dask.array.linalg doesn't have __all__. If it is added, replace this with
2122
#
22-
# from cupy.linalg import __all__ as linalg_all
23+
# from dask.array.linalg import __all__ as linalg_all
2324
_n = {}
2425
exec('from dask.array.linalg import *', _n)
2526
del _n['__builtins__']
@@ -32,7 +33,15 @@
3233
QRResult = _linalg.QRResult
3334
SlogdetResult = _linalg.SlogdetResult
3435
SVDResult = _linalg.SVDResult
35-
qr = get_xp(da)(_linalg.qr)
36+
# TODO: use the QR wrapper once dask
37+
# supports the mode keyword on QR
38+
# https://github.com/dask/dask/issues/10388
39+
#qr = get_xp(da)(_linalg.qr)
40+
def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
41+
**kwargs) -> QRResult:
42+
if mode != "reduced":
43+
raise ValueError("dask arrays only support using mode='reduced'")
44+
return QRResult(*da.linalg.qr(x, **kwargs))
3645
cholesky = get_xp(da)(_linalg.cholesky)
3746
matrix_rank = get_xp(da)(_linalg.matrix_rank)
3847
matrix_norm = get_xp(da)(_linalg.matrix_norm)
@@ -44,7 +53,7 @@
4453
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
4554
if full_matrices:
4655
raise ValueError("full_matrics=True is not supported by dask.")
47-
return da.linalg.svd(x, **kwargs)
56+
return da.linalg.svd(x, coerce_signs=False, **kwargs)
4857

4958
def svdvals(x: Array) -> Array:
5059
# TODO: can't avoid computing U or V for dask

Diff for: dask-xfails.txt

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32]
3737
# (I think the test is not forcing the op to be computed?)
3838
array_api_tests/test_creation_functions.py::test_linspace
3939

40+
# out.shape=(2,) but should be (1,)
41+
array_api_tests/test_indexing_functions.py::test_take
42+
4043
# out=-0, but should be +0
4144
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
4245
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]

Diff for: tests/test_common.py

+20
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,23 @@ def test_to_device_host(library):
6060
# here is that we can test portably after calling
6161
# to_device(x, "cpu") to return to host
6262
assert_allclose(x, expected)
63+
64+
65+
@pytest.mark.parametrize("target_library,func", is_functions.items())
66+
@pytest.mark.parametrize("source_library", is_functions.keys())
67+
def test_asarray(source_library, target_library, func, request):
68+
if source_library == "dask.array" and target_library == "torch":
69+
# Allow rest of test to execute instead of immediately xfailing
70+
# xref https://github.com/pandas-dev/pandas/issues/38902
71+
72+
# TODO: remove xfail once
73+
# https://github.com/dask/dask/issues/8260 is resolved
74+
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
75+
src_lib = import_(source_library, wrapper=True)
76+
tgt_lib = import_(target_library, wrapper=True)
77+
is_tgt_type = globals()[func]
78+
79+
a = src_lib.asarray([1, 2, 3])
80+
b = tgt_lib.asarray(a)
81+
82+
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"

0 commit comments

Comments
 (0)