Skip to content

Commit 38867dd

Browse files
authored
Update dist parameter hints (#5315)
1 parent 2634f41 commit 38867dd

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

pymc/distributions/continuous.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def polyagamma_cdf(*args, **kwargs):
8585
normal_lcdf,
8686
zvalue,
8787
)
88-
from pymc.distributions.distribution import Continuous
88+
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous
8989
from pymc.distributions.shape_utils import rv_size_is_none
9090
from pymc.math import invlogit, logdiffexp, logit
9191
from pymc.util import UNSET
@@ -692,12 +692,12 @@ class TruncatedNormal(BoundedContinuous):
692692
@classmethod
693693
def dist(
694694
cls,
695-
mu: Optional[Union[float, np.ndarray]] = None,
696-
sigma: Optional[Union[float, np.ndarray]] = None,
697-
tau: Optional[Union[float, np.ndarray]] = None,
698-
sd: Optional[Union[float, np.ndarray]] = None,
699-
lower: Optional[Union[float, np.ndarray]] = None,
700-
upper: Optional[Union[float, np.ndarray]] = None,
695+
mu: Optional[DIST_PARAMETER_TYPES] = None,
696+
sigma: Optional[DIST_PARAMETER_TYPES] = None,
697+
tau: Optional[DIST_PARAMETER_TYPES] = None,
698+
sd: Optional[DIST_PARAMETER_TYPES] = None,
699+
lower: Optional[DIST_PARAMETER_TYPES] = None,
700+
upper: Optional[DIST_PARAMETER_TYPES] = None,
701701
transform: str = "auto",
702702
*args,
703703
**kwargs,

pymc/distributions/distribution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
"NoDistribution",
6262
]
6363

64+
DIST_PARAMETER_TYPES = Union[np.ndarray, int, float, TensorVariable]
65+
6466
vectorized_ppc = contextvars.ContextVar(
6567
"vectorized_ppc", default=None
6668
) # type: contextvars.ContextVar[Optional[Callable]]

0 commit comments

Comments
 (0)