Skip to content

Commit 320135e

Browse files
committed
Use separate argument in CustomDist for functions that return symbolic representations
1 parent 22b4446 commit 320135e

File tree

2 files changed

+87
-55
lines changed

2 files changed

+87
-55
lines changed

pymc/distributions/distribution.py

+65-44
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def dist(
613613
cls,
614614
*dist_params,
615615
class_name: str,
616-
random: Callable,
616+
dist: Callable,
617617
logp: Optional[Callable] = None,
618618
logcdf: Optional[Callable] = None,
619619
moment: Optional[Callable] = None,
@@ -622,7 +622,7 @@ def dist(
622622
**kwargs,
623623
):
624624
warnings.warn(
625-
"CustomDist with symbolic random graph is still experimental. Expect bugs!",
625+
"CustomDist with dist function is still experimental. Expect bugs!",
626626
UserWarning,
627627
)
628628

@@ -644,7 +644,7 @@ def dist(
644644
class_name=class_name,
645645
logp=logp,
646646
logcdf=logcdf,
647-
random=random,
647+
dist=dist,
648648
moment=moment,
649649
ndim_supp=ndim_supp,
650650
**kwargs,
@@ -655,7 +655,7 @@ def rv_op(
655655
cls,
656656
*dist_params,
657657
class_name: str,
658-
random: Callable,
658+
dist: Callable,
659659
logp: Optional[Callable],
660660
logcdf: Optional[Callable],
661661
moment: Optional[Callable],
@@ -666,16 +666,16 @@ def rv_op(
666666
dummy_size_param = size.type()
667667
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
668668
with BlockModelAccess(
669-
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API"
669+
error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API"
670670
):
671-
dummy_rv = random(*dummy_dist_params, dummy_size_param)
671+
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
672672
dummy_params = [dummy_size_param] + dummy_dist_params
673673
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
674674

675675
rv_type = type(
676676
f"CustomSymbolicDistRV_{class_name}",
677677
(CustomSymbolicDistRV,),
678-
# If logp is not provided, we infer it from the random graph
678+
# If logp is not provided, we try to infer it from the dist graph
679679
dict(
680680
inline_logprob=logp is None,
681681
),
@@ -697,11 +697,11 @@ def custom_dist_get_moment(op, rv, size, *params):
697697
return moment(rv, size, *params[: len(params)])
698698

699699
@_change_dist_size.register(rv_type)
700-
def change_custom_symbolic_dist_size(op, dist, new_size, expand):
701-
node = dist.owner
700+
def change_custom_symbolic_dist_size(op, rv, new_size, expand):
701+
node = rv.owner
702702

703703
if expand:
704-
shape = tuple(dist.shape)
704+
shape = tuple(rv.shape)
705705
old_size = shape[: len(shape) - node.op.ndim_supp]
706706
new_size = tuple(new_size) + tuple(old_size)
707707
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
@@ -711,7 +711,7 @@ def change_custom_symbolic_dist_size(op, dist, new_size, expand):
711711
# OpFromGraph has to be recreated if the size type changes!
712712
dummy_size_param = new_size.type()
713713
dummy_dist_params = [dist_param.type() for dist_param in old_dist_params]
714-
dummy_rv = random(*dummy_dist_params, dummy_size_param)
714+
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
715715
dummy_params = [dummy_size_param] + dummy_dist_params
716716
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
717717
new_rv_op = rv_type(
@@ -737,17 +737,18 @@ class CustomDist:
737737
This class can be used to wrap black-box random and logp methods for use in
738738
forward and mcmc sampling.
739739
740-
A user can provide a `random` function that returns numerical draws (e.g., via
741-
NumPy routines) or an Aesara graph that represents the random graph when evaluated.
740+
A user can provide a `dist` function that returns a PyTensor graph built from
741+
simpler PyMC distributions, which represents the distribution. This graph is
742+
used to take random draws, and to infer the logp expression automatically
743+
when not provided by the user.
742744
743-
A user can provide a `logp` function that must return an Aesara graph that
744-
represents the logp graph when evaluated. This is used for mcmc sampling. In some
745-
cases, if a user provides a `random` function that returns an Aesara graph, PyMC
746-
will be able to automatically derive the appropriate `logp` graph when performing
747-
MCMC sampling.
745+
Alternatively, a user can provide a `random` function that returns numerical
746+
draws (e.g., via NumPy routines), and a `logp` function that must return an
747+
Python graph that represents the logp graph when evaluated. This is used for
748+
mcmc sampling.
748749
749750
Additionally, a user can provide a `logcdf` and `moment` functions that must return
750-
an Aesara graph that computes those quantities. These may be used by other PyMC
751+
an PyTensor graph that computes those quantities. These may be used by other PyMC
751752
routines.
752753
753754
Parameters
@@ -765,11 +766,20 @@ class CustomDist:
765766
different methods across separate models, be sure to use distinct
766767
class_names.
767768
769+
dist: Optional[Callable]
770+
A callable that returns a PyTensor graph built from simpler PyMC distributions
771+
which represents the distribution. This can be used by PyMC to take random draws
772+
as well as to infer the logp of the distribution in some cases. In that case
773+
it's not necessary to implement ``random`` or ``logp`` functions.
774+
775+
It must have the following signature: ``dist(*dist_params, size)``.
776+
The symbolic tensor distribution parameters are passed as positional arguments in
777+
the same order as they are supplied when the ``CustomDist`` is constructed.
778+
768779
random : Optional[Callable]
769-
A callable that can be used to 1) generate random draws from the distribution or
770-
2) returns an Aesara graph that represents the random draws.
780+
A callable that can be used to generate random draws from the distribution
771781
772-
If 1) it must have the following signature: ``random(*dist_params, rng=None, size=None)``.
782+
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
773783
The numerical distribution parameters are passed as positional arguments in the
774784
same order as they are supplied when the ``CustomDist`` is constructed.
775785
The keyword arguments are ``rng``, which will provide the random variable's
@@ -778,9 +788,6 @@ class CustomDist:
778788
error will be raised when trying to draw random samples from the distribution's
779789
prior or posterior predictive.
780790
781-
If 2) it must have the following signature: ``random(*dist_params, size)``.
782-
The symbolic tensor distribution parameters are passed as postional arguments in
783-
the same order as they are supplied when the ``CustomDist`` is constructed.
784791
logp : Optional[Callable]
785792
A callable that calculates the log probability of some given ``value``
786793
conditioned on certain distribution parameter values. It must have the
@@ -789,8 +796,8 @@ class CustomDist:
789796
are the tensors that hold the values of the distribution parameters.
790797
This function must return an PyTensor tensor.
791798
792-
When the `random` function is specified and returns an `Aesara` graph, PyMC
793-
will try to automatically infer the `logp` when this is not provided.
799+
When the `dist` function is specified, PyMC will try to automatically
800+
infer the `logp` when this is not provided.
794801
795802
Otherwise, a ``NotImplementedError`` will be raised when trying to compute the
796803
distribution's logp.
@@ -818,11 +825,11 @@ class CustomDist:
818825
The list of number of dimensions in the support of each of the distribution's
819826
parameters. If ``None``, it is assumed that all parameters are scalars, hence
820827
the number of dimensions of their support will be 0. This is not needed if an
821-
Aesara random function is provided
828+
PyTensor dist function is provided.
822829
dtype : str
823830
The dtype of the distribution. All draws and observations passed into the
824-
distribution will be cast onto this dtype. This is not needed if an Aesara
825-
random function is provided, which should already return the right dtype!
831+
distribution will be cast onto this dtype. This is not needed if an PyTensor
832+
dist function is provided, which should already return the right dtype!
826833
kwargs :
827834
Extra keyword arguments are passed to the parent's class ``__new__`` method.
828835
@@ -884,16 +891,16 @@ def random(
884891
)
885892
prior = pm.sample_prior_predictive(10)
886893
887-
Provide a random function that creates an Aesara random graph. PyMC can
888-
automatically infer that the logp of this variable corresponds to a shifted
889-
Exponential distribution.
894+
Provide a dist function that creates a PyTensor graph built from other
895+
PyMC distributions. PyMC can automatically infer that the logp of this
896+
variable corresponds to a shifted Exponential distribution.
890897
891898
.. code-block:: python
892899
893900
import pymc as pm
894901
from pytensor.tensor import TensorVariable
895902
896-
def random(
903+
def dist(
897904
lam: TensorVariable,
898905
shift: TensorVariable,
899906
size: TensorVariable,
@@ -907,16 +914,16 @@ def random(
907914
"custom_dist",
908915
lam,
909916
shift,
910-
random=random,
917+
dist=dist,
911918
observed=[-1, -1, 0],
912919
)
913920
914921
prior = pm.sample_prior_predictive()
915922
posterior = pm.sample()
916923
917-
Provide a random function that creates an Aesara random graph. PyMC can
918-
automatically infer that the logp of this variable corresponds to a
919-
modified-PERT distribution.
924+
Provide a dist function that creates a PyTensor graph built from other
925+
PyMC distributions. PyMC can automatically infer that the logp of
926+
this variable corresponds to a modified-PERT distribution.
920927
921928
.. code-block:: python
922929
@@ -940,7 +947,7 @@ def pert(
940947
peak = pm.Normal("peak", 50, 10)
941948
high = pm.Normal("high", 100, 10)
942949
lmbda = 4
943-
pm.CustomDist("pert", low, peak, high, lmbda, random=pert, observed=[30, 35, 73])
950+
pm.CustomDist("pert", low, peak, high, lmbda, dist=pert, observed=[30, 35, 73])
944951
945952
m.point_logps()
946953
@@ -950,6 +957,7 @@ def __new__(
950957
cls,
951958
name,
952959
*dist_params,
960+
dist: Optional[Callable] = None,
953961
random: Optional[Callable] = None,
954962
logp: Optional[Callable] = None,
955963
logcdf: Optional[Callable] = None,
@@ -968,12 +976,13 @@ def __new__(
968976
"parameters are positional arguments."
969977
)
970978
dist_params = cls.parse_dist_params(dist_params)
971-
if cls.is_symbolic_random(random, dist_params):
979+
cls.check_valid_dist_random(dist, random, dist_params)
980+
if dist is not None:
972981
return _CustomSymbolicDist(
973982
name,
974983
*dist_params,
975984
class_name=name,
976-
random=random,
985+
dist=dist,
977986
logp=logp,
978987
logcdf=logcdf,
979988
moment=moment,
@@ -1001,6 +1010,7 @@ def dist(
10011010
cls,
10021011
*dist_params,
10031012
class_name: str,
1013+
dist: Optional[Callable] = None,
10041014
random: Optional[Callable] = None,
10051015
logp: Optional[Callable] = None,
10061016
logcdf: Optional[Callable] = None,
@@ -1011,11 +1021,12 @@ def dist(
10111021
**kwargs,
10121022
):
10131023
dist_params = cls.parse_dist_params(dist_params)
1014-
if cls.is_symbolic_random(random, dist_params):
1024+
cls.check_valid_dist_random(dist, random, dist_params)
1025+
if dist is not None:
10151026
return _CustomSymbolicDist.dist(
10161027
*dist_params,
10171028
class_name=class_name,
1018-
random=random,
1029+
dist=dist,
10191030
logp=logp,
10201031
logcdf=logcdf,
10211032
moment=moment,
@@ -1048,6 +1059,16 @@ def parse_dist_params(cls, dist_params):
10481059
)
10491060
return [as_tensor_variable(param) for param in dist_params]
10501061

1062+
@classmethod
1063+
def check_valid_dist_random(cls, dist, random, dist_params):
1064+
if dist is not None and random is not None:
1065+
raise ValueError("Cannot provide both dist and random functions")
1066+
if random is not None and cls.is_symbolic_random(random, dist_params):
1067+
raise TypeError(
1068+
"API change: function passed to `random` argument should no longer return a PyTensor graph. "
1069+
"Pass such function to the `dist` argument instead."
1070+
)
1071+
10511072
@classmethod
10521073
def is_symbolic_random(self, random, dist_params):
10531074
if random is None:
@@ -1056,7 +1077,7 @@ def is_symbolic_random(self, random, dist_params):
10561077
try:
10571078
size = normalize_size_param(None)
10581079
with BlockModelAccess(
1059-
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API"
1080+
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API to create such variables."
10601081
):
10611082
out = random(*dist_params, size)
10621083
except BlockModelAccessError:

pymc/tests/distributions/test_distribution.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def test_dist(self):
368368

369369
class TestCustomSymbolicDist:
370370
def test_basic(self):
371-
def custom_random(mu, sigma, size):
371+
def custom_dist(mu, sigma, size):
372372
return at.exp(pm.Normal.dist(mu, sigma, size=size))
373373

374374
with Model() as m:
@@ -379,7 +379,7 @@ def custom_random(mu, sigma, size):
379379
"lognormal",
380380
mu,
381381
sigma,
382-
random=custom_random,
382+
dist=custom_dist,
383383
size=(10,),
384384
transform=log,
385385
initval=np.ones(10),
@@ -401,7 +401,7 @@ def custom_random(mu, sigma, size):
401401
np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip))
402402

403403
def test_random_multiple_rngs(self):
404-
def custom_random(p, sigma, size):
404+
def custom_dist(p, sigma, size):
405405
idx = pm.Bernoulli.dist(p=p)
406406
comps = pm.Normal.dist([-sigma, sigma], 1e-1, size=(*size, 2)).T
407407
return comps[idx]
@@ -411,7 +411,7 @@ def custom_random(p, sigma, size):
411411
0.5,
412412
10.0,
413413
class_name="customdist",
414-
random=custom_random,
414+
dist=custom_dist,
415415
size=(10,),
416416
)
417417

@@ -426,7 +426,7 @@ def custom_random(p, sigma, size):
426426
assert np.unique(draws).size == 20
427427

428428
def test_custom_methods(self):
429-
def custom_random(mu, size):
429+
def custom_dist(mu, size):
430430
if rv_size_is_none(size):
431431
return mu
432432
return at.full(size, mu)
@@ -444,7 +444,7 @@ def custom_logcdf(value, mu):
444444
customdist = CustomDist.dist(
445445
[np.e, np.e],
446446
class_name="customdist",
447-
random=custom_random,
447+
dist=custom_dist,
448448
moment=custom_moment,
449449
logp=custom_logp,
450450
logcdf=custom_logcdf,
@@ -458,15 +458,15 @@ def custom_logcdf(value, mu):
458458
np.testing.assert_allclose(logcdf(customdist, [0, 0]).eval(), [np.e + 3, np.e + 3])
459459

460460
def test_change_size(self):
461-
def custom_random(mu, sigma, size):
461+
def custom_dist(mu, sigma, size):
462462
return at.exp(pm.Normal.dist(mu, sigma, size=size))
463463

464464
with pytest.warns(UserWarning, match="experimental"):
465465
lognormal = CustomDist.dist(
466466
0,
467467
1,
468468
class_name="lognormal",
469-
random=custom_random,
469+
dist=custom_dist,
470470
size=(10,),
471471
)
472472
assert isinstance(lognormal.owner.op, CustomSymbolicDistRV)
@@ -481,15 +481,26 @@ def custom_random(mu, sigma, size):
481481
assert tuple(new_lognormal.shape.eval()) == (2, 5, 10)
482482

483483
def test_error_model_access(self):
484-
def random(size):
484+
def custom_dist(size):
485485
return pm.Flat("Flat", size=size)
486486

487487
with pm.Model() as m:
488488
with pytest.raises(
489489
BlockModelAccessError,
490-
match="Model variables cannot be created in the random function",
490+
match="Model variables cannot be created in the dist function",
491491
):
492-
CustomDist("custom_dist", random=random)
492+
CustomDist("custom_dist", dist=custom_dist)
493+
494+
def test_api_change_error(self):
495+
def old_random(size):
496+
return pm.Flat.dist(size=size)
497+
498+
# Old API raises
499+
with pytest.raises(TypeError, match="API change: function passed to `random` argument"):
500+
pm.CustomDist.dist(random=old_random, class_name="custom_dist")
501+
502+
# New API is fine
503+
pm.CustomDist.dist(dist=old_random, class_name="custom_dist")
493504

494505

495506
class TestSymbolicRandomVariable:

0 commit comments

Comments
 (0)