@@ -195,33 +195,26 @@ def logpt(
195
195
getattr (_var .tag , "total_size" , None ), rv_value_var .shape , rv_value_var .ndim
196
196
)
197
197
198
- # Unlike aeppl, PyMC's logpt is expected to plug in the values variables to corresponding
199
- # RVs automatically unless the values are explicity set to None. Hence we iterate through
200
- # the graph to find RVs and construct a new RVs to values dictionary.
198
+ # Aeppl needs all rv-values pairs, not just that of the requested var.
199
+ # Hence we iterate through the graph to collect them.
201
200
tmp_rvs_to_values = rv_values .copy ()
202
201
transform_map = {}
203
202
for node in io_toposort (graph_inputs (var ), var ):
204
- if isinstance (node .op , RandomVariable ):
205
- curr_var = node .out
203
+ try :
204
+ curr_vars = [node .default_output ()]
205
+ except ValueError :
206
+ curr_vars = node .outputs
207
+ for curr_var in curr_vars :
206
208
rv_value_var = getattr (
207
- curr_var .tag , "observations" , getattr (curr_var .tag , "value_var" , curr_var )
209
+ curr_var .tag , "observations" , getattr (curr_var .tag , "value_var" , None )
208
210
)
211
+ if rv_value_var is None :
212
+ continue
209
213
rv_value = rv_values .get (curr_var , rv_value_var )
210
214
tmp_rvs_to_values [curr_var ] = rv_value
211
215
# Along with value variables we also check for transforms if any.
212
216
if hasattr (rv_value_var .tag , "transform" ) and transformed :
213
217
transform_map [rv_value ] = rv_value_var .tag .transform
214
- # The condition below is a hackish way of excluding the value variable for the
215
- # RV being indexed in case of Advanced Indexing of RVs. It gets added by the
216
- # logic above but aeppl does not expect us to include it in the dictionary of
217
- # {RV:values} given to it.
218
- if isinstance (node .op , subtensor_types ):
219
- curr_var = node .out
220
- if (
221
- curr_var in tmp_rvs_to_values .keys ()
222
- and curr_var .owner .inputs [0 ] in tmp_rvs_to_values .keys ()
223
- ):
224
- tmp_rvs_to_values .pop (curr_var .owner .inputs [0 ])
225
218
226
219
transform_opt = TransformValuesOpt (transform_map )
227
220
temp_logp_var_dict = factorized_joint_logprob (
0 commit comments