Skip to content

Fix local_fill_sink rewrite for multiple output Elemwise Ops #773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ricardoV94
Copy link
Member

Description

Reported by @tomicapretto in pymc-devs/pymc#7315 (review)

The rewrite implicitly assumed Elemwise nodes have a single output which is not true. The reported issue involved the gradient of BetaIncGrad which includes a ScalarLoop Op with multiple outputs.

The changes get rid of the eager sink at the local node rewriter level. This was actually not working because the nested replacements referenced variables that were never part of the original fgraph and those replacements were being ignored altogether. Instead we wrap the rewrite in an in2out that will safely achieve the intended behavior.

The new test works as a regression for the bug, in that if we were to call the old rewrite (or just try to compile the graph) it would lead to the reported issue when finding the MultiOutput Op.

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link

codecov bot commented May 16, 2024

Codecov Report

Attention: Patch coverage is 82.35294% with 3 lines in your changes missing coverage. Please review.

Project coverage is 80.85%. Comparing base (d80c0bf) to head (7019ddf).
Report is 209 commits behind head on main.

Files Patch % Lines
pytensor/tensor/rewriting/basic.py 82.35% 1 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #773      +/-   ##
==========================================
- Coverage   80.85%   80.85%   -0.01%     
==========================================
  Files         162      162              
  Lines       47019    47016       -3     
  Branches    11504    11501       -3     
==========================================
- Hits        38018    38014       -4     
- Misses       6750     6751       +1     
  Partials     2251     2251              
Files Coverage Δ
pytensor/tensor/rewriting/basic.py 93.75% <82.35%> (-0.57%) ⬇️

... and 2 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the fix_local_fill_sink_multiple_outputs branch 2 times, most recently from ba17e8f to 381b88f Compare May 16, 2024 12:23
The changes get rid of the eager sink at the local node rewriter level. This was actually not working because the nested replacements referenced variables that were never part of the original fgraph and those replacements were being ignored altogether. Instead we wrap the rewrite in an in2out that will safely achieve the intended behavior.
Comment on lines +337 to +342
a, b = inp.owner.inputs
if b.type.dtype != inp.dtype:
# The input was implicitly casted by the fill operation
b = b.cast(inp.dtype)
models.append(a)
inputs.append(b)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was also a potential source of bugs in the old rewrite. Ops may behave fundamentally different if the input types change so we shouldn't let that happen

@ricardoV94 ricardoV94 force-pushed the fix_local_fill_sink_multiple_outputs branch from 381b88f to 7019ddf Compare May 16, 2024 12:26
@node_rewriter([Elemwise])
def local_fill_sink(fgraph, node):
"""
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
f need to be an elemwise that isn't a fill.
"""
if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill:
if isinstance(node.op.scalar_op, Second):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra checks were only needed for the recursive call of this rewrite. A default call will never call it on an node without Op that is not already Elemwise

@ricardoV94
Copy link
Member Author

For reference fill is identical to np.broadcast_arrays(x, y)[1], that is it broadcasts y to the broadcasted shape of x and y

Copy link

@tomicapretto tomicapretto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not able to say anything about the changes here because I'm not familiar with the parts being modified. But I can say this indeed fixes the example I shared on the linked PR, so this is good.

@ricardoV94 ricardoV94 merged commit 8c157a2 into pymc-devs:main May 17, 2024
54 of 55 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants