diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 684ccb68a3..be2444921d 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -115,6 +115,7 @@ jobs:
 
           - |
             tests/backends/test_mcbackend.py
+            tests/backends/test_zarr.py
             tests/distributions/test_truncated.py
             tests/logprob/test_abstract.py
             tests/logprob/test_basic.py
@@ -240,6 +241,7 @@ jobs:
 
           - |
             tests/backends/test_arviz.py
+            tests/backends/test_zarr.py
             tests/variational/test_updates.py
       fail-fast: false
     runs-on: ${{ matrix.os }}
diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml
index 71b6c78ed4..de0572e0a2 100644
--- a/conda-envs/environment-dev.yml
+++ b/conda-envs/environment-dev.yml
@@ -19,6 +19,7 @@ dependencies:
 - scipy>=1.4.1
 - typing-extensions>=3.7.4
 - threadpoolctl>=3.1.0
+- zarr>=2.5.0,<3
 # Extra dependencies for dev, testing and docs build
 - ipython>=7.16
 - jax
diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml
index f795fca078..c399a3e24a 100644
--- a/conda-envs/environment-docs.yml
+++ b/conda-envs/environment-docs.yml
@@ -17,6 +17,7 @@ dependencies:
 - scipy>=1.4.1
 - typing-extensions>=3.7.4
 - threadpoolctl>=3.1.0
+- zarr>=2.5.0,<3
 # Extra dependencies for docs build
 - ipython>=7.16
 - jax
diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml
index 48649a617d..39deb8a41a 100644
--- a/conda-envs/environment-jax.yml
+++ b/conda-envs/environment-jax.yml
@@ -10,6 +10,7 @@ dependencies:
 - cachetools>=4.2.1
 - cloudpickle
 - h5py>=2.7
+- zarr>=2.5.0,<3
 # Jaxlib version must not be greater than jax version!
 - blackjax>=1.2.2
 - jax>=0.4.28
diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml
index e6fe9857e0..79c57a44c6 100644
--- a/conda-envs/environment-test.yml
+++ b/conda-envs/environment-test.yml
@@ -21,6 +21,7 @@ dependencies:
 - scipy>=1.4.1
 - typing-extensions>=3.7.4
 - threadpoolctl>=3.1.0
+- zarr>=2.5.0,<3
 # Extra dependencies for testing
 - ipython>=7.16
 - pre-commit>=2.8.0
diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml
index ee5bd206f4..bbcba9149f 100644
--- a/conda-envs/windows-environment-dev.yml
+++ b/conda-envs/windows-environment-dev.yml
@@ -20,6 +20,7 @@ dependencies:
 - scipy>=1.4.1
 - typing-extensions>=3.7.4
 - threadpoolctl>=3.1.0
+- zarr>=2.5.0,<3
 # Extra dependencies for dev, testing and docs build
 - ipython>=7.16
 - myst-nb<=1.0.0
diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml
index fa59852830..399fab811b 100644
--- a/conda-envs/windows-environment-test.yml
+++ b/conda-envs/windows-environment-test.yml
@@ -23,6 +23,7 @@ dependencies:
 - scipy>=1.4.1
 - typing-extensions>=3.7.4
 - threadpoolctl>=3.1.0
+- zarr>=2.5.0,<3
 # Extra dependencies for testing
 - ipython>=7.16
 - pre-commit>=2.8.0
diff --git a/docs/source/api/backends.rst b/docs/source/api/backends.rst
index ca00a56d81..8f0c76f453 100644
--- a/docs/source/api/backends.rst
+++ b/docs/source/api/backends.rst
@@ -20,3 +20,5 @@ Internal structures
    NDArray
    base.BaseTrace
    base.MultiTrace
+   zarr.ZarrTrace
+   zarr.ZarrChain
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 74ac0d9746..b9afc12e73 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -309,6 +309,7 @@
     "python": ("https://docs.python.org/3/", None),
     "scipy": ("https://docs.scipy.org/doc/scipy/", None),
     "xarray": ("https://docs.xarray.dev/en/stable/", None),
+    "zarr": ("https://zarr.readthedocs.io/en/stable/", None),
 }
 
 
diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py
index cd007cf3c0..8bcba42301 100644
--- a/pymc/backends/__init__.py
+++ b/pymc/backends/__init__.py
@@ -72,9 +72,11 @@
 from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
 from pymc.backends.base import BaseTrace, IBaseTrace
 from pymc.backends.ndarray import NDArray
+from pymc.backends.zarr import ZarrTrace
 from pymc.blocking import PointType
 from pymc.model import Model
 from pymc.step_methods.compound import BlockedStep, CompoundStep
+from pymc.util import get_random_generator
 
 HAS_MCB = False
 try:
@@ -102,11 +104,13 @@ def _init_trace(
     model: Model,
     trace_vars: list[TensorVariable] | None = None,
     initial_point: PointType | None = None,
+    rng: np.random.Generator | None = None,
 ) -> BaseTrace:
     """Initialize a trace backend for a chain."""
+    rng_ = get_random_generator(rng)
     strace: BaseTrace
     if trace is None:
-        strace = NDArray(model=model, vars=trace_vars, test_point=initial_point)
+        strace = NDArray(model=model, vars=trace_vars, test_point=initial_point, rng=rng_)
     elif isinstance(trace, BaseTrace):
         if len(trace) > 0:
             raise ValueError("Continuation of traces is no longer supported.")
@@ -120,15 +124,29 @@ def _init_trace(
 
 def init_traces(
     *,
-    backend: TraceOrBackend | None,
+    backend: TraceOrBackend | ZarrTrace | None,
     chains: int,
     expected_length: int,
     step: BlockedStep | CompoundStep,
     initial_point: PointType,
     model: Model,
     trace_vars: list[TensorVariable] | None = None,
+    tune: int = 0,
+    rng: np.random.Generator | None = None,
 ) -> tuple[RunType | None, Sequence[IBaseTrace]]:
     """Initialize a trace recorder for each chain."""
+    if isinstance(backend, ZarrTrace):
+        backend.init_trace(
+            chains=chains,
+            draws=expected_length - tune,
+            tune=tune,
+            step=step,
+            model=model,
+            vars=trace_vars,
+            test_point=initial_point,
+            rng=rng,
+        )
+        return None, backend.straces
     if HAS_MCB and isinstance(backend, Backend):
         return init_chain_adapters(
             backend=backend,
@@ -136,6 +154,7 @@ def init_traces(
             initial_point=initial_point,
             step=step,
             model=model,
+            rng=rng,
         )
 
     assert backend is None or isinstance(backend, BaseTrace)
@@ -148,7 +167,8 @@ def init_traces(
             model=model,
             trace_vars=trace_vars,
             initial_point=initial_point,
+            rng=rng_,
         )
-        for chain_number in range(chains)
+        for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains))
     ]
     return None, traces
diff --git a/pymc/backends/base.py b/pymc/backends/base.py
index 5a2a043a39..1188efbfaf 100644
--- a/pymc/backends/base.py
+++ b/pymc/backends/base.py
@@ -34,7 +34,7 @@
 
 from pymc.backends.report import SamplerReport
 from pymc.model import modelcontext
-from pymc.pytensorf import compile
+from pymc.pytensorf import compile, copy_function_with_new_rngs
 from pymc.util import get_var_name
 
 logger = logging.getLogger(__name__)
@@ -159,6 +159,7 @@ def __init__(
         fn=None,
         var_shapes=None,
         var_dtypes=None,
+        rng=None,
     ):
         model = modelcontext(model)
 
@@ -177,6 +178,8 @@ def __init__(
                 on_unused_input="ignore",
             )
             fn.trust_input = True
+        if rng is not None:
+            fn = copy_function_with_new_rngs(fn=fn, rng=rng)
 
         # Get variable shapes. Most backends will need this
         # information.
diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py
index 3d2c8fd9e7..b6342a2182 100644
--- a/pymc/backends/mcbackend.py
+++ b/pymc/backends/mcbackend.py
@@ -29,7 +29,7 @@
 
 from pymc.backends.base import IBaseTrace
 from pymc.model import Model
-from pymc.pytensorf import PointFunc
+from pymc.pytensorf import PointFunc, copy_function_with_new_rngs
 from pymc.step_methods.compound import (
     BlockedStep,
     CompoundStep,
@@ -38,6 +38,7 @@
     flat_statname,
     flatten_steps,
 )
+from pymc.util import get_random_generator
 
 _log = logging.getLogger(__name__)
 
@@ -96,7 +97,11 @@ class ChainRecordAdapter(IBaseTrace):
     """Wraps an McBackend ``Chain`` as an ``IBaseTrace``."""
 
     def __init__(
-        self, chain: mcb.Chain, point_fn: PointFunc, stats_bijection: StatsBijection
+        self,
+        chain: mcb.Chain,
+        point_fn: PointFunc,
+        stats_bijection: StatsBijection,
+        rng: np.random.Generator | None = None,
     ) -> None:
         # Assign attributes required by IBaseTrace
         self.chain = chain.cmeta.chain_number
@@ -107,8 +112,11 @@ def __init__(
             for sstats in stats_bijection._stat_groups
         ]
 
+        self._rng = rng
         self._chain = chain
         self._point_fn = point_fn
+        if rng is not None:
+            self._point_fn = copy_function_with_new_rngs(self._point_fn, rng)
         self._statsbj = stats_bijection
         super().__init__()
 
