Skip to content

Commit b1436d9

Browse files
committed
Pre-commit fix
1 parent 1bf85f0 commit b1436d9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc/sampling_jax.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _get_batched_jittered_initial_points(
155155
Each item has shape `(chains, *var.shape)`
156156
"""
157157
if isinstance(random_seed, (int, np.integer)):
158-
random_seed = np.random.default_rng(random_seed).integers(2 ** 30, size=chains)
158+
random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains)
159159
elif not isinstance(random_seed, (list, tuple, np.ndarray)):
160160
raise ValueError(f"The `seeds` must be int or array-like. Got {type(random_seed)} instead.")
161161

@@ -457,7 +457,7 @@ def sample_numpyro_nuts(
457457

458458
if random_seed is None:
459459
random_seed = model.rng_seeder.randint(
460-
2 ** 30, dtype=np.int64, size=chains if chains > 1 else None
460+
2**30, dtype=np.int64, size=chains if chains > 1 else None
461461
)
462462

463463
tic1 = datetime.now()

0 commit comments

Comments
 (0)