From 131df712fe8c04cc34c0d38763a8e8f5a2164120 Mon Sep 17 00:00:00 2001
From: Luciano Paz <luciano.paz.neuro@gmail.com>
Date: Wed, 25 Sep 2024 11:53:58 +0200
Subject: [PATCH 1/7] Fix dangling step in test_population

---
 tests/sampling/test_population.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/sampling/test_population.py b/tests/sampling/test_population.py
index 1f145dbcaf..4e3d91bcbb 100644
--- a/tests/sampling/test_population.py
+++ b/tests/sampling/test_population.py
@@ -65,7 +65,7 @@ def test_nonparallelized_chains_are_random(self):
                     cores=1,
                     draws=20,
                     tune=0,
-                    step=DEMetropolis(),
+                    step=step,
                     compute_convergence_checks=False,
                 )
                 samples = idata.posterior["x"].values[:, 5]
@@ -82,7 +82,7 @@ def test_parallelized_chains_are_random(self):
                     cores=4,
                     draws=20,
                     tune=0,
-                    step=DEMetropolis(),
+                    step=step,
                     compute_convergence_checks=False,
                 )
                 samples = idata.posterior["x"].values[:, 5]

From 1c54e45aec273857ad33f78d4528488b3cd8a3ea Mon Sep 17 00:00:00 2001
From: Luciano Paz <luciano.paz.neuro@gmail.com>
Date: Mon, 23 Sep 2024 11:48:03 +0200
Subject: [PATCH 2/7] Detach step methods from numpy global random state

---
 pymc/math.py                           |  4 +-
 pymc/sampling/mcmc.py                  | 71 +++++++++++++---------
 pymc/sampling/parallel.py              | 28 +++++----
 pymc/sampling/population.py            | 25 ++++----
 pymc/step_methods/arraystep.py         | 43 ++++++++++----
 pymc/step_methods/compound.py          | 10 +++-
 pymc/step_methods/hmc/base_hmc.py      | 20 +++++--
 pymc/step_methods/hmc/hmc.py           | 14 ++++-
 pymc/step_methods/hmc/nuts.py          | 18 ++++--
 pymc/step_methods/hmc/quadpotential.py | 80 +++++++++++++++++++------
 pymc/step_methods/metropolis.py        | 82 ++++++++++++++++----------
 pymc/step_methods/slicer.py            | 19 +++---
 pymc/util.py                           | 64 ++++++++++++++++++--
 tests/sampling/test_forward.py         | 22 +++----
 tests/sampling/test_parallel.py        |  4 +-
 tests/step_methods/hmc/test_nuts.py    |  9 ++-
 tests/step_methods/test_metropolis.py  | 27 +++++----
 tests/step_methods/test_slicer.py      |  4 ++
 18 files changed, 379 insertions(+), 165 deletions(-)

diff --git a/pymc/math.py b/pymc/math.py
index b85ffe63ce..b5fc50a8eb 100644
--- a/pymc/math.py
+++ b/pymc/math.py
@@ -292,10 +292,10 @@ def logdiffexp_numpy(a, b):
 invlogit = sigmoid
 
 
-def logbern(log_p):
+def logbern(log_p, rng=None):
     if np.isnan(log_p):
         raise FloatingPointError("log_p can't be nan.")
-    return np.log(np.random.uniform()) < log_p
+    return np.log((rng or np.random).uniform()) < log_p
 
 
 def logit(p):
diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py
index 32d2702ff2..02143ede05 100644
--- a/pymc/sampling/mcmc.py
+++ b/pymc/sampling/mcmc.py
@@ -71,6 +71,7 @@
     _get_seeds_per_chain,
     default_progress_theme,
     drop_warning_stat,
+    get_random_generator,
     get_untransformed_name,
     is_transformed_name,
 )
@@ -489,10 +490,15 @@ def sample(
     cores : int
         The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
         system, but at most 4.
-    random_seed : int, array-like of int, RandomState or Generator, optional
-        Random seed(s) used by the sampling steps. If a list, tuple or array of ints
-        is passed, each entry will be used to seed each chain. A ValueError will be
-        raised if the length does not match the number of chains.
+    random_seed : int, array-like of int, or Generator, optional
+        Random seed(s) used by the sampling steps. Each step will create its own
+        :py:class:`~numpy.random.Generator` object to make its random draws in a way that is
+        indepedent from all other steppers and all other chains. If a list, tuple or array of ints
+        is passed, each entry will be used to seed the creation of ``Generator`` objects.
+        A ``ValueError`` will be raised if the length does not match the number of chains.
+        A ``TypeError`` will be raised if a :py:class:`~numpy.random.RandomState` object is passed.
+        We no longer support ``RandomState`` objects because their seeding mechanism does not allow
+        easy spawning of new independent random streams that are needed by the step methods.
     progressbar : bool, optional default=True
         Whether or not to display a progress bar in the command line. The bar shows the percentage
         of completion, the sampling speed in samples per second (SPS), and the estimated remaining
@@ -684,7 +690,8 @@ def joined_blas_limiter():
 
     if random_seed == -1:
         random_seed = None
-    random_seed_list = _get_seeds_per_chain(random_seed, chains)
+    rngs = get_random_generator(random_seed).spawn(chains)
+    random_seed_list = [rng.integers(2**30) for rng in rngs]
 
     if not discard_tuned_samples and not return_inferencedata:
         warnings.warn(
@@ -832,11 +839,11 @@ def joined_blas_limiter():
     if parallel:
         # For parallel sampling we can pass the list of random seeds directly, as
         # global seeding will only be called inside each process
-        sample_args["random_seed"] = random_seed_list
+        sample_args["rngs"] = rngs
     else:
         # We pass None if the original random seed was None. The single core sampler
         # methods will only set a global seed when it is not None.
-        sample_args["random_seed"] = random_seed if random_seed is None else random_seed_list
+        sample_args["rngs"] = rngs
 
     t_start = time.time()
     if parallel:
@@ -987,7 +994,7 @@ def _sample_many(
     chains: int,
     traces: Sequence[IBaseTrace],
     start: Sequence[PointType],
-    random_seed: Sequence[RandomSeed] | None,
+    rngs: Sequence[np.random.Generator],
     step: Step,
     callback: SamplingIteratorCallback | None = None,
     **kwargs,
@@ -1002,8 +1009,8 @@ def _sample_many(
         Total number of chains to sample.
     start: list
         Starting points for each chain
-    random_seed: list of random seeds, optional
-        A list of seeds, one for each chain
+    rngs: list of random Generators
+        A list of :py:class:`~numpy.random.Generator` objects, one for each chain
     step: function
         Step function
     """
@@ -1014,7 +1021,7 @@ def _sample_many(
             start=start[i],
             step=step,
             trace=traces[i],
-            random_seed=None if random_seed is None else random_seed[i],
+            rng=rngs[i],
             callback=callback,
             **kwargs,
         )
@@ -1025,7 +1032,7 @@ def _sample(
     *,
     chain: int,
     progressbar: bool,
-    random_seed: RandomSeed,
+    rng: np.random.Generator,
     start: PointType,
     draws: int,
     step: Step,
@@ -1073,7 +1080,7 @@ def _sample(
         chain=chain,
         tune=tune,
         model=model,
-        random_seed=random_seed,
+        rng=rng,
         callback=callback,
     )
     _pbar_data = {"chain": chain, "divergences": 0}
@@ -1112,8 +1119,8 @@ def _iter_sample(
     trace: IBaseTrace,
     chain: int = 0,
     tune: int = 0,
+    rng: np.random.Generator,
     model: Model | None = None,
-    random_seed: RandomSeed = None,
     callback: SamplingIteratorCallback | None = None,
 ) -> Iterator[bool]:
     """Generator for sampling one chain. (Used in singleprocess sampling.)
@@ -1147,8 +1154,7 @@ def _iter_sample(
     if draws < 1:
         raise ValueError("Argument `draws` must be greater than 0.")
 
-    if random_seed is not None:
-        np.random.seed(random_seed)
+    step.set_rng(rng)
 
     point = start
 
@@ -1191,7 +1197,7 @@ def _mp_sample(
     step,
     chains: int,
     cores: int,
-    random_seed: Sequence[RandomSeed],
+    rngs: Sequence[np.random.Generator],
     start: Sequence[PointType],
     progressbar: bool = True,
     progressbar_theme: Theme | None = default_progress_theme,
@@ -1216,8 +1222,8 @@ def _mp_sample(
         The number of chains to sample.
     cores : int
         The number of chains to run in parallel.
-    random_seed : list of random seeds
-        Random seeds for each chain.
+    rngs: list of random Generators
+        A list of :py:class:`~numpy.random.Generator` objects, one for each chain
     start : list
         Starting points for each chain.
         Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
@@ -1245,7 +1251,7 @@ def _mp_sample(
         tune=tune,
         chains=chains,
         cores=cores,
-        seeds=random_seed,
+        rngs=rngs,
         start_points=start,
         step_method=step,
         progressbar=progressbar,
@@ -1444,12 +1450,12 @@ def init_nuts(
         mean = np.mean(apoints_data, axis=0)
         var = np.ones_like(mean)
         n = len(var)
-        potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
+        potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0])
     elif init == "jitter+adapt_diag":
         mean = np.mean(apoints_data, axis=0)
         var = np.ones_like(mean)
         n = len(var)
-        potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
+        potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0])
     elif init == "jitter+adapt_diag_grad":
         mean = np.mean(apoints_data, axis=0)
         var = np.ones_like(mean)
@@ -1466,6 +1472,7 @@ def init_nuts(
             alpha=0.02,
             use_grads=True,
             stop_adaptation=stop_adaptation,
+            rng=random_seed_list[0],
         )
     elif init == "advi+adapt_diag":
         approx = pm.fit(
@@ -1486,7 +1493,9 @@ def init_nuts(
         mean = approx.mean.get_value()
         weight = 50
         n = len(cov)
-        potential = quadpotential.QuadPotentialDiagAdapt(n, mean, cov, weight)
+        potential = quadpotential.QuadPotentialDiagAdapt(
+            n, mean, cov, weight, rng=random_seed_list[0]
+        )
     elif init == "advi":
         approx = pm.fit(
             random_seed=random_seed_list[0],
@@ -1502,7 +1511,7 @@ def init_nuts(
         )
         initial_points = [approx_sample[i] for i in range(chains)]
         cov = approx.std.eval() ** 2
-        potential = quadpotential.QuadPotentialDiag(cov)
+        potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0])
     elif init == "advi_map":
         start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
         approx = pm.MeanField(model=model, start=start)
@@ -1519,28 +1528,32 @@ def init_nuts(
         )
         initial_points = [approx_sample[i] for i in range(chains)]
         cov = approx.std.eval() ** 2
-        potential = quadpotential.QuadPotentialDiag(cov)
+        potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0])
     elif init == "map":
         start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
         cov = -pm.find_hessian(point=start, negate_output=False)
         initial_points = [start] * chains
-        potential = quadpotential.QuadPotentialFull(cov)
+        potential = quadpotential.QuadPotentialFull(cov, rng=random_seed_list[0])
     elif init == "adapt_full":
         mean = np.mean(apoints_data * chains, axis=0)
         initial_point = initial_points[0]
         initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars)
         cov = np.eye(initial_point_model_size)
-        potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10)
+        potential = quadpotential.QuadPotentialFullAdapt(
+            initial_point_model_size, mean, cov, 10, rng=random_seed_list[0]
+        )
     elif init == "jitter+adapt_full":
         mean = np.mean(apoints_data, axis=0)
         initial_point = initial_points[0]
         initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars)
         cov = np.eye(initial_point_model_size)
-        potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10)
+        potential = quadpotential.QuadPotentialFullAdapt(
+            initial_point_model_size, mean, cov, 10, rng=random_seed_list[0]
+        )
     else:
         raise ValueError(f"Unknown initializer: {init}.")
 
-    step = pm.NUTS(potential=potential, model=model, **kwargs)
+    step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs)
 
     # Filter deterministics from initial_points
     value_var_names = [var.name for var in model.value_vars]
diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py
index a34947c706..4b76e53a97 100644
--- a/pymc/sampling/parallel.py
+++ b/pymc/sampling/parallel.py
@@ -33,7 +33,7 @@
 
 from pymc.blocking import DictToArrayBijection
 from pymc.exceptions import SamplingError
-from pymc.util import CustomProgress, RandomSeed, default_progress_theme
+from pymc.util import CustomProgress, default_progress_theme
 
 logger = logging.getLogger(__name__)
 
@@ -93,15 +93,18 @@ def __init__(
         shared_point,
         draws: int,
         tune: int,
-        seed,
+        rng: np.random.Generator,
+        seed_seq: np.random.SeedSequence,
         blas_cores,
     ):
+        # For some strange reason, spawn multiprocessing doesn't copy the rng
+        # seed sequence, so we have to rebuild it from scratch
+        rng = np.random.Generator(type(rng.bit_generator)(seed_seq))
         self._msg_pipe = msg_pipe
         self._step_method = step_method
         self._step_method_is_pickled = step_method_is_pickled
         self._shared_point = shared_point
-        self._seed = seed
-        self._at_seed = seed + 1
+        self._rng = rng
         self._draws = draws
         self._tune = tune
         self._blas_cores = blas_cores
@@ -159,7 +162,7 @@ def _recv_msg(self):
         return self._msg_pipe.recv()
 
     def _start_loop(self):
-        np.random.seed(self._seed)
+        self._step_method.set_rng(self._rng)
 
         draw = 0
         tuning = True
@@ -210,7 +213,7 @@ def __init__(
         step_method,
         step_method_pickled,
         chain: int,
-        seed,
+        rng: np.random.Generator,
         start: dict[str, np.ndarray],
         blas_cores,
         mp_ctx,
@@ -260,7 +263,8 @@ def __init__(
                 self._shared_point,
                 draws,
                 tune,
-                seed,
+                rng,
+                rng.bit_generator.seed_seq,
                 blas_cores,
             ),
         )