@@ -257,6 +265,7 @@ def init_chain_adapters(
     initial_point: Mapping[str, np.ndarray],
     step: CompoundStep | BlockedStep,
     model: Model,
+    rng: np.random.Generator | None,
 ) -> tuple[mcb.Run, list[ChainRecordAdapter]]:
     """Create an McBackend metadata description for the MCMC run.
 
@@ -286,7 +295,8 @@ def init_chain_adapters(
             chain=run.init_chain(chain_number=chain_number),
             point_fn=point_fn,
             stats_bijection=statsbj,
+            rng=rng_,
         )
-        for chain_number in range(chains)
+        for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains))
     ]
     return run, adapters
diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py
new file mode 100644
index 0000000000..2fb7134303
--- /dev/null
+++ b/pymc/backends/zarr.py
@@ -0,0 +1,867 @@
+#   Copyright 2024 The PyMC Developers
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+from collections.abc import Callable, Mapping, MutableMapping, Sequence
+from typing import Any
+
+import arviz as az
+import numcodecs
+import numpy as np
+import xarray as xr
+import zarr
+
+from arviz.data.base import make_attrs
+from arviz.data.inference_data import WARMUP_TAG
+from numcodecs.abc import Codec
+from pytensor.tensor.variable import TensorVariable
+
+import pymc
+
+from pymc.backends.arviz import (
+    coords_and_dims_for_inferencedata,
+    find_constants,
+    find_observations,
+)
+from pymc.backends.base import BaseTrace
+from pymc.blocking import StatDtype, StatShape
+from pymc.model.core import Model, modelcontext
+from pymc.pytensorf import copy_function_with_new_rngs
+from pymc.step_methods.compound import (
+    BlockedStep,
+    CompoundStep,
+    StatsBijection,
+    get_stats_dtypes_shapes_from_steps,
+)
+from pymc.util import (
+    UNSET,
+    _UnsetType,
+    get_default_varnames,
+    get_random_generator,
+    is_transformed_name,
+)
+
+try:
+    from zarr.storage import BaseStore, default_compressor
+    from zarr.sync import Synchronizer
+
+    _zarr_available = True
+except ImportError:
+    _zarr_available = False
+
+
+class ZarrChain(BaseTrace):
+    """Interface object to interact with a single chain in a :class:`~.ZarrTrace`.
+
+    Parameters
+    ----------
+    store : zarr.storage.BaseStore | collections.abc.MutableMapping
+        The store object where the zarr groups and arrays will be stored and read from.
+        This store must exist before creating a ``ZarrChain`` object. ``ZarrChain`` are
+        only intended to be used as interfaces to the individual chains of
+        :class:`~.ZarrTrace` objects. This means that the :class:`~.ZarrTrace` should
+        be the one that creates the store that is then provided to a ``ZarrChain``.
+    stats_bijection : pymc.step_methods.compound.StatsBijection
+        An object that maps between a list of step method stats and a dictionary of
+        said stats with the accompanying stepper index.
+    synchronizer : zarr.sync.Synchronizer | None
+        The synchronizer to use for the underlying zarr arrays.
+    model : Model
+        If None, the model is taken from the `with` context.
+    vars : Sequence[TensorVariable] | None
+        Sampling values will be stored for these variables. If None,
+        `model.unobserved_RVs` is used.
+    test_point : dict[str, numpy.ndarray] | None
+        This is not used and is inherited from the signature of :class:`~.BaseTrace`,
+        which uses it to determine the shape and dtype of `vars`.
+    draws_per_chunk : int
+        The number of draws that make up a chunk in the variable's posterior array.
+        The interface only writes the samples to the store once a chunk is completely
+        filled.
+    """
+
+    def __init__(
+        self,
+        store: BaseStore | MutableMapping,
+        stats_bijection: StatsBijection,
+        synchronizer: Synchronizer | None = None,
+        model: Model | None = None,
+        vars: Sequence[TensorVariable] | None = None,
+        test_point: dict[str, np.ndarray] | None = None,
+        draws_per_chunk: int = 1,
+        fn: Callable | None = None,
+    ):
+        if not _zarr_available:
+            raise RuntimeError("You must install zarr to be able to create ZarrChain instances")
+        super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn)
+        self._step_method: BlockedStep | CompoundStep | None = None
+        self.unconstrained_variables = {
+            var.name for var in self.vars if is_transformed_name(var.name)
+        }
+        self.draw_idx = 0
+        self._buffers: dict[str, dict[str, list]] = {
+            "posterior": {},
+            "sample_stats": {},
+        }
+        self._buffered_draws = 0
+        self.draws_per_chunk = int(draws_per_chunk)
+        assert self.draws_per_chunk > 0
+        self._posterior = zarr.open_group(
+            store, synchronizer=synchronizer, path="posterior", mode="a"
+        )
+        if self.unconstrained_variables:
+            self._unconstrained_posterior = zarr.open_group(
+                store, synchronizer=synchronizer, path="unconstrained_posterior", mode="a"
+            )
+            self._buffers["unconstrained_posterior"] = {}
+        self._sample_stats = zarr.open_group(
+            store, synchronizer=synchronizer, path="sample_stats", mode="a"
+        )
+        self._sampling_state = zarr.open_group(
+            store, synchronizer=synchronizer, path="_sampling_state", mode="a"
+        )
+        self.stats_bijection = stats_bijection
+
+    def link_stepper(self, step_method: BlockedStep | CompoundStep):
+        """Provide a reference to the step method used during sampling.
+
+        This reference can be used to facilite writing the stepper's sampling state
+        each time the samples are flushed into the storage.
+        """
+        self._step_method = step_method
+
+    def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None):  # type: ignore[override]
+        self.chain = chain
+        self.total_draws = draws
+        self.draws_until_flush = min([self.draws_per_chunk, draws - self.draw_idx])
+        self.clear_buffers()
+
+    def clear_buffers(self):
+        for group in self._buffers:
+            self._buffers[group] = {}
+        self._buffered_draws = 0
+
+    def buffer(self, group, var_name, value):
+        buffer = self._buffers[group]
+        if var_name not in buffer:
+            buffer[var_name] = []
+        buffer[var_name].append(value)
+
+    def record(
+        self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]
+    ) -> bool | None:
+        """Record the step method's returned draw and stats.
+
+        The draws and stats are first stored in an internal buffer. Once the buffer is
+        filled, the samples and stats are written (flushed) onto the desired zarr store.
+
+        Returns
+        -------
+        flushed : bool | None
+            Returns ``True`` only if the data was written onto the desired zarr store.
+            Any other time that the recorded draw and stats are written into the
+            internal buffer, ``None`` is returned.
+
+        See Also
+        --------
+        :meth:`~ZarrChain.flush`
+        """
+        unconstrained_variables = self.unconstrained_variables
+        for var_name, var_value in zip(self.varnames, self.fn(**draw)):
+            if var_name in unconstrained_variables:
+                self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value)
+            else:
+                self.buffer(group="posterior", var_name=var_name, value=var_value)
+        for var_name, var_value in self.stats_bijection.map(stats).items():
+            self.buffer(group="sample_stats", var_name=var_name, value=var_value)
+        self._buffered_draws += 1
+        if self._buffered_draws == self.draws_until_flush:
+            self.flush()
+            return True
+        return None
+
+    def record_sampling_state(self, step: BlockedStep | CompoundStep | None = None):
+        """Record the sampling state information to the store's ``_sampling_state`` group.
+
+        The sampling state includes the number of draws taken so far (``draw_idx``) and
+        the step method's ``sampling_state``.
+
+        Parameters
+        ----------
+        step : BlockedStep | CompoundStep | None
+            The step method from which to take the ``sampling_state``. If ``None``,
+            the ``step`` is taken to be the step method that was linked to the
+            ``ZarrChain`` when calling :meth:`~ZarrChain.link_stepper`. If this method was never
+            called, no step method ``sampling_state`` information is stored in the
+            chain.
+        """
+        if step is None:
+            step = self._step_method
+        if step is not None:
+            self.store_sampling_state(step.sampling_state)
+        self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx)
+
+    def store_sampling_state(self, sampling_state):
+        self._sampling_state.sampling_state.set_coordinate_selection(
+            self.chain, np.array([sampling_state], dtype="object")
+        )
+
+    def flush(self):
+        """Write the data stored in the internal buffer to the desired zarr store.
+
+        After writing the draws and stats returned by each step of the step method,
+        the :meth:`~ZarrChain.record_sampling_state` is called, the internal buffer is cleared and
+        the number of steps until the next flush is determined.
+        """
+        chain = self.chain
+        draw_slice = slice(self.draw_idx, self.draw_idx + self.draws_until_flush)
+        for group_name, buffer in self._buffers.items():
+            group = getattr(self, f"_{group_name}")
+            for var_name, var_value in buffer.items():
+                group[var_name].set_orthogonal_selection(
+                    (chain, draw_slice),
+                    np.stack(var_value),
+                )
+        self.draw_idx += self.draws_until_flush
+        self.record_sampling_state()
+        self.clear_buffers()
+        self.draws_until_flush = min([self.draws_per_chunk, self.total_draws - self.draw_idx])
+
+
+FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None
+DEFAULT_FILL_VALUES: dict[Any, FILL_VALUE_TYPE] = {
+    np.floating: np.nan,
+    np.integer: 0,
+    np.bool_: False,
+    np.str_: "",
+    np.datetime64: np.datetime64(0, "Y"),
+    np.timedelta64: np.timedelta64(0, "Y"),
+}
+
+
+def get_initial_fill_value_and_codec(
+    dtype: Any,
+) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
+    _dtype = np.dtype(dtype)
+    fill_value: FILL_VALUE_TYPE = None
+    codec = None
+    try:
+        fill_value = DEFAULT_FILL_VALUES[_dtype]
+    except KeyError:
+        for key in DEFAULT_FILL_VALUES:
+            if np.issubdtype(_dtype, key):
+                fill_value = DEFAULT_FILL_VALUES[key]
+                break
+        else:
+            codec = numcodecs.Pickle()
+    return fill_value, _dtype, codec
+
+
+class ZarrTrace:
+    """Object that stores and enables access to MCMC draws stored in a :class:`zarr.hierarchy.Group` objects.
+
+    This class creats a zarr hierarchy to represent the sampling information which is
+    intended to mimic :class:`arviz.InferenceData`. The hierarchy looks like this:
+
+    | root
+    | |--> constant_data
+    | |--> observed_data
+    | |--> posterior
+    | |--> unconstrained_posterior
+    | |--> sample_stats
+    | |--> warmup_posterior
+    | |--> warmup_unconstrained_posterior
+    | |--> warmup_sample_stats
+    | |--> _sampling_state
+
+    The root group is created when the ``ZarrTrace`` object is initialized. The rest of
+    the groups are created once :meth:`~ZarrChain.init_trace` is called with a few exceptions:
+    unconstrained_posterior is only created if ``include_transformed = True``, and the
+    groups prefixed with ``warmup_`` are created only after calling
+    :meth:`~ZarrTrace.split_warmup_groups`.
+
+    Since ``ZarrTrace`` objects are intended to be as close to
+    :class:`arviz.InferenceData` objects as possible, the groups store the dimension
+    and coordinate information following the `xarray zarr standard <https://xarray.pydata.org/en/v2023.11.0/internals/zarr-encoding-spec.html>`_.
+
+    Parameters
+    ----------
+    store : zarr.storage.BaseStore | collections.abc.MutableMapping | None
+        The store object where the zarr groups and arrays will be stored and read from.
+        Any zarr compatible storage object works. Keep in mind that if ``None`` is
+        provided, a :class:`zarr.storage.MemoryStore` will be used, which means that
+        information won't be visible to other processes and won't persist after the
+        ``ZarrTrace`` life-cycle ends. If you want to have persistent storage, please
+        use one of the multiple disk backed zarr storage options, e.g.
+        :class:`~zarr.storage.DirectoryStore` or :class:`~zarr.storage.ZipStore`.
+    synchronizer : zarr.sync.Synchronizer | None
+        The synchronizer to use for the underlying zarr arrays.
+    compressor : numcodec.abc.Codec | None | pymc.util.UNSET
+        The compressor to use for the underlying zarr arrays. If ``None``, no compressor
+        is used. If ``UNSET``, zarr's default compressor is used.
+    draws_per_chunk : int
+        The number of draws that make up a chunk in the variable's posterior array.
+        Each variable's array shape is set to ``(n_chains, n_draws, *rv_shape)``, but
+        the chunks are set to ``(1, draws_per_chunk, *rv_shape)``. This means that each
+        chain will have it's own chunk to read or write to, allowing for concurrent
+        write operations of different chains not to interfere with each other, and that
+        multiple draws can belong to the same chunk. The variable's core dimension
+        however, will never be split across different chunks.
+    include_transformed : bool
+        If ``True``, the transformed, unconstrained value variables are included in the
+        storage group.
+
+    Notes
+    -----
+    ``ZarrTrace`` objects represent the storage information. If the underlying store
+    persists on disk or over the network (e.g. with a :class:`zarr.storage.FSStore`)
+    multiple process will be able to concurrently access the same storage and read or
+    write to it.
+
+    The intended division of labour is for ``ZarrTrace`` to handle the creation and
+    management of the zarr group and storage objects and arrays, and for individual
+    :class:`~.ZarrChain` objects to handle recording MCMC samples to the trace. This
+    division was chosen to stay close to the existing `pymc.backends.base.MultiTrace`
+    and `pymc.backends.ndarray.NDArray` way of working with the existing samplers.
+
+    One extra feature of ``ZarrTrace`` is that it enables direct access to any array's
+    metadata. ``ZarrTrace`` takes advantage of this to tag arrays as ``deterministic``
+    or ``freeRV`` depending on what kind of variable they were in the defining model.
+
+    See Also
+    --------
+    :class:`~pymc.backends.zarr.ZarrChain`
+    """
+
+    def __init__(
+        self,
+        store: BaseStore | MutableMapping | None = None,
+        synchronizer: Synchronizer | None = None,
+        compressor: Codec | None | _UnsetType = UNSET,
+        draws_per_chunk: int = 1,
+        include_transformed: bool = False,
+    ):
+        if not _zarr_available:
+            raise RuntimeError("You must install zarr to be able to create ZarrTrace instances")
+        self.synchronizer = synchronizer
+        if compressor is UNSET:
+            compressor = default_compressor
+        self.compressor = compressor
+        self.root = zarr.group(
+            store=store,
+            overwrite=True,
+            synchronizer=synchronizer,
+        )
+
+        self.draws_per_chunk = int(draws_per_chunk)
+        assert self.draws_per_chunk >= 1
+
+        self.include_transformed = include_transformed
+
+        self._is_base_setup = False
+
+    def groups(self) -> list[str]:
+        return [str(group_name) for group_name, _ in self.root.groups()]
+
+    @property
+    def posterior(self) -> zarr.Group:
+        return self.root.posterior
+
+    @property
+    def unconstrained_posterior(self) -> zarr.Group:
+        return self.root.unconstrained_posterior
+
+    @property
+    def sample_stats(self) -> zarr.Group:
+        return self.root.sample_stats
+
+    @property
+    def constant_data(self) -> zarr.Group:
+        return self.root.constant_data
+
+    @property
+    def observed_data(self) -> zarr.Group:
+        return self.root.observed_data
+
+    @property
+    def _sampling_state(self) -> zarr.Group:
+        return self.root._sampling_state
+
+    def init_trace(
+        self,
+        chains: int,
+        draws: int,
+        tune: int,
+        step: BlockedStep | CompoundStep,
+        model: Model | None = None,
+        vars: Sequence[TensorVariable] | None = None,
+        test_point: dict[str, np.ndarray] | None = None,
+        rng: np.random.Generator | None = None,
+    ):
+        """Initialize the trace groups and arrays.
+
+        This function creates and fills with default values the groups below the
+        ``ZarrTrace.root`` group. It creates the ``constant_data``, ``observed_data``,
+        ``posterior``, ``unconstrained_posterior`` (if ``include_transformed = True``),
+        ``sample_stats``, and ``_sampling_state`` zarr groups, and all of the relevant
+        arrays that must be stored there.
+
+        Every array in the posterior and sample stats groups will have the
+        (chains, tune + draws) batch dimensions to the left of the core dimensions of
+        the model's random variable or the step method's stat shape. The warmup (tuning
+        draws) and the posterior samples are split at a later stage, once
+        :meth:`~ZarrTrace.split_warmup_groups` is called.
+
+        After the creation if the zarr hierarchies, it initializes the list of
+        :class:`~pymc.backends.zarr.Zarrchain` instances (one for each chain) under the
+        ``straces`` attribute. These objects serve as the interface to record draws and
+        samples generated by the step methods for each chain.
+
+        Parameters
+        ----------
+        chains : int
+            The number of chains to use to initialize the arrays.
+        draws : int
+            The number of posterior draws to use to initialize the arrays.
+        tune : int
+            The number of tuning steps to use to initialize the arrays.
+        step : pymc.step_methods.compound.BlockedStep | pymc.step_methods.compound.CompoundStep
+            The step method that will be used to generate the draws and stats.
+        model : pymc.model.core.Model | None
+            If None, the model is taken from the ``with`` context.
+        vars : Sequence[TensorVariable] | None
+            Sampling values will be stored for these variables. If ``None``,
+            ``model.unobserved_RVs`` is used.
+        test_point : dict[str, numpy.ndarray] | None
+            This is not used and is a product of the inheritance of :class:`ZarrChain`
+            from :class:`~.BaseTrace`, which uses it to determine the shape and dtype
+            of `vars`.
+        rng : numpy.random.Generator | None
+            A random generator to use to seed the shared random generators that are
+            present in the pytensor function that maps samples drawn by step methods
+            onto samples in the posterior trace. Note that this only does anything
+            if there are deterministic variables that are generated by raw pytensor
+            random variables.
+        """
+        if self._is_base_setup:
+            raise RuntimeError("The ZarrTrace has already been initialized")  # pragma: no cover
+        model = modelcontext(model)
+        self.model = model
+        self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model)
+        if vars is None:
+            vars = model.unobserved_value_vars
+
+        unnamed_vars = {var for var in vars if var.name is None}
+        assert not unnamed_vars, f"Can't trace unnamed variables: {unnamed_vars}"
+        self.varnames = get_default_varnames(
+            [var.name for var in vars], include_transformed=self.include_transformed
+        )
+        self.vars = [var for var in vars if var.name in self.varnames]
+
+        self.fn = model.compile_fn(
+            self.vars,
+            inputs=model.value_vars,
+            on_unused_input="ignore",
+            point_fn=False,
+        )
+
+        # Get variable shapes. Most backends will need this
+        # information.
+        if test_point is None:
+            test_point = model.initial_point()
+        var_values = list(zip(self.varnames, self.fn(**test_point)))
+        self.var_dtype_shapes = {
+            var: (value.dtype, value.shape)
+            for var, value in var_values
+            if not is_transformed_name(var)
+        }
+        extra_var_attrs = {
+            var: {
+                "kind": "freeRV"
+                if is_transformed_name(var) or model[var] in model.free_RVs
+                else "deterministic"
+            }
+            for var in self.var_dtype_shapes
+        }
+        self.unc_var_dtype_shapes = {
+            var: (value.dtype, value.shape) for var, value in var_values if is_transformed_name(var)
+        }
+        extra_unc_var_attrs = {var: {"kind": "freeRV"} for var in self.unc_var_dtype_shapes}
+
+        self.create_group(
+            name="constant_data",
+            data_dict=find_constants(self.model),
+        )
+
+        self.create_group(
+            name="observed_data",
+            data_dict=find_observations(self.model),
+        )
+
+        # Create the posterior that includes warmup draws
+        self.init_group_with_empty(
+            group=self.root.create_group(name="posterior", overwrite=True),
+            var_dtype_and_shape=self.var_dtype_shapes,
+            chains=chains,
+            draws=tune + draws,
+            extra_var_attrs=extra_var_attrs,
+        )
+
+        # Create the unconstrained posterior group that includes warmup draws
+        if self.include_transformed and self.unc_var_dtype_shapes:
+            self.init_group_with_empty(
+                group=self.root.create_group(name="unconstrained_posterior", overwrite=True),
+                var_dtype_and_shape=self.unc_var_dtype_shapes,
+                chains=chains,
+                draws=tune + draws,
+                extra_var_attrs=extra_unc_var_attrs,
+            )
+
+        # Create the sample stats that include warmup draws
+        stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(
+            [step] if isinstance(step, BlockedStep) else step.methods
+        )
+        self.init_group_with_empty(
+            group=self.root.create_group(name="sample_stats", overwrite=True),
+            var_dtype_and_shape=stats_dtypes_shapes,
+            chains=chains,
+            draws=tune + draws,
+        )
+
+        self.init_sampling_state_group(tune=tune, chains=chains)
+
+        self.straces = [
+            ZarrChain(
+                store=self.root.store,
+                synchronizer=self.synchronizer,
+                model=self.model,
+                vars=self.vars,
+                test_point=test_point,
+                stats_bijection=StatsBijection(step.stats_dtypes),
+                draws_per_chunk=self.draws_per_chunk,
+                fn=copy_function_with_new_rngs(self.fn, rng_),
+            )
+            for rng_ in get_random_generator(rng).spawn(chains)
+        ]
+        for chain, strace in enumerate(self.straces):
+            strace.setup(draws=tune + draws, chain=chain, sampler_vars=None)
+
+    def split_warmup_groups(self):
+        """Split the warmup and standard groups.
+
+        This method takes the entries in the arrays in the posterior, sample_stats
+        and unconstrained_posterior that happened in the tuning phase and moves them
+        into the warmup_ groups. If the ``warmup_posterior`` group already exists, then
+        nothing is done.
+
+        See Also
+        --------
+        :meth:`~ZarrTrace.split_warmup`
+        """
+        if "warmup_posterior" not in self.groups():
+            self.split_warmup("posterior", error_if_already_split=False)
+            self.split_warmup("sample_stats", error_if_already_split=False)
+            try:
+                self.split_warmup("unconstrained_posterior", error_if_already_split=False)
+            except KeyError:
+                pass
+
+    @property
+    def tuning_steps(self):
+        try:
+            return int(self._sampling_state.tuning_steps.get_basic_selection())
+        except AttributeError:  # pragma: no cover
+            raise ValueError(
+                "ZarrTrace has not been initialized and there is no tuning step information available"
+            )
+
+    @property
+    def sampling_time(self):
+        try:
+            return float(self._sampling_state.sampling_time.get_basic_selection())
+        except AttributeError:  # pragma: no cover
+            raise ValueError(
+                "ZarrTrace has not been initialized and there is no sampling time information available"
+            )
+
+    @sampling_time.setter
+    def sampling_time(self, value):
+        self._sampling_state.sampling_time.set_basic_selection((), float(value))
+
+    def init_sampling_state_group(self, tune: int, chains: int):
+        state = self.root.create_group(name="_sampling_state", overwrite=True)
+        sampling_state = state.empty(
+            name="sampling_state",
+            overwrite=True,
+            shape=(chains,),
+            chunks=(1,),
+            dtype="object",
+            object_codec=numcodecs.Pickle(),
+            compressor=self.compressor,
+        )
+        sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
+        draw_idx = state.array(
+            name="draw_idx",
+            overwrite=True,
+            data=np.zeros(chains, dtype="int"),
+            chunks=(1,),
+            dtype="int",
+            fill_value=-1,
+            compressor=self.compressor,
+        )
+        draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
+
+        state.array(
+            name="tuning_steps",
+            data=tune,
+            overwrite=True,
+            dtype="int",
+            fill_value=0,
+            compressor=self.compressor,
+        )
+        state.array(
+            name="sampling_time",
+            data=0.0,
+            dtype="float",
+            fill_value=0.0,
+            compressor=self.compressor,
+        )
+        state.array(
+            name="sampling_start_time",
+            data=0.0,
+            dtype="float",
+            fill_value=0.0,
+            compressor=self.compressor,
+        )
+
+        chain = state.array(
+            name="chain",
+            data=np.arange(chains),
+            compressor=self.compressor,
+        )
+
+        chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
+
+        state.empty(
+            name="global_warnings",
+            dtype="object",
+            object_codec=numcodecs.Pickle(),
+            shape=(0,),
+        )
+
+    def init_group_with_empty(
+        self,
+        group: zarr.Group,
+        var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]],
+        chains: int,
+        draws: int,
+        extra_var_attrs: dict | None = None,
+    ) -> zarr.Group:
+        group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)}
+        for name, (_dtype, shape) in var_dtype_and_shape.items():
+            fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype)
+            shape = shape or ()
+            array = group.full(
+                name=name,
+                dtype=dtype,
+                fill_value=fill_value,
+                object_codec=object_codec,
+                shape=(chains, draws, *shape),
+                chunks=(1, self.draws_per_chunk, *shape),
+                compressor=self.compressor,
+            )
+            try:
+                dims = self.vars_to_dims[name]
+                for dim in dims:
+                    group_coords[dim] = self.coords[dim]
+            except KeyError:
+                dims = []
+                for i, shape_i in enumerate(shape):
+                    dim = f"{name}_dim_{i}"
+                    dims.append(dim)
+                    group_coords[dim] = np.arange(shape_i, dtype="int")
+            dims = ("chain", "draw", *dims)
+            attrs = extra_var_attrs[name] if extra_var_attrs is not None else {}
+            attrs.update({"_ARRAY_DIMENSIONS": dims})
+            array.attrs.update(attrs)
+        for dim, coord in group_coords.items():
+            array = group.array(
+                name=dim,
+                data=coord,
+                fill_value=None,
+                compressor=self.compressor,
+            )
+            array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
+        return group
+
+    def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None:
+        group: zarr.Group | None = None
+        if data_dict:
+            group_coords = {}
+            group = self.root.create_group(name=name, overwrite=True)
+            for var_name, var_value in data_dict.items():
+                fill_value, dtype, object_codec = get_initial_fill_value_and_codec(var_value.dtype)
+                array = group.array(
+                    name=var_name,
+                    data=var_value,
+                    fill_value=fill_value,
+                    dtype=dtype,
+                    object_codec=object_codec,
+                    compressor=self.compressor,
+                )
+                try:
+                    dims = self.vars_to_dims[var_name]
+                    for dim in dims:
+                        group_coords[dim] = self.coords[dim]
+                except KeyError:
+                    dims = []
+                    for i in range(var_value.ndim):
+                        dim = f"{var_name}_dim_{i}"
+                        dims.append(dim)
+                        group_coords[dim] = np.arange(var_value.shape[i], dtype="int")
+                array.attrs.update({"_ARRAY_DIMENSIONS": dims})
+            for dim, coord in group_coords.items():
+                array = group.array(
+                    name=dim,
+                    data=coord,
+                    fill_value=None,
+                    compressor=self.compressor,
+                )
+                array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
+        return group
+
+    def split_warmup(self, group_name: str, error_if_already_split: bool = True):
+        """Split the arrays of a group into the warmup and regular groups.
+
+        This function takes the first ``self.tuning_steps`` draws of supplied
+        ``group_name`` and moves them into a new zarr group called
+        ``f"warmup_{group_name}"``.
+
+        Parameters
+        ----------
+        group_name : str
+            The name of the group that should be split.
+        error_if_already_split : bool
+            If ``True`` and if the ``f"warmup_{group_name}"`` group already exists in
+            the root hierarchy, a ``ValueError`` is raised. If this flag is ``False``
+            but the warmup group already exists, the contents of that group are
+            overwritten.
+        """
+        if error_if_already_split and f"{WARMUP_TAG}{group_name}" in {
+            group_name for group_name, _ in self.root.groups()
+        }:
+            raise RuntimeError(f"Warmup data for {group_name} has already been split")
+        posterior_group = self.root[group_name]
+        tune = self.tuning_steps
+        warmup_group = self.root.create_group(f"{WARMUP_TAG}{group_name}", overwrite=True)
+        if tune == 0:
+            try:
+                self.root.pop(f"{WARMUP_TAG}{group_name}")
+            except KeyError:
+                pass
+            return
+        for name, array in posterior_group.arrays():
+            array_attrs = array.attrs.asdict()
+            if name == "draw":
+                warmup_array = warmup_group.array(
+                    name="draw",
+                    data=np.arange(tune),
+                    dtype="int",
+                    compressor=self.compressor,
+                )
+                posterior_array = posterior_group.array(
+                    name=name,
+                    data=np.arange(len(array) - tune),
+                    dtype="int",
+                    overwrite=True,
+                    compressor=self.compressor,
+                )
+                posterior_array.attrs.update(array_attrs)
+            else:
+                dims = array.attrs["_ARRAY_DIMENSIONS"]
+                warmup_idx: slice | tuple[slice, slice]
+                if len(dims) >= 2 and dims[:2] == ["chain", "draw"]:
+                    must_overwrite_posterior = True
+                    warmup_idx = (slice(None), slice(None, tune, None))
+                    posterior_idx = (slice(None), slice(tune, None, None))
+                else:
+                    must_overwrite_posterior = False
+                    warmup_idx = slice(None)
+                fill_value, dtype, object_codec = get_initial_fill_value_and_codec(array.dtype)
+                warmup_array = warmup_group.array(
+                    name=name,
+                    data=array[warmup_idx],
+                    chunks=array.chunks,
+                    dtype=dtype,
+                    fill_value=fill_value,
+                    object_codec=object_codec,
+                    compressor=self.compressor,
+                )
+                if must_overwrite_posterior:
+                    posterior_array = posterior_group.array(
+                        name=name,
+                        data=array[posterior_idx],
+                        chunks=array.chunks,
+                        dtype=dtype,
+                        fill_value=fill_value,
+                        object_codec=object_codec,
+                        overwrite=True,
+                        compressor=self.compressor,
+                    )
+                    posterior_array.attrs.update(array_attrs)
+            warmup_array.attrs.update(array_attrs)
+
+    def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData:
+        """Convert ``ZarrTrace`` to :class:`~.arviz.InferenceData`.
+
+        This converts all the groups in the ``ZarrTrace.root`` hierarchy into an
+        ``InferenceData`` object. The only exception is that ``_sampling_state`` is
+        excluded.
+
+        Parameters
+        ----------
+        save_warmup : bool
+            If ``True``, all of the warmup groups are stored in the inference data
+            object.
+
+        Notes
+        -----
+        ``xarray`` and in turn ``arviz`` require the zarr groups to have consolidated
+        metadata. To achieve this, a new consolidated store is constructed by calling
+        :func:`zarr.consolidate_metadata` on the root's store. This means that the
+        returned ``InferenceData`` object will operate on a different storage unit
+        than the calling ``ZarrTrace``, so future changes to the ``ZarrTrace`` won't be
+        automatically reflected in the returned ``InferenceData`` object.
+        """
+        self.split_warmup_groups()
+        # Xarray complains if we try to open a zarr hierarchy that doesn't have consolidated metadata
+        consolidated_root = zarr.consolidate_metadata(self.root.store)
+        # The ConsolidatedMetadataStore looks like an empty store from xarray's point of view
+        # we need to actually grab the underlying store so that xarray doesn't produce completely
+        # empty arrays
+        store = consolidated_root.store.store
+        groups = {}
+        try:
+            global_attrs = {
+                "tuning_steps": self.tuning_steps,
+                "sampling_time": self.sampling_time,
+            }
+        except AttributeError:
+            global_attrs = {}  # pragma: no cover
+        for name, _ in self.root.groups():
+            if name.startswith("_") or (not save_warmup and name.startswith(WARMUP_TAG)):
+                continue
+            data = xr.open_zarr(store, group=name, mask_and_scale=False)
+            attrs = {**data.attrs, **global_attrs}
+            data.attrs = make_attrs(attrs=attrs, library=pymc)
+            groups[name] = data.load() if az.rcParams["data.load"] == "eager" else data
+        return az.InferenceData(**groups)
diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py
index f665d5931c..7e07cfa5b5 100644
--- a/pymc/pytensorf.py
+++ b/pymc/pytensorf.py
@@ -14,7 +14,7 @@
 import warnings
 
 from collections.abc import Callable, Generator, Iterable, Sequence
-from typing import cast
+from typing import cast, overload
 
 import numpy as np
 import pandas as pd
@@ -22,6 +22,7 @@
 import pytensor.tensor as pt
 import scipy.sparse as sps
 
+from pytensor import shared
 from pytensor.compile import Function, Mode, get_mode
 from pytensor.compile.builders import OpFromGraph
 from pytensor.gradient import grad
@@ -37,12 +38,13 @@
 )
 from pytensor.graph.fg import FunctionGraph, Output
 from pytensor.graph.op import Op
+from pytensor.link.jax.linker import JAXLinker
 from pytensor.scalar.basic import Cast
 from pytensor.scan.op import Scan
 from pytensor.tensor.basic import _as_tensor_variable
 from pytensor.tensor.elemwise import Elemwise
 from pytensor.tensor.random.op import RandomVariable
-from pytensor.tensor.random.type import RandomType
+from pytensor.tensor.random.type import RandomGeneratorType, RandomType
 from pytensor.tensor.random.var import RandomGeneratorSharedVariable
 from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
 from pytensor.tensor.rewriting.shape import ShapeFeature
@@ -51,7 +53,7 @@
 from pytensor.tensor.variable import TensorVariable
 
 from pymc.exceptions import NotConstantValueError
-from pymc.util import makeiter
+from pymc.util import RandomGeneratorState, makeiter, random_generator_from_state
 from pymc.vartypes import continuous_types, isgenerator, typefilter
 
 PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable
@@ -1163,3 +1165,64 @@ def normalize_rng_param(rng: None | Variable) -> Variable:
             "The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
         )
     return rng
+
+
+@overload
+def copy_function_with_new_rngs(
+    fn: PointFunc, rng: np.random.Generator | RandomGeneratorState
+) -> PointFunc: ...
+
+
+@overload
+def copy_function_with_new_rngs(
+    fn: Function, rng: np.random.Generator | RandomGeneratorState
+) -> Function: ...
+
+
+def copy_function_with_new_rngs(
+    fn: Function, rng: np.random.Generator | RandomGeneratorState
+) -> Function:
+    """Copy a compiled pytensor function and replace the random Generators with spawns.
+
+    Parameters
+    ----------
+    fn : pytensor.compile.function.types.Function | pymc.util.PointFunc
+        The compiled function
+    rng : numpy.random.Generator | RandomGeneratorState
+        The random generator or its state
+
+    Returns
+    -------
+    fn_out : pytensor.compile.function.types.Function | pymc.pytensorf.PointFunc
+        A copy of the input function with the shared random generator states set to
+        spawns of the supplied ``rng``. If the function has no shared random generators
+        in it, the input ``fn`` is returned without any changes.
+        If ``fn`` is a :clas:`~pymc.pytensorf.PointFunc` instance, and the inner
+        pytensor function has random variables, then the inner pytensor function is
+        copied, setting new random generators, and a new ``PointFunc`` instance is
+        returned.
+    """
+    # Copy the function and replace any shared RNGs
+    # This is needed so that it can work correctly with multiple traces
+    # This will be costly if set_rng is called too often!
+    rng_gen = rng if isinstance(rng, np.random.Generator) else random_generator_from_state(rng)
+    fn_ = fn.f if isinstance(fn, PointFunc) else fn
+    shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)]
+    n_shared_rngs = len(shared_rngs)
+    if n_shared_rngs > 0 and isinstance(fn_.maker.linker, JAXLinker):
+        # Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
+        # used internally are not the ones that `function.get_shared()` returns.
+        warnings.warn(
+            "At the moment, it is not possible to set the random generator's key for "
+            "JAX linked functions. This means that the draws yielded by the random "
+            "variables that are requested by 'Deterministic' will not be reproducible."
+        )
+        return fn
+    swap = {
+        old_shared_rng: shared(rng, borrow=True)
+        for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True)
+    }
+    if isinstance(fn, PointFunc):
+        return PointFunc(fn.f.copy(swap=swap)) if n_shared_rngs > 0 else fn
+    else:
+        return fn.copy(swap=swap) if n_shared_rngs > 0 else fn
diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py
index b2d643a5f1..8d7972832d 100644
--- a/pymc/sampling/mcmc.py
+++ b/pymc/sampling/mcmc.py
@@ -26,6 +26,7 @@
     Any,
     Literal,
     TypeAlias,
+    cast,
     overload,
 )
 
@@ -40,6 +41,7 @@
 from rich.theme import Theme
 from threadpoolctl import threadpool_limits
 from typing_extensions import Protocol
+from zarr.storage import MemoryStore
 
 import pymc as pm
 
@@ -50,6 +52,7 @@
     find_observations,
 )
 from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
+from pymc.backends.zarr import ZarrChain, ZarrTrace
 from pymc.blocking import DictToArrayBijection
 from pymc.exceptions import SamplingError
 from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
@@ -503,7 +506,7 @@ def sample(
     model: Model | None = None,
     compile_kwargs: dict | None = None,
     **kwargs,
-) -> InferenceData | MultiTrace:
+) -> InferenceData | MultiTrace | ZarrTrace:
     r"""Draw samples from the posterior using the given step methods.
 
     Multiple step methods are supported via compound step methods.
