forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Labels
Description
🐛 Describe the bug
The bug is happening during scheduling:
Prescheduled Fusion IR:
Inputs:
T0_g[ iS0{i0}, iS1{i1}, bS2{1}, bS3{1} ], float
T1_g[ iS4{i6}, iS5{i7}, bS6{1}, bS7{1} ], float
Outputs:
T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ], float
T8_g[ iS32{i0}, iS33{i1}, bS34{1}, bS35{1} ], float
%kernel_math {
T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]
= T0_g[ iS0{i0}, iS1{i1}, bS2{1}, bS3{1} ]
+ T1_g[ iS4{i6}, iS5{i7}, bS6{1}, bS7{1} ];
T3_l[ iS12{i0}, iS13{i1}, bS14{1}, bS15{1} ]
= T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]
<= double(0.0000000000000000);
T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ]
= where(T3_l[ iS12{i0}, iS13{i1}, bS14{1}, bS15{1} ]
, double(0.0000000000000000)
, T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]);
T5_l[ iS20{i0}, iS21{i1}, bS22{1}, bS23{1} ]
= T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ];
T6_l[ iS24{i0}, iS25{i1}, rS26{1}, rS27{1} ]
= reduction( T5_l[ iS20{i0}, iS21{i1}, bS22{1}, bS23{1} ], op = add, initial value = double(0.0000000000000000), allreduce = false )
T7_l[ iS28{i0}, iS29{i1}, bS30{1}, bS31{1} ]
= broadcast( T6_l[ iS24{i0}, iS25{i1}, rS26{1}, rS27{1} ] )
T8_g[ iS32{i0}, iS33{i1}, bS34{1}, bS35{1} ]
= T7_l[ iS28{i0}, iS29{i1}, bS30{1}, bS31{1} ]
/ double(1.0000000000000000);
}
Error:
RuntimeError: Merging IterDomains requires that their iteration types match.
Repro requires a bump in the devel fork from upstream to pick up the python frontend changes.
Repro:
import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType
def nvfuser_fusion_id5(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1, 1, 1], contiguous=[True, True, True, True], dtype=DataType.Float)
T1 = fd.define_tensor(symbolic_sizes=[-1, -1, 1, 1], contiguous=[True, True, True, True], dtype=DataType.Float)
T2 = fd.ops.add(T0, T1)
S3 = fd.define_constant(0.00000)
T4 = fd.ops.le(T2, S3)
S5 = fd.define_constant(0.00000)
T6 = fd.ops.where(T4, S5, T2)
T7 = fd.ops.cast(T6, dtype=DataType.Float)
T8 = fd.ops.sum(T7, axes=[3, 2], keepdim=False, dtype=DataType.Null)
T9 = fd.ops.broadcast_in_dim(T8, output_shape=[64, 512, 1, 1], broadcast_dims=[0, 1])
S10 = fd.define_constant(1.00000)
T11 = fd.ops.div(T9, S10)
fd.add_output(T6)
fd.add_output(T11)
fs = Fusion()
with FusionDefinition(fs) as fd:
nvfuser_fusion_id5(fd)
input1 = torch.randn(64, 512, 1, 1, device='cuda', dtype=torch.float32)
input2 = torch.randn(64, 512, 1, 1, device='cuda', dtype=torch.float32)
out = fs.execute([input1, input2])
Versions
Upstream TOT?