@@ -379,7 +383,7 @@ def __init__(
         tune: int,
         chains: int,
         cores: int,
-        seeds: Sequence["RandomSeed"],
+        rngs: Sequence[np.random.Generator],
         start_points: Sequence[dict[str, np.ndarray]],
         step_method,
         progressbar: bool = True,
@@ -387,8 +391,8 @@ def __init__(
         blas_cores: int | None = None,
         mp_ctx=None,
     ):
-        if any(len(arg) != chains for arg in [seeds, start_points]):
-            raise ValueError(f"Number of seeds and start_points must be {chains}.")
+        if any(len(arg) != chains for arg in [rngs, start_points]):
+            raise ValueError(f"Number of rngs and start_points must be {chains}.")
 
         if mp_ctx is None or isinstance(mp_ctx, str):
             # Closes issue https://github.com/pymc-devs/pymc/issues/3849
@@ -416,12 +420,12 @@ def __init__(
                 step_method,
                 step_method_pickled,
                 chain,
-                seed,
+                rng,
                 start,
                 blas_cores,
                 mp_ctx,
             )
-            for chain, seed, start in zip(range(chains), seeds, start_points)
+            for chain, rng, start in zip(range(chains), rngs, start_points)
         ]
 
         self._inactive = self._samplers.copy()
diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py
index 4d5ced3f52..c0dc813b5c 100644
--- a/pymc/sampling/population.py
+++ b/pymc/sampling/population.py
@@ -37,7 +37,7 @@
     StatsType,
 )
 from pymc.step_methods.metropolis import DEMetropolis
-from pymc.util import CustomProgress, RandomSeed
+from pymc.util import CustomProgress
 
 __all__ = ()
 
