Skip to content

Commit bcfff6d

Browse files
committed
Prevent Model access in random function of CustomDist
1 parent 439a973 commit bcfff6d

File tree

4 files changed

+44
-3
lines changed

4 files changed

+44
-3
lines changed

pymc/distributions/distribution.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
find_size,
4747
shape_from_dims,
4848
)
49+
from pymc.exceptions import BlockModelAccessError
4950
from pymc.logprob.abstract import (
5051
MeasurableVariable,
5152
_get_measurable_outputs,
@@ -54,6 +55,7 @@
5455
_logprob,
5556
)
5657
from pymc.logprob.rewriting import logprob_rewrites_db
58+
from pymc.model import BlockModelAccess
5759
from pymc.printing import str_for_dist
5860
from pymc.pytensorf import collect_default_updates, convert_observed_data
5961
from pymc.util import UNSET, _add_future_warning_tag
@@ -662,7 +664,10 @@ def rv_op(
662664
size = normalize_size_param(size)
663665
dummy_size_param = size.type()
664666
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
665-
dummy_rv = random(*dummy_dist_params, dummy_size_param)
667+
with BlockModelAccess(
668+
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API"
669+
):
670+
dummy_rv = random(*dummy_dist_params, dummy_size_param)
666671
dummy_params = [dummy_size_param] + dummy_dist_params
667672
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
668673

@@ -1050,7 +1055,12 @@ def is_symbolic_random(self, random, dist_params):
10501055
# Try calling random with symbolic inputs
10511056
try:
10521057
size = normalize_size_param(None)
1053-
out = random(*dist_params, size)
1058+
with BlockModelAccess(
1059+
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API"
1060+
):
1061+
out = random(*dist_params, size)
1062+
except BlockModelAccessError:
1063+
raise
10541064
except Exception:
10551065
# If it fails we assume it was not
10561066
return False

pymc/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,7 @@ class TruncationError(RuntimeError):
8282

8383
class NotConstantValueError(ValueError):
8484
pass
85+
86+
87+
class BlockModelAccessError(RuntimeError):
88+
pass

pymc/model.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@
5353
from pymc.data import GenTensorVariable, Minibatch
5454
from pymc.distributions.logprob import _joint_logp
5555
from pymc.distributions.transforms import _default_transform
56-
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning
56+
from pymc.exceptions import (
57+
BlockModelAccessError,
58+
ImputationWarning,
59+
SamplingError,
60+
ShapeError,
61+
ShapeWarning,
62+
)
5763
from pymc.initial_point import make_initial_point_fn
5864
from pymc.pytensorf import (
5965
PointFunc,
@@ -195,6 +201,8 @@ def get_context(cls, error_if_none=True) -> Optional[T]:
195201
if error_if_none:
196202
raise TypeError(f"No {cls} on context stack")
197203
return None
204+
if isinstance(candidate, BlockModelAccess):
205+
raise BlockModelAccessError(candidate.error_msg_on_access)
198206
return candidate
199207

200208
def get_contexts(cls) -> List[T]:
@@ -1798,6 +1806,13 @@ def point_logps(self, point=None, round_vals=2):
17981806
Model._context_class = Model
17991807

18001808

1809+
class BlockModelAccess(Model):
1810+
"""This class can be used to prevent user access to Model contexts"""
1811+
1812+
def __init__(self, *args, error_msg_on_access="Model access is blocked", **kwargs):
1813+
self.error_msg_on_access = error_msg_on_access
1814+
1815+
18011816
def set_data(new_data, model=None, *, coords=None):
18021817
"""Sets the value of one or more data container variables. Note that the shape is also
18031818
dynamic, it is updated when the value is changed. See the examples below for two common

pymc/tests/distributions/test_distribution.py

+12
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
4747
from pymc.distributions.transforms import log
48+
from pymc.exceptions import BlockModelAccessError
4849
from pymc.logprob.abstract import get_measurable_outputs, logcdf
4950
from pymc.model import Model
5051
from pymc.sampling import draw, sample
@@ -479,6 +480,17 @@ def custom_random(mu, sigma, size):
479480
assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV)
480481
assert tuple(new_lognormal.shape.eval()) == (2, 5, 10)
481482

483+
def test_error_model_access(self):
484+
def random(size):
485+
return pm.Flat("Flat", size=size)
486+
487+
with pm.Model() as m:
488+
with pytest.raises(
489+
BlockModelAccessError,
490+
match="Model variables cannot be created in the random function",
491+
):
492+
CustomDist("custom_dist", random=random)
493+
482494

483495
class TestSymbolicRandomVarible:
484496
def test_inline(self):

0 commit comments

Comments
 (0)