Skip to content

Commit f6c8aac

Browse files
committed
Avoid reusing RNGs across distinct RandomVariables
This risk is present when resizing a RandomVariable, whose new size depends on the size of the original RandomVariable. This can lead to wrong update expressions for the reused RNG variables that still depended on the original RandomVariable. Similarly, there is a risk when clone-replacing RandomVariables, as the cloned variables will share the same RNGs. This happened when attempting to use Metropolis with Simulator variables. * change_rv_size does not reuse old RNG * compile_pymc raises if distinct update expression are inferred for the same RNG
1 parent 6b5f33a commit f6c8aac

File tree

7 files changed

+127
-41
lines changed

7 files changed

+127
-41
lines changed

pymc/aesaraf.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,31 @@ def find_rng_nodes(
864864
]
865865

866866

867+
def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]:
868+
"""Replace any RNG nodes upsteram of outputs by new RNGs of the same type
869+
870+
This can be used when combining a pre-existing graph with a cloned one, to ensure
871+
RNGs are unique across the two graphs.
872+
"""
873+
rng_nodes = find_rng_nodes(outputs)
874+
875+
# Nothing to do here
876+
if not rng_nodes:
877+
return outputs
878+
879+
graph = FunctionGraph(outputs=outputs, clone=False)
880+
new_rng_nodes: List[Union[np.random.RandomState, np.random.Generator]] = []
881+
for rng_node in rng_nodes:
882+
rng_cls: type
883+
if isinstance(rng_node, at.random.var.RandomStateSharedVariable):
884+
rng_cls = np.random.RandomState
885+
else:
886+
rng_cls = np.random.Generator
887+
new_rng_nodes.append(aesara.shared(rng_cls(np.random.PCG64())))
888+
graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True)
889+
return graph.outputs
890+
891+
867892
SeedSequenceSeed = Optional[Union[int, Sequence[int], np.ndarray, np.random.SeedSequence]]
868893

869894