@@ -53,7 +53,7 @@ def _sample_population(
     initial_points: Sequence[PointType],
     draws: int,
     start: Sequence[PointType],
-    random_seed: RandomSeed,
+    rngs: Sequence[np.random.Generator],
     step: BlockedStep | CompoundStep,
     tune: int,
     model: Model,
@@ -70,7 +70,8 @@ def _sample_population(
         The number of samples to draw
     start : list
         Start points for each chain
-    random_seed : single random seed, optional
+    rngs: sequence of random Generators
+        A list of :py:class:`~numpy.random.Generator` objects, one for each chain
     step : function
         Step function (should be or contain a population step method)
     tune : int
@@ -96,7 +97,7 @@ def _sample_population(
         traces=traces,
         tune=tune,
         model=model,
-        random_seed=random_seed,
+        rngs=rngs,
         progressbar=progressbar,
     )
 
@@ -248,8 +249,6 @@ def _run_secondary(c, stepper_dumps, secondary_end, task, progress):
         progress : progress.Progress
             The progress bar
         """
-        # re-seed each child process to make them unique
-        np.random.seed(None)
         try:
             stepper = cloudpickle.loads(stepper_dumps)
             # the stepper is not necessarily a PopulationArraySharedStep itself,
@@ -317,8 +316,8 @@ def _prepare_iter_population(
     parallelize: bool,
     traces: Sequence[BaseTrace],
     tune: int,
+    rngs: Sequence[np.random.Generator],
     model=None,
-    random_seed: RandomSeed = None,
     progressbar=True,
 ) -> Iterator[int]:
     """Prepare a PopulationStepper and traces for population sampling.
@@ -335,8 +334,9 @@ def _prepare_iter_population(
         Setting for multiprocess parallelization
     tune : int
         Number of iterations to tune.
+    rngs: sequence of random Generators
+        A list of :py:class:`~numpy.random.Generator` objects, one for each chain
     model : Model (optional if in ``with`` context)
-    random_seed : single random seed, optional
     progressbar : bool
         ``progressbar`` argument for the ``PopulationStepper``, (defaults to True)
 
@@ -352,9 +352,6 @@ def _prepare_iter_population(
     if draws < 1:
         raise ValueError("Argument `draws` should be above 0.")
 
-    if random_seed is not None:
-        np.random.seed(random_seed)
-
     # The initialization of traces, samplers and points must happen in the right order:
     # 1. population of points is created
     # 2. steppers are initialized and linked to the points object
@@ -366,13 +363,17 @@ def _prepare_iter_population(
 
     # 2. Set up the steppers
     steppers: list[Step] = []
-    for c in range(nchains):
+    assert (
+        len(rngs) == nchains
+    ), f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}"
+    for c, rng in enumerate(rngs):
         # need independent samplers for each chain
         # it is important to copy the actual steppers (but not the delta_logp)
         if isinstance(step, CompoundStep):
             chainstep = CompoundStep([copy(m) for m in step.methods])
         else:
             chainstep = copy(step)
+        chainstep.set_rng(rng)
         # link population samplers to the shared population state
         for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]:
             if isinstance(sm, PopulationArrayStepShared):
diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py
index ca6036ecc6..602dfd6e51 100644
--- a/pymc/step_methods/arraystep.py
+++ b/pymc/step_methods/arraystep.py
@@ -18,12 +18,10 @@
 
 import numpy as np
 
-from numpy.random import uniform
-
 from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType
 from pymc.model import modelcontext
 from pymc.step_methods.compound import BlockedStep
-from pymc.util import get_var_name
+from pymc.util import RandomGenerator, get_random_generator, get_var_name
 
 __all__ = ["ArrayStep", "ArrayStepShared", "metrop_select"]
 
@@ -39,13 +37,18 @@ class ArrayStep(BlockedStep):
     fs: list of logp PyTensor functions
     allvars: Boolean (default False)
     blocked: Boolean (default True)
+    rng: RandomGenerator
+        An object that can produce be used to produce the step method's
+        :py:class:`~numpy.random.Generator` object. Refer to
+        :py:func:`pymc.util.get_random_generator` for more information.
     """
 
-    def __init__(self, vars, fs, allvars=False, blocked=True):
+    def __init__(self, vars, fs, allvars=False, blocked=True, rng: RandomGenerator = None):
         self.vars = vars
         self.fs = fs
         self.allvars = allvars
         self.blocked = blocked
+        self.rng = get_random_generator(rng)
 
     def step(self, point: PointType) -> tuple[PointType, StatsType]:
         partial_funcs_and_point: list[Callable | PointType] = [
@@ -79,17 +82,22 @@ class ArrayStepShared(BlockedStep):
     and unmapping overhead as well as moving fewer variables around.
     """
 
-    def __init__(self, vars, shared, blocked=True):
+    def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
         """
         Parameters
         ----------
         vars: list of sampling value variables
         shared: dict of PyTensor variable -> shared variable
         blocked: Boolean (default True)
+        rng: RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information.
         """
         self.vars = vars
         self.shared = {get_var_name(var): shared for var, shared in shared.items()}
         self.blocked = blocked
+        self.rng = get_random_generator(rng)
 
     def step(self, point: PointType) -> tuple[PointType, StatsType]:
         for name, shared_var in self.shared.items():
@@ -120,13 +128,17 @@ class PopulationArrayStepShared(ArrayStepShared):
     Works by linking a list of Points that is updated as the chains are iterated.
     """
 
-    def __init__(self, vars, shared, blocked=True):
+    def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
         """
         Parameters
         ----------
         vars: list of sampling value variables
         shared: dict of PyTensor variable -> shared variable
         blocked: Boolean (default True)
+        rng: RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information.
         """
         self.population = None
         self.this_chain = None
@@ -155,7 +167,14 @@ def link_population(self, population, chain_index):
 
 class GradientSharedStep(ArrayStepShared):
     def __init__(
-        self, vars, model=None, blocked=True, dtype=None, logp_dlogp_func=None, **pytensor_kwargs
+        self,
+        vars,
+        model=None,
+        blocked=True,
+        dtype=None,
+        logp_dlogp_func=None,
+        rng: RandomGenerator = None,
+        **pytensor_kwargs,
     ):
         model = modelcontext(model)
 
@@ -166,14 +185,16 @@ def __init__(
 
         self._logp_dlogp_func = func
 
-        super().__init__(vars, func._extra_vars_shared, blocked)
+        super().__init__(vars, func._extra_vars_shared, blocked, rng=rng)
 
     def step(self, point) -> tuple[PointType, StatsType]:
         self._logp_dlogp_func._extra_are_set = True
         return super().step(point)
 
 
-def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> tuple[np.ndarray, bool]:
+def metrop_select(
+    mr: np.ndarray, q: np.ndarray, q0: np.ndarray, rng: np.random.Generator
+) -> tuple[np.ndarray, bool]:
     """Perform rejection/acceptance step for Metropolis class samplers.
 
     Returns the new sample q if a uniform random number is less than the
@@ -185,6 +206,8 @@ def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> tuple[np.nda
     mr: float, Metropolis acceptance rate
     q: proposed sample
     q0: current sample
+    rng: numpy.random.Generator
+        A random number generator object
 
     Returns
     -------
@@ -193,7 +216,7 @@ def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> tuple[np.nda
     # Compare acceptance ratio to uniform random number
     # TODO XXX: This `uniform` is not given a model-specific RNG state, which
     # means that sampler runs that use it will not be reproducible.
-    if np.isfinite(mr) and np.log(uniform()) < mr:
+    if np.isfinite(mr) and np.log(rng.uniform()) < mr:
         return q, True
     else:
         return q0, False
diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py
index 7c0d8563ca..1c1d6fbb50 100644
--- a/pymc/step_methods/compound.py
+++ b/pymc/step_methods/compound.py
@@ -31,6 +31,7 @@
 
 from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType
 from pymc.model import modelcontext
+from pymc.util import get_random_generator
 
 __all__ = ("Competence", "CompoundStep")
 
@@ -143,15 +144,18 @@ def __new__(cls, *args, **kwargs):
             # In this case we create a separate sampler for each var
             # and append them to a CompoundStep
             steps = []
-            for var in vars:
+            rngs = get_random_generator(kwargs.pop("rng", None)).spawn(len(vars))
+            for var, rng in zip(vars, rngs):
                 step = super().__new__(cls)
                 step.stats_dtypes = stats_dtypes
                 step.stats_dtypes_shapes = stats_dtypes_shapes
                 # If we don't return the instance we have to manually
                 # call __init__
-                step.__init__([var], *args, **kwargs)
+                _kwargs = kwargs.copy()
+                _kwargs["rng"] = rng
+                step.__init__([var], *args, **_kwargs)
                 # Hack for creating the class correctly when unpickling.
-                step.__newargs = ([var], *args), kwargs
+                step.__newargs = ([var], *args), _kwargs
                 steps.append(step)
 
             return CompoundStep(steps)
diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py
index def6829d26..b320ed8194 100644
--- a/pymc/step_methods/hmc/base_hmc.py
+++ b/pymc/step_methods/hmc/base_hmc.py
@@ -29,6 +29,7 @@
 from pymc.stats.convergence import SamplerWarning, WarningType
 from pymc.step_methods import step_sizes
 from pymc.step_methods.arraystep import GradientSharedStep
+from pymc.step_methods.compound import StepMethodState
 from pymc.step_methods.hmc import integration
 from pymc.step_methods.hmc.integration import IntegrationError, State
 from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
@@ -75,6 +76,7 @@ def __init__(
         t0=10,
         adapt_step_size=True,
         step_rand=None,
+        rng=None,
         **pytensor_kwargs,
     ):
         """Set up Hamiltonian samplers with common structures.
@@ -98,6 +100,14 @@ def __init__(
         potential: Potential, optional
             An object that represents the Hamiltonian with methods `velocity`,
             `energy`, and `random` methods.
+        rng: RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information. The
+            resulting ``Generator`` object will be used stored in the step method
+            and used for accept/reject random selections. The step's ``Generator``
+            will also be used to spawn independent ``Generators`` that will be used
+            by the ``potential`` attribute.
         **pytensor_kwargs: passed to PyTensor functions
         """
         self._model = modelcontext(model)
@@ -106,7 +116,9 @@ def __init__(
             vars = self._model.continuous_value_vars
         else:
             vars = get_value_vars_from_user_vars(vars, self._model)
-        super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **pytensor_kwargs)
+        super().__init__(
+            vars, blocked=blocked, model=self._model, dtype=dtype, rng=rng, **pytensor_kwargs
+        )
 
         self.adapt_step_size = adapt_step_size
         self.Emax = Emax
@@ -131,7 +143,7 @@ def __init__(
         if scaling is None and potential is None:
             mean = floatX(np.zeros(size))
             var = floatX(np.ones(size))
-            potential = QuadPotentialDiagAdapt(size, mean, var, 10)
+            potential = QuadPotentialDiagAdapt(size, mean, var, 10, rng=self.rng.spawn(1)[0])
 
         if isinstance(scaling, dict):
             point = Point(scaling, model=self._model)
@@ -143,7 +155,7 @@ def __init__(
         if potential is not None:
             self.potential = potential
         else:
-            self.potential = quad_potential(scaling, is_cov)
+            self.potential = quad_potential(scaling, is_cov, rng=self.rng.spawn(1)[0])
 
         self.integrator = integration.CpuLeapfrogIntegrator(self.potential, self._logp_dlogp_func)
 
@@ -193,7 +205,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
         self.step_size = step_size
 
         if self._step_rand is not None:
-            step_size = self._step_rand(step_size)
+            step_size = self._step_rand(step_size, rng=self.rng)
 
         hmc_step = self._hamiltonian_step(start, p0.data, step_size)
 
diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py
index 3c43509883..106faee501 100644
--- a/pymc/step_methods/hmc/hmc.py
+++ b/pymc/step_methods/hmc/hmc.py
@@ -27,8 +27,8 @@
 __all__ = ["HamiltonianMC"]
 
 
-def unif(step_size, elow=0.85, ehigh=1.15):
-    return np.random.uniform(elow, ehigh) * step_size
+def unif(step_size, elow=0.85, ehigh=1.15, rng: np.random.Generator | None = None):
+    return (rng or np.random).uniform(elow, ehigh) * step_size
 
 
 class HamiltonianMC(BaseHMC):
@@ -113,6 +113,14 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
             The maximum number of leapfrog steps.
         model: pymc.Model
             The model
+        rng : RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information. The
+            resulting ``Generator`` object will be used stored in the step method
+            and used for accept/reject random selections. The step's ``Generator``
+            will also be used to spawn independent ``Generators`` that will be used
+            by the ``potential`` attribute.
         **kwargs: passed to BaseHMC
         """
         kwargs.setdefault("step_rand", unif)
@@ -151,7 +159,7 @@ def _hamiltonian_step(self, start, p0, step_size: float) -> HMCStepData:
 
         accept_stat = min(1, np.exp(-energy_change))
 
-        if div_info is not None or np.random.rand() >= accept_stat:
+        if div_info is not None or self.rng.random() >= accept_stat:
             end = start
             accepted = False
         else:
diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py
index 541303bdf3..3c4b4e6800 100644
--- a/pymc/step_methods/hmc/nuts.py
+++ b/pymc/step_methods/hmc/nuts.py
@@ -169,6 +169,14 @@ def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs)
             of the scaling matrix.
         model: pymc.Model
             The model
+        rng : RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information. The
+            resulting ``Generator`` object will be used stored in the step method
+            and used for accept/reject random selections. The step's ``Generator``
+            will also be used to spawn independent ``Generators`` that will be used
+            by the ``potential`` attribute.
         kwargs: passed to BaseHMC
 
         Notes
@@ -189,11 +197,11 @@ def _hamiltonian_step(self, start, p0, step_size):
         else:
             max_treedepth = self.max_treedepth
 
-        tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax)
+        tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax, rng=self.rng)
 
         reached_max_treedepth = False
         for _ in range(max_treedepth):
-            direction = logbern(np.log(0.5)) * 2 - 1
+            direction = logbern(np.log(0.5), rng=self.rng) * 2 - 1
             divergence_info, turning = tree.extend(direction)
 
             if divergence_info or turning:
@@ -233,6 +241,7 @@ def __init__(
         start: State,
         step_size: float,
         Emax: float,
+        rng: np.random.Generator,
     ):
         """Binary tree from the NUTS algorithm.
 
@@ -254,6 +263,7 @@ def __init__(
         self.step_size = step_size
         self.Emax = Emax
         self.start_energy = start.energy
+        self.rng = rng
 
         self.left = self.right = start
         self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0)
@@ -302,7 +312,7 @@ def extend(self, direction):
             return diverging, turning
 
         size1, size2 = self.log_size, tree.log_size
-        if logbern(size2 - size1):
+        if logbern(size2 - size1, rng=self.rng):
             self.proposal = tree.proposal
 
         self.log_size = np.logaddexp(self.log_size, tree.log_size)
@@ -390,7 +400,7 @@ def _build_subtree(self, left, depth, epsilon):
                 turning = turning | turning1 | turning2
 
             log_size = np.logaddexp(tree1.log_size, tree2.log_size)
-            if logbern(tree2.log_size - log_size):
+            if logbern(tree2.log_size - log_size, rng=self.rng):
                 proposal = tree2.proposal
             else:
                 proposal = tree1.proposal
diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py
index 4f975ff95c..abddaaf35f 100644
--- a/pymc/step_methods/hmc/quadpotential.py
+++ b/pymc/step_methods/hmc/quadpotential.py
@@ -22,7 +22,6 @@
 import pytensor
 import scipy.linalg
 
-from numpy.random import normal
 from scipy.sparse import issparse
 
 from pymc.pytensorf import floatX
@@ -38,7 +37,7 @@
 ]
 
 
-def quad_potential(C, is_cov):
+def quad_potential(C, is_cov, rng=None):
     """
     Compute a QuadPotential object from a scaling matrix.
 
@@ -49,6 +48,10 @@ def quad_potential(C, is_cov):
         vector treated as diagonal matrix.
     is_cov: Boolean
         whether C is provided as a covariance matrix or hessian
+    rng: RandomGenerator
+        An object that can produce be used to produce the step method's
+        :py:class:`~numpy.random.Generator` object. Refer to
+        :py:func:`pymc.util.get_random_generator` for more information.
 
     Returns
     -------
@@ -58,21 +61,21 @@ def quad_potential(C, is_cov):
         if not chol_available:
             raise ImportError("Sparse mass matrices require scikits.sparse")
         elif is_cov:
-            return QuadPotentialSparse(C)
+            return QuadPotentialSparse(C, rng=rng)
         else:
             raise ValueError("Sparse precision matrices are not supported")
 
     partial_check_positive_definite(C)
     if C.ndim == 1:
         if is_cov:
-            return QuadPotentialDiag(C)
+            return QuadPotentialDiag(C, rng=rng)
         else:
-            return QuadPotentialDiag(1.0 / C)
+            return QuadPotentialDiag(1.0 / C, rng=rng)
     else:
         if is_cov:
-            return QuadPotentialFull(C)
+            return QuadPotentialFull(C, rng=rng)
         else:
-            return QuadPotentialFullInv(C)
+            return QuadPotentialFullInv(C, rng=rng)
 
 
 def partial_check_positive_definite(C):
@@ -100,6 +103,9 @@ def __str__(self):
 class QuadPotential:
     dtype: np.dtype
 
+    def __init__(self, rng=None):
+        self.rng = np.random.default_rng(rng)
+
     @overload
     def velocity(self, x: np.ndarray, out: None) -> np.ndarray: ...
 
@@ -172,6 +178,7 @@ def __init__(
         discard_window=50,
         early_update=False,
         store_mass_matrix_trace=False,
+        rng=None,
     ):
         """Set up a diagonal mass matrix.
 
@@ -202,6 +209,8 @@ def __init__(
         store_mass_matrix_trace : bool
             If true, store the mass matrix at each step of the adaptation. Only for debugging
             purposes.
+        rng : Generator | int | None
+            Numpy random number generator
         """
         if initial_diag is not None and initial_diag.ndim != 1:
             raise ValueError("Initial diagonal must be one-dimensional.")
@@ -234,6 +243,8 @@ def __init__(
         self._store_mass_matrix_trace = store_mass_matrix_trace
         self._mass_trace = []
 
+        super().__init__(rng=rng)
+
         self.reset()
 
     def reset(self):
@@ -264,7 +275,7 @@ def velocity_energy(self, x, v_out):
 
     def random(self):
         """Draw random value from QuadPotential."""
-        vals = normal(size=self._n).astype(self.dtype)
+        vals = self.rng.normal(size=self._n).astype(self.dtype)
         return self._inv_stds * vals
 
     def _update_from_weightvar(self, weightvar):
@@ -405,7 +416,7 @@ def current_mean(self, out=None):
 
 
 class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
-    def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, **kwargs):
+    def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None, **kwargs):
         """Set up a diagonal mass matrix.
 
         Parameters
@@ -430,11 +441,15 @@ def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, **kwargs
         store_mass_matrix_trace : bool
             If true, store the mass matrix at each step of the adaptation. Only for debugging
             purposes.
+        rng: RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information.
         """
         if len(args) > 3:
             raise ValueError("Unsupported arguments to QuadPotentialDiagAdaptExp")
 
-        super().__init__(*args, **kwargs)
+        super().__init__(*args, rng=rng, **kwargs)
         self._alpha = alpha
         self._use_grads = use_grads
 
@@ -488,13 +503,19 @@ def _update_from_variances(self, var_estimator, inv_var_estimator):
 class QuadPotentialDiag(QuadPotential):
     """Quad potential using a diagonal covariance matrix."""
 
-    def __init__(self, v, dtype=None):
+    def __init__(self, v, dtype=None, rng=None):
         """Use a vector to represent a diagonal matrix for a covariance matrix.
 
         Parameters
         ----------
         v: vector, 0 <= ndim <= 1
            Diagonal of covariance matrix for the potential vector
+        dtype :
+            The dtype to assign to the resulting momentum
+        rng : RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information.
         """
         if dtype is None:
             dtype = pytensor.config.floatX
@@ -505,6 +526,7 @@ def __init__(self, v, dtype=None):
         self.s = s
         self.inv_s = 1.0 / s
         self.v = v
+        self.rng = np.random.default_rng(rng)
 
     def velocity(self, x, out=None):
         """Compute the current velocity at a position in parameter space."""
@@ -515,7 +537,7 @@ def velocity(self, x, out=None):
 
     def random(self):
         """Draw random value from QuadPotential."""
-        return floatX(normal(size=self.s.shape)) * self.inv_s
+        return floatX(self.rng.normal(size=self.s.shape)) * self.inv_s
 
     def energy(self, x, velocity=None):
         """Compute kinetic energy at a position in parameter space."""
@@ -532,18 +554,25 @@ def velocity_energy(self, x, v_out):
 class QuadPotentialFullInv(QuadPotential):
     """QuadPotential object for Hamiltonian calculations using inverse of covariance matrix."""
 
-    def __init__(self, A, dtype=None):
+    def __init__(self, A, dtype=None, rng=None):
         """Compute the lower cholesky decomposition of the potential.
 
         Parameters
         ----------
         A: matrix, ndim = 2
            Inverse of covariance matrix for the potential vector
+        dtype :
+            The dtype to assign to the resulting momentum
+        rng : RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information.
         """
         if dtype is None:
             dtype = pytensor.config.floatX
         self.dtype = dtype
         self.L = floatX(scipy.linalg.cholesky(A, lower=True))
+        self.rng = np.random.default_rng(rng)
 
     def velocity(self, x, out=None):
         """Compute the current velocity at a position in parameter space."""
@@ -554,7 +583,7 @@ def velocity(self, x, out=None):
 
     def random(self):
         """Draw random value from QuadPotential."""
-        n = floatX(normal(size=self.L.shape[0]))
+        n = floatX(self.rng.normal(size=self.L.shape[0]))
         return np.dot(self.L, n)
 
     def energy(self, x, velocity=None):
@@ -572,13 +601,19 @@ def velocity_energy(self, x, v_out):
 class QuadPotentialFull(QuadPotential):
     """Basic QuadPotential object for Hamiltonian calculations."""
 
-    def __init__(self, cov, dtype=None):
+    def __init__(self, cov, dtype=None, rng=None):
         """Compute the lower cholesky decomposition of the potential.
 
         Parameters
         ----------
         A: matrix, ndim = 2
             scaling matrix for the potential vector
+        dtype :
+            The dtype to assign to the resulting momentum
+        rng : RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information.
         """
         if dtype is None:
             dtype = pytensor.config.floatX
@@ -586,6 +621,7 @@ def __init__(self, cov, dtype=None):
         self._cov = np.array(cov, dtype=self.dtype, copy=True)
         self._chol = scipy.linalg.cholesky(self._cov, lower=True)
         self._n = len(self._cov)
+        self.rng = np.random.default_rng(rng)
 
     def velocity(self, x, out=None):
         """Compute the current velocity at a position in parameter space."""
@@ -593,7 +629,7 @@ def velocity(self, x, out=None):
 
     def random(self):
         """Draw random value from QuadPotential."""
-        vals = np.random.normal(size=self._n).astype(self.dtype)
+        vals = self.rng.normal(size=self._n).astype(self.dtype)
         return scipy.linalg.solve_triangular(self._chol.T, vals, overwrite_b=True)
 
     def energy(self, x, velocity=None):
@@ -623,6 +659,7 @@ def __init__(
         adaptation_window_multiplier=2,
         update_window=1,
         dtype=None,
+        rng=None,
     ):
         warnings.warn("QuadPotentialFullAdapt is an experimental feature")
 
@@ -652,6 +689,8 @@ def __init__(
         self.adaptation_window_multiplier = float(adaptation_window_multiplier)
         self._update_window = int(update_window)
 
+        self.rng = np.random.default_rng(rng)
+
         self.reset()
 
     def reset(self):
@@ -772,18 +811,23 @@ def current_mean(self):
     import pytensor.sparse
 
     class QuadPotentialSparse(QuadPotential):
-        def __init__(self, A):
+        def __init__(self, A, rng=None):
             """Compute a sparse cholesky decomposition of the potential.
 
             Parameters
             ----------
             A: matrix, ndim = 2
                 scaling matrix for the potential vector
+            rng : RandomGenerator
+                An object that can produce be used to produce the step method's
+                :py:class:`~numpy.random.Generator` object. Refer to
+                :py:func:`pymc.util.get_random_generator` for more information.
             """
             self.A = A
             self.size = A.shape[0]
             self.factor = factor = cholmod.cholesky(A)
             self.d_sqrt = np.sqrt(factor.D())
+            self.rng = np.random.default_rng(rng)
 
         def velocity(self, x):
             """Compute the current velocity at a position in parameter space."""
@@ -792,7 +836,7 @@ def velocity(self, x):
 
         def random(self):
             """Draw random value from QuadPotential."""
-            n = floatX(normal(size=self.size))
+            n = floatX(self.rng.normal(size=self.size))
             n /= self.d_sqrt
             n = self.factor.solve_Lt(n)
             n = self.factor.apply_Pt(n)
diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py
index d752999ec1..aa5101dbb0 100644
--- a/pymc/step_methods/metropolis.py
+++ b/pymc/step_methods/metropolis.py
@@ -116,7 +116,6 @@ class Metropolis(ArrayStepShared):
 
     name = "metropolis"
 
-    default_blocked = False
     stats_dtypes_shapes = {
         "accept": (np.float64, []),
         "accepted": (np.float64, []),
@@ -134,6 +133,7 @@ def __init__(
         tune_interval=100,
         model=None,
         mode=None,
+        rng=None,
         **kwargs,
     ):
         """Create an instance of a Metropolis stepper
@@ -157,6 +157,10 @@ def __init__(
             Optional model for sampling step. Defaults to None (taken from context).
         mode: string or `Mode` instance.
             compilation mode passed to PyTensor functions
+        rng: RandomGenerator
+            An object that can produce be used to produce the step method's
+            :py:class:`~numpy.random.Generator` object. Refer to
+            :py:func:`pymc.util.get_random_generator` for more information.
         """
 
         model = pm.modelcontext(model)
@@ -223,7 +227,7 @@ def __init__(
 
         shared = pm.make_shared_replacements(initial_values, vars, model)
         self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
-        super().__init__(vars, shared)
+        super().__init__(vars, shared, rng=rng)
 
     def reset_tuning(self):
         """Resets the tuned sampler parameters to their initial values."""
@@ -243,7 +247,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
             self.steps_until_tune = self.tune_interval
             self.accepted_sum[:] = 0
 
-        delta = self.proposal_dist() * self.scaling
+        delta = self.proposal_dist(rng=self.rng) * self.scaling
 
         if self.any_discrete:
             if self.all_discrete:
@@ -260,11 +264,11 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
             q0d = q0d.copy()
             q_temp = q0d.copy()
             # Shuffle order of updates (probably we don't need to do this in every step)
-            np.random.shuffle(self.enum_dims)
+            self.rng.shuffle(self.enum_dims)
             for i in self.enum_dims:
                 q_temp[i] = q[i]
                 accept_rate_i = self.delta_logp(q_temp, q0d)
-                q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0d)
+                q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0d, rng=self.rng)
                 q_temp[i] = q0d[i] = q_temp_[i]
                 self.accept_rate_iter[i] = accept_rate_i
                 self.accepted_iter[i] = accepted_i
@@ -272,7 +276,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
             q = q_temp
         else:
             accept_rate = self.delta_logp(q, q0d)
-            q, accepted = metrop_select(accept_rate, q, q0d)
+            q, accepted = metrop_select(accept_rate, q, q0d, rng=self.rng)
             self.accept_rate_iter = accept_rate
             self.accepted_iter = accepted
             self.accepted_sum += accepted
@@ -357,7 +361,10 @@ class BinaryMetropolis(ArrayStep):
         The frequency of tuning. Defaults to 100 iterations.
     model: PyMC Model
         Optional model for sampling step. Defaults to None (taken from context).
-
+    rng: RandomGenerator
+        An object that can produce be used to produce the step method's
+        :py:class:`~numpy.random.Generator` object. Refer to
+        :py:func:`pymc.util.get_random_generator` for more information.
     """
 
     name = "binary_metropolis"
@@ -393,7 +400,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
         # Convert adaptive_scale_factor to a jump probability
         p_jump = 1.0 - 0.5**self.scaling
 
-        rand_array = nr.random(q0.shape)
+        rand_array = self.rng.random(q0.shape)
         q = np.copy(q0)
         # Locations where switches occur, according to p_jump
         switch_locs = rand_array < p_jump
@@ -401,7 +408,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
         logp_q = logp(RaveledVars(q, point_map_info))
 
         accept = logp_q - logp_q0
-        q_new, accepted = metrop_select(accept, q, q0)
+        q_new, accepted = metrop_select(accept, q, q0, rng=self.rng)
         self.accepted += accepted
 
         stats = {
@@ -453,7 +460,10 @@ class BinaryGibbsMetropolis(ArrayStep):
         which resulting in more efficient antithetical sampling. Default is 0.8
     model: PyMC Model
         Optional model for sampling step. Defaults to None (taken from context).
-
+    rng: RandomGenerator
+        An object that can produce be used to produce the step method's
+        :py:class:`~numpy.random.Generator` object. Refer to
+        :py:func:`pymc.util.get_random_generator` for more information.
     """
 
     name = "binary_gibbs_metropolis"
@@ -498,7 +508,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
         logp: Callable[[RaveledVars], np.ndarray] = args[0]
         order = self.order
         if self.shuffle_dims:
-            nr.shuffle(order)
+            self.rng.shuffle(order)
 
         q = RaveledVars(np.copy(apoint.data), apoint.point_map_info)
 
@@ -507,10 +517,12 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
         for idx in order:
             # No need to do metropolis update if the same value is proposed,
             # as you will get the same value regardless of accepted or reject
-            if nr.rand() < self.transit_p:
+            if self.rng.random() < self.transit_p:
                 curr_val, q.data[idx] = q.data[idx], True - q.data[idx]
                 logp_prop = logp(q)
-                q.data[idx], accepted = metrop_select(logp_prop - logp_curr, q.data[idx], curr_val)
+                q.data[idx], accepted = metrop_select(
+                    logp_prop - logp_curr, q.data[idx], curr_val, rng=self.rng
+                )
                 if accepted:
                     logp_curr = logp_prop
 
@@ -561,7 +573,7 @@ class CategoricalGibbsMetropolis(ArrayStep):
         "tune": (bool, []),
     }
 
-    def __init__(self, vars, proposal="uniform", order="random", model=None):
+    def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None):
         model = pm.modelcontext(model)
 
         vars = get_value_vars_from_user_vars(vars, model)
@@ -615,7 +627,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
         # that indicates whether a draw was done in a tuning phase.
         self.tune = True
 
-        super().__init__(vars, [model.compile_logp()])
+        super().__init__(vars, [model.compile_logp()], rng=rng)
 
     def reset_tuning(self):
         # There are no tuning parameters in this step method.
@@ -628,15 +640,17 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType
 
         dimcats = self.dimcats
         if self.shuffle_dims:
-            nr.shuffle(dimcats)
+            self.rng.shuffle(dimcats)
 
         q = RaveledVars(np.copy(q0), point_map_info)
         logp_curr = logp(q)
 
         for dim, k in dimcats:
-            curr_val, q.data[dim] = q.data[dim], sample_except(k, q.data[dim])
+            curr_val, q.data[dim] = q.data[dim], sample_except(k, q.data[dim], rng=self.rng)
             logp_prop = logp(q)
-            q.data[dim], accepted = metrop_select(logp_prop - logp_curr, q.data[dim], curr_val)
+            q.data[dim], accepted = metrop_select(
+                logp_prop - logp_curr, q.data[dim], curr_val, rng=self.rng
+            )
             if accepted:
                 logp_curr = logp_prop
 
@@ -652,7 +666,7 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType
 
         dimcats = self.dimcats
         if self.shuffle_dims:
-            nr.shuffle(dimcats)
+            self.rng.shuffle(dimcats)
 
         q = RaveledVars(np.copy(q0), point_map_info)
         logp_curr = logp(q)
@@ -677,9 +691,9 @@ def metropolis_proportional(self, q, logp, logp_curr, dim, k):
         probs = scipy.special.softmax(log_probs, axis=0)
         prob_curr, probs[given_cat] = probs[given_cat], 0.0
         probs /= 1.0 - prob_curr
-        proposed_cat = nr.choice(candidates, p=probs)
+        proposed_cat = self.rng.choice(candidates, p=probs)
         accept_ratio = (1.0 - prob_curr) / (1.0 - probs[proposed_cat])
-        if not np.isfinite(accept_ratio) or nr.uniform() >= accept_ratio:
+        if not np.isfinite(accept_ratio) or self.rng.uniform() >= accept_ratio:
             q.data[dim] = given_cat
             return logp_curr
         q.data[dim] = proposed_cat
@@ -739,6 +753,10 @@ class DEMetropolis(PopulationArrayStepShared):
         Optional model for sampling step. Defaults to None (taken from context).
     mode:  string or `Mode` instance.
         compilation mode passed to PyTensor functions
+    rng: RandomGenerator
+        An object that can produce be used to produce the step method's
+        :py:class:`~numpy.random.Generator` object. Refer to
+        :py:func:`pymc.util.get_random_generator` for more information.
 
     References
     ----------
@@ -821,7 +839,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
             self.steps_until_tune = self.tune_interval
             self.accepted = 0
 
-        epsilon = self.proposal_dist() * self.scaling
+        epsilon = self.proposal_dist(rng=self.rng) * self.scaling
 
         # differential evolution proposal
         # select two other chains
@@ -832,7 +850,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
         q = floatX(q0d + self.lamb * (r1.data - r2.data) + epsilon)
 
         accept = self.delta_logp(q, q0d)
-        q_new, accepted = metrop_select(accept, q, q0d)
+        q_new, accepted = metrop_select(accept, q, q0d, rng=self.rng)
         self.accepted += accepted
 
         self.steps_until_tune -= 1
@@ -883,6 +901,10 @@ class DEMetropolisZ(ArrayStepShared):
         Optional model for sampling step. Defaults to None (taken from context).
     mode:  string or `Mode` instance.
         compilation mode passed to PyTensor functions
+    rng: RandomGenerator
+        An object that can produce be used to produce the step method's
+        :py:class:`~numpy.random.Generator` object. Refer to
+        :py:func:`pymc.util.get_random_generator` for more information.
 
     References
     ----------
@@ -986,17 +1008,17 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
             self.steps_until_tune = self.tune_interval
             self.accepted = 0
 
-        epsilon = self.proposal_dist() * self.scaling
+        epsilon = self.proposal_dist(rng=self.rng) * self.scaling
 
         it = len(self._history)
         # use the DE-MCMC-Z proposal scheme as soon as the history has 2 entries
         if it > 1:
             # differential evolution proposal
             # select two other chains
-            iz1 = np.random.randint(it)
-            iz2 = np.random.randint(it)
+            iz1 = self.rng.integers(it)
+            iz2 = self.rng.integers(it)
             while iz2 == iz1:
-                iz2 = np.random.randint(it)
+                iz2 = self.rng.integers(it)
 
             z1 = self._history[iz1]
             z2 = self._history[iz2]
@@ -1007,7 +1029,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
             q = floatX(q0d + epsilon)
 
         accept = self.delta_logp(q, q0d)
-        q_new, accepted = metrop_select(accept, q, q0d)
+        q_new, accepted = metrop_select(accept, q, q0d, rng=self.rng)
         self.accepted += accepted
         self._history.append(q_new)
 
@@ -1039,8 +1061,8 @@ def competence(var, has_grad):
         return Competence.COMPATIBLE
 
 
-def sample_except(limit, excluded):
-    candidate = nr.choice(limit - 1)
+def sample_except(limit, excluded, rng: np.random.Generator):
+    candidate = rng.choice(limit - 1)
     if candidate >= excluded:
         candidate += 1
     return candidate
diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py
index 3a9d90800a..3e096aeb9f 100644
--- a/pymc/step_methods/slicer.py
+++ b/pymc/step_methods/slicer.py
@@ -16,7 +16,6 @@
 
 
 import numpy as np
-import numpy.random as nr
 
 from pymc.blocking import RaveledVars, StatsType
 from pymc.model import modelcontext
@@ -47,6 +46,10 @@ class Slice(ArrayStepShared):
         Optional model for sampling step. It will be taken from the context if not provided.
     iter_limit : int, default np.inf
         Maximum number of iterations for the slice sampler.
+    rng: RandomGenerator
+        An object that can produce be used to produce the step method's
+        :py:class:`~numpy.random.Generator` object. Refer to
+        :py:func:`pymc.util.get_random_generator` for more information.
 
     """
 
@@ -58,7 +61,9 @@ class Slice(ArrayStepShared):
         "nstep_in": (int, []),
     }
 
-    def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs):
+    def __init__(
+        self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs
+    ):
         model = modelcontext(model)
         self.w = np.asarray(w).copy()
         self.tune = tune
@@ -78,7 +83,7 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, *
         self.logp = compile_pymc([raveled_inp], logp)
         self.logp.trust_input = True
 
-        super().__init__(vars, shared)
+        super().__init__(vars, shared, rng=rng)
 
     def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]:
         # The arguments are determined by the list passed via `super().__init__(..., fs, ...)`
@@ -96,10 +101,10 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]:
         logp = self.logp
         for i, wi in enumerate(self.w):
             # uniformly sample from 0 to p(q), but in log space
-            y = logp(q) - nr.standard_exponential()
+            y = logp(q) - self.rng.standard_exponential()
 
             # Create initial interval
-            ql[i] = q[i] - nr.uniform() * wi  # q[i] + r * w
+            ql[i] = q[i] - self.rng.uniform() * wi  # q[i] + r * w
             qr[i] = ql[i] + wi  # Equivalent to q[i] + (1-r) * w
 
             # Stepping out procedure
@@ -120,14 +125,14 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]:
             nstep_out += cnt
 
             cnt = 0
-            q[i] = nr.uniform(ql[i], qr[i])
+            q[i] = self.rng.uniform(ql[i], qr[i])
             while y > logp(q):  # Changed leq to lt, to accommodate for locally flat posteriors
                 # Sample uniformly from slice
                 if q[i] > q0_val[i]:
                     qr[i] = q[i]
                 elif q[i] < q0_val[i]:
                     ql[i] = q[i]
-                q[i] = nr.uniform(ql[i], qr[i])
+                q[i] = self.rng.uniform(ql[i], qr[i])
                 cnt += 1
                 if cnt > self.iter_limit:
                     raise RuntimeError(LOOP_ERR_MSG % self.iter_limit)
diff --git a/pymc/util.py b/pymc/util.py
index fe55813385..7733d41b60 100644
--- a/pymc/util.py
+++ b/pymc/util.py
@@ -16,6 +16,7 @@
 import warnings
 
 from collections.abc import Sequence
+from copy import deepcopy
 from typing import NewType, cast
 
 import arviz
@@ -399,6 +400,7 @@ def wrapped(**kwargs):
 
 RandomSeed = None | int | Sequence[int] | np.ndarray
 RandomState = RandomSeed | np.random.RandomState | np.random.Generator
+RandomGenerator = RandomSeed | np.random.Generator | np.random.BitGenerator
 
 
 def _get_seeds_per_chain(
@@ -431,10 +433,15 @@ def _get_unique_seeds_per_chain(integers_fn):
             seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)]
         return seeds
 
-    if random_state is None or isinstance(random_state, int):
-        if chains == 1 and isinstance(random_state, int):
-            return (random_state,)
-        return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers)
+    try:
+        int_random_state = int(random_state)  # type: ignore
+    except Exception:
+        int_random_state = None
+
+    if random_state is None or int_random_state is not None:
+        if chains == 1 and int_random_state is not None:
+            return (int_random_state,)
+        return _get_unique_seeds_per_chain(np.random.default_rng(int_random_state).integers)
     if isinstance(random_state, np.random.Generator):
         return _get_unique_seeds_per_chain(random_state.integers)
     if isinstance(random_state, np.random.RandomState):
@@ -578,3 +585,52 @@ def update(
                 **fields,
             )
         return None
+
+
+def get_random_generator(
+    seed: RandomGenerator | np.random.RandomState = None, copy: bool = True
+) -> np.random.Generator:
+    """Build a :py:class:`~numpy.random.Generator` object from a suitable seed.
+
+    Parameters
+    ----------
+    seed : None | int | Sequence[int] | numpy.random.Generator | numpy.random.BitGenerator | numpy.random.RandomState
+        A suitable seed to use to generate the :py:class:`~numpy.random.Generator` object.
+        For more details on suitable seeds, refer to :py:func:`numpy.random.default_rng`.
+    copy : bool
+        Boolean flag that indicates whether to copy the seed object before feeding
+        it to :py:func:`numpy.random.default_rng`. If `copy` is `False`, and the seed
+        object is a ``BitGenerator`` or ``Generator`` object, the returned
+        ``Generator`` will use the ``seed`` object where possible. This means that it
+        will return the ``seed`` input object if it is a ``Generator`` or that it
+        will return a new ``Generator`` whose ``bit_generator`` attribute will be the
+        input ``seed`` object. To avoid this potential object sharing, you must set
+        ``copy`` to ``True``.
+
+    Returns
+    -------
+    rng : numpy.random.Generator
+        The result of passing the input ``seed`` (or a copy of it) through
+        :py:func:`numpy.random.default_rng`.
+
+    Raises
+    ------
+    TypeError:
+        If the supplied ``seed`` is a :py:class:`~numpy.random.RandomState` object. We
+        do not support using these legacy objects because their seeding strategy is not
+        amenable to spawning new independent random streams.
+    """
+    if isinstance(seed, np.random.RandomState):
+        raise TypeError(
+            "Cannot create a random Generator from a RandomStream object. "
+            "Please provide a random seed, BitGenerator or Generator instead."
+        )
+    if copy:
+        # If seed is a numpy.random.Generator or numpy.random.BitGenerator,
+        # numpy.random.default_rng will use the exact same object to return.
+        # In the former case, it will return seed, in the latter it will return
+        # a new Generator object that has the same BitGenerator. This would potentially
+        # make the new generator be shared across many users. To avoid this, we
+        # deepcopy by default.
+        seed = deepcopy(seed)
+    return np.random.default_rng(seed)
diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py
index 24579bae02..dd408e5f86 100644
--- a/tests/sampling/test_forward.py
+++ b/tests/sampling/test_forward.py
@@ -497,7 +497,7 @@ def test_normal_scalar(self):
             assert ppc["a"].shape == (nchains, ndraws)
 
             # test default case
-            random_state = np.random.RandomState(20160911)
+            random_state = np.random.default_rng(20160911)
             idata_ppc = pm.sample_posterior_predictive(
                 trace, var_names=["a"], random_seed=random_state
             )
@@ -623,9 +623,9 @@ def test_model_not_drawable_prior(self, seeded_test):
             assert samples["foo"].shape == (1, 40, 200)
 
     def test_model_shared_variable(self):
-        rng = np.random.RandomState(9832)
+        rng = np.random.default_rng(9832)
 
-        x = rng.randn(100)
+        x = rng.normal(size=100)
         y = x > 0
         x_shared = pytensor.shared(x)
         y_shared = pytensor.shared(y)
@@ -656,10 +656,10 @@ def test_model_shared_variable(self):
         npt.assert_allclose(post_pred["p"], expected_p)
 
     def test_deterministic_of_observed(self):
-        rng = np.random.RandomState(8442)
+        rng = np.random.default_rng(8442)
 
-        meas_in_1 = pm.pytensorf.floatX(2 + 4 * rng.randn(10))
-        meas_in_2 = pm.pytensorf.floatX(5 + 4 * rng.randn(10))
+        meas_in_1 = pm.pytensorf.floatX(2 + 4 * rng.normal(size=10))
+        meas_in_2 = pm.pytensorf.floatX(5 + 4 * rng.normal(size=10))
         nchains = 2
         with pm.Model() as model:
             mu_in_1 = pm.Normal("mu_in_1", 0, 2)
@@ -696,10 +696,10 @@ def test_deterministic_of_observed(self):
             npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
 
     def test_deterministic_of_observed_modified_interface(self):
-        rng = np.random.RandomState(4982)
+        rng = np.random.default_rng(4982)
 
-        meas_in_1 = pm.pytensorf.floatX(2 + 4 * rng.randn(100))
-        meas_in_2 = pm.pytensorf.floatX(5 + 4 * rng.randn(100))
+        meas_in_1 = pm.pytensorf.floatX(2 + 4 * rng.normal(size=100))
+        meas_in_2 = pm.pytensorf.floatX(5 + 4 * rng.normal(size=100))
         with pm.Model() as model:
             mu_in_1 = pm.Normal("mu_in_1", 0, 1, initval=0)
             sigma_in_1 = pm.HalfNormal("sd_in_1", 1, initval=1)
@@ -1408,7 +1408,7 @@ def test_distinct_rvs():
         Y_rv = pm.Normal("y")
 
         pp_samples = pm.sample_prior_predictive(
-            draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
+            draws=2, return_inferencedata=False, random_seed=npr.default_rng(2023532)
         )
 
     assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0]
@@ -1418,7 +1418,7 @@ def test_distinct_rvs():
         Y_rv = pm.Normal("y")
 
         pp_samples_2 = pm.sample_prior_predictive(
-            draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
+            draws=2, return_inferencedata=False, random_seed=npr.default_rng(2023532)
         )
 
     assert np.array_equal(pp_samples["y"], pp_samples_2["y"])
diff --git a/tests/sampling/test_parallel.py b/tests/sampling/test_parallel.py
index c69c75fabc..8c71bcac00 100644
--- a/tests/sampling/test_parallel.py
+++ b/tests/sampling/test_parallel.py
@@ -157,7 +157,7 @@ def test_explicit_sample(mp_start_method):
         10,
         step,
         chain=3,
-        seed=1,
+        rng=np.random.default_rng(1),
         mp_ctx=ctx,
         start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))},
         step_method_pickled=step_method_pickled,
