|
34 | 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
35 | 35 | # SOFTWARE.
|
36 | 36 |
|
37 |
| -from typing import Dict, Optional, Tuple |
| 37 | +from typing import Dict, Optional, Sequence, Tuple |
38 | 38 |
|
39 | 39 | import pytensor.tensor as pt
|
40 | 40 |
|
|
43 | 43 | from pytensor.graph.features import Feature
|
44 | 44 | from pytensor.graph.fg import FunctionGraph
|
45 | 45 | from pytensor.graph.rewriting.basic import GraphRewriter, node_rewriter
|
46 |
| -from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery, SequenceDB |
| 46 | +from pytensor.graph.rewriting.db import ( |
| 47 | + EquilibriumDB, |
| 48 | + LocalGroupDB, |
| 49 | + RewriteDatabaseQuery, |
| 50 | + SequenceDB, |
| 51 | + TopoDB, |
| 52 | +) |
47 | 53 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
48 | 54 | from pytensor.tensor.extra_ops import BroadcastTo
|
49 | 55 | from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
|
50 |
| -from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless |
| 56 | +from pytensor.tensor.rewriting.basic import register_canonicalize |
51 | 57 | from pytensor.tensor.rewriting.shape import ShapeFeature
|
52 | 58 | from pytensor.tensor.subtensor import (
|
53 | 59 | AdvancedIncSubtensor,
|
@@ -191,9 +197,8 @@ def local_lift_DiracDelta(fgraph, node):
|
191 | 197 | return new_node.outputs
|
192 | 198 |
|
193 | 199 |
|
194 |
| -@register_useless |
195 |
| -@node_rewriter((DiracDelta,)) |
196 |
| -def local_remove_DiracDelta(fgraph, node): |
| 200 | +@node_rewriter([DiracDelta]) |
| 201 | +def remove_DiracDelta(fgraph, node): |
197 | 202 | r"""Remove `DiracDelta`\s."""
|
198 | 203 | dd_val = node.inputs[0]
|
199 | 204 | return [dd_val]
|
@@ -270,6 +275,17 @@ def incsubtensor_rv_replace(fgraph, node):
|
270 | 275 |
|
271 | 276 | logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")
|
272 | 277 |
|
| 278 | +# Rewrites that remove IR Ops |
| 279 | +cleanup_ir_rewrites_db = LocalGroupDB() |
| 280 | +cleanup_ir_rewrites_db.name = "cleanup_ir_rewrites_db" |
| 281 | +logprob_rewrites_db.register( |
| 282 | + "cleanup_ir_rewrites", |
| 283 | + TopoDB(cleanup_ir_rewrites_db, order="out_to_in", ignore_newtrees=True, failure_callback=None), |
| 284 | + "cleanup", |
| 285 | +) |
| 286 | + |
| 287 | +cleanup_ir_rewrites_db.register("remove_DiracDelta", remove_DiracDelta, "cleanup") |
| 288 | + |
273 | 289 |
|
274 | 290 | def construct_ir_fgraph(
|
275 | 291 | rv_values: Dict[Variable, Variable],
|
@@ -351,3 +367,9 @@ def construct_ir_fgraph(
|
351 | 367 | fgraph.replace_all(new_to_old)
|
352 | 368 |
|
353 | 369 | return fgraph, rv_values, memo
|
| 370 | + |
| 371 | + |
| 372 | +def cleanup_ir(vars: Sequence[Variable]) -> None: |
| 373 | + fgraph = FunctionGraph(outputs=vars, clone=False) |
| 374 | + ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["cleanup"])) |
| 375 | + ir_rewriter.rewrite(fgraph) |
0 commit comments