Skip to content

Commit a0ffd4f

Browse files
praguptacharlifu
andcommitted
[release/2.8] [Bugfix][Inductor] Fix dependency list merged incorrectly for a custo… (#2419)
…m op with multiple mutated inputs and None return type. (pytorch#157133) This is an attempt to fix a memory allocation issue when using `torch.compile` with a custom layernorm kernel in vllm: ```C++ // In-place fused Add and RMS Normalization. ops.def( "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ``` We observed abnormal extra memory allocations with this op enabled using `torch.compile`: <img width="738" alt="{374E9FCF-FB46-4750-8B60-D31E3ADCE00A}" src="https://github.com/user-attachments/assets/6c45e1aa-ccde-4c56-99dc-bf4776d699d5" /> and without this op: <img width="738" alt="{9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56}" src="https://github.com/user-attachments/assets/56e2ee43-ab87-492d-834c-69e9cafbb0df" /> After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs `Tensor input` and `Tensor residual` should share a same dependency list, which makes it can not reuse the buffer of `Tensor input`. ``` buf1.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf16.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] ``` ``` op13: ExternKernelSchedulerNode(FallbackKernel) op13.writes = [ StarDep(name='buf17', mode=None), StarDep(name='buf18', mode=None), StarDep(name='buf19', mode=None)] op13.unmet_dependencies = [ StarDep(name='buf13', mode=None), StarDep(name='buf16', mode=None), WeakDep(name='buf11', mutating_buf='buf18'), WeakDep(name='buf12', mutating_buf='buf18'), WeakDep(name='buf13', mutating_buf='buf18'), WeakDep(name='buf2', mutating_buf='buf18'), WeakDep(name='buf3', mutating_buf='buf18')] op13.met_dependencies = [StarDep(name='arg11_1', mode=None)] op13.outputs = [ buf17: FallbackKernel buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf17.aliases = ['buf16', 'buf1'] buf17.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf18: MutationOutput buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf18.mutations = ['buf16'] buf18.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True), ] buf19: MutationOutput buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf19.mutations = ['buf1'] buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)] ] op13.node.kernel = torch.ops._C.fused_add_rms_norm.default ``` Here we can see `buf16` shares the same dependency list with `buf1` because `buf16` and `buf1` are in the aliases list of `buf17`. This is incorrect since those two are two separate tensors. And this makes the compiler could not reuse `buf16` for subsequent ops. Pull Request resolved: pytorch#157133 Approved by: https://github.com/jansel (cherry picked from commit 02724b5) Fixes #ISSUE_NUMBER Co-authored-by: charlifu <[email protected]>
1 parent 2d72fcd commit a0ffd4f

File tree

5 files changed

+72
-4
lines changed

5 files changed

+72
-4
lines changed

test/dynamo/test_logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,7 @@ def bar():
959959
"autotuning",
960960
"graph_region_expansion",
961961
"hierarchical_compile",
962+
"compute_dependencies",
962963
}
963964
for name in torch._logging._internal.log_registry.artifact_names:
964965
if name not in exclusions:

test/inductor/test_auto_functionalize.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -445,12 +445,17 @@ def run_aot_eager(self, f, orig_args, _dynamic=False):
445445
graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
446446
return [aot_eager_args, result, graph]
447447

448-
def run_inductor(self, f, orig_args, _dynamic=False):
448+
def run_inductor(
449+
self,
450+
f,
451+
orig_args,
452+
_dynamic=False,
453+
log_module="torch._inductor.compile_fx",
454+
log_function="post_grad_graphs",
455+
):
449456
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
450457

451-
log_stream, ctx = logs_to_string(
452-
"torch._inductor.compile_fx", "post_grad_graphs"
453-
)
458+
log_stream, ctx = logs_to_string(log_module, log_function)
454459
result = None
455460
with ctx():
456461
result = torch.compile(
@@ -1733,6 +1738,41 @@ def f(x, w):
17331738
y = f(x, w)
17341739
self.assertEqual(y, x.sin())
17351740

1741+
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
1742+
def test_scheduling_with_multiple_mutates(self):
1743+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1744+
torch.library.define(
1745+
"mylib::foo",
1746+
"(Tensor! x, Tensor! y, Tensor z) -> ()",
1747+
tags=torch.Tag.pt2_compliant_tag,
1748+
lib=lib,
1749+
)
1750+
1751+
@torch.library.impl("mylib::foo", "cpu", lib=lib)
1752+
@torch._dynamo.disable
1753+
def foo(x, y, z):
1754+
pass
1755+
1756+
def func(x, w):
1757+
a = torch.empty_like(x) # buf0
1758+
b = torch.empty_like(x) # buf1
1759+
torch.ops.mylib.foo(a, b, x) # buf2, buf3, buf4
1760+
c = torch.mm(a, w) # buf5
1761+
torch.ops.mylib.foo(c, b, x) # buf6, buf7, buf8
1762+
return c
1763+
1764+
input = torch.rand(2, 2)
1765+
weight = torch.rand(2, 2)
1766+
[inductor_args, output, graph_inductor] = self.run_inductor(
1767+
func,
1768+
[input, weight],
1769+
False,
1770+
"torch._inductor.scheduler",
1771+
"compute_dependencies",
1772+
)
1773+
name_to_users = eval(graph_inductor)
1774+
self.assertNotEqual(name_to_users["buf1"], name_to_users["buf5"])
1775+
17361776

17371777
if __name__ == "__main__":
17381778
from torch._inductor.test_case import run_tests

torch/_inductor/scheduler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@
7474
log = logging.getLogger(__name__)
7575
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
7676
loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering")
77+
compute_dependencies_log = torch._logging.getArtifactLogger(
78+
__name__, "compute_dependencies"
79+
)
7780

7881
PartitionType = list["BaseSchedulerNode"]
7982

@@ -2278,6 +2281,15 @@ def __add__(self, other: DedupList[T]) -> DedupList[T]:
22782281
for node in self.nodes:
22792282
for buf1 in node.get_outputs():
22802283
buf1_name = buf1.get_name()
2284+
# This is for handling auto functionized ops which return None
2285+
# and mutate more than 1 inputs, we shouldn't let them all
2286+
# point to the same user list since buffers in the aliases
2287+
# list might not be alias to each other.
2288+
if (
2289+
isinstance(buf1.node.layout, ir.NoneLayout)
2290+
and len(buf1.get_aliases()) > 1
2291+
):
2292+
continue
22812293
for buf2_name in buf1.get_aliases():
22822294
if buf1_name in name_to_users and buf2_name in name_to_users:
22832295
# merge the two
@@ -2445,6 +2457,18 @@ def add_user(
24452457
for name in self.name_to_donated_buffer:
24462458
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
24472459

2460+
# For debug logging
2461+
logbuf = IndentedBuffer()
2462+
logbuf.splice("{")
2463+
for key, value in name_to_users.items():
2464+
with logbuf.indent():
2465+
users = [v.get_name() for v in value.items]
2466+
logbuf.splice(f"'{key}': {users},")
2467+
logbuf.splice("}")
2468+
str = logbuf.getrawvalue().rstrip()
2469+
compute_dependencies_log.debug("BUFFER USER LIST\n")
2470+
compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str)
2471+
24482472
def dead_node_elimination(self) -> None:
24492473
"""
24502474
Remove any nodes without users

torch/_logging/_internal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def set_logs(
252252
graph_region_expansion: bool = False,
253253
inductor_metrics: bool = False,
254254
hierarchical_compile: bool = False,
255+
compute_dependencies: bool = False,
255256
) -> None:
256257
"""
257258
Sets the log level for individual components and toggles individual log
@@ -565,6 +566,7 @@ def _set_logs(**kwargs) -> None:
565566
graph_region_expansion=graph_region_expansion,
566567
inductor_metrics=inductor_metrics,
567568
hierarchical_compile=hierarchical_compile,
569+
compute_dependencies=compute_dependencies,
568570
)
569571

570572

torch/_logging/_registrations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@
183183
)
184184
register_artifact("perf_hints", "", off_by_default=True)
185185
register_artifact("onnx_diagnostics", "", off_by_default=True)
186+
register_artifact("compute_dependencies", "", off_by_default=True)
186187
register_artifact(
187188
"fusion",
188189
"Detailed Inductor fusion decisions. More detailed than 'schedule'",

0 commit comments

Comments
 (0)