@@ -190,7 +190,7 @@ def test_iterator():
         tune=10,
         chains=3,
         cores=2,
-        seeds=[2, 3, 4],
+        rngs=np.random.default_rng(1).spawn(3),
         start_points=[start] * 3,
         step_method=step,
         progressbar=False,
diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py
index 1bec2d2f46..2bb71b893e 100644
--- a/tests/step_methods/hmc/test_nuts.py
+++ b/tests/step_methods/hmc/test_nuts.py
@@ -36,14 +36,15 @@ class TestNUTSUniform(sf.NutsFixture, sf.UniformFixture):
     min_n_eff = 9000
     rtol = 0.1
     atol = 0.05
+    step_args = {"random_seed": 202010}
 
 
 class TestNUTSUniform2(TestNUTSUniform):
-    step_args = {"target_accept": 0.95}
+    step_args = {"target_accept": 0.95, "random_seed": 202010}
 
 
 class TestNUTSUniform3(TestNUTSUniform):
-    step_args = {"target_accept": 0.80}
+    step_args = {"target_accept": 0.80, "random_seed": 202010}
 
 
 class TestNUTSNormal(sf.NutsFixture, sf.NormalFixture):
@@ -54,6 +55,7 @@ class TestNUTSNormal(sf.NutsFixture, sf.NormalFixture):
     min_n_eff = 10000
     rtol = 0.1
     atol = 0.05
