Skip to content

Commit c2a5ea7

Browse files
committed
Merge branch 'main' into adding-blackjax-support
2 parents b1436d9 + a0e43da commit c2a5ea7

10 files changed

+162
-137
lines changed

pymc/data.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import aesara
2626
import aesara.tensor as at
2727
import numpy as np
28-
import pandas as pd
2928

3029
from aesara.compile.sharedvalue import SharedVariable
3130
from aesara.graph.basic import Apply
@@ -472,7 +471,7 @@ def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict
472471
coords = {}
473472

474473
# If value is a df or a series, we interpret the index as coords:
475-
if isinstance(value, (pd.Series, pd.DataFrame)):
474+
if hasattr(value, "index"):
476475
dim_name = None
477476
if dims is not None:
478477
dim_name = dims[0]
@@ -482,7 +481,7 @@ def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict
482481
coords[dim_name] = value.index
483482

484483
# If value is a df, we also interpret the columns as coords:
485-
if isinstance(value, pd.DataFrame):
484+
if hasattr(value, "columns"):
486485
dim_name = None
487486
if dims is not None:
488487
dim_name = dims[1]
@@ -501,7 +500,7 @@ def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict
501500
for size, dim in zip(value.shape, dims):
502501
coord = model.coords.get(dim, None)
503502
if coord is None:
504-
coords[dim] = pd.RangeIndex(size, name=dim)
503+
coords[dim] = range(size)
505504

506505
return coords
507506