@@ -570,7 +573,13 @@ def sample(
         Number of iterations of initializer. Only works for 'ADVI' init methods.
     trace : backend, optional
         A backend instance or None.
-        If None, the NDArray backend is used.
+        If ``None``, a ``MultiTrace`` object with underlying ``NDArray`` trace objects
+        is used. If ``trace`` is a :class:`~pymc.backends.zarr.ZarrTrace` instance,
+        the drawn samples will be written onto the desired storage while sampling is
+        on-going. This means sampling runs that, for whatever reason, die in the middle
+        of their execution will write the partial results onto the storage. If the
+        storage persist on disk, these results should be available even after a server
+        crash. See :class:`~pymc.backends.zarr.ZarrTrace` for more information.
     discard_tuned_samples : bool
         Whether to discard posterior samples of the tune interval.
     compute_convergence_checks : bool, default=True
@@ -607,8 +616,12 @@ def sample(
 
     Returns
     -------
-    trace : pymc.backends.base.MultiTrace or arviz.InferenceData
-        A ``MultiTrace`` or ArviZ ``InferenceData`` object that contains the samples.
+    trace : pymc.backends.base.MultiTrace | pymc.backends.zarr.ZarrTrace | arviz.InferenceData
+        A ``MultiTrace``, :class:`~arviz.InferenceData` or
+        :class:`~pymc.backends.zarr.ZarrTrace` object that contains the samples. A
+        ``ZarrTrace`` is only returned if the supplied ``trace`` argument is a
+        ``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for
+        the benefits this backend provides.
 
     Notes
     -----
@@ -741,7 +754,7 @@ def joined_blas_limiter():
     rngs = get_random_generator(random_seed).spawn(chains)
     random_seed_list = [rng.integers(2**30) for rng in rngs]
 
-    if not discard_tuned_samples and not return_inferencedata:
+    if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace):
         warnings.warn(
             "Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
             " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n"
@@ -852,6 +865,8 @@ def joined_blas_limiter():
         trace_vars=trace_vars,
         initial_point=initial_points[0],
         model=model,
+        tune=tune,
+        rng=rngs[0].spawn(1)[0],
     )
 
     sample_args = {
@@ -934,7 +949,7 @@ def joined_blas_limiter():
     # into a function to make it easier to test and refactor.
     return _sample_return(
         run=run,
-        traces=traces,
+        traces=trace if isinstance(trace, ZarrTrace) else traces,
         tune=tune,
         t_sampling=t_sampling,
         discard_tuned_samples=discard_tuned_samples,
@@ -949,7 +964,7 @@ def joined_blas_limiter():
 def _sample_return(
     *,
     run: RunType | None,
-    traces: Sequence[IBaseTrace],
+    traces: Sequence[IBaseTrace] | ZarrTrace,
     tune: int,
     t_sampling: float,
     discard_tuned_samples: bool,
@@ -958,18 +973,69 @@ def _sample_return(
     keep_warning_stat: bool,
     idata_kwargs: dict[str, Any],
     model: Model,
-) -> InferenceData | MultiTrace:
+) -> InferenceData | MultiTrace | ZarrTrace:
     """Pick/slice chains, run diagnostics and convert to the desired return type.
 
     Final step of `pm.sampler`.
     """
+    if isinstance(traces, ZarrTrace):
+        # Split warmup from posterior samples
+        traces.split_warmup_groups()
+
+        # Set sampling time
+        traces.sampling_time = t_sampling
+
+        # Compute number of actual draws per chain
+        total_draws_per_chain = traces._sampling_state.draw_idx[:]
+        n_chains = len(traces.straces)
+        desired_tune = traces.tuning_steps
+        desired_draw = len(traces.posterior.draw)
+        tuning_steps_per_chain = np.clip(total_draws_per_chain, 0, desired_tune)
+        draws_per_chain = total_draws_per_chain - tuning_steps_per_chain
+
+        total_n_tune = tuning_steps_per_chain.sum()
+        total_draws = draws_per_chain.sum()
+
+        _log.info(
+            f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations '
+            f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) "
+            f"took {t_sampling:.0f} seconds."
+        )
+
+        if compute_convergence_checks or return_inferencedata:
+            idata = traces.to_inferencedata(save_warmup=not discard_tuned_samples)
+            log_likelihood = idata_kwargs.pop("log_likelihood", False)
+            if log_likelihood:
+                from pymc.stats.log_density import compute_log_likelihood
+
+                idata = compute_log_likelihood(
+                    idata,
+                    var_names=None if log_likelihood is True else log_likelihood,
+                    extend_inferencedata=True,
+                    model=model,
+                    sample_dims=["chain", "draw"],
+                    progressbar=False,
+                )
+            if compute_convergence_checks:
+                warns = run_convergence_checks(idata, model)
+                for warn in warns:
+                    traces._sampling_state.global_warnings.append(np.array([warn]))
+                log_warnings(warns)
+
+            if return_inferencedata:
+                # By default we drop the "warning" stat which contains `SamplerWarning`
+                # objects that can not be stored with `.to_netcdf()`.
+                if not keep_warning_stat:
+                    return drop_warning_stat(idata)
+                return idata
+        return traces
+
     # Pick and slice chains to keep the maximum number of samples
     if discard_tuned_samples:
         traces, length = _choose_chains(traces, tune)
     else:
         traces, length = _choose_chains(traces, 0)
     mtrace = MultiTrace(traces)[:length]
-
     # count the number of tune/draw iterations that happened
     # ideally via the "tune" statistic, but not all samplers record it!
     if "tune" in mtrace.stat_names:
@@ -1212,6 +1278,8 @@ def _iter_sample(
     step.set_rng(rng)
 
     point = start
+    if isinstance(trace, ZarrChain):
+        trace.link_stepper(step)
 
     try:
         step.tune = bool(tune)
@@ -1233,13 +1301,14 @@ def _iter_sample(
                 )
 
             yield diverging
-    except KeyboardInterrupt:
-        trace.close()
-        raise
-    except BaseException:
+    except (KeyboardInterrupt, BaseException):
+        if isinstance(trace, ZarrChain):
+            trace.record_sampling_state(step=step)
         trace.close()
         raise
     else:
+        if isinstance(trace, ZarrChain):
+            trace.record_sampling_state(step=step)
         trace.close()
 
 
@@ -1298,6 +1367,19 @@ def _mp_sample(
 
     # We did draws += tune in pm.sample
     draws -= tune
+    zarr_chains: list[ZarrChain] | None = None
+    zarr_recording = False
+    if all(isinstance(trace, ZarrChain) for trace in traces):
+        if isinstance(cast(ZarrChain, traces[0])._posterior.store, MemoryStore):
+            warnings.warn(
+                "Parallel sampling with MemoryStore zarr store wont write the processes "
+                "step method sampling state. If you wish to be able to access the step "
+                "method sampling state, please use a different storage backend, e.g. "
+                "DirectoryStore or ZipStore"
+            )
+        else:
+            zarr_chains = cast(list[ZarrChain], traces)
+            zarr_recording = True
 
     sampler = ps.ParallelSampler(
         draws=draws,
@@ -1311,13 +1393,16 @@ def _mp_sample(
         progressbar_theme=progressbar_theme,
         blas_cores=blas_cores,
         mp_ctx=mp_ctx,
+        zarr_chains=zarr_chains,
     )
     try:
         try:
             with sampler:
                 for draw in sampler:
                     strace = traces[draw.chain]
-                    strace.record(draw.point, draw.stats)
+                    if not zarr_recording:
+                        # Zarr recording happens in each process
+                        strace.record(draw.point, draw.stats)
                     log_warning_stats(draw.stats)
 
                     if callback is not None:
diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py
index 67417e0d8f..28e74d5e8a 100644
--- a/pymc/sampling/parallel.py
+++ b/pymc/sampling/parallel.py
@@ -22,6 +22,7 @@
 
 from collections import namedtuple
 from collections.abc import Sequence
+from typing import cast
 
 import cloudpickle
 import numpy as np
@@ -31,6 +32,7 @@
 from rich.theme import Theme
 from threadpoolctl import threadpool_limits
 
+from pymc.backends.zarr import ZarrChain
 from pymc.blocking import DictToArrayBijection
 from pymc.exceptions import SamplingError
 from pymc.util import (
@@ -104,13 +106,27 @@ def __init__(
         tune: int,
         rng_state: RandomGeneratorState,
         blas_cores,
+        chain: int,
+        zarr_chains: list[ZarrChain] | bytes | None = None,
+        zarr_chains_is_pickled: bool = False,
     ):
-        # For some strange reason, spawn multiprocessing doesn't copy the rng
-        # seed sequence, so we have to rebuild it from scratch
+        # Because of https://github.com/numpy/numpy/issues/27727, we can't send
+        # the rng instance to the child process because pickling (copying) looses
+        # the seed sequence state information. For this reason, we send a
+        # RandomGeneratorState instead.
         rng = random_generator_from_state(rng_state)
         self._msg_pipe = msg_pipe
         self._step_method = step_method
         self._step_method_is_pickled = step_method_is_pickled
+        self.chain = chain
+        self._zarr_recording = False
+        self._zarr_chain: ZarrChain | None = None
+        if zarr_chains_is_pickled:
+            self._zarr_chain = cloudpickle.loads(zarr_chains)[self.chain]
+        elif zarr_chains is not None:
+            self._zarr_chain = cast(list[ZarrChain], zarr_chains)[self.chain]
+        self._zarr_recording = self._zarr_chain is not None
+
         self._shared_point = shared_point
         self._rng = rng
         self._draws = draws
@@ -135,6 +151,7 @@ def run(self):
                 # We do not create this in __init__, as pickling this
                 # would destroy the shared memory.
                 self._unpickle_step_method()
+                self._link_step_to_zarrchain()
                 self._point = self._make_numpy_refs()
                 self._start_loop()
             except KeyboardInterrupt:
@@ -148,6 +165,10 @@ def run(self):
             finally:
                 self._msg_pipe.close()
 
+    def _link_step_to_zarrchain(self):
+        if self._zarr_recording:
+            self._zarr_chain.link_stepper(self._step_method)
+
     def _wait_for_abortion(self):
         while True:
             msg = self._recv_msg()
@@ -170,6 +191,7 @@ def _recv_msg(self):
         return self._msg_pipe.recv()
 
     def _start_loop(self):
+        zarr_recording = self._zarr_recording
         self._step_method.set_rng(self._rng)
 
         draw = 0
@@ -199,6 +221,8 @@ def _start_loop(self):
             if msg[0] == "abort":
                 raise KeyboardInterrupt()
             elif msg[0] == "write_next":
+                if zarr_recording:
+                    self._zarr_chain.record(point, stats)
                 self._write_point(point)
                 is_last = draw + 1 == self._draws + self._tune
                 self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats))
@@ -225,6 +249,8 @@ def __init__(
         start: dict[str, np.ndarray],
         blas_cores,
         mp_ctx,
+        zarr_chains: list[ZarrChain] | None = None,
+        zarr_chains_pickled: bytes | None = None,
     ):
         self.chain = chain
         process_name = f"worker_chain_{chain}"
@@ -247,6 +273,16 @@ def __init__(
         self._readable = True
         self._num_samples = 0
 
+        zarr_chains_send: list[ZarrChain] | bytes | None = None
+        if zarr_chains_pickled is not None:
+            zarr_chains_send = zarr_chains_pickled
+        elif zarr_chains is not None:
+            if mp_ctx.get_start_method() == "spawn":
+                raise ValueError(
+                    "please provide a pre-pickled zarr_chains when multiprocessing start method is 'spawn'"
+                )
+            zarr_chains_send = zarr_chains
+
         if step_method_pickled is not None:
             step_method_send = step_method_pickled
         else:
@@ -270,6 +306,9 @@ def __init__(
                 tune,
                 get_state_from_generator(rng),
                 blas_cores,
+                self.chain,
+                zarr_chains_send,
+                zarr_chains_pickled is not None,
             ),
         )
         self._process.start()
@@ -392,6 +431,7 @@ def __init__(
         progressbar_theme: Theme | None = default_progress_theme,
         blas_cores: int | None = None,
         mp_ctx=None,
+        zarr_chains: list[ZarrChain] | None = None,
     ):
         if any(len(arg) != chains for arg in [rngs, start_points]):
             raise ValueError(f"Number of rngs and start_points must be {chains}.")
@@ -412,8 +452,15 @@ def __init__(
             mp_ctx = multiprocessing.get_context(mp_ctx)
 
         step_method_pickled = None
+        zarr_chains_pickled = None
+        self.zarr_recording = False
+        if zarr_chains is not None:
+            assert all(isinstance(zarr_chain, ZarrChain) for zarr_chain in zarr_chains)
+            self.zarr_recording = True
         if mp_ctx.get_start_method() != "fork":
             step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
+            if zarr_chains is not None:
+                zarr_chains_pickled = cloudpickle.dumps(zarr_chains, protocol=-1)
 
         self._samplers = [
             ProcessAdapter(
@@ -426,6 +473,8 @@ def __init__(
                 start,
                 blas_cores,
                 mp_ctx,
+                zarr_chains=zarr_chains,
+                zarr_chains_pickled=zarr_chains_pickled,
             )
             for chain, rng, start in zip(range(chains), rngs, start_points)
         ]
diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py
index 4e5a229960..b8a7ba593a 100644
--- a/pymc/sampling/population.py
+++ b/pymc/sampling/population.py
@@ -27,6 +27,7 @@
 from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
 
 from pymc.backends.base import BaseTrace
+from pymc.backends.zarr import ZarrChain
 from pymc.initial_point import PointType
 from pymc.model import Model, modelcontext
 from pymc.stats.convergence import log_warning_stats
@@ -36,6 +37,7 @@
     PopulationArrayStepShared,
     StatsType,
 )
+from pymc.step_methods.compound import StepMethodState
 from pymc.step_methods.metropolis import DEMetropolis
 from pymc.util import CustomProgress
 
@@ -81,6 +83,11 @@ def _sample_population(
         Show progress bars? (defaults to True)
     parallelize : bool
         Setting for multiprocess parallelization
+    traces : Sequence[BaseTrace]
+        A sequences of chain traces where the sampling results will be stored. Can be
+        a sequence of :py:class:`~pymc.backends.ndarray.NDArray`,
+        :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or
+        :py:class:`~pymc.backends.zarr.ZarrChain`.
     """
     warn_population_size(
         step=step,
@@ -263,6 +270,9 @@ def _run_secondary(c, stepper_dumps, secondary_end, task, progress):
                 # receiving a None is the signal to exit
                 if incoming is None:
                     break
+                elif incoming == "sampling_state":
+                    secondary_end.send((c, stepper.sampling_state))
+                    continue
                 tune_stop, population = incoming
                 if tune_stop:
                     stepper.stop_tuning()
@@ -307,6 +317,14 @@ def step(self, tune_stop: bool, population) -> list[tuple[PointType, StatsType]]
                 updates.append(self._steppers[c].step(population[c]))
         return updates
 
+    def request_sampling_state(self, chain) -> StepMethodState:
+        if self.is_parallelized:
+            self._primary_ends[chain].send(("sampling_state",))
+            _, sampling_state = self._primary_ends[chain].recv()
+        else:
+            sampling_state = self._steppers[chain].sampling_state
+        return sampling_state
+
 
 def _prepare_iter_population(
     *,
@@ -332,6 +350,11 @@ def _prepare_iter_population(
         Start points for each chain
     parallelize : bool
         Setting for multiprocess parallelization
+    traces : Sequence[BaseTrace]
+        A sequences of chain traces where the sampling results will be stored. Can be
+        a sequence of :py:class:`~pymc.backends.ndarray.NDArray`,
+        :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or
+        :py:class:`~pymc.backends.zarr.ZarrChain`.
     tune : int
         Number of iterations to tune.
     rngs: sequence of random Generators
@@ -411,8 +434,11 @@ def _iter_population(
         the helper object for (parallelized) stepping of chains
     steppers : list
         The step methods for each chain
-    traces : list
-        Traces for each chain
+    traces : Sequence[BaseTrace]
+        A sequences of chain traces where the sampling results will be stored. Can be
+        a sequence of :py:class:`~pymc.backends.ndarray.NDArray`,
+        :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or
+        :py:class:`~pymc.backends.zarr.ZarrChain`.
     points : list
         population of chain states
 
@@ -432,8 +458,11 @@ def _iter_population(
                 # apply the update to the points and record to the traces
                 for c, strace in enumerate(traces):
                     points[c], stats = updates[c]
-                    strace.record(points[c], stats)
+                    flushed = strace.record(points[c], stats)
                     log_warning_stats(stats)
+                    if flushed and isinstance(strace, ZarrChain):
+                        sampling_state = popstep.request_sampling_state(c)
+                        strace.store_sampling_state(sampling_state)
                 # yield the state of all chains in parallel
                 yield i
     except KeyboardInterrupt:
diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py
index d0393afd57..1fcb3d2673 100644
--- a/pymc/step_methods/compound.py
+++ b/pymc/step_methods/compound.py
@@ -22,6 +22,7 @@
 
 from abc import ABC, abstractmethod
 from collections.abc import Iterable, Mapping, Sequence
+from dataclasses import field
 from enum import IntEnum, unique
 from typing import Any
 
@@ -96,6 +97,7 @@ def infer_warn_stats_info(
 
 @dataclass_state
 class StepMethodState(DataClassState):
+    var_names: list[str] = field(metadata={"tensor_name": True, "frozen": True})
     rng: RandomGeneratorState
 
 
diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py
index e24276cf14..ec7bbbae48 100644
--- a/pymc/step_methods/state.py
+++ b/pymc/step_methods/state.py
@@ -12,7 +12,7 @@
 #   See the License for the specific language governing permissions and
 #   limitations under the License.
 from copy import deepcopy
-from dataclasses import Field, dataclass, fields
+from dataclasses import MISSING, Field, dataclass, fields
 from typing import Any, ClassVar
 
 import numpy as np
@@ -67,7 +67,16 @@ def sampling_state(self) -> DataClassState:
         state_class = self._state_class
         kwargs = {}
         for field in fields(state_class):
-            val = getattr(self, field.name)
+            is_tensor_name = field.metadata.get("tensor_name", False)
+            val: Any
+            if is_tensor_name:
+                val = [var.name for var in getattr(self, "vars")]
+            else:
+                val = getattr(self, field.name, field.default)
+            if val is MISSING:
+                raise AttributeError(
+                    f"{type(self).__name__!r} object has no attribute {field.name!r}"
+                )
             _val: Any
             if isinstance(val, WithSamplingState):
                 _val = val.sampling_state
@@ -85,11 +94,17 @@ def sampling_state(self, state: DataClassState):
             state, state_class
         ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
         for field in fields(state_class):
+            is_tensor_name = field.metadata.get("tensor_name", False)
             state_val = deepcopy(getattr(state, field.name))
             if isinstance(state_val, RandomGeneratorState):
                 state_val = random_generator_from_state(state_val)
-            self_val = getattr(self, field.name)
             is_frozen = field.metadata.get("frozen", False)
+            self_val: Any
+            if is_tensor_name:
+                self_val = [var.name for var in getattr(self, "vars")]
+                assert is_frozen
+            else:
+                self_val = getattr(self, field.name, field.default)
             if is_frozen:
                 if not equal_dataclass_values(state_val, self_val):
                     raise ValueError(
diff --git a/pymc/util.py b/pymc/util.py
index 8a059d7e0d..63576676eb 100644
--- a/pymc/util.py
+++ b/pymc/util.py
@@ -13,6 +13,7 @@
 #   limitations under the License.
 
 import functools
+import re
 import warnings
 
 from collections import namedtuple
@@ -276,7 +277,12 @@ def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData:
     nidata = arviz.InferenceData(attrs=idata.attrs)
     for gname, group in idata.items():
         if "sample_stat" in gname:
-            group = group.drop_vars(names=["warning", "warning_dim_0"], errors="ignore")
+            warning_vars = [
+                name
+                for name in group.data_vars
+                if name == "warning" or re.match(r"sampler_\d+__warning", str(name))
+            ]
+            group = group.drop_vars(names=[*warning_vars, "warning_dim_0"], errors="ignore")
         nidata.add_groups({gname: group}, coords=group.coords, dims=group.dims)
     return nidata
 
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 56f7f964fc..e7e3644aae 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -32,3 +32,4 @@ threadpoolctl>=3.1.0
 types-cachetools
 typing-extensions>=3.7.4
 watermark
+zarr>=2.5.0,<3
diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py
new file mode 100644
index 0000000000..32f508ef1a
--- /dev/null
+++ b/tests/backends/test_zarr.py
@@ -0,0 +1,538 @@
+#   Copyright 2024 The PyMC Developers
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+import itertools
+
+from dataclasses import asdict
+
+import numpy as np
+import pytest
+import xarray as xr
+import zarr
+
+from arviz import InferenceData
+
+import pymc as pm
+
+from pymc.backends.zarr import ZarrTrace
+from pymc.stats.convergence import SamplerWarning
+from pymc.step_methods import NUTS, CompoundStep, Metropolis
+from pymc.step_methods.state import equal_dataclass_values
+from tests.helpers import equal_sampling_states
+
+
+@pytest.fixture(scope="module")
+def model():
+    time_int = np.array([np.timedelta64(np.timedelta64(i, "h"), "ns") for i in range(25)])
+    coords = {
+        "dim_int": range(3),
+        "dim_str": ["A", "B"],
+        "dim_time": np.datetime64("2024-10-16") + time_int,
+        "dim_interval": time_int,
+    }
+    rng = np.random.default_rng(42)
+    with pm.Model(coords=coords) as model:
+        data1 = pm.Data("data1", np.ones(3, dtype="bool"), dims=["dim_int"])
+        data2 = pm.Data("data2", np.ones(3, dtype="bool"))
+        time = pm.Data("time", time_int / np.timedelta64(1, "h"), dims="dim_time")
+
+        a = pm.Normal("a", shape=(len(coords["dim_int"]), len(coords["dim_str"])))
+        b = pm.Normal("b", dims=["dim_int", "dim_str"])
+        c = pm.Deterministic("c", a + b, dims=["dim_int", "dim_str"])
+
+        d = pm.LogNormal("d", dims="dim_time")
+        e = pm.Deterministic("e", (time + d)[:, None] + c[0], dims=["dim_interval", "dim_str"])
+
+        obs = pm.Normal(
+            "obs",
+            mu=e,
+            observed=rng.normal(size=(len(coords["dim_time"]), len(coords["dim_str"]))),
+            dims=["dim_time", "dim_str"],
+        )
+
+    return model
+
+
+@pytest.fixture(params=["include_transformed", "discard_transformed"])
+def include_transformed(request):
+    return request.param == "include_transformed"
+
+
+@pytest.fixture(params=["frequent_writes", "sparse_writes"])
+def draws_per_chunk(request):
+    spec = {
+        "frequent_writes": 1,
+        "sparse_writes": 7,
+    }
+    return spec[request.param]
+
+
+@pytest.fixture(params=["single_step", "compound_step"])
+def model_step(request, model):
+    rng = np.random.default_rng(42)
+    with model:
+        if request.param == "single_step":
+            step = NUTS(rng=rng)
+        else:
+            rngs = rng.spawn(2)
+            step = CompoundStep(
+                [
+                    Metropolis(vars=model["a"], rng=rngs[0]),
+                    NUTS(vars=[rv for rv in model.value_vars if rv.name != "a"], rng=rngs[1]),
+                ]
+            )
+    return step
+
+
+def test_record(model, model_step, include_transformed, draws_per_chunk):
+    store = zarr.TempStore()
+    trace = ZarrTrace(
+        store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
+    )
+    draws = 5
+    tune = 5
+    trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step)
+
+    # Assert that init was successful
+    expected_groups = {
+        "_sampling_state",
+        "sample_stats",
+        "posterior",
+        "constant_data",
+        "observed_data",
+    }
+    if include_transformed:
+        expected_groups.add("unconstrained_posterior")
+    assert {group_name for group_name, _ in trace.root.groups()} == expected_groups
+
+    # Record samples from the ZarrChain
+    manually_collected_warmup_draws = []
+    manually_collected_warmup_stats = []
+    manually_collected_draws = []
+    manually_collected_stats = []
+    point = model.initial_point()
+    for draw in range(tune + draws):
+        tuning = draw < tune
+        if not tuning:
+            model_step.stop_tuning()
+        point, stats = model_step.step(point)
+        if tuning:
+            manually_collected_warmup_draws.append(point)
+            manually_collected_warmup_stats.append(stats)
+        else:
+            manually_collected_draws.append(point)
+            manually_collected_stats.append(stats)
+        trace.straces[0].record(point, stats)
+    trace.straces[0].record_sampling_state(model_step)
+    assert {group_name for group_name, _ in trace.root.groups()} == expected_groups
+
+    # Assert split warmup
+    trace.split_warmup("posterior")
+    trace.split_warmup("sample_stats")
+    expected_groups = {
+        "_sampling_state",
+        "sample_stats",
+        "posterior",
+        "warmup_sample_stats",
+        "warmup_posterior",
+        "constant_data",
+        "observed_data",
+    }
+    if include_transformed:
+        trace.split_warmup("unconstrained_posterior")
+        expected_groups.add("unconstrained_posterior")
+        expected_groups.add("warmup_unconstrained_posterior")
+    assert {group_name for group_name, _ in trace.root.groups()} == expected_groups
+    # trace.consolidate()
+
+    # Assert observed data is correct
+    assert set(dict(trace.observed_data.arrays())) == {"obs", "dim_time", "dim_str"}
+    assert list(trace.observed_data.obs.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time", "dim_str"]
+    np.testing.assert_array_equal(trace.observed_data.dim_time[:], model.coords["dim_time"])
+    np.testing.assert_array_equal(trace.observed_data.dim_str[:], model.coords["dim_str"])
+
+    # Assert constant data is correct
+    assert set(dict(trace.constant_data.arrays())) == {
+        "data1",
+        "data2",
+        "data2_dim_0",
+        "time",
+        "dim_time",
+        "dim_int",
+    }
+    assert list(trace.constant_data.data1.attrs["_ARRAY_DIMENSIONS"]) == ["dim_int"]
+    assert list(trace.constant_data.data2.attrs["_ARRAY_DIMENSIONS"]) == ["data2_dim_0"]
+    assert list(trace.constant_data.time.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time"]
+    np.testing.assert_array_equal(trace.constant_data.dim_time[:], model.coords["dim_time"])
+    np.testing.assert_array_equal(trace.constant_data.dim_int[:], model.coords["dim_int"])
+
+    # Assert unconstrained posterior has correct shapes and kinds
+    assert {rv.name for rv in model.free_RVs + model.deterministics} <= set(
+        dict(trace.posterior.arrays())
+    )
+    if include_transformed:
+        assert {"d_log__", "chain", "draw", "d_log___dim_0"} == set(
+            dict(trace.unconstrained_posterior.arrays())
+        )
+        assert list(trace.unconstrained_posterior.d_log__.attrs["_ARRAY_DIMENSIONS"]) == [
+            "chain",
+            "draw",
+            "d_log___dim_0",
+        ]
+        assert trace.unconstrained_posterior.d_log__.attrs["kind"] == "freeRV"
+        np.testing.assert_array_equal(trace.unconstrained_posterior.chain, np.arange(1))
+        np.testing.assert_array_equal(trace.unconstrained_posterior.draw, np.arange(draws))
+        np.testing.assert_array_equal(
+            trace.unconstrained_posterior.d_log___dim_0, np.arange(len(model.coords["dim_time"]))
+        )
+
+    # Assert posterior has correct shapes and kinds
+    posterior_dims = set()
+    for kind, rv_name in [
+        (kind, rv.name)
+        for kind, rv in itertools.chain(
+            itertools.zip_longest([], model.free_RVs, fillvalue="freeRV"),
+            itertools.zip_longest([], model.deterministics, fillvalue="deterministic"),
+        )
+    ]:
+        if rv_name == "a":
+            expected_dims = ["a_dim_0", "a_dim_1"]
+        else:
+            expected_dims = model.named_vars_to_dims[rv_name]
+        posterior_dims |= set(expected_dims)
+        assert list(trace.posterior[rv_name].attrs["_ARRAY_DIMENSIONS"]) == [
+            "chain",
+            "draw",
+            *expected_dims,
+        ]
+        assert trace.posterior[rv_name].attrs["kind"] == kind
+    for posterior_dim in posterior_dims:
+        try:
+            model_coord = model.coords[posterior_dim]
+        except KeyError:
+            model_coord = {
+                "a_dim_0": np.arange(len(model.coords["dim_int"])),
+                "a_dim_1": np.arange(len(model.coords["dim_str"])),
+                "chain": np.arange(1),
+                "draw": np.arange(draws),
+            }[posterior_dim]
+        np.testing.assert_array_equal(trace.posterior[posterior_dim][:], model_coord)
+
+    # Assert sample stats have correct shape
+    stats_bijection = trace.straces[0].stats_bijection
+    for draw_idx, (draw, stat) in enumerate(
+        zip(manually_collected_draws, manually_collected_stats)
+    ):
+        stat = stats_bijection.map(stat)
+        for var, value in draw.items():
+            if var in trace.posterior.arrays():
+                assert np.array_equal(trace.posterior[var][0, draw_idx], value)
+        for var, value in stat.items():
+            sample_stats = trace.root["sample_stats"]
+            stat_val = sample_stats[var][0, draw_idx]
+            if not isinstance(stat_val, SamplerWarning):
+                unequal_stats = stat_val != value
+            else:
+                unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value))
+            if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)):
+                raise AssertionError(f"{var} value does not match: {stat_val} != {value}")
+
+    # Assert manually collected warmup samples match
+    for draw_idx, (draw, stat) in enumerate(
+        zip(manually_collected_warmup_draws, manually_collected_warmup_stats)
+    ):
+        stat = stats_bijection.map(stat)
+        for var, value in draw.items():
+            if var == "d_log__":
+                if not include_transformed:
+                    continue
+                posterior = trace.root["warmup_unconstrained_posterior"]
+            else:
+                posterior = trace.root["warmup_posterior"]
+            if var in posterior.arrays():
+                assert np.array_equal(posterior[var][0, draw_idx], value)
+        for var, value in stat.items():
+            sample_stats = trace.root["warmup_sample_stats"]
+            stat_val = sample_stats[var][0, draw_idx]
+            if not isinstance(stat_val, SamplerWarning):
+                unequal_stats = stat_val != value
+            else:
+                unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value))
+            if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)):
+                raise AssertionError(f"{var} value does not match: {stat_val} != {value}")
+
+    # Assert manually collected posterior samples match
+    for draw_idx, (draw, stat) in enumerate(
+        zip(manually_collected_draws, manually_collected_stats)
+    ):
+        stat = stats_bijection.map(stat)
+        for var, value in draw.items():
+            if var == "d_log__":
+                if not include_transformed:
+                    continue
+                posterior = trace.root["unconstrained_posterior"]
+            else:
+                posterior = trace.root["posterior"]
+            if var in posterior.arrays():
+                assert np.array_equal(posterior[var][0, draw_idx], value)
+        for var, value in stat.items():
+            sample_stats = trace.root["sample_stats"]
+            stat_val = sample_stats[var][0, draw_idx]
+            if not isinstance(stat_val, SamplerWarning):
+                unequal_stats = stat_val != value
+            else:
+                unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value))
+            if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)):
+                raise AssertionError(f"{var} value does not match: {stat_val} != {value}")
+
+    # Assert sampling_state is correct
+    assert list(trace._sampling_state.draw_idx[:]) == [draws + tune]
+    assert equal_sampling_states(
+        trace._sampling_state.sampling_state[0],
+        model_step.sampling_state,
+    )
+
+    # Assert to inference data returns the expected groups
+    idata = trace.to_inferencedata(save_warmup=True)
+    expected_groups = {
+        "posterior",
+        "constant_data",
+        "observed_data",
+        "sample_stats",
+        "warmup_posterior",
+        "warmup_sample_stats",
+    }
+    if include_transformed:
+        expected_groups.add("unconstrained_posterior")
+        expected_groups.add("warmup_unconstrained_posterior")
+    assert set(idata.groups()) == expected_groups
+    for group in idata.groups():
+        for name, value in itertools.chain(
+            idata[group].data_vars.items(), idata[group].coords.items()
+        ):
+            try:
+                array = getattr(trace, group)[name][:]
+            except AttributeError:
+                array = trace.root[group][name][:]
+            if "sample_stats" in group and "warning" in name:
+                continue
+            np.testing.assert_array_equal(array, value)
+
+
+@pytest.mark.parametrize("tune", [0, 5, 10])
+def test_split_warmup(tune, model, model_step, include_transformed):
+    store = zarr.MemoryStore()
+    trace = ZarrTrace(store=store, include_transformed=include_transformed)
+    draws = 10 - tune
+    trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step)
+
+    trace.split_warmup("posterior")
+    trace.split_warmup("sample_stats")
+    assert len(trace.root.posterior.draw) == draws
+    assert len(trace.root.sample_stats.draw) == draws
+    if tune == 0:
+        with pytest.raises(KeyError):
+            trace.root["warmup_posterior"]
+    else:
+        assert len(trace.root["warmup_posterior"].draw) == tune
+        assert len(trace.root["warmup_sample_stats"].draw) == tune
+
+        with pytest.raises(RuntimeError):
+            trace.split_warmup("posterior")
+
+        for var_name, posterior_array in trace.posterior.arrays():
+            dims = posterior_array.attrs["_ARRAY_DIMENSIONS"]
+            if len(dims) >= 2 and dims[1] == "draw":
+                assert posterior_array.shape[1] == draws
+                assert trace.root["warmup_posterior"][var_name].shape[1] == tune
+        for var_name, sample_stats_array in trace.sample_stats.arrays():
+            dims = sample_stats_array.attrs["_ARRAY_DIMENSIONS"]
+            if len(dims) >= 2 and dims[1] == "draw":
+                assert sample_stats_array.shape[1] == draws
+                assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune
+
+
+@pytest.fixture(scope="function", params=["discard_tuning", "keep_tuning"])
+def discard_tuned_samples(request):
+    return request.param == "discard_tuning"
+
+
+@pytest.fixture(scope="function", params=["return_idata", "return_zarr"])
+def return_inferencedata(request):
+    return request.param == "return_idata"
+
+
+@pytest.fixture(
+    scope="function", params=[True, False], ids=["keep_warning_stat", "discard_warning_stat"]
+)
+def keep_warning_stat(request):
+    return request.param
+
+
+@pytest.fixture(
+    scope="function", params=[True, False], ids=["parallel_sampling", "sequential_sampling"]
+)
+def parallel(request):
+    return request.param
+
+
+@pytest.fixture(scope="function", params=[True, False], ids=["compute_loglike", "no_loglike"])
+def log_likelihood(request):
+    return request.param
+
+
+def test_sample(
+    model,
+    model_step,
+    include_transformed,
+    discard_tuned_samples,
+    return_inferencedata,
+    keep_warning_stat,
+    parallel,
+    log_likelihood,
+    draws_per_chunk,
+):
+    if not return_inferencedata and not log_likelihood:
+        pytest.skip(
+            reason="log_likelihood is only computed if an inference data object is returned"
+        )
+    store = zarr.TempStore()
+    trace = ZarrTrace(
+        store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
+    )
+    tune = 2
+    draws = 3
+    if parallel:
+        chains = 2
+        cores = 2
+    else:
+        chains = 1
+        cores = 1
+    with model:
+        out_trace = pm.sample(
+            draws=draws,
+            tune=tune,
+            chains=chains,
+            cores=cores,
+            trace=trace,
+            step=model_step,
+            discard_tuned_samples=discard_tuned_samples,
+            return_inferencedata=return_inferencedata,
+            keep_warning_stat=keep_warning_stat,
+            idata_kwargs={"log_likelihood": log_likelihood},
+        )
+
+    if not return_inferencedata:
+        assert isinstance(out_trace, ZarrTrace)
+        assert out_trace.root.store is trace.root.store
+    else:
+        assert isinstance(out_trace, InferenceData)
+
+    expected_groups = {"posterior", "constant_data", "observed_data", "sample_stats"}
+    if include_transformed:
+        expected_groups |= {"unconstrained_posterior"}
+    if not return_inferencedata or not discard_tuned_samples:
+        expected_groups |= {"warmup_posterior", "warmup_sample_stats"}
+        if include_transformed:
+            expected_groups |= {"warmup_unconstrained_posterior"}
+    if not return_inferencedata:
+        expected_groups |= {"_sampling_state"}
+    elif log_likelihood:
+        expected_groups |= {"log_likelihood"}
+    assert set(out_trace.groups()) == expected_groups
+
+    if return_inferencedata:
+        warning_stat = (
+            "sampler_1__warning" if isinstance(model_step, CompoundStep) else "sampler_0__warning"
+        )
+        if keep_warning_stat:
+            assert warning_stat in out_trace.sample_stats
+        else:
+            assert warning_stat not in out_trace.sample_stats
+
+    # Assert that all variables have non empty samples (not NaNs)
+    if return_inferencedata:
+        assert all(
+            (not np.any(np.isnan(v))) and v.shape[:2] == (chains, draws)
+            for v in out_trace.posterior.data_vars.values()
+        )
+    else:
+        dimensions = {*model.coords, "a_dim_0", "a_dim_1", "chain", "draw"}
+        assert all(
+            (not np.any(np.isnan(v[:]))) and v.shape[:2] == (chains, draws)
+            for name, v in out_trace.posterior.arrays()
+            if name not in dimensions
+        )
+
+    # Assert that the trace has valid sampling state stored for each chain
+    for step_method_state in trace._sampling_state.sampling_state[:]:
+        # We have no access to the actual step method that was using by each chain in pymc.sample
+        # The best way to see if the step method state is valid is by trying to set
+        # the model_step sampling state to the one stored in the trace.
+        model_step.sampling_state = step_method_state
+
+
+def test_sampling_consistency(
+    model,
+    model_step,
+    draws_per_chunk,
+):
+    # Test that pm.sample will generate the same posterior and sampling state
+    # regardless of whether sampling was done in parallel or not.
+    store1 = zarr.TempStore()
+    parallel_trace = ZarrTrace(
+        store=store1, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
+    )
+    store2 = zarr.TempStore()
+    sequential_trace = ZarrTrace(
+        store=store2, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
+    )
+    tune = 2
+    draws = 3
+    chains = 2
+    random_seed = 12345
+    initial_step_state = model_step.sampling_state
+    with model:
+        parallel_idata = pm.sample(
+            draws=draws,
+            tune=tune,
+            chains=chains,
+            cores=chains,
+            trace=parallel_trace,
+            step=model_step,
+            discard_tuned_samples=True,
+            return_inferencedata=True,
+            keep_warning_stat=False,
+            idata_kwargs={"log_likelihood": False},
+            random_seed=random_seed,
+        )
+        model_step.sampling_state = initial_step_state
+        sequential_idata = pm.sample(
+            draws=draws,
+            tune=tune,
+            chains=chains,
+            cores=1,
+            trace=sequential_trace,
+            step=model_step,
+            discard_tuned_samples=True,
+            return_inferencedata=True,
+            keep_warning_stat=False,
+            idata_kwargs={"log_likelihood": False},
+            random_seed=random_seed,
+        )
+    for chain in range(chains):
+        assert equal_sampling_states(
+            parallel_trace._sampling_state.sampling_state[chain],
+            sequential_trace._sampling_state.sampling_state[chain],
+        )
+    xr.testing.assert_equal(parallel_idata.posterior, sequential_idata.posterior)
diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py
index 41b068e042..2330c043be 100644
--- a/tests/sampling/test_mcmc.py
+++ b/tests/sampling/test_mcmc.py
@@ -909,3 +909,49 @@ def test_sample(self, seeded_test):
         np.testing.assert_allclose(
             x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
         )
+
+
+@pytest.fixture(scope="function", params=[None, "mcbackend", "zarr"])
+def trace_backend(request):
+    if request.param is None:
+        return None
+    elif request.param == "mcbackend":
+        try:
+            import mcbackend as mcb
+        except ImportError:
+            pytest.skip("Requires McBackend to be installed.")
+        return mcb.NumPyBackend()
+    elif request.param == "zarr":
+        try:
+            trace = pm.backends.zarr.ZarrTrace()
+        except RuntimeError:
+            pytest.skip("Requires zarr to be installed")
+        return trace
+
+
+@pytest.fixture(scope="function", params=["FAST_COMPILE", "NUMBA", "JAX"])
+def pytensor_mode(request):
+    return request.param
+
+
+def test_random_deterministics(trace_backend, pytensor_mode):
+    with pm.Model() as m:
+        x = pm.Bernoulli("x", p=0.5) * 0  # Force it to be zero
+        pm.Deterministic("y", x + pm.Normal.dist())
+
+        if pytensor_mode == "JAX":
+            expected_warning = (
+                "At the moment, it is not possible to set the random generator's key for "
+                "JAX linked functions. This means that the draws yielded by the random "
+                "variables that are requested by 'Deterministic' will not be reproducible."
+            )
+            with pytest.warns(UserWarning, match=expected_warning):
+                with pytensor.config.change_flags(mode=pytensor_mode):
+                    idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
+                    idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
+            assert not idata1.posterior.equals(idata2.posterior)
+        else:
+            with pytensor.config.change_flags(mode=pytensor_mode):
+                idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
+                idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
+            assert idata1.posterior.equals(idata2.posterior)