1
- # pylint: skip-file
2
1
import os
3
2
import re
4
3
import sys
5
4
import warnings
6
5
7
6
from functools import partial
8
- from typing import Callable , List , Optional
7
+ from typing import Callable , Dict , List , Optional , Sequence , Union
9
8
9
+ from pymc .initial_point import StartDict
10
10
from pymc .sampling import _init_jitter
11
11
12
12
xla_flags = os .getenv ("XLA_FLAGS" , "" )
@@ -132,7 +132,7 @@ def _sample_stats_to_xarray(posterior):
132
132
return data
133
133
134
134
135
- def _get_log_likelihood (model , samples ):
135
+ def _get_log_likelihood (model : Model , samples ) -> Dict :
136
136
"""Compute log-likelihood for all observations"""
137
137
data = {}
138
138
for v in model .observed_RVs :
@@ -144,8 +144,13 @@ def _get_log_likelihood(model, samples):
144
144
145
145
146
146
def _get_batched_jittered_initial_points (
147
- model , chains , initvals , random_seed , jitter = True , jitter_max_retries = 10
148
- ):
147
+ model : Model ,
148
+ chains : int ,
149
+ initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]],
150
+ random_seed : int ,
151
+ jitter : bool = True ,
152
+ jitter_max_retries : int = 10 ,
153
+ ) -> Union [np .ndarray , List [np .ndarray ]]:
149
154
"""Get jittered initial point in format expected by NumPyro MCMC kernel
150
155
151
156
Returns
@@ -154,10 +159,8 @@ def _get_batched_jittered_initial_points(
154
159
list with one item per variable and number of chains as batch dimension.
155
160
Each item has shape `(chains, *var.shape)`
156
161
"""
157
- if isinstance (random_seed , (int , np .integer )):
158
- random_seed = np .random .default_rng (random_seed ).integers (2 ** 30 , size = chains )
159
- elif not isinstance (random_seed , (list , tuple , np .ndarray )):
160
- raise ValueError (f"The `seeds` must be int or array-like. Got { type (random_seed )} instead." )
162
+
163
+ random_seed = np .random .default_rng (random_seed ).integers (2 ** 30 , size = chains )
161
164
162
165
assert len (random_seed ) == chains
163
166
@@ -373,19 +376,19 @@ def sample_blackjax_nuts(
373
376
374
377
375
378
def sample_numpyro_nuts (
376
- draws = 1000 ,
377
- tune = 1000 ,
378
- chains = 4 ,
379
- target_accept = 0.8 ,
380
- random_seed = None ,
381
- initvals = None ,
382
- model = None ,
379
+ draws : int = 1000 ,
380
+ tune : int = 1000 ,
381
+ chains : int = 4 ,
382
+ target_accept : float = 0.8 ,
383
+ random_seed : int = None ,
384
+ initvals : Optional [ Union [ StartDict , Sequence [ Optional [ StartDict ]]]] = None ,
385
+ model : Optional [ Model ] = None ,
383
386
var_names = None ,
384
- progress_bar = True ,
385
- keep_untransformed = False ,
386
- chain_method = "parallel" ,
387
- idata_kwargs = None ,
388
- nuts_kwargs = None ,
387
+ progress_bar : bool = True ,
388
+ keep_untransformed : bool = False ,
389
+ chain_method : str = "parallel" ,
390
+ idata_kwargs : Optional [ Dict ] = None ,
391
+ nuts_kwargs : Optional [ Dict ] = None ,
389
392
):
390
393
"""
391
394
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
@@ -456,9 +459,7 @@ def sample_numpyro_nuts(
456
459
dims = {}
457
460
458
461
if random_seed is None :
459
- random_seed = model .rng_seeder .randint (
460
- 2 ** 30 , dtype = np .int64 , size = chains if chains > 1 else None
461
- )
462
+ random_seed = model .rng_seeder .randint (2 ** 30 , dtype = np .int64 )
462
463
463
464
tic1 = datetime .now ()
464
465
print ("Compiling..." , file = sys .stdout )
0 commit comments