From 7019ddf4119d6f8e7ec474700f9d630c5fd681aa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 16 May 2024 13:42:34 +0200 Subject: [PATCH] Fix local_fill_sink rewrite for multiple output Elemwise Ops The changes get rid of the eager sink at the local node rewriter level. This was actually not working because the nested replacements referenced variables that were never part of the original fgraph and those replacements were being ignored altogether. Instead we wrap the rewrite in an in2out that will safely achieve the intended behavior. --- pytensor/tensor/rewriting/basic.py | 65 +++++++++++++--------------- tests/tensor/rewriting/test_basic.py | 18 ++++++++ 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 01c5cce5e8..c24012705d 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -41,6 +41,7 @@ ) from pytensor.graph.rewriting.db import RewriteDatabase from pytensor.raise_op import Assert, CheckAndRaise, assert_op +from pytensor.scalar.basic import Second from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -320,56 +321,52 @@ def dimshuffled_alloc(i): return new_outs -@register_canonicalize("shape_unsafe") @node_rewriter([Elemwise]) def local_fill_sink(fgraph, node): """ f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) f need to be an elemwise that isn't a fill. """ - if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill: + if isinstance(node.op.scalar_op, Second): return False + models = [] inputs = [] for inp in node.inputs: if inp.owner and inp.owner.op == fill: - models.append(inp.owner.inputs[0]) - inputs.append(inp.owner.inputs[1]) + a, b = inp.owner.inputs + if b.type.dtype != inp.dtype: + # The input was implicitly casted by the fill operation + b = b.cast(inp.dtype) + models.append(a) + inputs.append(b) else: inputs.append(inp) + if not models: return False - c = node.op(*inputs) - for model in models: - if ( - model.type.dtype != c.type.dtype - or model.type.broadcastable != c.type.broadcastable - ): - c = fill(model, c) - # The newly created node c doesn't has 'clients', - # so this iteration is took place with node.outputs[0] - # TODO: This should just be a WalkingGraphRewrite! - replacements = {node.outputs[0]: c} - for client, cl_idx in fgraph.clients[node.outputs[0]]: - if ( - hasattr(client, "op") - and isinstance(client.op, Elemwise) - and client.op != fill - ): - client_inputs = client.inputs[:] - client_inputs[cl_idx] = c - new_client = client.op(*client_inputs) - - # Add clients to new_client - fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[ - client.outputs[0] - ] - r = local_fill_sink.transform(fgraph, new_client.owner) - if not r: - continue - replacements.update(r) - return replacements + outputs = node.op.make_node(*inputs).outputs + + # Check if we need to propagate the fill to the new outputs + # It's enough to check the first output, as Elemwise outputs must all have the same shapes + # Note: There are orderings that may require fewer fills. + old_bcast_pattern = node.outputs[0].type.broadcastable + models_iter = iter(models) + while old_bcast_pattern != outputs[0].type.broadcastable: + model = next(models_iter) + # Only apply this model if it would actually do anything + if broadcasted_by(outputs[0], model): + outputs = [fill(model, output) for output in outputs] + + return outputs + + +# The rewrite is wrapped in an in2out GraphRewriter +# so that fill can be sinked until the terminal nodes in a single pass through the graph +# without triggering other rewrites after each local substitution +topological_fill_sink = in2out(local_fill_sink) +register_canonicalize(topological_fill_sink, "shape_unsafe") @register_specialize("shape_unsafe") diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index bacbc540c5..366e09ed4a 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -19,6 +19,7 @@ from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.printing import debugprint, pprint from pytensor.raise_op import Assert, CheckAndRaise +from pytensor.scalar import Composite, float64 from pytensor.tensor.basic import ( Alloc, Join, @@ -64,6 +65,7 @@ local_merge_alloc, local_useless_alloc, local_useless_elemwise, + topological_fill_sink, ) from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot from pytensor.tensor.rewriting.shape import ShapeFeature @@ -1992,3 +1994,19 @@ def test_shape_unsafe_tag(): fn = function([x, y], out, mode=mode.excluding("shape_unsafe")) with pytest.raises(ValueError): fn([0, 1], [2, 3, 4]), [0, 1] + + +def test_topological_fill_sink_multi_output_client(): + x = float64("x") + elem_op_with_2_outputs = Elemwise(Composite([x], [x + 1, x + 2])) + + x = pt.vector("x", shape=(1,)) + z = pt.vector("z", shape=(None,)) + bcast_x = pt.full_like(z, x) + out = pt.add(*elem_op_with_2_outputs(pt.exp(bcast_x))) + + fg = FunctionGraph([x, z], [out], copy_inputs=False) + topological_fill_sink.rewrite(fg) + [new_out] = fg.outputs + expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x)))) + assert equal_computations([new_out], [expected_out])