diff --git a/pixi.lock b/pixi.lock
index 56ee7c56..2153c4e2 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -5256,7 +5256,7 @@ packages:
 - pypi: .
   name: array-api-extra
   version: 0.7.2.dev0
-  sha256: 74777bddfe6ab8d3ced9e5d1c645cb95c637707a45de9e96c88fc3b41723e3af
+  sha256: 68490b5f2feb7687422f882f54bb2a93c687425b984a69ecd58c9d6d73653139
   requires_dist:
   - array-api-compat>=1.11.2,<2
   requires_python: '>=3.10'
diff --git a/pyproject.toml b/pyproject.toml
index 67651904..9d897cc0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -213,8 +213,8 @@ filterwarnings = ["error"]
 log_cli_level = "INFO"
 testpaths = ["tests"]
 markers = [
-  "skip_xp_backend(library, *, reason=None): Skip test for a specific backend",
-  "xfail_xp_backend(library, *, reason=None): Xfail test for a specific backend",
+  "skip_xp_backend(library, /, *, reason=None): Skip test for a specific backend",
+  "xfail_xp_backend(library, /, *, reason=None, strict=None): Xfail test for a specific backend",
 ]
 
 
diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py
index 319297c8..301a851f 100644
--- a/src/array_api_extra/_lib/_testing.py
+++ b/src/array_api_extra/_lib/_testing.py
@@ -195,7 +195,9 @@ def xp_assert_close(
         )
 
 
-def xfail(request: pytest.FixtureRequest, reason: str) -> None:
+def xfail(
+    request: pytest.FixtureRequest, *, reason: str, strict: bool | None = None
+) -> None:
     """
     XFAIL the currently running test.
 
@@ -209,5 +211,13 @@ def xfail(request: pytest.FixtureRequest, reason: str) -> None:
         ``request`` argument of the test function.
     reason : str
         Reason for the expected failure.
+    strict: bool, optional
+        If True, the test will be marked as failed if it passes.
+        If False, the test will be marked as passed if it fails.
+        Default: ``xfail_strict`` value in ``pyproject.toml``, or False if absent.
     """
-    request.node.add_marker(pytest.mark.xfail(reason=reason))
+    if strict is not None:
+        marker = pytest.mark.xfail(reason=reason, strict=strict)
+    else:
+        marker = pytest.mark.xfail(reason=reason)
+    request.node.add_marker(marker)
diff --git a/tests/conftest.py b/tests/conftest.py
index 410a87ff..5676cc0d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,6 @@
 """Pytest fixtures."""
 
 from collections.abc import Callable, Generator
-from contextlib import suppress
 from functools import partial, wraps
 from types import ModuleType
 from typing import ParamSpec, TypeVar, cast
@@ -34,20 +33,29 @@ def library(request: pytest.FixtureRequest) -> Backend:  # numpydoc ignore=PR01,
     """
     elem = cast(Backend, request.param)
 
-    for marker_name, skip_or_xfail in (
-        ("skip_xp_backend", pytest.skip),
-        ("xfail_xp_backend", partial(xfail, request)),
+    for marker_name, skip_or_xfail, allow_kwargs in (
+        ("skip_xp_backend", pytest.skip, {"reason"}),
+        ("xfail_xp_backend", partial(xfail, request), {"reason", "strict"}),
     ):
         for marker in request.node.iter_markers(marker_name):
-            library = marker.kwargs.get("library") or marker.args[0]  # type: ignore[no-untyped-usage]
-            if not isinstance(library, Backend):
-                msg = f"argument of {marker_name} must be a Backend enum"
+            if len(marker.args) != 1:  # pyright: ignore[reportUnknownArgumentType]
+                msg = f"Expected exactly one positional argument; got {marker.args}"
                 raise TypeError(msg)
+            if not isinstance(marker.args[0], Backend):
+                msg = f"Argument of {marker_name} must be a Backend enum"
+                raise TypeError(msg)
+            if invalid_kwargs := set(marker.kwargs) - allow_kwargs:  # pyright: ignore[reportUnknownArgumentType]
+                msg = f"Unexpected kwarg(s): {invalid_kwargs}"
+                raise TypeError(msg)
+
+            library: Backend = marker.args[0]
+            reason: str | None = marker.kwargs.get("reason", None)
+            strict: bool | None = marker.kwargs.get("strict", None)
+
             if library == elem:
-                reason = str(library)
-                with suppress(KeyError):
-                    reason += ":" + cast(str, marker.kwargs["reason"])
-                skip_or_xfail(reason=reason)
+                reason = f"{library}: {reason}" if reason else str(library)  # pyright: ignore[reportUnknownArgumentType]
+                kwargs = {"strict": strict} if strict is not None else {}
+                skip_or_xfail(reason=reason, **kwargs)  # pyright: ignore[reportUnknownArgumentType]
 
     return elem
 
diff --git a/tests/test_at.py b/tests/test_at.py
index 4ccf584e..fa9bcdc8 100644
--- a/tests/test_at.py
+++ b/tests/test_at.py
@@ -115,11 +115,15 @@ def assert_copy(
         pytest.param(
             *(True, 1, 1),
             marks=(
-                pytest.mark.skip_xp_backend(  # test passes when copy=False
-                    Backend.JAX, reason="bool mask update with shaped rhs"
+                pytest.mark.xfail_xp_backend(
+                    Backend.JAX,
+                    reason="bool mask update with shaped rhs",
+                    strict=False,  # test passes when copy=False
                 ),
-                pytest.mark.skip_xp_backend(  # test passes when copy=False
-                    Backend.JAX_GPU, reason="bool mask update with shaped rhs"
+                pytest.mark.xfail_xp_backend(
+                    Backend.JAX_GPU,
+                    reason="bool mask update with shaped rhs",
+                    strict=False,  # test passes when copy=False
                 ),
                 pytest.mark.xfail_xp_backend(
                     Backend.DASK, reason="bool mask update with shaped rhs"
diff --git a/tests/test_funcs.py b/tests/test_funcs.py
index 4e40f09b..652e12ef 100644
--- a/tests/test_funcs.py
+++ b/tests/test_funcs.py
@@ -196,7 +196,7 @@ def test_device(self, xp: ModuleType, device: Device):
         y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
         assert get_device(y) == device
 
-    @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
+    @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
     @pytest.mark.filterwarnings("ignore::RuntimeWarning")  # overflows, etc.
     @hypothesis.settings(
         # The xp and library fixtures are not regenerated between hypothesis iterations
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index ebd4811f..a104e93c 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -27,7 +27,7 @@
 lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))
 
 
-@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
+@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse")
 @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_inverse")
 class TestIn1D:
     # cover both code paths