Skip to content

Commit 367f60c

Browse files
Refactoring printing module with aeppl/printing.py with additions
1 parent 434333f commit 367f60c

File tree

5 files changed

+565
-229
lines changed

5 files changed

+565
-229
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,9 +1249,13 @@ def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
12491249
nonzero_dist,
12501250
]
12511251
if name is not None:
1252-
return Mixture(name, weights, comp_dists, **kwargs)
1252+
out_rv = Mixture(name, weights, comp_dists, **kwargs)
12531253
else:
1254-
return Mixture.dist(weights, comp_dists, **kwargs)
1254+
out_rv = Mixture.dist(weights, comp_dists, **kwargs)
1255+
1256+
out_rv.is_zero_inflated = True
1257+
1258+
return out_rv
12551259

12561260

12571261
class ZeroInflatedPoisson:

pymc/distributions/distribution.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import contextvars
1515
import functools
1616
import sys
17-
import types
1817
import warnings
1918

2019
from abc import ABCMeta
@@ -57,7 +56,6 @@
5756
)
5857
from pymc.logprob.rewriting import logprob_rewrites_db
5958
from pymc.model import BlockModelAccess
60-
from pymc.printing import str_for_dist
6159
from pymc.pytensorf import collect_default_updates, convert_observed_data
6260
from pymc.util import UNSET, _add_future_warning_tag
6361
from pymc.vartypes import string_types
@@ -317,12 +315,6 @@ def __new__(
317315
initval=initval,
318316
)
319317

320-
# add in pretty-printing support
321-
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
322-
rv_out._repr_latex_ = types.MethodType(
323-
functools.partial(str_for_dist, formatting="latex"), rv_out
324-
)
325-
326318
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
327319
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
328320
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")

pymc/distributions/mixture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ class Mixture(Distribution):
162162
"""
163163

164164
rv_type = MarginalMixtureRV
165+
is_zero_inflated = False
165166

166167
@classmethod
167168
def dist(cls, w, comp_dists, **kwargs):

pymc/model.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import functools
15+
import itertools
1616
import threading
1717
import types
1818
import warnings
@@ -589,10 +589,7 @@ def __init__(
589589

590590
from pymc.printing import str_for_model
591591

592-
self.str_repr = types.MethodType(str_for_model, self)
593-
self._repr_latex_ = types.MethodType(
594-
functools.partial(str_for_model, formatting="latex"), self
595-
)
592+
self._repr_latex_ = types.MethodType(str_for_model, self)
596593

597594
@property
598595
def model(self):
@@ -2015,17 +2012,17 @@ def Deterministic(name, var, model=None, dims=None):
20152012
model.deterministics.append(var)
20162013
model.add_named_variable(var, dims)
20172014

2018-
from pymc.printing import str_for_potential_or_deterministic
2015+
# from pymc.printing import str_for_potential_or_deterministic
20192016

2020-
var.str_repr = types.MethodType(
2021-
functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
2022-
)
2023-
var._repr_latex_ = types.MethodType(
2024-
functools.partial(
2025-
str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
2026-
),
2027-
var,
2028-
)
2017+
# var.str_repr = types.MethodType(
2018+
# functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
2019+
# )
2020+
# var._repr_latex_ = types.MethodType(
2021+
# functools.partial(
2022+
# str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
2023+
# ),
2024+
# var,
2025+
# )
20292026

20302027
return var
20312028

@@ -2047,16 +2044,16 @@ def Potential(name, var, model=None):
20472044
model.potentials.append(var)
20482045
model.add_named_variable(var)
20492046

2050-
from pymc.printing import str_for_potential_or_deterministic
2051-
2052-
var.str_repr = types.MethodType(
2053-
functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
2054-
)
2055-
var._repr_latex_ = types.MethodType(
2056-
functools.partial(
2057-
str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
2058-
),
2059-
var,
2060-
)
2047+
# from pymc.printing import str_for_potential_or_deterministic
2048+
2049+
# var.str_repr = types.MethodType(
2050+
# functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
2051+
# )
2052+
# var._repr_latex_ = types.MethodType(
2053+
# functools.partial(
2054+
# str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
2055+
# ),
2056+
# var,
2057+
# )
20612058

20622059
return var

0 commit comments

Comments
 (0)