+    step_args = {"random_seed": 123456}
 
 
 class TestNUTSBetaBinomial(sf.NutsFixture, sf.BetaBinomialFixture):
@@ -63,6 +65,7 @@ class TestNUTSBetaBinomial(sf.NutsFixture, sf.BetaBinomialFixture):
     burn = 0
     chains = 2
     min_n_eff = 400
+    step_args = {"random_seed": 202010}
 
 
 class TestNUTSStudentT(sf.NutsFixture, sf.StudentTFixture):
@@ -73,6 +76,7 @@ class TestNUTSStudentT(sf.NutsFixture, sf.StudentTFixture):
     min_n_eff = 1000
     rtol = 0.1
     atol = 0.05
+    step_args = {"random_seed": 202010}
 
 
 @pytest.mark.skip("Takes too long to run")
@@ -92,6 +96,7 @@ class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture):
     burn = 0
     chains = 2
     min_n_eff = 200
+    step_args = {"random_seed": 202010}
 
 
 class TestNutsCheckTrace:
diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py
index 7bfdb645c7..f414a534e8 100644
--- a/tests/step_methods/test_metropolis.py
+++ b/tests/step_methods/test_metropolis.py
@@ -36,6 +36,8 @@
 from tests.helpers import RVsAssignmentStepsTester, StepMethodTester
 from tests.models import mv_simple, mv_simple_discrete, simple_categorical
 
+SEED = sum(ord(c) for c in "test_metropolis")
+
 
 class TestMetropolisUniform(sf.MetropolisFixture, sf.UniformFixture):
     n_samples = 50000
@@ -45,6 +47,7 @@ class TestMetropolisUniform(sf.MetropolisFixture, sf.UniformFixture):
     min_n_eff = 10000
     rtol = 0.1
     atol = 0.05
+    step_args = {"rng": np.random.default_rng(SEED)}
 
 
 class TestMetropolis:
@@ -81,7 +84,7 @@ def test_tuning_reset(self):
             idata = pm.sample(
                 tune=600,
                 draws=500,
-                step=Metropolis(tune=True, scaling=0.1),
+                step=Metropolis(tune=True, scaling=0.1, rng=SEED),
                 cores=1,
                 chains=3,
                 discard_tuned_samples=False,
@@ -113,7 +116,7 @@ def test_tuning_reset(self):
     def test_elemwise_update(self, batched_dist):
         with pm.Model() as m:
             m.register_rv(batched_dist, name="batched_dist")
-            step = pm.Metropolis([batched_dist])
+            step = pm.Metropolis([batched_dist], rng=SEED)
             assert step.elemwise_update == (batched_dist.ndim > 0)
             trace = pm.sample(draws=1000, chains=2, step=step, random_seed=428)
 
@@ -124,7 +127,7 @@ def test_elemwise_update_different_scales(self):
         mu = [1, 2, 3, 4, 5, 100, 1_000, 10_000]
         with pm.Model() as m:
             x = pm.Poisson("x", mu=mu)
-            step = pm.Metropolis([x])
+            step = pm.Metropolis([x], rng=SEED)
             trace = pm.sample(draws=1000, chains=2, step=step, random_seed=128).posterior
 
         np.testing.assert_allclose(trace["x"].mean(("draw", "chain")), mu, rtol=0.1)
@@ -134,7 +137,7 @@ def test_multinomial_no_elemwise_update(self):
         with pm.Model() as m:
             batched_dist = pm.Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4))
             with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
-                step = pm.Metropolis([batched_dist])
+                step = pm.Metropolis([batched_dist], rng=SEED)
                 assert not step.elemwise_update
 
 
@@ -167,7 +170,7 @@ def test_tuning_lambda_sequential(self):
             idata = pm.sample(
                 tune=1000,
                 draws=500,
-                step=DEMetropolisZ(tune="lambda", lamb=0.92),
+                step=DEMetropolisZ(tune="lambda", lamb=0.92, rng=SEED),
                 cores=1,
                 chains=3,
                 discard_tuned_samples=False,
@@ -185,7 +188,7 @@ def test_tuning_epsilon_parallel(self):
             idata = pm.sample(
                 tune=1000,
                 draws=500,
-                step=DEMetropolisZ(tune="scaling", scaling=0.002),
+                step=DEMetropolisZ(tune="scaling", scaling=0.002, rng=SEED),
                 cores=2,
                 chains=2,
                 discard_tuned_samples=False,
@@ -203,7 +206,7 @@ def test_tuning_none(self):
             idata = pm.sample(
                 tune=1000,
                 draws=500,
-                step=DEMetropolisZ(tune=None),
+                step=DEMetropolisZ(tune=None, rng=SEED),
                 cores=1,
                 chains=2,
                 discard_tuned_samples=False,
@@ -221,7 +224,7 @@ def test_tuning_reset(self):
             idata = pm.sample(
                 tune=1000,
                 draws=500,
-                step=DEMetropolisZ(tune="scaling", scaling=0.002),
+                step=DEMetropolisZ(tune="scaling", scaling=0.002, rng=SEED),
                 cores=1,
                 chains=3,
                 discard_tuned_samples=False,
@@ -245,7 +248,7 @@ def test_tune_drop_fraction(self):
         draws = 200
         with pm.Model() as pmodel:
             pm.Normal("n", 0, 2, size=(3,))
-            step = DEMetropolisZ(tune_drop_fraction=tune_drop_fraction)
+            step = DEMetropolisZ(tune_drop_fraction=tune_drop_fraction, rng=SEED)
             idata = pm.sample(
                 tune=tune, draws=draws, step=step, cores=1, chains=1, discard_tuned_samples=False
             )
@@ -292,7 +295,7 @@ def test_step_discrete(self):
         unc = np.diag(C) ** 0.5
         check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0))
         with model:
-            step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal)
+            step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal, rng=123456)
             idata = pm.sample(
                 tune=1000,
                 draws=2000,
@@ -311,7 +314,7 @@ def test_step_categorical(self, proposal):
         unc = C**0.5
         check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0))
         with model:
