Skip to content

Commit c57769c

Browse files
ricardoV94michaelosthege
authored andcommitted
Rename _replace_rvs_in_graphs and fix bug when replacing input
1 parent e1060de commit c57769c

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

pymc/logprob/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def ignore_logprob_multiple_vars(
337337
making each "unmeasurable", whereas a sequential call to `ignore_logprob`
338338
would not do this correctly.
339339
"""
340-
from pymc.pytensorf import _replace_rvs_in_graphs
340+
from pymc.pytensorf import _replace_vars_in_graphs
341341

342342
measurable_vars_to_unmeasurable_vars = {
343343
measurable_var: ignore_logprob(measurable_var) for measurable_var in vars
@@ -353,5 +353,5 @@ def replacement_fn(var, replacements):
353353

354354
return []
355355

356-
unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn)
356+
unmeasurable_vars, _ = _replace_vars_in_graphs(graphs=vars, replacement_fn=replacement_fn)
357357
return unmeasurable_vars

pymc/pytensorf.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,22 @@ def expand(var):
205205
yield from walk(graphs, expand, bfs=False)
206206

207207

208-
def _replace_rvs_in_graphs(
208+
def _replace_vars_in_graphs(
209209
graphs: Iterable[TensorVariable],
210210
replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]],
211211
**kwargs,
212212
) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]:
213-
"""Replace random variables in graphs
213+
"""Replace variables in graphs.
214214
215215
This will *not* recompute test values.
216216
217217
Parameters
218218
----------
219219
graphs
220220
The graphs in which random variables are to be replaced.
221+
replacement_fn
222+
A callable called on each graph output that populates a replacement dictionary and returns
223+
nodes that should be investigated further.
221224
222225
Returns
223226
-------
@@ -256,7 +259,8 @@ def expand_replace(var):
256259
toposort = fg.toposort()
257260
sorted_replacements = sorted(
258261
tuple(replacements.items()),
259-
key=lambda pair: toposort.index(pair[0].owner),
262+
# Root inputs don't have owner, we give them negative priority -1
263+
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner is not None else -1,
260264
reverse=True,
261265
)
262266
fg.replace_all(sorted_replacements, import_missing=True)
@@ -317,7 +321,7 @@ def populate_replacements(
317321
equiv = clone_get_equiv(inputs, graphs, False, False, {})
318322
graphs = [equiv[n] for n in graphs]
319323

320-
graphs, _ = _replace_rvs_in_graphs(
324+
graphs, _ = _replace_vars_in_graphs(
321325
graphs,
322326
replacement_fn=populate_replacements,
323327
**kwargs,
@@ -385,7 +389,7 @@ def poulate_replacements(rv, replacements):
385389
# replacements if that is not a simple input variable
386390
return [value]
387391

388-
graphs, _ = _replace_rvs_in_graphs(
392+
graphs, _ = _replace_vars_in_graphs(
389393
graphs,
390394
replacement_fn=poulate_replacements,
391395
**kwargs,

tests/test_pytensorf.py

+20
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pytest
2525
import scipy.sparse as sps
2626

27+
from pytensor import shared
2728
from pytensor.compile.builders import OpFromGraph
2829
from pytensor.graph.basic import Variable, equal_computations
2930
from pytensor.tensor.random.basic import normal, uniform
@@ -40,6 +41,7 @@
4041
from pymc.exceptions import NotConstantValueError
4142
from pymc.logprob.utils import ParameterValueError
4243
from pymc.pytensorf import (
44+
_replace_vars_in_graphs,
4345
collect_default_updates,
4446
compile_pymc,
4547
constant_fold,
@@ -821,3 +823,21 @@ def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn):
821823
),
822824
[expected_x, expected_y, expected_z, expected_w],
823825
)
826+
827+
def test_replace_input(self):
828+
inp = shared(0.0, name="inp")
829+
x = pm.Normal.dist(inp)
830+
831+
assert x.eval() < 50
832+
833+
new_inp = inp + 100
834+
835+
def replacement_fn(var, replacements):
836+
if var is x:
837+
replacements[x.owner.inputs[3]] = new_inp
838+
839+
return []
840+
841+
[new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn)
842+
843+
assert new_x.eval() > 50

0 commit comments

Comments
 (0)