Skip to content

Commit fbc62d5

Browse files
committed
Remove IR Ops from final logprob graph
1 parent 4d0360c commit fbc62d5

File tree

6 files changed

+84
-31
lines changed

6 files changed

+84
-31
lines changed

pymc/logprob/basic.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
_logprob_helper,
6565
get_measurable_outputs,
6666
)
67-
from pymc.logprob.rewriting import construct_ir_fgraph
67+
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
6868
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
6969
from pymc.logprob.utils import rvs_to_value_vars
7070

@@ -107,6 +107,7 @@ def logp(
107107
fgraph, _, _ = construct_ir_fgraph({rv: value})
108108
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
109109
expr = _logprob_helper(ir_rv, ir_value, **kwargs)
110+
cleanup_ir([expr])
110111
if warn_missing_rvs:
111112
_warn_rvs_in_inferred_graph(expr)
112113
return expr
@@ -124,6 +125,7 @@ def logcdf(
124125
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
125126
[ir_rv] = fgraph.outputs
126127
expr = _logcdf_helper(ir_rv, value, **kwargs)
128+
cleanup_ir([expr])
127129
if warn_missing_rvs:
128130
_warn_rvs_in_inferred_graph(expr)
129131
return expr
@@ -141,6 +143,7 @@ def icdf(
141143
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
142144
[ir_rv] = fgraph.outputs
143145
expr = _icdf_helper(ir_rv, value, **kwargs)
146+
cleanup_ir([expr])
144147
if warn_missing_rvs:
145148
_warn_rvs_in_inferred_graph(expr)
146149
return expr
@@ -321,6 +324,8 @@ def factorized_joint_logprob(
321324
f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
322325
)
323326

327+
cleanup_ir(logprob_vars.values())
328+
324329
return logprob_vars
325330

326331

pymc/logprob/rewriting.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
3636

37-
from typing import Dict, Optional, Tuple
37+
from typing import Dict, Optional, Sequence, Tuple
3838

3939
import pytensor.tensor as pt
4040

@@ -43,11 +43,17 @@
4343
from pytensor.graph.features import Feature
4444
from pytensor.graph.fg import FunctionGraph
4545
from pytensor.graph.rewriting.basic import GraphRewriter, node_rewriter
46-
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery, SequenceDB
46+
from pytensor.graph.rewriting.db import (
47+
EquilibriumDB,
48+
LocalGroupDB,
49+
RewriteDatabaseQuery,
50+
SequenceDB,
51+
TopoDB,
52+
)
4753
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4854
from pytensor.tensor.extra_ops import BroadcastTo
4955
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
50-
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
56+
from pytensor.tensor.rewriting.basic import register_canonicalize
5157
from pytensor.tensor.rewriting.shape import ShapeFeature
5258
from pytensor.tensor.subtensor import (
5359
AdvancedIncSubtensor,
@@ -191,9 +197,8 @@ def local_lift_DiracDelta(fgraph, node):
191197
return new_node.outputs
192198

193199

194-
@register_useless
195-
@node_rewriter((DiracDelta,))
196-
def local_remove_DiracDelta(fgraph, node):
200+
@node_rewriter([DiracDelta])
201+
def remove_DiracDelta(fgraph, node):
197202
r"""Remove `DiracDelta`\s."""
198203
dd_val = node.inputs[0]
199204
return [dd_val]
@@ -270,6 +275,17 @@ def incsubtensor_rv_replace(fgraph, node):
270275

271276
logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")
272277

278+
# Rewrites that remove IR Ops
279+
cleanup_ir_rewrites_db = LocalGroupDB()
280+
cleanup_ir_rewrites_db.name = "cleanup_ir_rewrites_db"
281+
logprob_rewrites_db.register(
282+
"cleanup_ir_rewrites",
283+
TopoDB(cleanup_ir_rewrites_db, order="out_to_in", ignore_newtrees=True, failure_callback=None),
284+
"cleanup",
285+
)
286+
287+
cleanup_ir_rewrites_db.register("remove_DiracDelta", remove_DiracDelta, "cleanup")
288+
273289

274290
def construct_ir_fgraph(
275291
rv_values: Dict[Variable, Variable],
@@ -351,3 +367,9 @@ def construct_ir_fgraph(
351367
fgraph.replace_all(new_to_old)
352368

353369
return fgraph, rv_values, memo
370+
371+
372+
def cleanup_ir(vars: Sequence[Variable]) -> None:
373+
fgraph = FunctionGraph(outputs=vars, clone=False)
374+
ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["cleanup"]))
375+
ir_rewriter.rewrite(fgraph)

pymc/logprob/transforms.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,6 @@
8888
tanh,
8989
true_div,
9090
)
91-
from pytensor.tensor.rewriting.basic import (
92-
register_specialize,
93-
register_stabilize,
94-
register_useless,
95-
)
9691
from pytensor.tensor.var import TensorVariable
9792

9893
from pymc.logprob.abstract import (
@@ -106,7 +101,11 @@
106101
_logprob,
107102
_logprob_helper,
108103
)
109-
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
104+
from pymc.logprob.rewriting import (
105+
PreserveRVMappings,
106+
cleanup_ir_rewrites_db,
107+
measurable_ir_rewrites_db,
108+
)
110109
from pymc.logprob.utils import check_potential_measurability, ignore_logprob
111110

112111

@@ -134,15 +133,20 @@ def grad(self, args, g_outs):
134133
transformed_variable = TransformedVariable()
135134

136135

137-
@register_specialize
138-
@register_stabilize
139-
@register_useless
140136
@node_rewriter([TransformedVariable])
141137
def remove_TransformedVariables(fgraph, node):
142138
if isinstance(node.op, TransformedVariable):
143139
return [node.inputs[0]]
144140

145141

142+
cleanup_ir_rewrites_db.register(
143+
"remove_TransformedVariables",
144+
remove_TransformedVariables,
145+
"cleanup",
146+
"transform",
147+
)
148+
149+
146150
class RVTransform(abc.ABC):
147151
ndim_supp = None
148152

tests/logprob/test_rewriting.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import pytest
4141
import scipy.stats.distributions as sp
4242

43+
from pytensor.graph import ancestors
4344
from pytensor.graph.rewriting.basic import in2out
4445
from pytensor.graph.rewriting.utils import rewrite_graph
4546
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -50,8 +51,10 @@
5051
Subtensor,
5152
)
5253

54+
from pymc.distributions.transforms import logodds
5355
from pymc.logprob.basic import factorized_joint_logprob
54-
from pymc.logprob.rewriting import local_lift_DiracDelta
56+
from pymc.logprob.rewriting import cleanup_ir, local_lift_DiracDelta
57+
from pymc.logprob.transforms import TransformedVariable, TransformValuesRewrite
5558
from pymc.logprob.utils import DiracDelta, dirac_delta
5659

5760

@@ -88,10 +91,23 @@ def test_local_lift_DiracDelta():
8891

8992
def test_local_remove_DiracDelta():
9093
c_at = pt.vector()
91-
dd_at = dirac_delta(c_at)
94+
dd_at = dirac_delta(c_at) + dirac_delta(5)
95+
assert sum(isinstance(v.owner.op, DiracDelta) for v in ancestors([dd_at]) if v.owner) == 2
96+
97+
cleanup_ir([dd_at])
98+
assert not any(isinstance(v.owner.op, DiracDelta) for v in ancestors([dd_at]) if v.owner)
99+
100+
101+
def test_local_remove_TransformedVariable():
102+
p_rv = pt.random.beta(1, 1, name="p")
103+
p_vv = p_rv.clone()
104+
105+
tr = TransformValuesRewrite({p_vv: logodds})
106+
[p_logp] = factorized_joint_logprob({p_rv: p_vv}, extra_rewrites=tr).values()
92107

93-
fn = pytensor.function([c_at], dd_at)
94-
assert not any(isinstance(node.op, DiracDelta) for node in fn.maker.fgraph.toposort())
108+
assert not any(
109+
isinstance(v.owner.op, TransformedVariable) for v in ancestors([p_logp]) if v.owner
110+
)
95111

96112

97113
@pytest.mark.parametrize(

tests/logprob/test_transforms.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def test_original_values_output_dict():
561561
assert p_vv in logp_dict
562562

563563

564+
@pytest.mark.filterwarnings("error")
564565
def test_mixture_transform():
565566
"""Make sure that non-`RandomVariable` `MeasurableVariable`s can be transformed.
566567
@@ -588,21 +589,17 @@ def test_mixture_transform():
588589

589590
transform_rewrite = TransformValuesRewrite({y_vv: LogTransform()})
590591

591-
with pytest.warns(None) as record:
592-
# This shouldn't raise any warnings
593-
logp_trans = factorized_joint_logprob(
594-
{Y_rv: y_vv, I_rv: i_vv},
595-
extra_rewrites=transform_rewrite,
596-
use_jacobian=False,
597-
)
598-
logp_trans_combined = pt.sum([pt.sum(factor) for factor in logp_trans.values()])
599-
600-
assert not record.list
592+
logp_trans = factorized_joint_logprob(
593+
{Y_rv: y_vv, I_rv: i_vv},
594+
extra_rewrites=transform_rewrite,
595+
use_jacobian=False,
596+
)
597+
logp_trans_combined = pt.sum([pt.sum(factor) for factor in logp_trans.values()])
601598

602599
# The untransformed graph should be the same as the transformed graph after
603600
# replacing the `Y_rv` value variable with a transformed version of itself
604601
logp_nt_fg = FunctionGraph(outputs=[logp_no_trans_comb], clone=False)
605-
y_trans = transformed_variable(pt.exp(y_vv), y_vv)
602+
y_trans = pt.exp(y_vv)
606603
y_trans.name = "y_log"
607604
logp_nt_fg.replace(y_vv, y_trans)
608605
logp_nt = logp_nt_fg.outputs[0]

tests/test_model.py

+9
Original file line numberDiff line numberDiff line change
@@ -1625,3 +1625,12 @@ def test_invalid_observed_value(self, capfd):
16251625
"Some of the observed values of variable y are associated with a non-finite logp" in out
16261626
)
16271627
assert "value = 0.53 -> logp = -inf" in out
1628+
1629+
1630+
def test_model_logp_fast_compile():
1631+
# Issue #5618
1632+
with pm.Model() as m:
1633+
pm.Dirichlet("a", np.ones(3))
1634+
1635+
with pytensor.config.change_flags(mode="FAST_COMPILE"):
1636+
assert m.point_logps() == {"a": -1.5}

0 commit comments

Comments
 (0)