pymc/distributions/discrete.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -971,12 +971,12 @@ class HyperGeometric(Discrete):
971971
972972
Parameters
973973
----------
974-
N : integer
975-
Total size of the population
976-
k : integer
977-
Number of successful individuals in the population
978-
n : integer
979-
Number of samples drawn from the population
974+
N : tensor_like of integer
975+
Total size of the population (N > 0)
976+
k : tensor_like of integer
977+
Number of successful individuals in the population (0 <= k <= N)
978+
n : tensor_like of integer
979+
Number of samples drawn from the population (0 <= n <= N)
980980
"""
981981

982982
rv_op = hypergeometric
@@ -1004,6 +1004,10 @@ def logp(value, good, bad, n):
10041004
value : numeric
10051005
Value(s) for which log-probability is calculated. If the log probabilities for multiple
10061006
values are desired the values must be provided in a numpy array or Aesara tensor
1007+
good : integer, array_like or TensorVariable
1008+
Number of successful individuals in the population. Alias for parameter :math:`k`.
1009+
bad : integer, array_like or TensorVariable
1010+
Number of unsuccessful individuals in the population. Alias for :math:`N-k`.
10071011
10081012
Returns
10091013
-------
@@ -1042,8 +1046,14 @@ def logcdf(value, good, bad, n):
10421046
10431047
Parameters
10441048
----------
1045-
value: numeric
1049+
value : numeric
10461050
Value for which log CDF is calculated.
1051+
good : integer
1052+
Number of successful individuals in the population. Alias for parameter :math:`k`.
1053+
bad : integer
1054+
Number of unsuccessful individuals in the population. Alias for :math:`N-k`.
1055+
n : integer
1056+
Number of samples drawn from the population (0 <= n <= N)
10471057
10481058
Returns
10491059
-------

pymc/model.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -631,9 +631,7 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
631631
raise ValueError(f"Can only compute the gradient of continuous types: {var}")
632632

633633
if tempered:
634-
# TODO: Should this differ from self.datalogpt,
635-
# where the potential terms are added to the observations?
636-
costs = [self.varlogpt + self.potentiallogpt, self.observedlogpt]
634+
costs = [self.varlogpt, self.datalogpt]
637635
else:
638636
costs = [self.logpt()]
639637

pymc/sampling.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,14 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
181181
Parameters
182182
----------
183183
model : Model object
184-
A fully-specified model object
185-
step : step function or vector of step functions
184+
A fully-specified model object.
185+
step : step function or iterable of step functions, optional
186186
One or more step functions that have been assigned to some subset of
187187
the model's parameters. Defaults to ``None`` (no assigned variables).
188-
methods : vector of step method classes
188+
methods : iterable of step method classes, optional
189189
The set of step methods from which the function may choose. Defaults
190190
to the main step methods provided by PyMC.
191-
step_kwargs : dict
191+
step_kwargs : dict, optional
192192
Parameters for the samplers. Keys are the lower case names of
193193
the step method, values a dict of arguments.
194194

pymc/sampling_jax.py

+25-24
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# pylint: skip-file
21
import os
32
import re
43
import sys
54
import warnings
65

76
from functools import partial
8-
from typing import Callable, List, Optional
7+
from typing import Callable, Dict, List, Optional, Sequence, Union
98

9+
from pymc.initial_point import StartDict
1010
from pymc.sampling import _init_jitter
1111

1212
xla_flags = os.getenv("XLA_FLAGS", "")
@@ -132,7 +132,7 @@ def _sample_stats_to_xarray(posterior):
132132
return data
133133

134134

135-
def _get_log_likelihood(model, samples):
135+
def _get_log_likelihood(model: Model, samples) -> Dict:
136136
"""Compute log-likelihood for all observations"""
137137
data = {}
138138
for v in model.observed_RVs:
@@ -144,8 +144,13 @@ def _get_log_likelihood(model, samples):
144144

145145

146146
def _get_batched_jittered_initial_points(
147-
model, chains, initvals, random_seed, jitter=True, jitter_max_retries=10
148-
):
147+
model: Model,
148+
chains: int,
149+
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
150+
random_seed: int,
151+
jitter: bool = True,
152+
jitter_max_retries: int = 10,
153+
) -> Union[np.ndarray, List[np.ndarray]]:
149154
"""Get jittered initial point in format expected by NumPyro MCMC kernel
150155
151156
Returns
@@ -154,10 +159,8 @@ def _get_batched_jittered_initial_points(
154159
list with one item per variable and number of chains as batch dimension.
155160
Each item has shape `(chains, *var.shape)`
156161
"""
157-
if isinstance(random_seed, (int, np.integer)):
158-
random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains)
159-
elif not isinstance(random_seed, (list, tuple, np.ndarray)):
160-
raise ValueError(f"The `seeds` must be int or array-like. Got {type(random_seed)} instead.")
162+
163+
random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains)
161164

162165
assert len(random_seed) == chains
163166

@@ -373,19 +376,19 @@ def sample_blackjax_nuts(
373376

374377

375378
def sample_numpyro_nuts(
376-
draws=1000,
377-
tune=1000,
378-
chains=4,
379-
target_accept=0.8,
380-
random_seed=None,
381-
initvals=None,
382-
model=None,
379+
draws: int = 1000,
380+
tune: int = 1000,
381+
chains: int = 4,
382+
target_accept: float = 0.8,
383+
random_seed: int = None,
384+
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
385+
model: Optional[Model] = None,
383386
var_names=None,
384-
progress_bar=True,
385-
keep_untransformed=False,
386-
chain_method="parallel",
387-
idata_kwargs=None,
388-
nuts_kwargs=None,
387+
progress_bar: bool = True,
388+
keep_untransformed: bool = False,
389+
chain_method: str = "parallel",
390+
idata_kwargs: Optional[Dict] = None,
391+
nuts_kwargs: Optional[Dict] = None,
389392
):
390393
"""
391394
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
@@ -456,9 +459,7 @@ def sample_numpyro_nuts(
456459
dims = {}
457460

458461
if random_seed is None:
459-
random_seed = model.rng_seeder.randint(
460-
2**30, dtype=np.int64, size=chains if chains > 1 else None
461-
)
462+
random_seed = model.rng_seeder.randint(2**30, dtype=np.int64)
462463

463464
tic1 = datetime.now()
464465
print("Compiling...", file=sys.stdout)

0 commit comments

Comments
 (0)