Skip to content

Commit af14bea

Browse files
thomasjpfantwiecki
authored andcommitted
MNT Removes pandas series from point_logps
1 parent 11e69ea commit af14bea

File tree

3 files changed

+15
-18
lines changed

3 files changed

+15
-18
lines changed

pymc/model.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from aesara.tensor.random.var import RandomStateSharedVariable
4747
from aesara.tensor.sharedvar import ScalarSharedVariable
4848
from aesara.tensor.var import TensorVariable
49-
from pandas import Series
5049

5150
from pymc.aesaraf import (
5251
compile_pymc,
@@ -1673,7 +1672,7 @@ def check_start_vals(self, start):
16731672

16741673
initial_eval = self.point_logps(point=elem)
16751674

1676-
if not np.all(np.isfinite(initial_eval)):
1675+
if not all(np.isfinite(v) for v in initial_eval.values()):
16771676
raise SamplingError(
16781677
"Initial evaluation of model at starting point failed!\n"
16791678
f"Starting values:\n{elem}\n\n"
@@ -1700,24 +1699,21 @@ def point_logps(self, point=None, round_vals=2):
17001699
17011700
Returns
17021701
-------
1703-
Pandas Series
1702+
log_probability_of_point : dict
1703+
Log probability of `point`.
17041704
"""
17051705
if point is None:
17061706
point = self.compute_initial_point()
17071707

17081708
factors = self.basic_RVs + self.potentials
1709-
return Series(
1710-
{
1711-
factor.name: np.round(np.asarray(factor_logp), round_vals)
1712-
for factor, factor_logp in zip(
1713-
factors,
1714-
self.compile_fn([at.sum(factor) for factor in self.logpt(factors, sum=False)])(
1715-
point
1716-
),
1717-
)
1718-
},
1719-
name="Point log-probability",
1720-
)
1709+
factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
1710+
return {
1711+
factor.name: np.round(np.asarray(factor_logp), round_vals)
1712+
for factor, factor_logp in zip(
1713+
factors,
1714+
self.compile_fn(factor_logps_fn)(point),
1715+
)
1716+
}
17211717

17221718

17231719
# this is really disgusting, but it breaks a self-loop: I can't pass Model

pymc/step_methods/hmc/base_hmc.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,9 @@ def astep(self, q0):
157157

158158
if not np.isfinite(start.energy):
159159
model = self._model
160-
check_test_point = model.point_logps()
161-
error_logp = check_test_point.loc[
160+
check_test_point_dict = model.point_logps()
161+
check_test_point = np.asarray(list(check_test_point_dict.values()))
162+
error_logp = check_test_point[
162163
(np.abs(check_test_point) >= 1e20) | np.isnan(check_test_point)
163164
]
164165
self.potential.raise_ok(q0.point_map_info)

pymc/tests/test_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def test_interval_near_boundary():
233233
pm.Uniform("x", initval=x0, lower=lb, upper=ub)
234234

235235
log_prob = model.point_logps()
236-
np.testing.assert_allclose(log_prob, np.array([-52.68]))
236+
np.testing.assert_allclose(list(log_prob.values()), np.array([-52.68]))
237237

238238

239239
def test_circular():

0 commit comments

Comments
 (0)