-            step = CategoricalGibbsMetropolis([model.x], proposal=proposal)
+            step = CategoricalGibbsMetropolis([model.x], proposal=proposal, rng=SEED)
             idata = pm.sample(
                 tune=1000,
                 draws=2000,
@@ -329,7 +332,7 @@ def test_step_categorical(self, proposal):
         [
             (
                 lambda C, _: Metropolis(
-                    S=C, proposal_dist=MultivariateNormalProposal, blocked=True
+                    S=C, proposal_dist=MultivariateNormalProposal, blocked=True, rng=SEED
                 ),
                 4000,
             ),
diff --git a/tests/step_methods/test_slicer.py b/tests/step_methods/test_slicer.py
index 80435573c0..899d4ec9ec 100644
--- a/tests/step_methods/test_slicer.py
+++ b/tests/step_methods/test_slicer.py
@@ -12,12 +12,15 @@
 #   See the License for the specific language governing permissions and
 #   limitations under the License.
 
+import numpy as np
 import pytest
 
 from pymc.step_methods.slicer import Slice
 from tests import sampler_fixtures as sf
 from tests.helpers import RVsAssignmentStepsTester, StepMethodTester
 
+SEED = 20240920
+
 
 class TestSliceUniform(sf.SliceFixture, sf.UniformFixture):
     n_samples = 10000
@@ -27,6 +30,7 @@ class TestSliceUniform(sf.SliceFixture, sf.UniformFixture):
     min_n_eff = 5000
     rtol = 0.1
     atol = 0.05
+    step_args = {"rng": np.random.default_rng(SEED)}
 
 
 class TestStepSlicer(StepMethodTester):

From c399241d73e5f7eed193faebd24c6f3e0b979f54 Mon Sep 17 00:00:00 2001
From: Luciano Paz <luciano.paz.neuro@gmail.com>
Date: Thu, 19 Sep 2024 09:53:06 +0200
Subject: [PATCH 3/7] Add sampling state base classes

---
 .github/workflows/tests.yml      |   3 +-
 pymc/step_methods/state.py       |  99 +++++++++++++++++++
 tests/helpers.py                 |  19 ++++
 tests/step_methods/test_state.py | 158 +++++++++++++++++++++++++++++++
 4 files changed, 278 insertions(+), 1 deletion(-)
 create mode 100644 pymc/step_methods/state.py
 create mode 100644 tests/step_methods/test_state.py

diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 9ac9eff143..0956f17b60 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -103,6 +103,7 @@ jobs:
             tests/ode/test_ode.py
             tests/ode/test_utils.py
             tests/step_methods/hmc/test_quadpotential.py
+            tests/step_methods/test_state.py
 
           - |
             tests/backends/test_mcbackend.py
@@ -197,7 +198,7 @@ jobs:
           - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py
           - tests/model/test_core.py tests/sampling/test_mcmc.py
           - tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py
-          - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py
+          - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py
 
       fail-fast: false
     runs-on: ${{ matrix.os }}
diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py
new file mode 100644
index 0000000000..9b85d7784b
--- /dev/null
+++ b/pymc/step_methods/state.py
@@ -0,0 +1,99 @@
+#   Copyright 2024 The PyMC Developers
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+from copy import deepcopy
+from dataclasses import Field, dataclass, fields
+from typing import Any, ClassVar
+
+import numpy as np
+
+dataclass_state = dataclass(kw_only=True)
+
+
+@dataclass_state
+class DataClassState:
+    __dataclass_fields__: ClassVar[dict[str, Field[Any]]] = {}
+
+
+def equal_dataclass_values(v1, v2):
+    if v1.__class__ != v2.__class__:
+        return False
+    if isinstance(v1, (list, tuple)):  # noqa: UP038
+        return len(v1) == len(v2) and all(
+            equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True)
+        )
+    elif isinstance(v1, dict):
+        if set(v1) != set(v2):
+            return False
+        return all(equal_dataclass_values(v1[k], v2[k]) for k in v1)
+    elif isinstance(v1, np.ndarray):
+        return bool(np.array_equal(v1, v2, equal_nan=True))
+    elif isinstance(v1, np.random.Generator):
+        return equal_dataclass_values(v1.bit_generator.state, v2.bit_generator.state)
+    elif isinstance(v1, DataClassState):
+        return set(fields(v1)) == set(fields(v2)) and all(
+            equal_dataclass_values(getattr(v1, f1.name), getattr(v2, f2.name))
+            for f1, f2 in zip(fields(v1), fields(v2), strict=True)
+        )
+    else:
+        return v1 == v2
+
+
+class WithSamplingState:
+    """Mixin class that adds the ``sampling_state`` property to an object.
+
+    The object's type must define the ``_state_class`` as a valid
+    :py:class:`~pymc.step_method.DataClassState`. Once that happens, the
+    object's ``sampling_state`` property can be read or set to get
+    the state represented as objects of the ``_state_class`` type.
+    """
+
+    _state_class: type[DataClassState] = DataClassState
+
+    @property
+    def sampling_state(self) -> DataClassState:
+        state_class = self._state_class
+        kwargs = {}
+        for field in fields(state_class):
+            val = getattr(self, field.name)
+            if isinstance(val, WithSamplingState):
+                _val = val.sampling_state
+            else:
+                _val = val
+            kwargs[field.name] = deepcopy(_val)
+        return state_class(**kwargs)
+
+    @sampling_state.setter
+    def sampling_state(self, state: DataClassState):
+        state_class = self._state_class
+        assert isinstance(
+            state, state_class
+        ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
+        for field in fields(state_class):
+            state_val = deepcopy(getattr(state, field.name))
+            self_val = getattr(self, field.name)
+            is_frozen = field.metadata.get("frozen", False)
+            if is_frozen:
+                if not equal_dataclass_values(state_val, self_val):
+                    raise ValueError(
+                        "The received sampling state must have the same values for the "
+                        f"frozen fields. Field {field.name!r} has different values. "
+                        f"Expected {self_val} but got {state_val}"
+                    )
+            else:
+                if isinstance(state_val, DataClassState):
+                    assert isinstance(self_val, WithSamplingState)
+                    self_val.sampling_state = state_val
+                    setattr(self, field.name, self_val)
+                else:
+                    setattr(self, field.name, state_val)
diff --git a/tests/helpers.py b/tests/helpers.py
index c0f210bf8c..c14433711b 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -17,6 +17,11 @@
 import tempfile
 import warnings
 
+<<<<<<< HEAD
+=======
+from copy import deepcopy
+from dataclasses import fields
+>>>>>>> 741b38626 (Fixup state)
 from logging.handlers import BufferingHandler
 
 import numpy as np
@@ -28,6 +33,7 @@
 
 import pymc as pm
 
+from pymc.step_methods.state import equal_dataclass_values
 from pymc.testing import fast_unstable_sampling_mode
 from tests.models import mv_simple, mv_simple_coarse
 
@@ -177,3 +183,16 @@ def continuous_steps(self, step, step_kwargs):
             assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(
                 step([c1, c2], **step_kwargs).vars
             )
+
+
+def equal_sampling_states(this, other):
+    if this.__class__ != other.__class__:
+        return False
+    this_fields = set([f.name for f in fields(this)])
+    other_fields = set([f.name for f in fields(other)])
+    for field in this_fields:
+        this_val = getattr(this, field)
+        other_val = getattr(other, field)
+        if not equal_dataclass_values(this_val, other_val):
+            return False
+    return this_fields == other_fields
diff --git a/tests/step_methods/test_state.py b/tests/step_methods/test_state.py
new file mode 100644
index 0000000000..e6a39264db
--- /dev/null
+++ b/tests/step_methods/test_state.py
@@ -0,0 +1,158 @@
+#   Copyright 2024 The PyMC Developers
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+from dataclasses import field
+
+import numpy as np
+import pytest
+
+from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
+from tests.helpers import equal_sampling_states
+
+
+@dataclass_state
+class State1(DataClassState):
+    a: int
+    b: float
+    c: str
+    d: np.ndarray
+    e: list
+    f: dict
+
+
+@dataclass_state
+class State2(DataClassState):
+    mutable_field: float
+    state1: State1
+    extra_info1: np.ndarray = field(metadata={"frozen": True})
+    extra_info2: list = field(metadata={"frozen": True})
+    extra_info3: dict = field(metadata={"frozen": True})
+
+
+class A(WithSamplingState):
+    _state_class = State1
+
+    def __init__(self, a=1, b=2.0, c="c", d=None, e=None, f=None):
+        self.a = a
+        self.b = b
+        self.c = c
+        if d is None:
+            d = np.array([1, 2])
+        if e is None:
+            e = [1, 2, 3]
+        if f is None:
+            f = {"a": 1, "b": "c"}
+        self.d = d
+        self.e = e
+        self.f = f
+
+
+class B(WithSamplingState):
+    _state_class = State2
+
+    def __init__(
+        self,
+        a=1,
+        b=2.0,
+        c="c",
+        d=None,
+        e=None,
+        f=None,
+        mutable_field=1.0,
+        extra_info1=None,
+        extra_info2=None,
+        extra_info3=None,
+    ):
+        self.state1 = A(a=a, b=b, c=c, d=d, e=e, f=f)
+        self.mutable_field = mutable_field
+        if extra_info1 is None:
+            extra_info1 = np.array([3, 4, 5])
+        if extra_info2 is None:
+            extra_info2 = [5, 6, 7]
+        if extra_info3 is None:
+            extra_info3 = {"foo": "bar"}
+        self.extra_info1 = extra_info1
+        self.extra_info2 = extra_info2
+        self.extra_info3 = extra_info3
+
+
+@dataclass_state
+class RngState(DataClassState):
+    rng: np.random.Generator
+
+
+class Step(WithSamplingState):
+    _state_class = RngState
+
+    def __init__(self, rng=None):
+        self.rng = np.random.default_rng(rng)
+
+
+def test_sampling_state():
+    b1 = B()
+    b2 = B(mutable_field=2.0)
+    b3 = B(c=1, extra_info1=np.array([10, 20]))
+    b4 = B(a=2, b=3.0, c="d")
+    b5 = B(c=1)
+    b6 = B(f={"a": 1, "b": "c", "d": None})
+
+    b1_state = b1.sampling_state
+    b2_state = b2.sampling_state
+    b3_state = b3.sampling_state
+    b4_state = b4.sampling_state
+
+    assert equal_sampling_states(b1_state.state1, b2_state.state1)
+    assert not equal_sampling_states(b1_state, b2_state)
+    assert not equal_sampling_states(b1_state, b3_state)
+    assert not equal_sampling_states(b1_state, b4_state)
+
+    b1.sampling_state = b2_state
+    assert equal_sampling_states(b1.sampling_state, b2_state)
+
+    expected_error_message = (
+        "The received sampling state must have the same values for the "
+        "frozen fields. Field 'extra_info1' has different values. "
+        r"Expected \[3 4 5\] but got \[10 20\]"
+    )
+    with pytest.raises(ValueError, match=expected_error_message):
+        b1.sampling_state = b3_state
+
+    with pytest.raises(AssertionError, match="Encountered invalid state class"):
+        b1.sampling_state = b1_state.state1
+
+    b1.sampling_state = b4_state
+    assert equal_sampling_states(b1.sampling_state, b4_state)
+    assert not equal_sampling_states(b1.sampling_state, b5.sampling_state)
+    assert not equal_sampling_states(b1.sampling_state, b6.sampling_state)
+
+
+@pytest.mark.parametrize(
+    "step",
+    [
+        Step(),
+        Step(1),
+        Step(np.random.Generator(np.random.Philox(1))),
+    ],
+    ids=["default_rng", "default_rng(1)", "philox"],
+)
+def test_sampling_state_rng(step):
+    original_state = step.sampling_state
+    values1 = step.rng.random(100)
+
+    final_state = step.sampling_state
+    assert not equal_sampling_states(original_state, final_state)
+
+    step.sampling_state = original_state
+    values2 = step.rng.random(100)
+    assert np.array_equal(values1, values2, equal_nan=True)
+    assert equal_sampling_states(step.sampling_state, final_state)

From 5f6ac334d11902c5678f5d8287bfe06666145bf7 Mon Sep 17 00:00:00 2001
From: Luciano Paz <luciano.paz.neuro@gmail.com>
Date: Thu, 19 Sep 2024 09:54:29 +0200
Subject: [PATCH 4/7] Add step method state

---
 pymc/step_methods/arraystep.py |  4 ++--
 pymc/step_methods/compound.py  | 42 +++++++++++++++++++++++++++++++---
 tests/helpers.py               | 26 ++++++++++++++++++---
 3 files changed, 64 insertions(+), 8 deletions(-)

diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py
index 602dfd6e51..bddf02f155 100644
--- a/pymc/step_methods/arraystep.py
+++ b/pymc/step_methods/arraystep.py
@@ -142,8 +142,8 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
         """
         self.population = None
         self.this_chain = None
-        self.other_chains = None
-        return super().__init__(vars, shared, blocked)
+        self.other_chains: list[int] | None = None
+        return super().__init__(vars, shared, blocked, rng=rng)
 
     def link_population(self, population, chain_index):
         """Links the sampler to the population.
diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py
index 1c1d6fbb50..87dd30420a 100644
--- a/pymc/step_methods/compound.py
+++ b/pymc/step_methods/compound.py
@@ -31,7 +31,8 @@
 
 from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType
 from pymc.model import modelcontext
-from pymc.util import get_random_generator
+from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
+from pymc.util import RandomGenerator, get_random_generator
 
 __all__ = ("Competence", "CompoundStep")
 
@@ -87,7 +88,12 @@ def infer_warn_stats_info(
     return stats_dtypes, sds
 
 
-class BlockedStep(ABC):
+@dataclass_state
+class StepMethodState(DataClassState):
+    rng: np.random.Generator
+
+
+class BlockedStep(ABC, WithSamplingState):
     stats_dtypes: list[dict[str, type]] = []
     """A list containing <=1 dictionary that maps stat names to dtypes.
 
@@ -195,6 +201,9 @@ def stop_tuning(self):
         if hasattr(self, "tune"):
             self.tune = False
 
+    def set_rng(self, rng: RandomGenerator):
+        self.rng = get_random_generator(rng, copy=False)
+
 
 def flat_statname(sampler_idx: int, sname: str) -> str:
     """Get the flat-stats name for a samplers stat."""
@@ -215,10 +224,20 @@ def get_stats_dtypes_shapes_from_steps(
     return result
 
 
-class CompoundStep:
+@dataclass_state
+class CompoundStepState(DataClassState):
+    methods: list[StepMethodState]
+
+    def __init__(self, methods: list[StepMethodState]):
+        self.methods = methods
+
+
+class CompoundStep(WithSamplingState):
     """Step method composed of a list of several other step
     methods applied in sequence."""
 
+    _state_class = CompoundStepState
+
     def __init__(self, methods):
         self.methods = list(methods)
         self.stats_dtypes = []
@@ -250,10 +269,27 @@ def reset_tuning(self):
             if hasattr(method, "reset_tuning"):
                 method.reset_tuning()
 
+    @property
+    def sampling_state(self) -> DataClassState:
+        return CompoundStepState(methods=[method.sampling_state for method in self.methods])
+
+    @sampling_state.setter
+    def sampling_state(self, state: DataClassState):
+        assert isinstance(
+            state, self._state_class
+        ), f"Invalid sampling state class {type(state)}. Expected {self._state_class}"
+        for method, state_method in zip(self.methods, state.methods):
+            method.sampling_state = state_method
+
     @property
     def vars(self) -> list[Variable]:
         return [var for method in self.methods for var in method.vars]
 
+    def set_rng(self, rng: RandomGenerator):
+        _rngs = get_random_generator(rng, copy=False).spawn(len(self.methods))
+        for method, _rng in zip(self.methods, _rngs):
+            method.set_rng(_rng)
+
 
 def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]:
     """Flatten a hierarchy of step methods to a list."""
diff --git a/tests/helpers.py b/tests/helpers.py
index c14433711b..b9d5c6d019 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -17,11 +17,8 @@
 import tempfile
 import warnings
 
-<<<<<<< HEAD
-=======
 from copy import deepcopy
 from dataclasses import fields
->>>>>>> 741b38626 (Fixup state)
 from logging.handlers import BufferingHandler
 
 import numpy as np
@@ -146,6 +143,21 @@ def step_continuous(self, step_fn, draws, chains=1, tune=1000):
         _, model_coarse, _ = mv_simple_coarse()
         with model:
             step = step_fn(C, model_coarse)
+            orig_step = deepcopy(step)
+            orig_state = step.sampling_state
+            assert equal_sampling_states(step.sampling_state, orig_state)
+
+            ip = model.initial_point()
+            value1, _ = step.step(ip)
+            final_state = step.sampling_state
+            step.sampling_state = orig_state
+
+            value2, _ = step.step(ip)
+
+            assert equal_sampling_states(step.sampling_state, final_state)
+            assert equal_dataclass_values(value1, value2)
+
+            step.sampling_state = orig_state
             with warnings.catch_warnings():
                 warnings.filterwarnings("ignore", "More chains .* than draws .*", UserWarning)
                 idata = pm.sample(
@@ -165,6 +177,14 @@ def step_continuous(self, step_fn, draws, chains=1, tune=1000):
             self.check_stat(check, idata)
             self.check_stat_dtype(idata, step)
 
+            curr_state = step.sampling_state
+            assert not equal_sampling_states(orig_state, curr_state)
+
+            orig_step.sampling_state = curr_state
+
+            assert equal_sampling_states(orig_step.sampling_state, curr_state)
+            assert orig_step.sampling_state is not curr_state
+
 
 class RVsAssignmentStepsTester:
     """

From ca2c60b44fa19fcdede6d51d69a78e9be1d3c449 Mon Sep 17 00:00:00 2001
From: Luciano Paz <luciano.paz.neuro@gmail.com>
Date: Thu, 19 Sep 2024 09:54:59 +0200
Subject: [PATCH 5/7] Add metropolis sampling state

---
 pymc/step_methods/metropolis.py       | 110 +++++++++++++++++++++++---
 tests/models.py                       |  11 +++
 tests/step_methods/test_metropolis.py |  57 ++++++++++++-
 3 files changed, 166 insertions(+), 12 deletions(-)

diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py
index aa5101dbb0..e6f9d9dc77 100644
--- a/pymc/step_methods/metropolis.py
+++ b/pymc/step_methods/metropolis.py
@@ -12,6 +12,8 @@
 #   See the License for the specific language governing permissions and
 #   limitations under the License.
 from collections.abc import Callable
+from dataclasses import field
+from typing import Any
 
 import numpy as np
 import numpy.random as nr
@@ -40,7 +42,8 @@
     StatsType,
     metrop_select,
 )
-from pymc.step_methods.compound import Competence
+from pymc.step_methods.compound import Competence, StepMethodState
+from pymc.step_methods.state import dataclass_state
 
 __all__ = [
     "Metropolis",
@@ -111,11 +114,31 @@ def __call__(self, num_draws=None, rng: np.random.Generator | None = None):
             return np.dot(self.chol, b)
 
 
+@dataclass_state
+class MetropolisState(StepMethodState):
+    scaling: np.ndarray
+    tune: bool
+    steps_until_tune: float
+    tune_interval: float
+    accepted_sum: np.ndarray
+    accept_rate_iter: np.ndarray
+    accepted_iter: np.ndarray
+    enum_dims: np.ndarray
+
+    discrete: np.ndarray = field(metadata={"frozen": True})
+    any_discrete: bool = field(metadata={"frozen": True})
+    all_discrete: bool = field(metadata={"frozen": True})
+    elemwise_update: bool = field(metadata={"frozen": True})
+    _untuned_settings: dict[str, np.ndarray | float] = field(metadata={"frozen": True})
+    mode: Any = field(metadata={"frozen": True})
+
+
 class Metropolis(ArrayStepShared):
     """Metropolis-Hastings sampling step"""
 
     name = "metropolis"
 
+    default_blocked = False
     stats_dtypes_shapes = {
         "accept": (np.float64, []),
         "accepted": (np.float64, []),
@@ -123,6 +146,8 @@ class Metropolis(ArrayStepShared):
         "scaling": (np.float64, []),
     }
 
+    _state_class = MetropolisState
+
     def __init__(
         self,
         vars=None,
@@ -346,6 +371,15 @@ def tune(scale, acc_rate):
     )
 
 
+@dataclass_state
+class BinaryMetropolisState(StepMethodState):
+    tune: bool
+    accepted: int
+    scaling: float
+    tune_interval: int
+    steps_until_tune: int
+
+
 class BinaryMetropolis(ArrayStep):
     """Metropolis-Hastings optimized for binary variables
 
@@ -375,7 +409,9 @@ class BinaryMetropolis(ArrayStep):
         "p_jump": (np.float64, []),
     }
 
-    def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
+    _state_class = BinaryMetropolisState
+
+    def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None, rng=None):
         model = pm.modelcontext(model)
 
         self.scaling = scaling
@@ -389,7 +425,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
         if not all([v.dtype in pm.discrete_types for v in vars]):
             raise ValueError("All variables must be Bernoulli for BinaryMetropolis")
 
-        super().__init__(vars, [model.compile_logp()])
+        super().__init__(vars, [model.compile_logp()], rng=rng)
 
     def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
         logp = args[0]
@@ -445,6 +481,14 @@ def competence(var):
         return Competence.INCOMPATIBLE
 
 
+@dataclass_state
+class BinaryGibbsMetropolisState(StepMethodState):
+    tune: bool
+    transit_p: int
+    shuffle_dims: bool
+    order: list
+
+
 class BinaryGibbsMetropolis(ArrayStep):
     """A Metropolis-within-Gibbs step method optimized for binary variables
 
@@ -472,7 +516,9 @@ class BinaryGibbsMetropolis(ArrayStep):
         "tune": (bool, []),
     }
 
-    def __init__(self, vars, order="random", transit_p=0.8, model=None):
+    _state_class = BinaryGibbsMetropolisState
+
+    def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None):
         model = pm.modelcontext(model)
 
         # Doesn't actually tune, but it's required to emit a sampler stat
@@ -498,7 +544,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
         if not all([v.dtype in pm.discrete_types for v in vars]):
             raise ValueError("All variables must be binary for BinaryGibbsMetropolis")
 
-        super().__init__(vars, [model.compile_logp()])
+        super().__init__(vars, [model.compile_logp()], rng=rng)
 
     def reset_tuning(self):
         # There are no tuning parameters in this step method.
@@ -557,6 +603,13 @@ def competence(var):
         return Competence.INCOMPATIBLE
 
 
+@dataclass_state
+class CategoricalGibbsMetropolisState(StepMethodState):
+    shuffle_dims: bool
+    dimcats: list[tuple]
+    tune: bool
+
+
 class CategoricalGibbsMetropolis(ArrayStep):
     """A Metropolis-within-Gibbs step method optimized for categorical variables.
 
@@ -573,6 +626,8 @@ class CategoricalGibbsMetropolis(ArrayStep):
         "tune": (bool, []),
     }
 
+    _state_class = CategoricalGibbsMetropolisState
+
     def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None):
         model = pm.modelcontext(model)
 
@@ -728,6 +783,18 @@ def competence(var):
         return Competence.INCOMPATIBLE
 
 
+@dataclass_state
+class DEMetropolisState(StepMethodState):
+    scaling: np.ndarray
+    lamb: float
+    tune: str | None
+    tune_interval: int
+    steps_until_tune: int
+    accepted: int
+
+    mode: Any = field(metadata={"frozen": True})
+
+
 class DEMetropolis(PopulationArrayStepShared):
     """
     Differential Evolution Metropolis sampling step.
@@ -778,6 +845,8 @@ class DEMetropolis(PopulationArrayStepShared):
         "lambda": (np.float64, []),
     }
 
