forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Labels
Description
🐛 Describe the bug
The same asserts on a few hugging face benchmarks.
RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp":2613, please report a bug to PyTorch.
Repro on devel
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
t0 = torch.randn(2, 512, 128, device="cuda")
t1 = torch.randn(128, device="cuda")
t2 = torch.randn(2, 512, 1, device="cuda") # alternatively (2, 512, 128)
t3 = torch.randn(2, 512, 128, device="cuda")
t4 = torch.randn(2, 512, 1, device="cuda")
t5 = torch.randn(2, 512, 128, device="cuda")
t6 = torch.randn(2, 512, 128, device="cuda")
t7 = torch.randn(2, 512, 128, device="cuda")
t8 = torch.randn(2, 512, 128, device="cuda")
def forward(arg331_1, arg329_1, arg213_1, arg346_1, arg211_1, _reshape_alias_default_2, arg199_1, arg348_1, arg226_1):
mul_default_10 = torch.ops.nvprims.mul.default(arg331_1, arg331_1); arg331_1 = None
broadcast_in_dim_default_5 = torch.ops.nvprims.broadcast_in_dim.default(arg329_1, [2, 512, 128], [2]); arg329_1 = None
broadcast_in_dim_default_4 = torch.ops.nvprims.broadcast_in_dim.default(arg213_1, [2, 512, 128], [0, 1, 2])
mul_default_14 = torch.ops.nvprims.mul.default(arg346_1, arg346_1); arg346_1 = None
broadcast_in_dim_default_3 = torch.ops.nvprims.broadcast_in_dim.default(arg211_1, [2, 512, 128], [0, 1, 2]); arg211_1 = None
div_default = torch.ops.nvprims.div.default(arg213_1, 128.0); arg213_1 = None
sub_default_4 = torch.ops.nvprims.sub.default(1.0, mul_default_10); mul_default_10 = None
mul_default_2 = torch.ops.nvprims.mul.default(_reshape_alias_default_2, broadcast_in_dim_default_5); broadcast_in_dim_default_5 = None
mul_default_15 = torch.ops.nvprims.mul.default(mul_default_14, 3.0); mul_default_14 = None
sub_default_1 = torch.ops.nvprims.sub.default(arg199_1, broadcast_in_dim_default_3); arg199_1 = broadcast_in_dim_default_3 = None
broadcast_in_dim_default_10 = torch.ops.nvprims.broadcast_in_dim.default(div_default, [2, 512, 128], [0, 1, 2]); div_default = None
mul_default_3 = torch.ops.nvprims.mul.default(mul_default_2, 128.0)
convert_element_type_default_2 = torch.ops.nvprims.convert_element_type.default(mul_default_2, torch.float32)
mul_default_1 = torch.ops.nvprims.mul.default(sub_default_1, broadcast_in_dim_default_4); sub_default_1 = broadcast_in_dim_default_4 = None
sum_default_2 = torch.ops.nvprims.sum.default(convert_element_type_default_2, [2]); convert_element_type_default_2 = None
mul_default_4 = torch.ops.nvprims.mul.default(mul_default_2, mul_default_1); mul_default_2 = None
mul_default_7 = torch.ops.nvprims.mul.default(_reshape_alias_default_2, mul_default_1); _reshape_alias_default_2 = None
broadcast_in_dim_default_6 = torch.ops.nvprims.broadcast_in_dim.default(sum_default_2, [2, 512, 1], [0, 1]); sum_default_2 = None
convert_element_type_default_3 = torch.ops.nvprims.convert_element_type.default(mul_default_4, torch.float32); mul_default_4 = None
convert_element_type_default_4 = torch.ops.nvprims.convert_element_type.default(mul_default_7, torch.float32); mul_default_7 = None
broadcast_in_dim_default_9 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_default_6, [2, 512, 128], [0, 1, 2]); broadcast_in_dim_default_6 = None
sum_default_3 = torch.ops.nvprims.sum.default(convert_element_type_default_3, [2]); convert_element_type_default_3 = None
sum_default_4 = torch.ops.nvprims.sum.default(convert_element_type_default_4, [0, 1]); convert_element_type_default_4 = None
sub_default_2 = torch.ops.nvprims.sub.default(mul_default_3, broadcast_in_dim_default_9); mul_default_3 = broadcast_in_dim_default_9 = None
broadcast_in_dim_default_7 = torch.ops.nvprims.broadcast_in_dim.default(sum_default_3, [2, 512, 1], [0, 1]); sum_default_3 = None
broadcast_in_dim_default_8 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_default_7, [2, 512, 128], [0, 1, 2]); broadcast_in_dim_default_7 = None
mul_default_5 = torch.ops.nvprims.mul.default(mul_default_1, broadcast_in_dim_default_8); mul_default_1 = broadcast_in_dim_default_8 = None
sub_default_3 = torch.ops.nvprims.sub.default(sub_default_2, mul_default_5); sub_default_2 = mul_default_5 = None
mul_default_6 = torch.ops.nvprims.mul.default(broadcast_in_dim_default_10, sub_default_3); broadcast_in_dim_default_10 = sub_default_3 = None
mul_default_8 = torch.ops.nvprims.mul.default(mul_default_6, arg348_1); arg348_1 = None
mul_default_9 = torch.ops.nvprims.mul.default(mul_default_6, arg226_1); mul_default_6 = arg226_1 = None
mul_default_11 = torch.ops.nvprims.mul.default(mul_default_8, sub_default_4); mul_default_8 = sub_default_4 = None
mul_default_17 = torch.ops.nvprims.mul.default(mul_default_9, 0.5); mul_default_9 = None
mul_default_12 = torch.ops.nvprims.mul.default(mul_default_11, 0.7978845608028654); mul_default_11 = None
mul_default_13 = torch.ops.nvprims.mul.default(mul_default_12, 0.044715)
mul_default_16 = torch.ops.nvprims.mul.default(mul_default_13, mul_default_15); mul_default_13 = mul_default_15 = None
add_default_1 = torch.ops.nvprims.add.default(mul_default_12, mul_default_16); mul_default_12 = mul_default_16 = None
add_default_2 = torch.ops.nvprims.add.default(add_default_1, mul_default_17); add_default_1 = mul_default_17 = None
return (sum_default_4, add_default_2)
gm = make_fx(forward)(t0, t1, t2, t3, t4, t5, t6, t7, t8)
print(gm.graph)
execute(gm, t0, t1, t2, t3, t4, t5, t6, t7, t8, executor="nvfuser")
Versions
You'll need ToT devel, since there's some nvprim changes merged from last master pull. Here's the commit where I got the repro vvv
commit ac4de38c6ee53b366e85fdfe408c3642d32b57df (HEAD, origin/devel, origin/HEAD)
Merge: 631094891a aab10bce45
Author: Christian Sarofeen <[email protected]>
Date: Tue Aug 30 15:44:39 2022 -0400
Merge pull request #1945 from csarofeen/master_merge_0828
Master merge 0828