Skip to content

Commit b17a60d

Browse files
committed
Ignore named variables that are not traceable in get_vars_in_point_list
1 parent fbc62d5 commit b17a60d

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

pymc/sampling/forward.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def get_vars_in_point_list(trace, model):
8989
names_in_trace = list(trace[0])
9090
else:
9191
names_in_trace = trace.varnames
92-
vars_in_trace = [model[v] for v in names_in_trace if v in model]
92+
traceable_varnames = {var.name for var in (model.free_RVs + model.deterministics)}
93+
vars_in_trace = [model[v] for v in names_in_trace if v in traceable_varnames]
9394
return vars_in_trace
9495

9596

tests/sampling/test_forward.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1634,11 +1634,13 @@ def test_get_vars_in_point_list():
16341634
with pm.Model() as modelA:
16351635
pm.Normal("a", 0, 1)
16361636
pm.Normal("b", 0, 1)
1637+
pm.Normal("d", 0, 1)
16371638
with pm.Model() as modelB:
16381639
a = pm.Normal("a", 0, 1)
16391640
pm.Normal("c", 0, 1)
1641+
pm.ConstantData("d", 0)
16401642

1641-
point_list = [{"a": 0, "b": 0}]
1643+
point_list = [{"a": 0, "b": 0, "d": 0}]
16421644
vars_in_trace = get_vars_in_point_list(point_list, modelB)
16431645
assert set(vars_in_trace) == {a}
16441646

0 commit comments

Comments
 (0)