+    _state_class = DEMetropolisState
+
     def __init__(
         self,
         vars=None,
@@ -789,6 +858,7 @@ def __init__(
         tune_interval=100,
         model=None,
         mode=None,
+        rng=None,
         **kwargs,
     ):
         model = pm.modelcontext(model)
@@ -824,7 +894,7 @@ def __init__(
 
         shared = pm.make_shared_replacements(initial_values, vars, model)
         self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
-        super().__init__(vars, shared)
+        super().__init__(vars, shared, rng=rng)
 
     def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
         point_map_info = q0.point_map_info
@@ -843,9 +913,11 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
 
         # differential evolution proposal
         # select two other chains
-        ir1, ir2 = np.random.choice(self.other_chains, 2, replace=False)
-        r1 = DictToArrayBijection.map(self.population[ir1])
-        r2 = DictToArrayBijection.map(self.population[ir2])
+        if self.other_chains is None:  # pragma: no cover
+            raise RuntimeError("Population sampler has not been linked to the other chains")
+        ir1, ir2 = self.rng.choice(self.other_chains, 2, replace=False)
+        r1 = DictToArrayBijection.map(self.population[ir1])  # type: ignore
+        r2 = DictToArrayBijection.map(self.population[ir2])  # type: ignore
         # propose a jump
         q = floatX(q0d + self.lamb * (r1.data - r2.data) + epsilon)
 
@@ -872,6 +944,21 @@ def competence(var, has_grad):
         return Competence.COMPATIBLE
 
 
+@dataclass_state
+class DEMetropolisZState(StepMethodState):
+    scaling: np.ndarray
+    lamb: float
+    tune: bool
+    tune_target: str | None
+    tune_interval: int
+    steps_until_tune: int
+    accepted: int
+    _history: list
+
+    _untuned_settings: dict[str, np.ndarray | float] = field(metadata={"frozen": True})
+    mode: Any = field(metadata={"frozen": True})
+
+
 class DEMetropolisZ(ArrayStepShared):
     """
     Adaptive Differential Evolution Metropolis sampling step that uses the past to inform jumps.
@@ -925,6 +1012,8 @@ class DEMetropolisZ(ArrayStepShared):
         "lambda": (np.float64, []),
     }
 
+    _state_class = DEMetropolisZState
+
     def __init__(
         self,
         vars=None,
@@ -937,6 +1026,7 @@ def __init__(
         tune_drop_fraction: float = 0.9,
         model=None,
         mode=None,
+        rng=None,
         **kwargs,
     ):
         model = pm.modelcontext(model)
@@ -984,7 +1074,7 @@ def __init__(
 
         shared = pm.make_shared_replacements(initial_values, vars, model)
         self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
-        super().__init__(vars, shared)
+        super().__init__(vars, shared, rng=rng)
 
     def reset_tuning(self):
         """Resets the tuned sampler parameters and history to their initial values."""
diff --git a/tests/models.py b/tests/models.py
index 24f80c7c0b..b66c1dc67d 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -186,3 +186,14 @@ def simple_normal(bounded_prior=False):
         pm.Normal("X_obs", mu=mu_i, sigma=sigma, observed=x0)
 
     return model.initial_point(), model, None
+
+
+def simple_binary():
+    p1 = 0.5
+    p2 = 0.5
+
+    with pm.Model() as model:
+        pm.Bernoulli("d1", p=p1)
+        pm.Bernoulli("d2", p=p2)
+
+    return model.initial_point(), model, (p1, p2)
diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py
index f414a534e8..a73538a61b 100644
--- a/tests/step_methods/test_metropolis.py
+++ b/tests/step_methods/test_metropolis.py
@@ -14,6 +14,8 @@
 
 import warnings
 
+from copy import deepcopy
+
 import arviz as az
 import numpy as np
 import numpy.testing as npt
@@ -24,6 +26,7 @@
 
 from pymc.step_methods.metropolis import (
     BinaryGibbsMetropolis,
+    BinaryMetropolis,
     CategoricalGibbsMetropolis,
     DEMetropolis,
     DEMetropolisZ,
@@ -31,10 +34,17 @@
     MultivariateNormalProposal,
     NormalProposal,
 )
+from pymc.step_methods.state import equal_dataclass_values
 from pymc.testing import fast_unstable_sampling_mode
 from tests import sampler_fixtures as sf
-from tests.helpers import RVsAssignmentStepsTester, StepMethodTester
-from tests.models import mv_simple, mv_simple_discrete, simple_categorical
+from tests.helpers import RVsAssignmentStepsTester, StepMethodTester, equal_sampling_states
+from tests.models import (
+    mv_simple,
+    mv_simple_discrete,
+    simple_binary,
+    simple_categorical,
+    simple_model,
+)
 
 SEED = sum(ord(c) for c in "test_metropolis")
 
@@ -47,6 +57,7 @@ class TestMetropolisUniform(sf.MetropolisFixture, sf.UniformFixture):
     min_n_eff = 10000
     rtol = 0.1
     atol = 0.05
+    ks_thin = 10
     step_args = {"rng": np.random.default_rng(SEED)}
 
 
@@ -367,3 +378,45 @@ def test_discrete_steps(self, step, step_kwargs):
     )
     def test_continuous_steps(self, step, step_kwargs):
         self.continuous_steps(step, step_kwargs)
+
+
+@pytest.mark.parametrize(
+    ["step_method", "model_fn"],
+    [
+        [Metropolis, simple_model],
+        [BinaryMetropolis, simple_binary],
+        [BinaryGibbsMetropolis, simple_binary],
+        [CategoricalGibbsMetropolis, simple_categorical],
+        [DEMetropolis, simple_model],
+        [DEMetropolisZ, simple_model],
+    ],
+)
+def test_sampling_state(step_method, model_fn):
+    with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
+        initial_point, model, _ = model_fn()
+        with model:
+            sampler = step_method(model.value_vars)
+            if hasattr(sampler, "link_population"):
+                sampler.link_population([initial_point] * 100, 0)
+            sampler_orig = deepcopy(sampler)
+            state_orig = sampler_orig.sampling_state
+
+            sample1, stat1 = sampler.step(initial_point)
+            sampler.tune = False
+
+            final_state1 = sampler.sampling_state
+
+            assert not equal_sampling_states(final_state1, state_orig)
+
+            sampler.sampling_state = state_orig
+
+            assert equal_sampling_states(sampler.sampling_state, state_orig)
+
+            sample2, stat2 = sampler.step(initial_point)
+            sampler.tune = False
+
+            final_state2 = sampler.sampling_state
+
+            assert equal_sampling_states(final_state1, final_state2)
+            assert equal_dataclass_values(sample1, sample2)
+            assert equal_dataclass_values(stat1, stat2)

From 04fbe64fb2a602290fb088bcfe69c1d7c620fcea Mon Sep 17 00:00:00 2001
From: Luciano Paz <luciano.paz.neuro@gmail.com>
Date: Thu, 19 Sep 2024 09:55:22 +0200
Subject: [PATCH 6/7] Add slice sampling state

---
 pymc/step_methods/slicer.py | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py
index 3e096aeb9f..2ea4b1f55f 100644
--- a/pymc/step_methods/slicer.py
+++ b/pymc/step_methods/slicer.py
@@ -21,7 +21,8 @@
 from pymc.model import modelcontext
 from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements
 from pymc.step_methods.arraystep import ArrayStepShared
-from pymc.step_methods.compound import Competence
+from pymc.step_methods.compound import Competence, StepMethodState
+from pymc.step_methods.state import dataclass_state
 from pymc.util import get_value_vars_from_user_vars
 from pymc.vartypes import continuous_types
 
@@ -30,6 +31,17 @@
 LOOP_ERR_MSG = "max slicer iters %d exceeded"
 
 
+dataclass_state
+
+
+@dataclass_state
+class SliceState(StepMethodState):
+    w: np.ndarray
+    tune: bool
+    n_tunes: float
+    iter_limit: float
+
+
 class Slice(ArrayStepShared):
     """
     Univariate slice sampler step method.
@@ -61,6 +73,8 @@ class Slice(ArrayStepShared):
         "nstep_in": (int, []),
     }
 
+    _state_class = SliceState
+
     def __init__(
         self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs
     ):

From af74f2cd46c04a66f3e8e7cad7c3444b44a725c1 Mon Sep 17 00:00:00 2001
From: Luciano Paz <luciano.paz.neuro@gmail.com>
Date: Thu, 19 Sep 2024 09:55:46 +0200
Subject: [PATCH 7/7] Add HMC sampling state

---
 pymc/step_methods/hmc/base_hmc.py      |  34 +++++--
 pymc/step_methods/hmc/hmc.py           |  10 +-
 pymc/step_methods/hmc/nuts.py          |  10 +-
 pymc/step_methods/hmc/quadpotential.py | 123 ++++++++++++++++++++++---
 pymc/step_methods/step_sizes.py        |  21 ++++-
 5 files changed, 178 insertions(+), 20 deletions(-)

diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py
index b320ed8194..87daff649c 100644
--- a/pymc/step_methods/hmc/base_hmc.py
+++ b/pymc/step_methods/hmc/base_hmc.py
@@ -27,14 +27,19 @@
 from pymc.model import Point, modelcontext
 from pymc.pytensorf import floatX
 from pymc.stats.convergence import SamplerWarning, WarningType
-from pymc.step_methods import step_sizes
 from pymc.step_methods.arraystep import GradientSharedStep
 from pymc.step_methods.compound import StepMethodState
 from pymc.step_methods.hmc import integration
 from pymc.step_methods.hmc.integration import IntegrationError, State
-from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
+from pymc.step_methods.hmc.quadpotential import (
+    PotentialState,
+    QuadPotentialDiagAdapt,
+    quad_potential,
+)
+from pymc.step_methods.state import dataclass_state
+from pymc.step_methods.step_sizes import DualAverageAdaptation, StepSizeState
 from pymc.tuning import guess_scaling
-from pymc.util import get_value_vars_from_user_vars
+from pymc.util import RandomGenerator, get_random_generator, get_value_vars_from_user_vars
 
 logger = logging.getLogger(__name__)
 
@@ -53,12 +58,27 @@ class HMCStepData(NamedTuple):
     stats: dict[str, Any]
 
 
+@dataclass_state
+class BaseHMCState(StepMethodState):
+    adapt_step_size: bool
+    Emax: float
+    iter_count: int
+    step_size: np.ndarray
+    step_adapt: StepSizeState
+    target_accept: float
+    tune: bool
+    potential: PotentialState
+    _num_divs_sample: int
+
+
 class BaseHMC(GradientSharedStep):
     """Superclass to implement Hamiltonian/hybrid monte carlo."""
 
     integrator: integration.CpuLeapfrogIntegrator
     default_blocked = True
 
+    _state_class = BaseHMCState
+
     def __init__(
         self,
         vars=None,
@@ -134,9 +154,7 @@ def __init__(
         size = sum(v.size for v in nuts_vars)
 
         self.step_size = step_scale / (size**0.25)
-        self.step_adapt = step_sizes.DualAverageAdaptation(
-            self.step_size, target_accept, gamma, k, t0
-        )
+        self.step_adapt = DualAverageAdaptation(self.step_size, target_accept, gamma, k, t0)
         self.target_accept = target_accept
         self.tune = True
 
@@ -268,3 +286,7 @@ def reset_tuning(self, start=None):
     def reset(self, start=None):
         self.tune = True
         self.potential.reset()
+
+    def set_rng(self, rng: RandomGenerator):
+        self.rng = get_random_generator(rng, copy=False)
+        self.potential.set_rng(self.rng.spawn(1)[0])
diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py
index 106faee501..a5ebbd7a8c 100644
--- a/pymc/step_methods/hmc/hmc.py
+++ b/pymc/step_methods/hmc/hmc.py
@@ -14,14 +14,16 @@
 
 from __future__ import annotations
 
+from dataclasses import field
 from typing import Any
 
 import numpy as np
 
 from pymc.stats.convergence import SamplerWarning
 from pymc.step_methods.compound import Competence
-from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
+from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
 from pymc.step_methods.hmc.integration import IntegrationError, State
+from pymc.step_methods.state import dataclass_state
 from pymc.vartypes import discrete_types
 
 __all__ = ["HamiltonianMC"]
@@ -31,6 +33,12 @@ def unif(step_size, elow=0.85, ehigh=1.15, rng: np.random.Generator | None = Non
     return (rng or np.random).uniform(elow, ehigh) * step_size
 
 
+@dataclass_state
+class HamiltonianMCState(BaseHMCState):
+    path_length: float = field(metadata={"frozen": True})
+    max_steps: int = field(metadata={"frozen": True})
+
+
 class HamiltonianMC(BaseHMC):
     R"""A sampler for continuous variables based on Hamiltonian mechanics.
 
diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py
index 3c4b4e6800..9bcde95104 100644
--- a/pymc/step_methods/hmc/nuts.py
+++ b/pymc/step_methods/hmc/nuts.py
@@ -15,6 +15,7 @@
 from __future__ import annotations
 
 from collections import namedtuple
+from dataclasses import field
 
 import numpy as np
 
@@ -23,13 +24,20 @@
 from pymc.stats.convergence import SamplerWarning
 from pymc.step_methods.compound import Competence
 from pymc.step_methods.hmc import integration
-from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
+from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
 from pymc.step_methods.hmc.integration import IntegrationError, State
+from pymc.step_methods.state import dataclass_state
 from pymc.vartypes import continuous_types
 
 __all__ = ["NUTS"]
 
 
+@dataclass_state
+class NUTSState(BaseHMCState):
+    max_treedepth: int = field(metadata={"frozen": True})
+    early_max_treedepth: int = field(metadata={"frozen": True})
+
+
 class NUTS(BaseHMC):
     r"""A sampler for continuous variables based on Hamiltonian mechanics.
 
diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py
index abddaaf35f..05da188f9b 100644
--- a/pymc/step_methods/hmc/quadpotential.py
+++ b/pymc/step_methods/hmc/quadpotential.py
@@ -16,7 +16,8 @@
 
 import warnings
 
-from typing import overload
+from dataclasses import field
+from typing import Any, overload
 
 import numpy as np
 import pytensor
@@ -25,6 +26,8 @@
 from scipy.sparse import issparse
 
 from pymc.pytensorf import floatX
+from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
+from pymc.util import RandomGenerator, get_random_generator
 
 __all__ = [
     "quad_potential",
@@ -100,11 +103,18 @@ def __str__(self):
         return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."
 
 
-class QuadPotential:
+@dataclass_state
+class PotentialState(DataClassState):
+    rng: np.random.Generator
+
+
+class QuadPotential(WithSamplingState):
     dtype: np.dtype
 
+    _state_class = PotentialState
+
     def __init__(self, rng=None):
-        self.rng = np.random.default_rng(rng)
+        self.rng = get_random_generator(rng)
 
     @overload
     def velocity(self, x: np.ndarray, out: None) -> np.ndarray: ...
@@ -157,15 +167,42 @@ def reset(self):
     def stats(self):
         return {"largest_eigval": np.nan, "smallest_eigval": np.nan}
 
+    def set_rng(self, rng: RandomGenerator):
+        self.rng = get_random_generator(rng, copy=False)
+
 
 def isquadpotential(value):
     """Check whether an object might be a QuadPotential object."""
     return isinstance(value, QuadPotential)
 
 
+@dataclass_state
+class QuadPotentialDiagAdaptState(PotentialState):
+    _var: np.ndarray
+    _stds: np.ndarray
+    _inv_stds: np.ndarray
+    _foreground_var: WeightedVarianceState
+    _background_var: WeightedVarianceState
+    _n_samples: int
+    adaptation_window: int
+    _mass_trace: list[np.ndarray] | None
+
+    dtype: Any = field(metadata={"frozen": True})
+    _n: int = field(metadata={"frozen": True})
+    _discard_window: int = field(metadata={"frozen": True})
+    _early_update: int = field(metadata={"frozen": True})
+    _initial_mean: np.ndarray = field(metadata={"frozen": True})
+    _initial_diag: np.ndarray = field(metadata={"frozen": True})
+    _initial_weight: np.ndarray = field(metadata={"frozen": True})
+    adaptation_window_multiplier: float = field(metadata={"frozen": True})
+    _store_mass_matrix_trace: bool = field(metadata={"frozen": True})
+
+
 class QuadPotentialDiagAdapt(QuadPotential):
     """Adapt a diagonal mass matrix from the sample variances."""
 
+    _state_class = QuadPotentialDiagAdaptState
+
     def __init__(
         self,
         n,
@@ -346,9 +383,20 @@ def raise_ok(self, map_info):
             raise ValueError("\n".join(errmsg))
 
 
-class _WeightedVariance:
+@dataclass_state
+class WeightedVarianceState(DataClassState):
+    n_samples: int
+    mean: np.ndarray
+    raw_var: np.ndarray
+
+    _dtype: Any = field(metadata={"frozen": True})
+
+
+class _WeightedVariance(WithSamplingState):
     """Online algorithm for computing mean of variance."""
 
+    _state_class = WeightedVarianceState
+
     def __init__(
         self, nelem, initial_mean=None, initial_variance=None, initial_weight=0, dtype="d"
     ):
@@ -390,7 +438,16 @@ def current_mean(self):
         return self.mean.copy(dtype=self._dtype)
 
 
-class _ExpWeightedVariance:
+@dataclass_state
+class ExpWeightedVarianceState(DataClassState):
+    _alpha: float
+    _mean: np.ndarray
+    _var: np.ndarray
+
+
+class _ExpWeightedVariance(WithSamplingState):
+    _state_class = ExpWeightedVarianceState
+
     def __init__(self, n_vars, *, init_mean, init_var, alpha):
         self._variance = init_var
         self._mean = init_mean
@@ -415,7 +472,18 @@ def current_mean(self, out=None):
         return out
 
 
+@dataclass_state
+class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState):
+    _alpha: float
+    _stop_adaptation: float
+    _variance_estimator: ExpWeightedVarianceState
+
+    _variance_estimator_grad: ExpWeightedVarianceState | None = None
+
+
 class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
+    _state_class = QuadPotentialDiagAdaptExpState
+
     def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None, **kwargs):
         """Set up a diagonal mass matrix.
 
@@ -526,7 +594,7 @@ def __init__(self, v, dtype=None, rng=None):
         self.s = s
         self.inv_s = 1.0 / s
         self.v = v
-        self.rng = np.random.default_rng(rng)
+        self.rng = get_random_generator(rng)
 
     def velocity(self, x, out=None):
         """Compute the current velocity at a position in parameter space."""
@@ -572,7 +640,7 @@ def __init__(self, A, dtype=None, rng=None):
             dtype = pytensor.config.floatX
         self.dtype = dtype
         self.L = floatX(scipy.linalg.cholesky(A, lower=True))
-        self.rng = np.random.default_rng(rng)
+        self.rng = get_random_generator(rng)
 
     def velocity(self, x, out=None):
         """Compute the current velocity at a position in parameter space."""
@@ -621,7 +689,7 @@ def __init__(self, cov, dtype=None, rng=None):
         self._cov = np.array(cov, dtype=self.dtype, copy=True)
         self._chol = scipy.linalg.cholesky(self._cov, lower=True)
         self._n = len(self._cov)
-        self.rng = np.random.default_rng(rng)
+        self.rng = get_random_generator(rng)
 
     def velocity(self, x, out=None):
         """Compute the current velocity at a position in parameter space."""
@@ -646,9 +714,31 @@ def velocity_energy(self, x, v_out):
     __call__ = random
 
 
+@dataclass_state
+class QuadPotentialFullAdaptState(PotentialState):
+    _previous_update: int
+    _cov: np.ndarray
+    _chol: np.ndarray
+    _chol_error: scipy.linalg.LinAlgError | ValueError | None = None
+    _foreground_cov: WeightedCovarianceState
+    _background_cov: WeightedCovarianceState
+    _n_samples: int
+    adaptation_window: int
+
+    dtype: Any = field(metadata={"frozen": True})
+    _n: int = field(metadata={"frozen": True})
+    _update_window: int = field(metadata={"frozen": True})
+    _initial_mean: np.ndarray = field(metadata={"frozen": True})
+    _initial_cov: np.ndarray = field(metadata={"frozen": True})
+    _initial_weight: np.ndarray = field(metadata={"frozen": True})
+    adaptation_window_multiplier: float = field(metadata={"frozen": True})
+
+
 class QuadPotentialFullAdapt(QuadPotentialFull):
     """Adapt a dense mass matrix using the sample covariances."""
 
+    _state_class = QuadPotentialFullAdaptState
+
     def __init__(
         self,
         n,
@@ -689,7 +779,7 @@ def __init__(
         self.adaptation_window_multiplier = float(adaptation_window_multiplier)
         self._update_window = int(update_window)
 
-        self.rng = np.random.default_rng(rng)
+        self.rng = get_random_generator(rng)
 
         self.reset()
 
@@ -742,7 +832,16 @@ def raise_ok(self, vmap):
             raise ValueError(str(self._chol_error))
 
 
-class _WeightedCovariance:
+@dataclass_state
+class WeightedCovarianceState(DataClassState):
+    n_samples: float
+    mean: np.ndarray
+    raw_cov: np.ndarray
+
+    _dtype: Any = field(metadata={"frozen": True})
+
+
+class _WeightedCovariance(WithSamplingState):
     """Online algorithm for computing mean and covariance
 
     This implements the `Welford's algorithm
@@ -752,6 +851,8 @@ class _WeightedCovariance:
 
     """
 
+    _state_class = WeightedCovarianceState
+
     def __init__(
         self,
         nelem,
@@ -827,7 +928,7 @@ def __init__(self, A, rng=None):
             self.size = A.shape[0]
             self.factor = factor = cholmod.cholesky(A)
             self.d_sqrt = np.sqrt(factor.D())
-            self.rng = np.random.default_rng(rng)
+            self.rng = get_random_generator(rng)
 
         def velocity(self, x):
             """Compute the current velocity at a position in parameter space."""
diff --git a/pymc/step_methods/step_sizes.py b/pymc/step_methods/step_sizes.py
index 6c2b7340fd..c0fdb934a3 100644
--- a/pymc/step_methods/step_sizes.py
+++ b/pymc/step_methods/step_sizes.py
@@ -12,14 +12,33 @@
 #   See the License for the specific language governing permissions and
 #   limitations under the License.
 
+
 import numpy as np
 
 from scipy import stats
 
 from pymc.stats.convergence import SamplerWarning, WarningType
+from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
+
+
+@dataclass_state
+class StepSizeState(DataClassState):
+    _log_step: np.ndarray
+    _log_bar: np.ndarray
+    _hbar: float
+    _count: int
+    _mu: np.ndarray
+    _tuned_stats: list
+    _initial_step: np.ndarray
+    _target: float
+    _k: float
+    _t0: float
+    _gamma: float
+
 
+class DualAverageAdaptation(WithSamplingState):
+    _state_class = StepSizeState
 
-class DualAverageAdaptation:
     def __init__(self, initial_step, target, gamma, k, t0):
         self._initial_step = initial_step
         self._target = target