diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8792aa2e..f998481c 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -325,31 +325,30 @@ def _asarray( else: COPY_FALSE = (False,) COPY_TRUE = (True,) - if copy in COPY_FALSE: + if copy in COPY_FALSE and namespace != "dask.array": # copy=False is not yet implemented in xp.asarray raise NotImplementedError("copy=False is not yet implemented") - if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): - #print('hit me') + if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)): if dtype is not None and obj.dtype != dtype: copy = True - #print(copy) if copy in COPY_TRUE: - copy_kwargs = {} - if namespace != "dask.array": - copy_kwargs["copy"] = True - else: - # No copy kw in dask.asarray so we go thorugh np.asarray first - # (like dask also does) but copy after - if dtype is None: - # Same dtype copy is no-op in dask - #print("in here?") - return obj.copy() - import numpy as np - #print(obj) - obj = np.asarray(obj).copy() - #print(obj) - return xp.array(obj, dtype=dtype, **copy_kwargs) + return xp.array(obj, copy=True, dtype=dtype) return obj + elif namespace == "dask.array": + if copy in COPY_TRUE: + if dtype is None: + return obj.copy() + # Go through numpy, since dask copy is no-op by default + import numpy as np + obj = np.array(obj, dtype=dtype, copy=True) + return xp.array(obj, dtype=dtype) + else: + import dask.array as da + import numpy as np + if not isinstance(obj, da.Array): + obj = np.asarray(obj, dtype=dtype) + return da.from_array(obj) + return obj return xp.asarray(obj, dtype=dtype, **kwargs) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 03f16e89..7f5b2c6e 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -16,10 +16,11 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from ...common._typing import Array + from typing import Literal -# cupy.linalg doesn't have __all__. If it is added, replace this with +# dask.array.linalg doesn't have __all__. If it is added, replace this with # -# from cupy.linalg import __all__ as linalg_all +# from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) del _n['__builtins__'] @@ -32,7 +33,15 @@ QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult -qr = get_xp(da)(_linalg.qr) +# TODO: use the QR wrapper once dask +# supports the mode keyword on QR +# https://github.com/dask/dask/issues/10388 +#qr = get_xp(da)(_linalg.qr) +def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', + **kwargs) -> QRResult: + if mode != "reduced": + raise ValueError("dask arrays only support using mode='reduced'") + return QRResult(*da.linalg.qr(x, **kwargs)) cholesky = get_xp(da)(_linalg.cholesky) matrix_rank = get_xp(da)(_linalg.matrix_rank) matrix_norm = get_xp(da)(_linalg.matrix_norm) @@ -44,7 +53,7 @@ def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") - return da.linalg.svd(x, **kwargs) + return da.linalg.svd(x, coerce_signs=False, **kwargs) def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask diff --git a/dask-xfails.txt b/dask-xfails.txt index ecde5420..0d74ecbb 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -37,6 +37,9 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] # (I think the test is not forcing the op to be computed?) array_api_tests/test_creation_functions.py::test_linspace +# out.shape=(2,) but should be (1,) +array_api_tests/test_indexing_functions.py::test_take + # out=-0, but should be +0 array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] diff --git a/tests/test_common.py b/tests/test_common.py index 66076bfe..22b98d83 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -60,3 +60,23 @@ def test_to_device_host(library): # here is that we can test portably after calling # to_device(x, "cpu") to return to host assert_allclose(x, expected) + + +@pytest.mark.parametrize("target_library,func", is_functions.items()) +@pytest.mark.parametrize("source_library", is_functions.keys()) +def test_asarray(source_library, target_library, func, request): + if source_library == "dask.array" and target_library == "torch": + # Allow rest of test to execute instead of immediately xfailing + # xref https://github.com/pandas-dev/pandas/issues/38902 + + # TODO: remove xfail once + # https://github.com/dask/dask/issues/8260 is resolved + request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion")) + src_lib = import_(source_library, wrapper=True) + tgt_lib = import_(target_library, wrapper=True) + is_tgt_type = globals()[func] + + a = src_lib.asarray([1, 2, 3]) + b = tgt_lib.asarray(a) + + assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"