diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 2b120349ea01a..0ff58e49008cd 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -959,6 +959,7 @@ def bar(): "autotuning", "graph_region_expansion", "hierarchical_compile", + "compute_dependencies", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index 6f15b493ec1bd..0cc2c9e3a7836 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -445,12 +445,17 @@ def run_aot_eager(self, f, orig_args, _dynamic=False): graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() return [aot_eager_args, result, graph] - def run_inductor(self, f, orig_args, _dynamic=False): + def run_inductor( + self, + f, + orig_args, + _dynamic=False, + log_module="torch._inductor.compile_fx", + log_function="post_grad_graphs", + ): compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) + log_stream, ctx = logs_to_string(log_module, log_function) result = None with ctx(): result = torch.compile( @@ -1733,6 +1738,41 @@ def f(x, w): y = f(x, w) self.assertEqual(y, x.sin()) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_scheduling_with_multiple_mutates(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor! x, Tensor! y, Tensor z) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo(x, y, z): + pass + + def func(x, w): + a = torch.empty_like(x) # buf0 + b = torch.empty_like(x) # buf1 + torch.ops.mylib.foo(a, b, x) # buf2, buf3, buf4 + c = torch.mm(a, w) # buf5 + torch.ops.mylib.foo(c, b, x) # buf6, buf7, buf8 + return c + + input = torch.rand(2, 2) + weight = torch.rand(2, 2) + [inductor_args, output, graph_inductor] = self.run_inductor( + func, + [input, weight], + False, + "torch._inductor.scheduler", + "compute_dependencies", + ) + name_to_users = eval(graph_inductor) + self.assertNotEqual(name_to_users["buf1"], name_to_users["buf5"]) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 687ba95e1dd1d..f855cc1de922d 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -74,6 +74,9 @@ log = logging.getLogger(__name__) fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") +compute_dependencies_log = torch._logging.getArtifactLogger( + __name__, "compute_dependencies" +) PartitionType = list["BaseSchedulerNode"] @@ -2278,6 +2281,15 @@ def __add__(self, other: DedupList[T]) -> DedupList[T]: for node in self.nodes: for buf1 in node.get_outputs(): buf1_name = buf1.get_name() + # This is for handling auto functionized ops which return None + # and mutate more than 1 inputs, we shouldn't let them all + # point to the same user list since buffers in the aliases + # list might not be alias to each other. + if ( + isinstance(buf1.node.layout, ir.NoneLayout) + and len(buf1.get_aliases()) > 1 + ): + continue for buf2_name in buf1.get_aliases(): if buf1_name in name_to_users and buf2_name in name_to_users: # merge the two @@ -2445,6 +2457,18 @@ def add_user( for name in self.name_to_donated_buffer: self.name_to_donated_buffer[name].set_users(name_to_users[name].items) + # For debug logging + logbuf = IndentedBuffer() + logbuf.splice("{") + for key, value in name_to_users.items(): + with logbuf.indent(): + users = [v.get_name() for v in value.items] + logbuf.splice(f"'{key}': {users},") + logbuf.splice("}") + str = logbuf.getrawvalue().rstrip() + compute_dependencies_log.debug("BUFFER USER LIST\n") + compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str) + def dead_node_elimination(self) -> None: """ Remove any nodes without users diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 3821218cefec9..f56f0165b206f 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -252,6 +252,7 @@ def set_logs( graph_region_expansion: bool = False, inductor_metrics: bool = False, hierarchical_compile: bool = False, + compute_dependencies: bool = False, ) -> None: """ Sets the log level for individual components and toggles individual log @@ -565,6 +566,7 @@ def _set_logs(**kwargs) -> None: graph_region_expansion=graph_region_expansion, inductor_metrics=inductor_metrics, hierarchical_compile=hierarchical_compile, + compute_dependencies=compute_dependencies, ) diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 62e5d9b7064ca..3c6f092ed4d24 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -183,6 +183,7 @@ ) register_artifact("perf_hints", "", off_by_default=True) register_artifact("onnx_diagnostics", "", off_by_default=True) +register_artifact("compute_dependencies", "", off_by_default=True) register_artifact( "fusion", "Detailed Inductor fusion decisions. More detailed than 'schedule'",