Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ def bar():
"autotuning",
"graph_region_expansion",
"hierarchical_compile",
"compute_dependencies",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:
Expand Down
48 changes: 44 additions & 4 deletions test/inductor/test_auto_functionalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,17 @@ def run_aot_eager(self, f, orig_args, _dynamic=False):
graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
return [aot_eager_args, result, graph]

def run_inductor(self, f, orig_args, _dynamic=False):
def run_inductor(
self,
f,
orig_args,
_dynamic=False,
log_module="torch._inductor.compile_fx",
log_function="post_grad_graphs",
):
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)

log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
log_stream, ctx = logs_to_string(log_module, log_function)
result = None
with ctx():
result = torch.compile(
Expand Down Expand Up @@ -1733,6 +1738,41 @@ def f(x, w):
y = f(x, w)
self.assertEqual(y, x.sin())

@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_scheduling_with_multiple_mutates(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor! x, Tensor! y, Tensor z) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)

@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo(x, y, z):
pass

def func(x, w):
a = torch.empty_like(x) # buf0
b = torch.empty_like(x) # buf1
torch.ops.mylib.foo(a, b, x) # buf2, buf3, buf4
c = torch.mm(a, w) # buf5
torch.ops.mylib.foo(c, b, x) # buf6, buf7, buf8
return c

input = torch.rand(2, 2)
weight = torch.rand(2, 2)
[inductor_args, output, graph_inductor] = self.run_inductor(
func,
[input, weight],
False,
"torch._inductor.scheduler",
"compute_dependencies",
)
name_to_users = eval(graph_inductor)
self.assertNotEqual(name_to_users["buf1"], name_to_users["buf5"])


if __name__ == "__main__":
from torch._inductor.test_case import run_tests
Expand Down
24 changes: 24 additions & 0 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
log = logging.getLogger(__name__)
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering")
compute_dependencies_log = torch._logging.getArtifactLogger(
__name__, "compute_dependencies"
)

PartitionType = list["BaseSchedulerNode"]

Expand Down Expand Up @@ -2278,6 +2281,15 @@ def __add__(self, other: DedupList[T]) -> DedupList[T]:
for node in self.nodes:
for buf1 in node.get_outputs():
buf1_name = buf1.get_name()
# This is for handling auto functionized ops which return None
# and mutate more than 1 inputs, we shouldn't let them all
# point to the same user list since buffers in the aliases
# list might not be alias to each other.
if (
isinstance(buf1.node.layout, ir.NoneLayout)
and len(buf1.get_aliases()) > 1
):
continue
for buf2_name in buf1.get_aliases():
if buf1_name in name_to_users and buf2_name in name_to_users:
# merge the two
Expand Down Expand Up @@ -2445,6 +2457,18 @@ def add_user(
for name in self.name_to_donated_buffer:
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)

# For debug logging
logbuf = IndentedBuffer()
logbuf.splice("{")
for key, value in name_to_users.items():
with logbuf.indent():
users = [v.get_name() for v in value.items]
logbuf.splice(f"'{key}': {users},")
logbuf.splice("}")
str = logbuf.getrawvalue().rstrip()
compute_dependencies_log.debug("BUFFER USER LIST\n")
compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str)

def dead_node_elimination(self) -> None:
"""
Remove any nodes without users
Expand Down
2 changes: 2 additions & 0 deletions torch/_logging/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def set_logs(
graph_region_expansion: bool = False,
inductor_metrics: bool = False,
hierarchical_compile: bool = False,
compute_dependencies: bool = False,
) -> None:
"""
Sets the log level for individual components and toggles individual log
Expand Down Expand Up @@ -565,6 +566,7 @@ def _set_logs(**kwargs) -> None:
graph_region_expansion=graph_region_expansion,
inductor_metrics=inductor_metrics,
hierarchical_compile=hierarchical_compile,
compute_dependencies=compute_dependencies,
)


Expand Down
1 change: 1 addition & 0 deletions torch/_logging/_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
)
register_artifact("perf_hints", "", off_by_default=True)
register_artifact("onnx_diagnostics", "", off_by_default=True)
register_artifact("compute_dependencies", "", off_by_default=True)
register_artifact(
"fusion",
"Detailed Inductor fusion decisions. More detailed than 'schedule'",
Expand Down