-
Notifications
You must be signed in to change notification settings - Fork 52
clarify if __array_namespace_info().default_device()
can be None
#923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi - for what it's worth, we made the deliberate choice to return The problem is that JAX's existing device placement does not entirely align with the model that the authors of the spec had in mind. For example, under JIT, there is no default device, because the array referenced in the Python API may not ever physically exist. Here's a silly example: @jax.jit
def f(x):
y = jnp.arange(10)
return x What device is Let's modify this slightly: @jax.jit
def f(x):
y = jnp.arange(len(x))
return x + y What device will Neither of these situations is compatible with the idea of a global default device, and so the very notion of "default_device" as envisioned by the array api specification is flawed, and not applicable to frameworks like JAX. Given that, we thought returning If you have other suggestions, I'm open to hear them! If the specification were changed such that |
IMHO, I think I'm personally happy for JAX to return None as the default device.
It's worth pointing out that this behaviour, while definitely desirable and nice to read, is something that's possible exclusively on lazy backends. In fact, the snippet above will crash on PyTorch if x does not lay on the default device. (CuPy has blocking design issues on this). from array_api_compat import array_namespace, device
def f(x):
xp = array_namespace(x)
y = xp.arange(x.shape[-1], device=device(x))
return x + y The array-api-compat shims are necessary to support NumPy 1.x, Dask, Sparse, and JAX itself. This pattern follows the guideline of prioritizing input->output propagation over global and context device:
The answer here is that no-one cares. from array_api_compat import array_namespace, device, to_device
def f(x):
"""Return x+arange, prioritizing the default device over x.device"""
xp = array_namespace(x)
y = xp.arange(x.shape[-1])
return to_device(x, device(y)) + y Here, we have some peculiar behaviour:
|
The spec only says it returns an object corresponding to the default device. ( https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.default_device.html#array_api.info.default_device)
jax.numpy
returnsNone
, so the question is whetherNone
corresponds to the default device or not.The text was updated successfully, but these errors were encountered: