@@ -613,7 +613,7 @@ def dist(
613
613
cls ,
614
614
* dist_params ,
615
615
class_name : str ,
616
- random : Callable ,
616
+ dist : Callable ,
617
617
logp : Optional [Callable ] = None ,
618
618
logcdf : Optional [Callable ] = None ,
619
619
moment : Optional [Callable ] = None ,
@@ -622,7 +622,7 @@ def dist(
622
622
** kwargs ,
623
623
):
624
624
warnings .warn (
625
- "CustomDist with symbolic random graph is still experimental. Expect bugs!" ,
625
+ "CustomDist with dist function is still experimental. Expect bugs!" ,
626
626
UserWarning ,
627
627
)
628
628
@@ -644,7 +644,7 @@ def dist(
644
644
class_name = class_name ,
645
645
logp = logp ,
646
646
logcdf = logcdf ,
647
- random = random ,
647
+ dist = dist ,
648
648
moment = moment ,
649
649
ndim_supp = ndim_supp ,
650
650
** kwargs ,
@@ -655,7 +655,7 @@ def rv_op(
655
655
cls ,
656
656
* dist_params ,
657
657
class_name : str ,
658
- random : Callable ,
658
+ dist : Callable ,
659
659
logp : Optional [Callable ],
660
660
logcdf : Optional [Callable ],
661
661
moment : Optional [Callable ],
@@ -666,16 +666,16 @@ def rv_op(
666
666
dummy_size_param = size .type ()
667
667
dummy_dist_params = [dist_param .type () for dist_param in dist_params ]
668
668
with BlockModelAccess (
669
- error_msg_on_access = "Model variables cannot be created in the random function. Use the `.dist` API"
669
+ error_msg_on_access = "Model variables cannot be created in the dist function. Use the `.dist` API"
670
670
):
671
- dummy_rv = random (* dummy_dist_params , dummy_size_param )
671
+ dummy_rv = dist (* dummy_dist_params , dummy_size_param )
672
672
dummy_params = [dummy_size_param ] + dummy_dist_params
673
673
dummy_updates_dict = collect_default_updates (dummy_params , (dummy_rv ,))
674
674
675
675
rv_type = type (
676
676
f"CustomSymbolicDistRV_{ class_name } " ,
677
677
(CustomSymbolicDistRV ,),
678
- # If logp is not provided, we infer it from the random graph
678
+ # If logp is not provided, we try to infer it from the dist graph
679
679
dict (
680
680
inline_logprob = logp is None ,
681
681
),
@@ -697,11 +697,11 @@ def custom_dist_get_moment(op, rv, size, *params):
697
697
return moment (rv , size , * params [: len (params )])
698
698
699
699
@_change_dist_size .register (rv_type )
700
- def change_custom_symbolic_dist_size (op , dist , new_size , expand ):
701
- node = dist .owner
700
+ def change_custom_symbolic_dist_size (op , rv , new_size , expand ):
701
+ node = rv .owner
702
702
703
703
if expand :
704
- shape = tuple (dist .shape )
704
+ shape = tuple (rv .shape )
705
705
old_size = shape [: len (shape ) - node .op .ndim_supp ]
706
706
new_size = tuple (new_size ) + tuple (old_size )
707
707
new_size = at .as_tensor (new_size , ndim = 1 , dtype = "int64" )
@@ -711,7 +711,7 @@ def change_custom_symbolic_dist_size(op, dist, new_size, expand):
711
711
# OpFromGraph has to be recreated if the size type changes!
712
712
dummy_size_param = new_size .type ()
713
713
dummy_dist_params = [dist_param .type () for dist_param in old_dist_params ]
714
- dummy_rv = random (* dummy_dist_params , dummy_size_param )
714
+ dummy_rv = dist (* dummy_dist_params , dummy_size_param )
715
715
dummy_params = [dummy_size_param ] + dummy_dist_params
716
716
dummy_updates_dict = collect_default_updates (dummy_params , (dummy_rv ,))
717
717
new_rv_op = rv_type (
@@ -737,17 +737,18 @@ class CustomDist:
737
737
This class can be used to wrap black-box random and logp methods for use in
738
738
forward and mcmc sampling.
739
739
740
- A user can provide a `random` function that returns numerical draws (e.g., via
741
- NumPy routines) or an Aesara graph that represents the random graph when evaluated.
740
+ A user can provide a `dist` function that returns a PyTensor graph built from
741
+ simpler PyMC distributions, which represents the distribution. This graph is
742
+ used to take random draws, and to infer the logp expression automatically
743
+ when not provided by the user.
742
744
743
- A user can provide a `logp` function that must return an Aesara graph that
744
- represents the logp graph when evaluated. This is used for mcmc sampling. In some
745
- cases, if a user provides a `random` function that returns an Aesara graph, PyMC
746
- will be able to automatically derive the appropriate `logp` graph when performing
747
- MCMC sampling.
745
+ Alternatively, a user can provide a `random` function that returns numerical
746
+ draws (e.g., via NumPy routines), and a `logp` function that must return an
747
+ Python graph that represents the logp graph when evaluated. This is used for
748
+ mcmc sampling.
748
749
749
750
Additionally, a user can provide a `logcdf` and `moment` functions that must return
750
- an Aesara graph that computes those quantities. These may be used by other PyMC
751
+ an PyTensor graph that computes those quantities. These may be used by other PyMC
751
752
routines.
752
753
753
754
Parameters
@@ -765,11 +766,20 @@ class CustomDist:
765
766
different methods across separate models, be sure to use distinct
766
767
class_names.
767
768
769
+ dist: Optional[Callable]
770
+ A callable that returns a PyTensor graph built from simpler PyMC distributions
771
+ which represents the distribution. This can be used by PyMC to take random draws
772
+ as well as to infer the logp of the distribution in some cases. In that case
773
+ it's not necessary to implement ``random`` or ``logp`` functions.
774
+
775
+ It must have the following signature: ``dist(*dist_params, size)``.
776
+ The symbolic tensor distribution parameters are passed as positional arguments in
777
+ the same order as they are supplied when the ``CustomDist`` is constructed.
778
+
768
779
random : Optional[Callable]
769
- A callable that can be used to 1) generate random draws from the distribution or
770
- 2) returns an Aesara graph that represents the random draws.
780
+ A callable that can be used to generate random draws from the distribution
771
781
772
- If 1) it must have the following signature: ``random(*dist_params, rng=None, size=None)``.
782
+ It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
773
783
The numerical distribution parameters are passed as positional arguments in the
774
784
same order as they are supplied when the ``CustomDist`` is constructed.
775
785
The keyword arguments are ``rng``, which will provide the random variable's
@@ -778,9 +788,6 @@ class CustomDist:
778
788
error will be raised when trying to draw random samples from the distribution's
779
789
prior or posterior predictive.
780
790
781
- If 2) it must have the following signature: ``random(*dist_params, size)``.
782
- The symbolic tensor distribution parameters are passed as postional arguments in
783
- the same order as they are supplied when the ``CustomDist`` is constructed.
784
791
logp : Optional[Callable]
785
792
A callable that calculates the log probability of some given ``value``
786
793
conditioned on certain distribution parameter values. It must have the
@@ -789,8 +796,8 @@ class CustomDist:
789
796
are the tensors that hold the values of the distribution parameters.
790
797
This function must return an PyTensor tensor.
791
798
792
- When the `random ` function is specified and returns an `Aesara` graph, PyMC
793
- will try to automatically infer the `logp` when this is not provided.
799
+ When the `dist ` function is specified, PyMC will try to automatically
800
+ infer the `logp` when this is not provided.
794
801
795
802
Otherwise, a ``NotImplementedError`` will be raised when trying to compute the
796
803
distribution's logp.
@@ -818,11 +825,11 @@ class CustomDist:
818
825
The list of number of dimensions in the support of each of the distribution's
819
826
parameters. If ``None``, it is assumed that all parameters are scalars, hence
820
827
the number of dimensions of their support will be 0. This is not needed if an
821
- Aesara random function is provided
828
+ PyTensor dist function is provided.
822
829
dtype : str
823
830
The dtype of the distribution. All draws and observations passed into the
824
- distribution will be cast onto this dtype. This is not needed if an Aesara
825
- random function is provided, which should already return the right dtype!
831
+ distribution will be cast onto this dtype. This is not needed if an PyTensor
832
+ dist function is provided, which should already return the right dtype!
826
833
kwargs :
827
834
Extra keyword arguments are passed to the parent's class ``__new__`` method.
828
835
@@ -884,16 +891,16 @@ def random(
884
891
)
885
892
prior = pm.sample_prior_predictive(10)
886
893
887
- Provide a random function that creates an Aesara random graph. PyMC can
888
- automatically infer that the logp of this variable corresponds to a shifted
889
- Exponential distribution.
894
+ Provide a dist function that creates a PyTensor graph built from other
895
+ PyMC distributions. PyMC can automatically infer that the logp of this
896
+ variable corresponds to a shifted Exponential distribution.
890
897
891
898
.. code-block:: python
892
899
893
900
import pymc as pm
894
901
from pytensor.tensor import TensorVariable
895
902
896
- def random (
903
+ def dist (
897
904
lam: TensorVariable,
898
905
shift: TensorVariable,
899
906
size: TensorVariable,
@@ -907,16 +914,16 @@ def random(
907
914
"custom_dist",
908
915
lam,
909
916
shift,
910
- random=random ,
917
+ dist=dist ,
911
918
observed=[-1, -1, 0],
912
919
)
913
920
914
921
prior = pm.sample_prior_predictive()
915
922
posterior = pm.sample()
916
923
917
- Provide a random function that creates an Aesara random graph. PyMC can
918
- automatically infer that the logp of this variable corresponds to a
919
- modified-PERT distribution.
924
+ Provide a dist function that creates a PyTensor graph built from other
925
+ PyMC distributions. PyMC can automatically infer that the logp of
926
+ this variable corresponds to a modified-PERT distribution.
920
927
921
928
.. code-block:: python
922
929
@@ -940,7 +947,7 @@ def pert(
940
947
peak = pm.Normal("peak", 50, 10)
941
948
high = pm.Normal("high", 100, 10)
942
949
lmbda = 4
943
- pm.CustomDist("pert", low, peak, high, lmbda, random =pert, observed=[30, 35, 73])
950
+ pm.CustomDist("pert", low, peak, high, lmbda, dist =pert, observed=[30, 35, 73])
944
951
945
952
m.point_logps()
946
953
@@ -950,6 +957,7 @@ def __new__(
950
957
cls ,
951
958
name ,
952
959
* dist_params ,
960
+ dist : Optional [Callable ] = None ,
953
961
random : Optional [Callable ] = None ,
954
962
logp : Optional [Callable ] = None ,
955
963
logcdf : Optional [Callable ] = None ,
@@ -968,12 +976,13 @@ def __new__(
968
976
"parameters are positional arguments."
969
977
)
970
978
dist_params = cls .parse_dist_params (dist_params )
971
- if cls .is_symbolic_random (random , dist_params ):
979
+ cls .check_valid_dist_random (dist , random , dist_params )
980
+ if dist is not None :
972
981
return _CustomSymbolicDist (
973
982
name ,
974
983
* dist_params ,
975
984
class_name = name ,
976
- random = random ,
985
+ dist = dist ,
977
986
logp = logp ,
978
987
logcdf = logcdf ,
979
988
moment = moment ,
@@ -1001,6 +1010,7 @@ def dist(
1001
1010
cls ,
1002
1011
* dist_params ,
1003
1012
class_name : str ,
1013
+ dist : Optional [Callable ] = None ,
1004
1014
random : Optional [Callable ] = None ,
1005
1015
logp : Optional [Callable ] = None ,
1006
1016
logcdf : Optional [Callable ] = None ,
@@ -1011,11 +1021,12 @@ def dist(
1011
1021
** kwargs ,
1012
1022
):
1013
1023
dist_params = cls .parse_dist_params (dist_params )
1014
- if cls .is_symbolic_random (random , dist_params ):
1024
+ cls .check_valid_dist_random (dist , random , dist_params )
1025
+ if dist is not None :
1015
1026
return _CustomSymbolicDist .dist (
1016
1027
* dist_params ,
1017
1028
class_name = class_name ,
1018
- random = random ,
1029
+ dist = dist ,
1019
1030
logp = logp ,
1020
1031
logcdf = logcdf ,
1021
1032
moment = moment ,
@@ -1048,6 +1059,16 @@ def parse_dist_params(cls, dist_params):
1048
1059
)
1049
1060
return [as_tensor_variable (param ) for param in dist_params ]
1050
1061
1062
+ @classmethod
1063
+ def check_valid_dist_random (cls , dist , random , dist_params ):
1064
+ if dist is not None and random is not None :
1065
+ raise ValueError ("Cannot provide both dist and random functions" )
1066
+ if random is not None and cls .is_symbolic_random (random , dist_params ):
1067
+ raise TypeError (
1068
+ "API change: function passed to `random` argument should no longer return a PyTensor graph. "
1069
+ "Pass such function to the `dist` argument instead."
1070
+ )
1071
+
1051
1072
@classmethod
1052
1073
def is_symbolic_random (self , random , dist_params ):
1053
1074
if random is None :
@@ -1056,7 +1077,7 @@ def is_symbolic_random(self, random, dist_params):
1056
1077
try :
1057
1078
size = normalize_size_param (None )
1058
1079
with BlockModelAccess (
1059
- error_msg_on_access = "Model variables cannot be created in the random function. Use the `.dist` API"
1080
+ error_msg_on_access = "Model variables cannot be created in the random function. Use the `.dist` API to create such variables. "
1060
1081
):
1061
1082
out = random (* dist_params , size )
1062
1083
except BlockModelAccessError :
0 commit comments