Skip to content

Commit f96da2d

Browse files
committed
Simplify tmp_rvs_to_values generation in logpt
1 parent bdd4d19 commit f96da2d

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

pymc/distributions/logprob.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,33 +195,26 @@ def logpt(
195195
getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim
196196
)
197197

198-
# Unlike aeppl, PyMC's logpt is expected to plug in the values variables to corresponding
199-
# RVs automatically unless the values are explicity set to None. Hence we iterate through
200-
# the graph to find RVs and construct a new RVs to values dictionary.
198+
# Aeppl needs all rv-values pairs, not just that of the requested var.
199+
# Hence we iterate through the graph to collect them.
201200
tmp_rvs_to_values = rv_values.copy()
202201
transform_map = {}
203202
for node in io_toposort(graph_inputs(var), var):
204-
if isinstance(node.op, RandomVariable):
205-
curr_var = node.out
203+
try:
204+
curr_vars = [node.default_output()]
205+
except ValueError:
206+
curr_vars = node.outputs
207+
for curr_var in curr_vars:
206208
rv_value_var = getattr(
207-
curr_var.tag, "observations", getattr(curr_var.tag, "value_var", curr_var)
209+
curr_var.tag, "observations", getattr(curr_var.tag, "value_var", None)
208210
)
211+
if rv_value_var is None:
212+
continue
209213
rv_value = rv_values.get(curr_var, rv_value_var)
210214
tmp_rvs_to_values[curr_var] = rv_value
211215
# Along with value variables we also check for transforms if any.
212216
if hasattr(rv_value_var.tag, "transform") and transformed:
213217
transform_map[rv_value] = rv_value_var.tag.transform
214-
# The condition below is a hackish way of excluding the value variable for the
215-
# RV being indexed in case of Advanced Indexing of RVs. It gets added by the
216-
# logic above but aeppl does not expect us to include it in the dictionary of
217-
# {RV:values} given to it.
218-
if isinstance(node.op, subtensor_types):
219-
curr_var = node.out
220-
if (
221-
curr_var in tmp_rvs_to_values.keys()
222-
and curr_var.owner.inputs[0] in tmp_rvs_to_values.keys()
223-
):
224-
tmp_rvs_to_values.pop(curr_var.owner.inputs[0])
225218

226219
transform_opt = TransformValuesOpt(transform_map)
227220
temp_logp_var_dict = factorized_joint_logprob(

0 commit comments

Comments
 (0)