@@ -944,12 +969,21 @@ def compile_pymc(
944969
assert random_var.owner.op is not None
945970
if isinstance(random_var.owner.op, RandomVariable):
946971
rng = random_var.owner.inputs[0]
947-
if not hasattr(rng, "default_update"):
948-
rng_updates[rng] = random_var.owner.outputs[0]
972+
if hasattr(rng, "default_update"):
973+
update_map = {rng: rng.default_update}
949974
else:
950-
rng_updates[rng] = rng.default_update
975+
update_map = {rng: random_var.owner.outputs[0]}
951976
else:
952-
rng_updates.update(random_var.owner.op.update(random_var.owner))
977+
update_map = random_var.owner.op.update(random_var.owner)
978+
# Check that we are not setting different update expressions for the same variables
979+
for rng, update in update_map.items():
980+
if rng not in rng_updates:
981+
rng_updates[rng] = update
982+
# When a variable has multiple outputs, it will be called twice with the same
983+
# update expression. We don't want to raise in that case, only if the update
984+
# expression in different from the one already registered
985+
elif rng_updates[rng] is not update:
986+
raise ValueError(f"Multiple update expressions found for the variable {rng}")
953987

954988
# We always reseed random variables as this provides RNGs with no chances of collision
955989
if rng_updates:

pymc/distributions/shape_utils.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
A collection of common shape operations needed for broadcasting
1818
samples from probability distributions for stochastic nodes in PyMC.
1919
"""
20+
import warnings
21+
2022
from functools import singledispatch
2123
from typing import Optional, Sequence, Tuple, Union
2224

@@ -579,8 +581,8 @@ def change_dist_size(
579581
Returns
580582
-------
581583
A new distribution variable that is equivalent to the original distribution with
582-
the new size. The new distribution may reuse the same RandomState/Generator inputs
583-
as the original distribution.
584+
the new size. The new distribution will not reuse the old RandomState/Generator
585+
input, so it will be independent from the original distribution.
584586
585587
Examples
586588
--------
@@ -618,24 +620,29 @@ def change_dist_size(
618620
def change_rv_size(op, rv, new_size, expand) -> TensorVariable:
619621
# Extract the RV node that is to be resized
620622
rv_node = rv.owner
621-
rng, size, dtype, *dist_params = rv_node.inputs
623+
old_rng, old_size, dtype, *dist_params = rv_node.inputs
622624

623625
if expand:
624-
shape = tuple(rv_node.op._infer_shape(size, dist_params))
625-
size = shape[: len(shape) - rv_node.op.ndim_supp]
626-
new_size = tuple(new_size) + tuple(size)
626+
shape = tuple(rv_node.op._infer_shape(old_size, dist_params))
627+
old_size = shape[: len(shape) - rv_node.op.ndim_supp]
628+
new_size = tuple(new_size) + tuple(old_size)
627629

628630
# Make sure the new size is a tensor. This dtype-aware conversion helps
629631
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
630632
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
631633

632-
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
633-
new_rv = new_rv_node.outputs[-1]
634+
new_rv = rv_node.op(*dist_params, size=new_size, dtype=dtype)
634635

635-
# Update "traditional" rng default_update, if that was set for old RV
636-
default_update = getattr(rng, "default_update", None)
637-
if default_update is not None and default_update is rv_node.outputs[0]:
638-
rng.default_update = new_rv_node.outputs[0]
636+
# Replicate "traditional" rng default_update, if that was set for old_rng
637+
default_update = getattr(old_rng, "default_update", None)
638+
if default_update is not None:
639+
if default_update is rv_node.outputs[0]:
640+
new_rv.owner.inputs[0].default_update = new_rv.owner.outputs[0]
641+
else:
642+
warnings.warn(
643+
f"Update expression of {rv} RNG could not be replicated in resized variable",
644+
UserWarning,
645+
)
639646

640647
return new_rv
641648

pymc/distributions/simulator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ def logp(cls, value, sim_op, sim_inputs):
248248
# TODO: Model rngs should be updated prior to multiprocessing split,
249249
# in which case this would not be needed. However, that would have to be
250250
# done for every sampler that may accomodate Simulators
251-
rng = aesara.shared(np.random.default_rng())
251+
rng = aesara.shared(np.random.default_rng(), name="simulator_rng")
252252
# Create a new simulatorRV with identical inputs as the original one
253253
sim_value = sim_op.make_node(rng, *sim_inputs[1:]).default_output()
254-
sim_value.name = "sim_value"
254+
sim_value.name = "simulator_value"
255255

256256
return sim_op.distance(
257257
sim_op.epsilon,

pymc/initial_point.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from aesara.graph.fg import FunctionGraph
2525
from aesara.tensor.var import TensorVariable
2626

27-
from pymc.aesaraf import compile_pymc, find_rng_nodes, reseed_rngs
27+
from pymc.aesaraf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs
2828
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name
2929

3030
StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]]
@@ -167,18 +167,8 @@ def make_initial_point_fn(
167167

168168
# Replace original rng shared variables so that we don't mess with them
169169
# when calling the final seeded function
170-
graph = FunctionGraph(outputs=initial_values, clone=False)
171-
rng_nodes = find_rng_nodes(graph.outputs)
172-
new_rng_nodes: List[Union[np.random.RandomState, np.random.Generator]] = []
173-
for rng_node in rng_nodes:
174-
rng_cls: type
175-
if isinstance(rng_node, at.random.var.RandomStateSharedVariable):
176-
rng_cls = np.random.RandomState
177-
else:
178-
rng_cls = np.random.Generator
179-
new_rng_nodes.append(aesara.shared(rng_cls(np.random.PCG64())))
180-
graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True)
181-
func = compile_pymc(inputs=[], outputs=graph.outputs, mode=aesara.compile.mode.FAST_COMPILE)
170+
initial_values = replace_rng_nodes(initial_values)
171+
func = compile_pymc(inputs=[], outputs=initial_values, mode=aesara.compile.mode.FAST_COMPILE)
182172

183173
varnames = []
184174
for var in model.free_RVs:

pymc/step_methods/metropolis.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323

2424
import pymc as pm
2525

26-
from pymc.aesaraf import compile_pymc, floatX, rvs_to_value_vars
26+
from pymc.aesaraf import (
27+
CallableTensor,
28+
compile_pymc,
29+
floatX,
30+
join_nonshared_inputs,
31+
replace_rng_nodes,
32+
rvs_to_value_vars,
33+
)
2734
from pymc.blocking import DictToArrayBijection, RaveledVars
2835
from pymc.step_methods.arraystep import (
2936
ArrayStep,
@@ -1046,12 +1053,14 @@ def sample_except(limit, excluded):
10461053

10471054

10481055
def delta_logp(point, logp, vars, shared):
1049-
[logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared)
1056+
[logp0], inarray0 = join_nonshared_inputs(point, [logp], vars, shared)
10501057

10511058
tensor_type = inarray0.type
10521059
inarray1 = tensor_type("inarray1")
10531060

1054-
logp1 = pm.CallableTensor(logp0)(inarray1)
1061+
logp1 = CallableTensor(logp0)(inarray1)
1062+
# Replace any potential duplicated RNG nodes
1063+
(logp1,) = replace_rng_nodes((logp1,))
10551064

10561065
f = compile_pymc([inarray1, inarray0], logp1 - logp0)
10571066
f.trust_input = True

pymc/tests/distributions/test_shape_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,8 @@ def test_rv_size_is_none():
506506

507507
def test_change_rv_size():
508508
loc = at.as_tensor_variable([1, 2])
509-
rv = normal(loc=loc)
509+
rng = aesara.shared(np.random.default_rng())
510+
rv = normal(loc=loc, rng=rng)
510511
assert rv.ndim == 1
511512
assert tuple(rv.shape.eval()) == (2,)
512513

@@ -525,6 +526,9 @@ def test_change_rv_size():
525526
assert loc in rv_new_ancestors
526527
assert rv not in rv_new_ancestors
527528

529+
# Check that the old rng is not reused
530+
assert rv_new.owner.inputs[0] is not rng
531+
528532
rv_newer = change_dist_size(rv_new, new_size=(4,), expand=True)
529533
assert rv_newer.ndim == 3
530534
assert tuple(rv_newer.shape.eval()) == (4, 3, 2)
@@ -555,22 +559,27 @@ def test_change_rv_size_default_update():
555559
rng = aesara.shared(np.random.default_rng(0))
556560
x = normal(rng=rng)
557561

558-
# Test that "traditional" default_update is updated
562+
# Test that "traditional" default_update is translated to the new rng
559563
rng.default_update = x.owner.outputs[0]
560564
new_x = change_dist_size(x, new_size=(2,))
561-
assert rng.default_update is not x.owner.outputs[0]
562-
assert rng.default_update is new_x.owner.outputs[0]
565+
new_rng = new_x.owner.inputs[0]
566+
assert rng.default_update is x.owner.outputs[0]
567+
assert new_rng.default_update is new_x.owner.outputs[0]
563568

564-
# Test that "non-traditional" default_update is left unchanged
569+
# Test that "non-traditional" default_update raises UserWarning
565570
next_rng = aesara.shared(np.random.default_rng(1))
566571
rng.default_update = next_rng
567-
new_x = change_dist_size(x, new_size=(2,))
572+
with pytest.warns(UserWarning, match="could not be replicated in resized variable"):
573+
new_x = change_dist_size(x, new_size=(2,))
574+
new_rng = new_x.owner.inputs[0]
568575
assert rng.default_update is next_rng
576+
assert not hasattr(new_rng, "default_update")
569577

570578
# Test that default_update is not set if there was none before
571579
del rng.default_update
572580
new_x = change_dist_size(x, new_size=(2,))
573-
assert not hasattr(rng, "default_update")
581+
new_rng = new_x.owner.inputs[0]
582+
assert not hasattr(new_rng, "default_update")
574583

575584

576585
def test_change_specify_shape_size_univariate():

pymc/tests/test_aesaraf.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
compile_pymc,
3838
convert_observed_data,
3939
extract_obs_data,
40+
replace_rng_nodes,
4041
reseed_rngs,
4142
rvs_to_value_vars,
4243
walk_model,
@@ -502,6 +503,42 @@ def test_random_seed(self):
502503
assert x3_eval == x2_eval
503504
assert y3_eval == y2_eval
504505

506+
def test_multiple_updates_same_variable(self):
507+
rng = aesara.shared(np.random.default_rng(), name="rng")
508+
x = at.random.normal(rng=rng)
509+
y = at.random.normal(rng=rng)
510+
511+
assert compile_pymc([], [x])
512+
assert compile_pymc([], [y])
513+
msg = "Multiple update expressions found for the variable rng"
514+
with pytest.raises(ValueError, match=msg):
515+
compile_pymc([], [x, y])
516+
517+
518+
def test_replace_rng_nodes():
519+
rng = aesara.shared(np.random.default_rng())
520+
x = at.random.normal(rng=rng)
521+
x_rng, *x_non_rng_inputs = x.owner.inputs
522+
523+
cloned_x = x.owner.clone().default_output()
524+
cloned_x_rng, *cloned_x_non_rng_inputs = cloned_x.owner.inputs
525+
526+
# RNG inputs are the same across the two variables
527+
assert x_rng is cloned_x_rng
528+
529+
(new_x,) = replace_rng_nodes([cloned_x])
530+
new_x_rng, *new_x_non_rng_inputs = new_x.owner.inputs
531+
532+
# Variables are still the same
533+
assert new_x is cloned_x
534+
535+
# RNG inputs are not the same as before
536+
assert new_x_rng is not x_rng
537+
538+
# All other inputs are the same as before
539+
for non_rng_inputs, new_non_rng_inputs in zip(x_non_rng_inputs, new_x_non_rng_inputs):
540+
assert non_rng_inputs is new_non_rng_inputs
541+
505542

506543
def test_reseed_rngs():
507544
# Reseed_rngs uses the `PCG64` bit_generator, which is currently the default

0 commit comments

Comments
 (0)