Skip to content

Fix local_fill_sink rewrite for multiple output Elemwise Ops #773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 31 additions & 34 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
Expand Down Expand Up @@ -320,56 +321,52 @@
return new_outs


@register_canonicalize("shape_unsafe")
@node_rewriter([Elemwise])
def local_fill_sink(fgraph, node):
"""
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
f need to be an elemwise that isn't a fill.
"""
if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill:
if isinstance(node.op.scalar_op, Second):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra checks were only needed for the recursive call of this rewrite. A default call will never call it on an node without Op that is not already Elemwise

return False

models = []
inputs = []
for inp in node.inputs:
if inp.owner and inp.owner.op == fill:
models.append(inp.owner.inputs[0])
inputs.append(inp.owner.inputs[1])
a, b = inp.owner.inputs
if b.type.dtype != inp.dtype:
# The input was implicitly casted by the fill operation
b = b.cast(inp.dtype)

Check warning on line 340 in pytensor/tensor/rewriting/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/basic.py#L340

Added line #L340 was not covered by tests
models.append(a)
inputs.append(b)
Comment on lines +337 to +342
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was also a potential source of bugs in the old rewrite. Ops may behave fundamentally different if the input types change so we shouldn't let that happen

else:
inputs.append(inp)

if not models:
return False
c = node.op(*inputs)
for model in models:
if (
model.type.dtype != c.type.dtype
or model.type.broadcastable != c.type.broadcastable
):
c = fill(model, c)

# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
# TODO: This should just be a WalkingGraphRewrite!
replacements = {node.outputs[0]: c}
for client, cl_idx in fgraph.clients[node.outputs[0]]:
if (
hasattr(client, "op")
and isinstance(client.op, Elemwise)
and client.op != fill
):
client_inputs = client.inputs[:]
client_inputs[cl_idx] = c
new_client = client.op(*client_inputs)

# Add clients to new_client
fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[
client.outputs[0]
]
r = local_fill_sink.transform(fgraph, new_client.owner)
if not r:
continue
replacements.update(r)
return replacements
outputs = node.op.make_node(*inputs).outputs

# Check if we need to propagate the fill to the new outputs
# It's enough to check the first output, as Elemwise outputs must all have the same shapes
# Note: There are orderings that may require fewer fills.
old_bcast_pattern = node.outputs[0].type.broadcastable
models_iter = iter(models)
while old_bcast_pattern != outputs[0].type.broadcastable:
model = next(models_iter)
# Only apply this model if it would actually do anything
if broadcasted_by(outputs[0], model):
outputs = [fill(model, output) for output in outputs]

return outputs


# The rewrite is wrapped in an in2out GraphRewriter
# so that fill can be sinked until the terminal nodes in a single pass through the graph
# without triggering other rewrites after each local substitution
topological_fill_sink = in2out(local_fill_sink)
register_canonicalize(topological_fill_sink, "shape_unsafe")


@register_specialize("shape_unsafe")
Expand Down
18 changes: 18 additions & 0 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.printing import debugprint, pprint
from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.scalar import Composite, float64
from pytensor.tensor.basic import (
Alloc,
Join,
Expand Down Expand Up @@ -64,6 +65,7 @@
local_merge_alloc,
local_useless_alloc,
local_useless_elemwise,
topological_fill_sink,
)
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
from pytensor.tensor.rewriting.shape import ShapeFeature
Expand Down Expand Up @@ -1992,3 +1994,19 @@ def test_shape_unsafe_tag():
fn = function([x, y], out, mode=mode.excluding("shape_unsafe"))
with pytest.raises(ValueError):
fn([0, 1], [2, 3, 4]), [0, 1]


def test_topological_fill_sink_multi_output_client():
x = float64("x")
elem_op_with_2_outputs = Elemwise(Composite([x], [x + 1, x + 2]))

x = pt.vector("x", shape=(1,))
z = pt.vector("z", shape=(None,))
bcast_x = pt.full_like(z, x)
out = pt.add(*elem_op_with_2_outputs(pt.exp(bcast_x)))

fg = FunctionGraph([x, z], [out], copy_inputs=False)
topological_fill_sink.rewrite(fg)
[new_out] = fg.outputs
expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
assert equal_computations([new_out], [expected_out])
Loading