diff --git a/pixi.lock b/pixi.lock index 56ee7c5..2153c4e 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5256,7 +5256,7 @@ packages: - pypi: . name: array-api-extra version: 0.7.2.dev0 - sha256: 74777bddfe6ab8d3ced9e5d1c645cb95c637707a45de9e96c88fc3b41723e3af + sha256: 68490b5f2feb7687422f882f54bb2a93c687425b984a69ecd58c9d6d73653139 requires_dist: - array-api-compat>=1.11.2,<2 requires_python: '>=3.10' diff --git a/pyproject.toml b/pyproject.toml index 6765190..9d897cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -213,8 +213,8 @@ filterwarnings = ["error"] log_cli_level = "INFO" testpaths = ["tests"] markers = [ - "skip_xp_backend(library, *, reason=None): Skip test for a specific backend", - "xfail_xp_backend(library, *, reason=None): Xfail test for a specific backend", + "skip_xp_backend(library, /, *, reason=None): Skip test for a specific backend", + "xfail_xp_backend(library, /, *, reason=None, strict=None): Xfail test for a specific backend", ] diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 319297c..301a851 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -195,7 +195,9 @@ def xp_assert_close( ) -def xfail(request: pytest.FixtureRequest, reason: str) -> None: +def xfail( + request: pytest.FixtureRequest, *, reason: str, strict: bool | None = None +) -> None: """ XFAIL the currently running test. @@ -209,5 +211,13 @@ def xfail(request: pytest.FixtureRequest, reason: str) -> None: ``request`` argument of the test function. reason : str Reason for the expected failure. + strict: bool, optional + If True, the test will be marked as failed if it passes. + If False, the test will be marked as passed if it fails. + Default: ``xfail_strict`` value in ``pyproject.toml``, or False if absent. """ - request.node.add_marker(pytest.mark.xfail(reason=reason)) + if strict is not None: + marker = pytest.mark.xfail(reason=reason, strict=strict) + else: + marker = pytest.mark.xfail(reason=reason) + request.node.add_marker(marker) diff --git a/tests/conftest.py b/tests/conftest.py index 410a87f..5676cc0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ """Pytest fixtures.""" from collections.abc import Callable, Generator -from contextlib import suppress from functools import partial, wraps from types import ModuleType from typing import ParamSpec, TypeVar, cast @@ -34,20 +33,29 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01, """ elem = cast(Backend, request.param) - for marker_name, skip_or_xfail in ( - ("skip_xp_backend", pytest.skip), - ("xfail_xp_backend", partial(xfail, request)), + for marker_name, skip_or_xfail, allow_kwargs in ( + ("skip_xp_backend", pytest.skip, {"reason"}), + ("xfail_xp_backend", partial(xfail, request), {"reason", "strict"}), ): for marker in request.node.iter_markers(marker_name): - library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage] - if not isinstance(library, Backend): - msg = f"argument of {marker_name} must be a Backend enum" + if len(marker.args) != 1: # pyright: ignore[reportUnknownArgumentType] + msg = f"Expected exactly one positional argument; got {marker.args}" raise TypeError(msg) + if not isinstance(marker.args[0], Backend): + msg = f"Argument of {marker_name} must be a Backend enum" + raise TypeError(msg) + if invalid_kwargs := set(marker.kwargs) - allow_kwargs: # pyright: ignore[reportUnknownArgumentType] + msg = f"Unexpected kwarg(s): {invalid_kwargs}" + raise TypeError(msg) + + library: Backend = marker.args[0] + reason: str | None = marker.kwargs.get("reason", None) + strict: bool | None = marker.kwargs.get("strict", None) + if library == elem: - reason = str(library) - with suppress(KeyError): - reason += ":" + cast(str, marker.kwargs["reason"]) - skip_or_xfail(reason=reason) + reason = f"{library}: {reason}" if reason else str(library) # pyright: ignore[reportUnknownArgumentType] + kwargs = {"strict": strict} if strict is not None else {} + skip_or_xfail(reason=reason, **kwargs) # pyright: ignore[reportUnknownArgumentType] return elem diff --git a/tests/test_at.py b/tests/test_at.py index 4ccf584..fa9bcdc 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -115,11 +115,15 @@ def assert_copy( pytest.param( *(True, 1, 1), marks=( - pytest.mark.skip_xp_backend( # test passes when copy=False - Backend.JAX, reason="bool mask update with shaped rhs" + pytest.mark.xfail_xp_backend( + Backend.JAX, + reason="bool mask update with shaped rhs", + strict=False, # test passes when copy=False ), - pytest.mark.skip_xp_backend( # test passes when copy=False - Backend.JAX_GPU, reason="bool mask update with shaped rhs" + pytest.mark.xfail_xp_backend( + Backend.JAX_GPU, + reason="bool mask update with shaped rhs", + strict=False, # test passes when copy=False ), pytest.mark.xfail_xp_backend( Backend.DASK, reason="bool mask update with shaped rhs" diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 4e40f09..652e12e 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -196,7 +196,7 @@ def test_device(self, xp: ModuleType, device: Device): y = apply_where(x % 2 == 0, x, self.f1, fill_value=x) assert get_device(y) == device - @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype") + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype") @pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc. @hypothesis.settings( # The xp and library fixtures are not regenerated between hypothesis iterations diff --git a/tests/test_helpers.py b/tests/test_helpers.py index ebd4811..a104e93 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -27,7 +27,7 @@ lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp")) -@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_inverse") class TestIn1D: # cover both code paths