Skip to content

clarify if __array_namespace_info().default_device() can be None #923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ev-br opened this issue Apr 9, 2025 · 2 comments
Open

clarify if __array_namespace_info().default_device() can be None #923

ev-br opened this issue Apr 9, 2025 · 2 comments

Comments

@ev-br
Copy link
Member

ev-br commented Apr 9, 2025

The spec only says it returns an object corresponding to the default device. ( https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.default_device.html#array_api.info.default_device)

jax.numpy returns None, so the question is whether None corresponds to the default device or not.

In [5]: import jax.numpy as jnp

In [6]: jnp.__array_namespace_info__().default_device() is None
Out[6]: True
@jakevdp
Copy link

jakevdp commented Apr 9, 2025

Hi - for what it's worth, we made the deliberate choice to return None here in order to make JAX's existing device placement semantics work with the specifications of the array API standard.

The problem is that JAX's existing device placement does not entirely align with the model that the authors of the spec had in mind. For example, under JIT, there is no default device, because the array referenced in the Python API may not ever physically exist. Here's a silly example:

@jax.jit
def f(x):
  y = jnp.arange(10)
  return x

What device is y on here? That question cannot be answered, because the compiler will recognize that y = jnp.arange(10) is dead code, and will eliminate this from the program: y in this program will never exist as an actual buffer on a device.

Let's modify this slightly:

@jax.jit
def f(x):
  y = jnp.arange(len(x))
  return x + y

What device will y be on now? Here this will be determined contextually by the compiler: because y only interacts with x, the compiler will allocate its buffer on the same device (or devices) as x.

Neither of these situations is compatible with the idea of a global default device, and so the very notion of "default_device" as envisioned by the array api specification is flawed, and not applicable to frameworks like JAX. Given that, we thought returning None from default_device would be the least bad approach. After all None is a valid argument to device in all cases, and explicitly passing device=None results in the same behavior as not passing device at all – that behavior seemed to align with the notion of a "default".

If you have other suggestions, I'm open to hear them! If the specification were changed such that default_device could not return None, I suppose our best option would probably be to define some NoDevice singleton that has the same semantics as None does currently.

@crusaderky
Copy link
Contributor

IMHO, I think I'm personally happy for JAX to return None as the default device.
But yes, I think it needs to be spelled out explicitly.

def f(x):
  y = jnp.arange(len(x))
  return x + y

What device will y be on now? Here this will be determined contextually by the compiler: because y only interacts with x, the compiler will allocate its buffer on the same device (or devices) as x.

It's worth pointing out that this behaviour, while definitely desirable and nice to read, is something that's possible exclusively on lazy backends. In fact, the snippet above will crash on PyTorch if x does not lay on the default device. (CuPy has blocking design issues on this).
As a result, the current best practice for Array API agnostic functions is

from array_api_compat import array_namespace, device

def f(x):
    xp = array_namespace(x)
    y = xp.arange(x.shape[-1], device=device(x))
    return x + y

The array-api-compat shims are necessary to support NumPy 1.x, Dask, Sparse, and JAX itself.

This pattern follows the guideline of prioritizing input->output propagation over global and context device:
https://data-apis.org/array-api/latest/design_topics/device_support.html#semantics

Preserve device assignment as much as possible (e.g. output arrays from a function are expected to be on the same device as input arrays to the function).

@jax.jit
def f(x):
  y = jnp.arange(10)
  return x

What device is y on here? That question cannot be answered

The answer here is that no-one cares.
A much more interesting example would be

from array_api_compat import array_namespace, device, to_device

def f(x):
    """Return x+arange, prioritizing the default device over x.device"""
    xp = array_namespace(x)
    y = xp.arange(x.shape[-1])
    return to_device(x, device(y)) + y

Here, we have some peculiar behaviour:

  • PyTorch returns an object on whatever was set with torch.set_default_device;
  • Eager JAX returns jax_default_device;
  • Jitted JAX currently crashes because to_device doesn't expect device to return None, which is the array-api-compat hack around Missing .device attribute inside @jax.jit jax-ml/jax#26000. Realistically though it would make sense if it returned x.device.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants