Skip to content

Fusion Segmenter asserts #1947

@jjsjann123

Description

@jjsjann123

🐛 Describe the bug

The same asserts on a few hugging face benchmarks.

RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp":2613, please report a bug to PyTorch.

Repro on devel

import torch

from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

t0 = torch.randn(2, 512, 128, device="cuda")
t1 = torch.randn(128, device="cuda")
t2 = torch.randn(2, 512, 1, device="cuda")  # alternatively (2, 512, 128)
t3 = torch.randn(2, 512, 128, device="cuda")
t4 = torch.randn(2, 512, 1, device="cuda")
t5 = torch.randn(2, 512, 128, device="cuda")
t6 = torch.randn(2, 512, 128, device="cuda")
t7 = torch.randn(2, 512, 128, device="cuda")
t8 = torch.randn(2, 512, 128, device="cuda")

def forward(arg331_1, arg329_1, arg213_1, arg346_1, arg211_1, _reshape_alias_default_2, arg199_1, arg348_1, arg226_1):
    mul_default_10 = torch.ops.nvprims.mul.default(arg331_1, arg331_1);  arg331_1 = None
    broadcast_in_dim_default_5 = torch.ops.nvprims.broadcast_in_dim.default(arg329_1, [2, 512, 128], [2]);  arg329_1 = None
    broadcast_in_dim_default_4 = torch.ops.nvprims.broadcast_in_dim.default(arg213_1, [2, 512, 128], [0, 1, 2])
    mul_default_14 = torch.ops.nvprims.mul.default(arg346_1, arg346_1);  arg346_1 = None
    broadcast_in_dim_default_3 = torch.ops.nvprims.broadcast_in_dim.default(arg211_1, [2, 512, 128], [0, 1, 2]);  arg211_1 = None
    div_default = torch.ops.nvprims.div.default(arg213_1, 128.0);  arg213_1 = None
    sub_default_4 = torch.ops.nvprims.sub.default(1.0, mul_default_10);  mul_default_10 = None
    mul_default_2 = torch.ops.nvprims.mul.default(_reshape_alias_default_2, broadcast_in_dim_default_5);  broadcast_in_dim_default_5 = None
    mul_default_15 = torch.ops.nvprims.mul.default(mul_default_14, 3.0);  mul_default_14 = None
    sub_default_1 = torch.ops.nvprims.sub.default(arg199_1, broadcast_in_dim_default_3);  arg199_1 = broadcast_in_dim_default_3 = None
    broadcast_in_dim_default_10 = torch.ops.nvprims.broadcast_in_dim.default(div_default, [2, 512, 128], [0, 1, 2]);  div_default = None
    mul_default_3 = torch.ops.nvprims.mul.default(mul_default_2, 128.0)
    convert_element_type_default_2 = torch.ops.nvprims.convert_element_type.default(mul_default_2, torch.float32)
    mul_default_1 = torch.ops.nvprims.mul.default(sub_default_1, broadcast_in_dim_default_4);  sub_default_1 = broadcast_in_dim_default_4 = None
    sum_default_2 = torch.ops.nvprims.sum.default(convert_element_type_default_2, [2]);  convert_element_type_default_2 = None
    mul_default_4 = torch.ops.nvprims.mul.default(mul_default_2, mul_default_1);  mul_default_2 = None
    mul_default_7 = torch.ops.nvprims.mul.default(_reshape_alias_default_2, mul_default_1);  _reshape_alias_default_2 = None
    broadcast_in_dim_default_6 = torch.ops.nvprims.broadcast_in_dim.default(sum_default_2, [2, 512, 1], [0, 1]);  sum_default_2 = None
    convert_element_type_default_3 = torch.ops.nvprims.convert_element_type.default(mul_default_4, torch.float32);  mul_default_4 = None
    convert_element_type_default_4 = torch.ops.nvprims.convert_element_type.default(mul_default_7, torch.float32);  mul_default_7 = None
    broadcast_in_dim_default_9 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_default_6, [2, 512, 128], [0, 1, 2]);  broadcast_in_dim_default_6 = None
    sum_default_3 = torch.ops.nvprims.sum.default(convert_element_type_default_3, [2]);  convert_element_type_default_3 = None
    sum_default_4 = torch.ops.nvprims.sum.default(convert_element_type_default_4, [0, 1]);  convert_element_type_default_4 = None
    sub_default_2 = torch.ops.nvprims.sub.default(mul_default_3, broadcast_in_dim_default_9);  mul_default_3 = broadcast_in_dim_default_9 = None
    broadcast_in_dim_default_7 = torch.ops.nvprims.broadcast_in_dim.default(sum_default_3, [2, 512, 1], [0, 1]);  sum_default_3 = None
    broadcast_in_dim_default_8 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_default_7, [2, 512, 128], [0, 1, 2]);  broadcast_in_dim_default_7 = None
    mul_default_5 = torch.ops.nvprims.mul.default(mul_default_1, broadcast_in_dim_default_8);  mul_default_1 = broadcast_in_dim_default_8 = None
    sub_default_3 = torch.ops.nvprims.sub.default(sub_default_2, mul_default_5);  sub_default_2 = mul_default_5 = None
    mul_default_6 = torch.ops.nvprims.mul.default(broadcast_in_dim_default_10, sub_default_3);  broadcast_in_dim_default_10 = sub_default_3 = None
    mul_default_8 = torch.ops.nvprims.mul.default(mul_default_6, arg348_1);  arg348_1 = None
    mul_default_9 = torch.ops.nvprims.mul.default(mul_default_6, arg226_1);  mul_default_6 = arg226_1 = None
    mul_default_11 = torch.ops.nvprims.mul.default(mul_default_8, sub_default_4);  mul_default_8 = sub_default_4 = None
    mul_default_17 = torch.ops.nvprims.mul.default(mul_default_9, 0.5);  mul_default_9 = None
    mul_default_12 = torch.ops.nvprims.mul.default(mul_default_11, 0.7978845608028654);  mul_default_11 = None
    mul_default_13 = torch.ops.nvprims.mul.default(mul_default_12, 0.044715)
    mul_default_16 = torch.ops.nvprims.mul.default(mul_default_13, mul_default_15);  mul_default_13 = mul_default_15 = None
    add_default_1 = torch.ops.nvprims.add.default(mul_default_12, mul_default_16);  mul_default_12 = mul_default_16 = None
    add_default_2 = torch.ops.nvprims.add.default(add_default_1, mul_default_17);  add_default_1 = mul_default_17 = None
    return (sum_default_4, add_default_2)

gm = make_fx(forward)(t0, t1, t2, t3, t4, t5, t6, t7, t8)
print(gm.graph)
execute(gm, t0, t1, t2, t3, t4, t5, t6, t7, t8, executor="nvfuser")

Versions

You'll need ToT devel, since there's some nvprim changes merged from last master pull. Here's the commit where I got the repro vvv

commit ac4de38c6ee53b366e85fdfe408c3642d32b57df (HEAD, origin/devel, origin/HEAD)
Merge: 631094891a aab10bce45
Author: Christian Sarofeen <[email protected]>
Date:   Tue Aug 30 15:44:39 2022 -0400

    Merge pull request #1945 from csarofeen/master_merge_0828
    
    Master merge 0828

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions