@@ -85,7 +85,7 @@ def polyagamma_cdf(*args, **kwargs):
85
85
normal_lcdf ,
86
86
zvalue ,
87
87
)
88
- from pymc .distributions .distribution import Continuous
88
+ from pymc .distributions .distribution import DIST_PARAMETER_TYPES , Continuous
89
89
from pymc .distributions .shape_utils import rv_size_is_none
90
90
from pymc .math import invlogit , logdiffexp , logit
91
91
from pymc .util import UNSET
@@ -692,12 +692,12 @@ class TruncatedNormal(BoundedContinuous):
692
692
@classmethod
693
693
def dist (
694
694
cls ,
695
- mu : Optional [Union [ float , np . ndarray ] ] = None ,
696
- sigma : Optional [Union [ float , np . ndarray ] ] = None ,
697
- tau : Optional [Union [ float , np . ndarray ] ] = None ,
698
- sd : Optional [Union [ float , np . ndarray ] ] = None ,
699
- lower : Optional [Union [ float , np . ndarray ] ] = None ,
700
- upper : Optional [Union [ float , np . ndarray ] ] = None ,
695
+ mu : Optional [DIST_PARAMETER_TYPES ] = None ,
696
+ sigma : Optional [DIST_PARAMETER_TYPES ] = None ,
697
+ tau : Optional [DIST_PARAMETER_TYPES ] = None ,
698
+ sd : Optional [DIST_PARAMETER_TYPES ] = None ,
699
+ lower : Optional [DIST_PARAMETER_TYPES ] = None ,
700
+ upper : Optional [DIST_PARAMETER_TYPES ] = None ,
701
701
transform : str = "auto" ,
702
702
* args ,
703
703
** kwargs ,
0 commit comments