diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 01c5cce5e8..c24012705d 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -41,6 +41,7 @@ ) from pytensor.graph.rewriting.db import RewriteDatabase from pytensor.raise_op import Assert, CheckAndRaise, assert_op +from pytensor.scalar.basic import Second from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -320,56 +321,52 @@ def dimshuffled_alloc(i): return new_outs -@register_canonicalize("shape_unsafe") @node_rewriter([Elemwise]) def local_fill_sink(fgraph, node): """ f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) f need to be an elemwise that isn't a fill. """ - if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill: + if isinstance(node.op.scalar_op, Second): return False + models = [] inputs = [] for inp in node.inputs: if inp.owner and inp.owner.op == fill: - models.append(inp.owner.inputs[0]) - inputs.append(inp.owner.inputs[1]) + a, b = inp.owner.inputs + if b.type.dtype != inp.dtype: + # The input was implicitly casted by the fill operation + b = b.cast(inp.dtype) + models.append(a) + inputs.append(b) else: inputs.append(inp) + if not models: return False - c = node.op(*inputs) - for model in models: - if ( - model.type.dtype != c.type.dtype - or model.type.broadcastable != c.type.broadcastable - ): - c = fill(model, c) - # The newly created node c doesn't has 'clients', - # so this iteration is took place with node.outputs[0] - # TODO: This should just be a WalkingGraphRewrite! - replacements = {node.outputs[0]: c} - for client, cl_idx in fgraph.clients[node.outputs[0]]: - if ( - hasattr(client, "op") - and isinstance(client.op, Elemwise) - and client.op != fill - ): - client_inputs = client.inputs[:] - client_inputs[cl_idx] = c - new_client = client.op(*client_inputs) - - # Add clients to new_client - fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[ - client.outputs[0] - ] - r = local_fill_sink.transform(fgraph, new_client.owner) - if not r: - continue - replacements.update(r) - return replacements + outputs = node.op.make_node(*inputs).outputs + + # Check if we need to propagate the fill to the new outputs + # It's enough to check the first output, as Elemwise outputs must all have the same shapes + # Note: There are orderings that may require fewer fills. + old_bcast_pattern = node.outputs[0].type.broadcastable + models_iter = iter(models) + while old_bcast_pattern != outputs[0].type.broadcastable: + model = next(models_iter) + # Only apply this model if it would actually do anything + if broadcasted_by(outputs[0], model): + outputs = [fill(model, output) for output in outputs] + + return outputs + + +# The rewrite is wrapped in an in2out GraphRewriter +# so that fill can be sinked until the terminal nodes in a single pass through the graph +# without triggering other rewrites after each local substitution +topological_fill_sink = in2out(local_fill_sink) +register_canonicalize(topological_fill_sink, "shape_unsafe") @register_specialize("shape_unsafe") diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index bacbc540c5..366e09ed4a 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -19,6 +19,7 @@ from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.printing import debugprint, pprint from pytensor.raise_op import Assert, CheckAndRaise +from pytensor.scalar import Composite, float64 from pytensor.tensor.basic import ( Alloc, Join, @@ -64,6 +65,7 @@ local_merge_alloc, local_useless_alloc, local_useless_elemwise, + topological_fill_sink, ) from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot from pytensor.tensor.rewriting.shape import ShapeFeature @@ -1992,3 +1994,19 @@ def test_shape_unsafe_tag(): fn = function([x, y], out, mode=mode.excluding("shape_unsafe")) with pytest.raises(ValueError): fn([0, 1], [2, 3, 4]), [0, 1] + + +def test_topological_fill_sink_multi_output_client(): + x = float64("x") + elem_op_with_2_outputs = Elemwise(Composite([x], [x + 1, x + 2])) + + x = pt.vector("x", shape=(1,)) + z = pt.vector("z", shape=(None,)) + bcast_x = pt.full_like(z, x) + out = pt.add(*elem_op_with_2_outputs(pt.exp(bcast_x))) + + fg = FunctionGraph([x, z], [out], copy_inputs=False) + topological_fill_sink.rewrite(fg) + [new_out] = fg.outputs + expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x)))) + assert equal_computations([new_out], [expected_out])