Skip to content

Commit 0b57581

Browse files
guangy10pytorchmergebot
authored andcommitted
[pytorch] Disable fast path in MultiheadAttention in Export (pytorch#106824)
Summary: We are seeing `aten._native_multi_head_attention` op (not in core Aten op set) is left in the exported graph and causes problems in the downstream at runtime. Two proposed solutions: 1. Disable fast path while tracing to leverage the non-optimized path to get decomp, that way, the blamed op won't show up in the exported graph 2. Add a decomp rule for `aten._native_multi_head_attention` After discussing with kimishpatel and bdhirsh, ROCm#1 is preferred and verified it could immediately unblock the critical model enablement work for PP. Test Plan: CI Differential Revision: D48169806 Pull Request resolved: pytorch#106824 Approved by: https://github.com/kimishpatel
1 parent 7f9d1ca commit 0b57581

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

torch/nn/modules/activation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,14 @@ def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool:
895895
return False
896896

897897

898+
def _is_make_fx_tracing():
899+
if not torch.jit.is_scripting():
900+
torch_dispatch_mode_stack = torch.utils._python_dispatch._get_current_dispatch_mode_stack()
901+
return any(type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode for x in torch_dispatch_mode_stack)
902+
else:
903+
return False
904+
905+
898906
class MultiheadAttention(Module):
899907
r"""Allows the model to jointly attend to information
900908
from different representation subspaces as described in the paper:
@@ -1169,6 +1177,8 @@ def forward(
11691177
# generator expressions.
11701178
if torch.overrides.has_torch_function(tensor_args):
11711179
why_not_fast_path = "some Tensor argument has_torch_function"
1180+
elif _is_make_fx_tracing():
1181+
why_not_fast_path = "we are running make_fx tracing"
11721182
elif not all(_check_arg_device(x) for x in tensor_args):
11731183
why_not_fast_path = ("some Tensor argument's device is neither one of "
11741184
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")

0 commit comments

Comments
 (0)