Skip to content

Mismatch in IterDomain Iteration Types in TorchBench functorch_dp_cifar10 #2008

@kevinstephano

Description

@kevinstephano

🐛 Describe the bug

The bug is happening during scheduling:

Prescheduled Fusion IR:

Inputs:
  T0_g[ iS0{i0}, iS1{i1}, bS2{1}, bS3{1} ], float
  T1_g[ iS4{i6}, iS5{i7}, bS6{1}, bS7{1} ], float
Outputs:
  T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ], float
  T8_g[ iS32{i0}, iS33{i1}, bS34{1}, bS35{1} ], float

%kernel_math {
T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]
   = T0_g[ iS0{i0}, iS1{i1}, bS2{1}, bS3{1} ]
   + T1_g[ iS4{i6}, iS5{i7}, bS6{1}, bS7{1} ];
T3_l[ iS12{i0}, iS13{i1}, bS14{1}, bS15{1} ]
   = T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]
   <= double(0.0000000000000000);
T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ]
   = where(T3_l[ iS12{i0}, iS13{i1}, bS14{1}, bS15{1} ]
  , double(0.0000000000000000)
  , T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]);
T5_l[ iS20{i0}, iS21{i1}, bS22{1}, bS23{1} ]
   = T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ];
T6_l[ iS24{i0}, iS25{i1}, rS26{1}, rS27{1} ]
   = reduction( T5_l[ iS20{i0}, iS21{i1}, bS22{1}, bS23{1} ], op = add, initial value = double(0.0000000000000000), allreduce = false )
T7_l[ iS28{i0}, iS29{i1}, bS30{1}, bS31{1} ]
   = broadcast( T6_l[ iS24{i0}, iS25{i1}, rS26{1}, rS27{1} ] )
T8_g[ iS32{i0}, iS33{i1}, bS34{1}, bS35{1} ]
   = T7_l[ iS28{i0}, iS29{i1}, bS30{1}, bS31{1} ]
   / double(1.0000000000000000);
}

Error:

RuntimeError: Merging IterDomains requires that their iteration types match.

Repro requires a bump in the devel fork from upstream to pick up the python frontend changes.

Repro:

import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType

def nvfuser_fusion_id5(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1, 1, 1], contiguous=[True, True, True, True], dtype=DataType.Float)
    T1 = fd.define_tensor(symbolic_sizes=[-1, -1, 1, 1], contiguous=[True, True, True, True], dtype=DataType.Float)
    T2 = fd.ops.add(T0, T1)
    S3 = fd.define_constant(0.00000)
    T4 = fd.ops.le(T2, S3)
    S5 = fd.define_constant(0.00000)
    T6 = fd.ops.where(T4, S5, T2)
    T7 = fd.ops.cast(T6, dtype=DataType.Float)
    T8 = fd.ops.sum(T7, axes=[3, 2], keepdim=False, dtype=DataType.Null)
    T9 = fd.ops.broadcast_in_dim(T8, output_shape=[64, 512, 1, 1], broadcast_dims=[0, 1])
    S10 = fd.define_constant(1.00000)
    T11 = fd.ops.div(T9, S10)
    fd.add_output(T6)
    fd.add_output(T11)

fs = Fusion()
with FusionDefinition(fs) as fd:
    nvfuser_fusion_id5(fd)

input1 = torch.randn(64, 512, 1, 1, device='cuda', dtype=torch.float32)
input2 = torch.randn(64, 512, 1, 1, device='cuda', dtype=torch.float32)

out = fs.execute([input1, input2])

Versions

Upstream TOT?

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions