Skip to content

Last dask fixes #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,31 +325,30 @@ def _asarray(
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
if copy in COPY_FALSE:
if copy in COPY_FALSE and namespace != "dask.array":
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
#print('hit me')
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
if dtype is not None and obj.dtype != dtype:
copy = True
#print(copy)
if copy in COPY_TRUE:
copy_kwargs = {}
if namespace != "dask.array":
copy_kwargs["copy"] = True
else:
# No copy kw in dask.asarray so we go thorugh np.asarray first
# (like dask also does) but copy after
if dtype is None:
# Same dtype copy is no-op in dask
#print("in here?")
return obj.copy()
import numpy as np
#print(obj)
obj = np.asarray(obj).copy()
#print(obj)
return xp.array(obj, dtype=dtype, **copy_kwargs)
return xp.array(obj, copy=True, dtype=dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I reverted back to the original changes here (before my dask PR), but it'd be good to make sure.

return obj
elif namespace == "dask.array":
if copy in COPY_TRUE:
if dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
import numpy as np
obj = np.array(obj, dtype=dtype, copy=True)
return xp.array(obj, dtype=dtype)
else:
import dask.array as da
import numpy as np
if not isinstance(obj, da.Array):
obj = np.asarray(obj, dtype=dtype)
return da.from_array(obj)
return obj

return xp.asarray(obj, dtype=dtype, **kwargs)

Expand Down
16 changes: 12 additions & 4 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
if TYPE_CHECKING:
from ...common._typing import Array

# cupy.linalg doesn't have __all__. If it is added, replace this with
# dask.array.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
# from dask.array.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
del _n['__builtins__']
Expand All @@ -32,7 +32,15 @@
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
# TODO: use the QR wrapper once dask
# supports the mode keyword on QR
# https://github.com/dask/dask/issues/10388
#qr = get_xp(da)(_linalg.qr)
def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',

Check failure on line 39 in array_api_compat/dask/array/linalg.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F405)

array_api_compat/dask/array/linalg.py:39:24: F405 `Literal` may be undefined, or defined from star imports

Check failure on line 39 in array_api_compat/dask/array/linalg.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F405)

array_api_compat/dask/array/linalg.py:39:33: F405 `reduced` may be undefined, or defined from star imports

Check failure on line 39 in array_api_compat/dask/array/linalg.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F405)

array_api_compat/dask/array/linalg.py:39:44: F405 `complete` may be undefined, or defined from star imports
**kwargs) -> QRResult:
if mode != "reduced":
raise ValueError("dask arrays only support using mode='reduced'")
return QRResult(*da.linalg.qr(x, **kwargs))
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)
Expand All @@ -44,7 +52,7 @@
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
if full_matrices:
raise ValueError("full_matrics=True is not supported by dask.")
return da.linalg.svd(x, **kwargs)
return da.linalg.svd(x, coerce_signs=False, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed sign dask will do a svd_flip on the data otherwise, which will make signs differ from numpy/scipy.


def svdvals(x: Array) -> Array:
# TODO: can't avoid computing U or V for dask
Expand Down
13 changes: 13 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,16 @@ def test_to_device_host(library):
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
assert_allclose(x, expected)


@pytest.mark.parametrize("target_library,func", is_functions.items())
@pytest.mark.parametrize("source_library", is_functions.keys())
def test_asarray(source_library, target_library, func):
src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[func]

a = src_lib.asarray([1, 2, 3])
b = tgt_lib.asarray(a)

assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
Loading