-
Notifications
You must be signed in to change notification settings - Fork 132
Fix local_fill_sink
rewrite for multiple output Elemwise Ops
#773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix local_fill_sink
rewrite for multiple output Elemwise Ops
#773
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #773 +/- ##
==========================================
- Coverage 80.85% 80.85% -0.01%
==========================================
Files 162 162
Lines 47019 47016 -3
Branches 11504 11501 -3
==========================================
- Hits 38018 38014 -4
- Misses 6750 6751 +1
Partials 2251 2251
|
ba17e8f
to
381b88f
Compare
The changes get rid of the eager sink at the local node rewriter level. This was actually not working because the nested replacements referenced variables that were never part of the original fgraph and those replacements were being ignored altogether. Instead we wrap the rewrite in an in2out that will safely achieve the intended behavior.
a, b = inp.owner.inputs | ||
if b.type.dtype != inp.dtype: | ||
# The input was implicitly casted by the fill operation | ||
b = b.cast(inp.dtype) | ||
models.append(a) | ||
inputs.append(b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was also a potential source of bugs in the old rewrite. Ops may behave fundamentally different if the input types change so we shouldn't let that happen
381b88f
to
7019ddf
Compare
@node_rewriter([Elemwise]) | ||
def local_fill_sink(fgraph, node): | ||
""" | ||
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) | ||
f need to be an elemwise that isn't a fill. | ||
""" | ||
if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill: | ||
if isinstance(node.op.scalar_op, Second): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extra checks were only needed for the recursive call of this rewrite. A default call will never call it on an node without Op that is not already Elemwise
For reference fill is identical to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not able to say anything about the changes here because I'm not familiar with the parts being modified. But I can say this indeed fixes the example I shared on the linked PR, so this is good.
Description
Reported by @tomicapretto in pymc-devs/pymc#7315 (review)
The rewrite implicitly assumed Elemwise nodes have a single output which is not true. The reported issue involved the gradient of BetaIncGrad which includes a ScalarLoop Op with multiple outputs.
The changes get rid of the eager sink at the local node rewriter level. This was actually not working because the nested replacements referenced variables that were never part of the original fgraph and those replacements were being ignored altogether. Instead we wrap the rewrite in an in2out that will safely achieve the intended behavior.
The new test works as a regression for the bug, in that if we were to call the old rewrite (or just try to compile the graph) it would lead to the reported issue when finding the MultiOutput Op.
Checklist
Type of change