@@ -205,19 +205,22 @@ def expand(var):
205
205
yield from walk (graphs , expand , bfs = False )
206
206
207
207
208
- def _replace_rvs_in_graphs (
208
+ def _replace_vars_in_graphs (
209
209
graphs : Iterable [TensorVariable ],
210
210
replacement_fn : Callable [[TensorVariable ], Dict [TensorVariable , TensorVariable ]],
211
211
** kwargs ,
212
212
) -> Tuple [List [TensorVariable ], Dict [TensorVariable , TensorVariable ]]:
213
- """Replace random variables in graphs
213
+ """Replace variables in graphs.
214
214
215
215
This will *not* recompute test values.
216
216
217
217
Parameters
218
218
----------
219
219
graphs
220
220
The graphs in which random variables are to be replaced.
221
+ replacement_fn
222
+ A callable called on each graph output that populates a replacement dictionary and returns
223
+ nodes that should be investigated further.
221
224
222
225
Returns
223
226
-------
@@ -256,7 +259,8 @@ def expand_replace(var):
256
259
toposort = fg .toposort ()
257
260
sorted_replacements = sorted (
258
261
tuple (replacements .items ()),
259
- key = lambda pair : toposort .index (pair [0 ].owner ),
262
+ # Root inputs don't have owner, we give them negative priority -1
263
+ key = lambda pair : toposort .index (pair [0 ].owner ) if pair [0 ].owner is not None else - 1 ,
260
264
reverse = True ,
261
265
)
262
266
fg .replace_all (sorted_replacements , import_missing = True )
@@ -317,7 +321,7 @@ def populate_replacements(
317
321
equiv = clone_get_equiv (inputs , graphs , False , False , {})
318
322
graphs = [equiv [n ] for n in graphs ]
319
323
320
- graphs , _ = _replace_rvs_in_graphs (
324
+ graphs , _ = _replace_vars_in_graphs (
321
325
graphs ,
322
326
replacement_fn = populate_replacements ,
323
327
** kwargs ,
@@ -385,7 +389,7 @@ def poulate_replacements(rv, replacements):
385
389
# replacements if that is not a simple input variable
386
390
return [value ]
387
391
388
- graphs , _ = _replace_rvs_in_graphs (
392
+ graphs , _ = _replace_vars_in_graphs (
389
393
graphs ,
390
394
replacement_fn = poulate_replacements ,
391
395
** kwargs ,
0 commit comments