Closed
Description
🐛 Describe the bug
TEST_F(NVFuserTest, FusionScheduleTransposeRepro1_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeSymbolicTensor(4);
auto tv1 = makeConcreteTensor({-1, -1, -1, 1});
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = add(tv0, tv1);
fusion.addOutput(tv2);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input0 = at::randn({1, 1, 333, 1}, options);
at::Tensor input1 = at::randn({1, 1, 333, 1}, options);
auto lparams = scheduleTranspose(&fusion, {input0, input1});
fusion.print();
FusionExecutor fe;
fe.compileFusion(&fusion, {input0, input1}, lparams);
auto outputs = fe.runFusion({input0, input1}, lparams);
auto tv_ref = input0 + input1;
testValidate(&fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
}
%kernel {
T3_l[ iblockIdx.x67{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS68{1}, iUR111{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x112{64}, iS110{1} ] ca_pos( 2 )
= T0_g[ iS76{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iS77{1}, iS116{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, iS117{64}, iS115{1} ];
T4_s[ iblockIdx.x49{( ceilDiv(( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) ), 1) )}, iUS50{1}, iUR86{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x87{64}, iS85{1} ] ca_pos( 2 )
= T1_g[ iS58{( ceilDiv(( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) ), 1) )}, iS59{1}, iS81{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, iS82{64}, iS80{1} ];
T5_l[ iblockIdx.x40{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS41{1}, iS106{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x107{64}, iS105{1} ] ca_pos( 2 ) produce_pos( 2)
= T3_l[ iblockIdx.x67{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS68{1}, iUR111{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x112{64}, iS110{1} ] ca_pos( 2 )
+ T4_s[ iblockIdx.x49{( ceilDiv(( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) ), 1) )}, iUS50{1}, iUR86{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x87{64}, iS85{1} ] ca_pos( 2 );
T2_g[ iblockIdx.x31{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS32{1}, iUR101{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x102{64}, iS100{1} ] ca_pos( 2 ) produce_pos( 2)
= T5_l[ iblockIdx.x40{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS41{1}, iS106{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x107{64}, iS105{1} ] ca_pos( 2 ) produce_pos( 2);
TransformPrinter :
T0_g[ iS76{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iS77{1}, iS116{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, iS117{64}, iS115{1} ]
root domain : (iS0{i1},iS1{i2},iS2{i3},iS3{i4})
Merge: iS0{i1} and iS1{i2} -> iS69{( i1 * i2 )}
Merge: iS69{( i1 * i2 )} and iS3{i4} -> iS70{( ( i1 * i2 ) * i4 )}
Split: iS70{( ( i1 * i2 ) * i4 )} by factor 8 -> iS71{( ceilDiv(( ( i1 * i2 ) * i4 ), 8) )}, iS72{8}, start offset: 0, stop offset: 0
Split: iS2{i3} by factor 8 -> iS73{( ceilDiv(i3, 8) )}, iS74{8}, start offset: 0, stop offset: 0
Merge: iS71{( ceilDiv(( ( i1 * i2 ) * i4 ), 8) )} and iS73{( ceilDiv(i3, 8) )} -> iS75{( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) )}
Split: iS75{( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) )} by factor 1 -> iS76{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iS77{1}, start offset: 0, stop offset: 0
Merge: iS74{8} and iS72{8} -> iS113{( 8 * 8 )}
Split: iS113{( 8 * 8 )} by factor 1 -> iS114{( ceilDiv(( 8 * 8 ), 1) )}, iS115{1}, start offset: 0, stop offset: 0
Split: iS114{( ceilDiv(( 8 * 8 ), 1) )} by factor 64 -> iS116{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, iS117{64}, start offset: 0, stop offset: 0
T3_l[ iblockIdx.x67{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS68{1}, iUR111{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x112{64}, iS110{1} ] ca_pos( 2 )
root domain : (iS12{i1},iS13{i2},iS14{i3},iS15{i4})
Merge: iS12{i1} and iS13{i2} -> iS60{( i1 * i2 )}
Merge: iS60{( i1 * i2 )} and iS15{i4} -> iS61{( ( i1 * i2 ) * i4 )}
Split: iS61{( ( i1 * i2 ) * i4 )} by factor 8 -> iS62{( ceilDiv(( ( i1 * i2 ) * i4 ), 8) )}, iS63{8}, start offset: 0, stop offset: 0
Split: iS14{i3} by factor 8 -> iS64{( ceilDiv(i3, 8) )}, iS65{8}, start offset: 0, stop offset: 0
Merge: iS62{( ceilDiv(( ( i1 * i2 ) * i4 ), 8) )} and iS64{( ceilDiv(i3, 8) )} -> iS66{( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) )}
Split: iS66{( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) )} by factor 1 -> iblockIdx.x67{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS68{1}, start offset: 0, stop offset: 0
Merge: iS65{8} and iS63{8} -> iS108{( 8 * 8 )}
Split: iS108{( 8 * 8 )} by factor 1 -> iS109{( ceilDiv(( 8 * 8 ), 1) )}, iS110{1}, start offset: 0, stop offset: 0
Split: iS109{( ceilDiv(( 8 * 8 ), 1) )} by factor 64 -> iUR111{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x112{64}, start offset: 0, stop offset: 0
T1_g[ iS58{( ceilDiv(( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) ), 1) )}, iS59{1}, iS81{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, iS82{64}, iS80{1} ]
root domain : (iS4{i5},iS5{i6},iS6{i7},bS7{1})
Merge: iS4{i5} and iS5{i6} -> iS51{( i5 * i6 )}
Merge: iS51{( i5 * i6 )} and bS7{1} -> iS52{( ( i5 * i6 ) * 1 )}
Split: iS52{( ( i5 * i6 ) * 1 )} by factor 8 -> iS53{( ceilDiv(( ( i5 * i6 ) * 1 ), 8) )}, iS54{8}, start offset: 0, stop offset: 0
Split: iS6{i7} by factor 8 -> iS55{( ceilDiv(i7, 8) )}, iS56{8}, start offset: 0, stop offset: 0
Merge: iS53{( ceilDiv(( ( i5 * i6 ) * 1 ), 8) )} and iS55{( ceilDiv(i7, 8) )} -> iS57{( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) )}
Split: iS57{( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) )} by factor 1 -> iS58{( ceilDiv(( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) ), 1) )}, iS59{1}, start offset: 0, stop offset: 0
Merge: iS54{8} and iS56{8} -> iS78{( 8 * 8 )}
Split: iS78{( 8 * 8 )} by factor 1 -> iS79{( ceilDiv(( 8 * 8 ), 1) )}, iS80{1}, start offset: 0, stop offset: 0
Split: iS79{( ceilDiv(( 8 * 8 ), 1) )} by factor 64 -> iS81{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, iS82{64}, start offset: 0, stop offset: 0
T4_s[ iblockIdx.x49{( ceilDiv(( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) ), 1) )}, iUS50{1}, iUR86{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x87{64}, iS85{1} ] ca_pos( 2 )
root domain : (iS16{i5},iS17{i6},iS18{i7},bS19{1})
Merge: iS16{i5} and iS17{i6} -> iS42{( i5 * i6 )}
Merge: iS42{( i5 * i6 )} and bS19{1} -> iS43{( ( i5 * i6 ) * 1 )}
Split: iS43{( ( i5 * i6 ) * 1 )} by factor 8 -> iS44{( ceilDiv(( ( i5 * i6 ) * 1 ), 8) )}, iS45{8}, start offset: 0, stop offset: 0
Split: iS18{i7} by factor 8 -> iS46{( ceilDiv(i7, 8) )}, iS47{8}, start offset: 0, stop offset: 0
Merge: iS44{( ceilDiv(( ( i5 * i6 ) * 1 ), 8) )} and iS46{( ceilDiv(i7, 8) )} -> iS48{( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) )}
Split: iS48{( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) )} by factor 1 -> iblockIdx.x49{( ceilDiv(( ( ceilDiv(( ( i5 * i6 ) * 1 ), 8) ) * ( ceilDiv(i7, 8) ) ), 1) )}, iUS50{1}, start offset: 0, stop offset: 0
Merge: iS45{8} and iS47{8} -> iS83{( 8 * 8 )}
Split: iS83{( 8 * 8 )} by factor 1 -> iS84{( ceilDiv(( 8 * 8 ), 1) )}, iS85{1}, start offset: 0, stop offset: 0
Split: iS84{( ceilDiv(( 8 * 8 ), 1) )} by factor 64 -> iUR86{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x87{64}, start offset: 0, stop offset: 0
T5_l[ iblockIdx.x40{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS41{1}, iS106{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x107{64}, iS105{1} ] ca_pos( 2 ) produce_pos( 2)
root domain : (iS8{i1},iS9{i2},iS10{i3},iS11{i4})
Merge: iS8{i1} and iS9{i2} -> iS33{( i1 * i2 )}
Merge: iS33{( i1 * i2 )} and iS11{i4} -> iS34{( ( i1 * i2 ) * i4 )}
Split: iS34{( ( i1 * i2 ) * i4 )} by factor 8 -> iS35{( ceilDiv(( ( i1 * i2 ) * i4 ), 8) )}, iS36{8}, start offset: 0, stop offset: 0
Split: iS10{i3} by factor 8 -> iS37{( ceilDiv(i3, 8) )}, iS38{8}, start offset: 0, stop offset: 0
Merge: iS35{( ceilDiv(( ( i1 * i2 ) * i4 ), 8) )} and iS37{( ceilDiv(i3, 8) )} -> iS39{( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) )}
Split: iS39{( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) )} by factor 1 -> iblockIdx.x40{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS41{1}, start offset: 0, stop offset: 0
Merge: iS38{8} and iS36{8} -> iS103{( 8 * 8 )}
Split: iS103{( 8 * 8 )} by factor 1 -> iS104{( ceilDiv(( 8 * 8 ), 1) )}, iS105{1}, start offset: 0, stop offset: 0
Split: iS104{( ceilDiv(( 8 * 8 ), 1) )} by factor 64 -> iS106{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x107{64}, start offset: 0, stop offset: 0
T2_g[ iblockIdx.x31{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS32{1}, iUR101{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x102{64}, iS100{1} ] ca_pos( 2 ) produce_pos( 2)
root domain : (iS20{i1},iS21{i2},iS22{i3},iS23{i4})
Merge: iS20{i1} and iS21{i2} -> iS24{( i1 * i2 )}
Merge: iS24{( i1 * i2 )} and iS23{i4} -> iS25{( ( i1 * i2 ) * i4 )}
Split: iS25{( ( i1 * i2 ) * i4 )} by factor 8 -> iS26{( ceilDiv(( ( i1 * i2 ) * i4 ), 8) )}, iS27{8}, start offset: 0, stop offset: 0
Split: iS22{i3} by factor 8 -> iS28{( ceilDiv(i3, 8) )}, iS29{8}, start offset: 0, stop offset: 0
Merge: iS26{( ceilDiv(( ( i1 * i2 ) * i4 ), 8) )} and iS28{( ceilDiv(i3, 8) )} -> iS30{( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) )}
Split: iS30{( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) )} by factor 1 -> iblockIdx.x31{( ceilDiv(( ( ceilDiv(( ( i1 * i2 ) * i4 ), 8) ) * ( ceilDiv(i3, 8) ) ), 1) )}, iUS32{1}, start offset: 0, stop offset: 0
Merge: iS29{8} and iS27{8} -> iS98{( 8 * 8 )}
Split: iS98{( 8 * 8 )} by factor 1 -> iS99{( ceilDiv(( 8 * 8 ), 1) )}, iS100{1}, start offset: 0, stop offset: 0
Split: iS99{( ceilDiv(( 8 * 8 ), 1) )} by factor 64 -> iUR101{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, ithreadIdx.x102{64}, start offset: 0, stop offset: 0
}
C++ exception with description "root_ind != nullptr INTERNAL ASSERT FAILED at "/home/gaoxiang/nvfuser3/torch/csrc/jit/codegen/cuda/index_compute.cpp":1503, please report a bug to PyTorch. Couldn't find root mapping for T1_g[ iS58{( ceilDiv(( ( ceilDiv(( ( T0.size[0] * T0.size[1] ) * 1 ), 8) ) * ( ceilDiv(T0.size[2], 8) ) ), 1) )}, iS59{1}, iS81{( ceilDiv(( ceilDiv(( 8 * 8 ), 1) ), 64) )}, iS82{64}, iS80{1} ] dim: 1 id: iS138{T0.size[1]}
Exception raised from getGlobalProducerStridedIndices at /home/gaoxiang/nvfuser3/torch/csrc/jit/codegen/cuda/index_compute.cpp:1503 (most recent call first):
frame #0: <unknown function> + 0x87810 (0x7efd5cefc810 in /home/gaoxiang/nvfuser3/build/lib/libc10.so)
frame #1: <unknown function> + 0x877a0 (0x7efd5cefc7a0 in /home/gaoxiang/nvfuser3/build/lib/libc10.so)
frame #2: <unknown function> + 0x876a0 (0x7efd5cefc6a0 in /home/gaoxiang/nvfuser3/build/lib/libc10.so)
frame #3: <unknown function> + 0x89908 (0x7efd5cefe908 in /home/gaoxiang/nvfuser3/build/lib/libc10.so)
frame #4: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x65 (0x7efd5cefcf25 in /home/gaoxiang/nvfuser3/build/lib/libc10.so)
frame #5: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x7a (0x7efd5cefaa5a in /home/gaoxiang/nvfuser3/build/lib/libc10.so)
frame #6: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x5d (0x7efd5cefaccd in /home/gaoxiang/nvfuser3/build/lib/libc10.so)
frame #7: <unknown function> + 0x6446769 (0x7efd72fa4769 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #8: <unknown function> + 0x644d738 (0x7efd72fab738 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #9: <unknown function> + 0x644d92a (0x7efd72fab92a in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #10: torch::jit::fuser::cuda::IndexLowering::lowerSrcIndex(torch::jit::fuser::cuda::Val*, torch::jit::fuser::cuda::Val*) const + 0xd1 (0x7efd730fc591 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #11: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::UnaryOp const*) + 0x4b (0x7efd730fa18b in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #12: void torch::jit::fuser::cuda::Expr::constDispatch<torch::jit::fuser::cuda::OptOutConstDispatch*>(torch::jit::fuser::cuda::OptOutConstDispatch*, torch::jit::fuser::cuda::Expr const*) + 0x88 (0x7efd72e6ef98 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #13: torch::jit::fuser::cuda::OptOutConstDispatch::handle(torch::jit::fuser::cuda::Expr const*) + 0x1d (0x7efd72e6627d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #14: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::IfThenElse const*) + 0xd1 (0x7efd730fc3f1 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #15: void torch::jit::fuser::cuda::Expr::constDispatch<torch::jit::fuser::cuda::OptOutConstDispatch*>(torch::jit::fuser::cuda::OptOutConstDispatch*, torch::jit::fuser::cuda::Expr const*) + 0x5a8 (0x7efd72e6f4b8 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #16: torch::jit::fuser::cuda::OptOutConstDispatch::handle(torch::jit::fuser::cuda::Expr const*) + 0x1d (0x7efd72e6627d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #17: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xcb (0x7efd730fc2cb in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #18: void torch::jit::fuser::cuda::Expr::constDispatch<torch::jit::fuser::cuda::OptOutConstDispatch*>(torch::jit::fuser::cuda::OptOutConstDispatch*, torch::jit::fuser::cuda::Expr const*) + 0x577 (0x7efd72e6f487 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #19: torch::jit::fuser::cuda::OptOutConstDispatch::handle(torch::jit::fuser::cuda::Expr const*) + 0x1d (0x7efd72e6627d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #20: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xcb (0x7efd730fc2cb in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #21: void torch::jit::fuser::cuda::Expr::constDispatch<torch::jit::fuser::cuda::OptOutConstDispatch*>(torch::jit::fuser::cuda::OptOutConstDispatch*, torch::jit::fuser::cuda::Expr const*) + 0x577 (0x7efd72e6f487 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #22: torch::jit::fuser::cuda::OptOutConstDispatch::handle(torch::jit::fuser::cuda::Expr const*) + 0x1d (0x7efd72e6627d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #23: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xcb (0x7efd730fc2cb in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #24: void torch::jit::fuser::cuda::Expr::constDispatch<torch::jit::fuser::cuda::OptOutConstDispatch*>(torch::jit::fuser::cuda::OptOutConstDispatch*, torch::jit::fuser::cuda::Expr const*) + 0x577 (0x7efd72e6f487 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #25: torch::jit::fuser::cuda::OptOutConstDispatch::handle(torch::jit::fuser::cuda::Expr const*) + 0x1d (0x7efd72e6627d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #26: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xcb (0x7efd730fc2cb in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #27: void torch::jit::fuser::cuda::Expr::constDispatch<torch::jit::fuser::cuda::OptOutConstDispatch*>(torch::jit::fuser::cuda::OptOutConstDispatch*, torch::jit::fuser::cuda::Expr const*) + 0x577 (0x7efd72e6f487 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #28: torch::jit::fuser::cuda::OptOutConstDispatch::handle(torch::jit::fuser::cuda::Expr const*) + 0x1d (0x7efd72e6627d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #29: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::IfThenElse const*) + 0xd1 (0x7efd730fc3f1 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #30: void torch::jit::fuser::cuda::Expr::constDispatch<torch::jit::fuser::cuda::OptOutConstDispatch*>(torch::jit::fuser::cuda::OptOutConstDispatch*, torch::jit::fuser::cuda::Expr const*) + 0x5a8 (0x7efd72e6f4b8 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #31: torch::jit::fuser::cuda::OptOutConstDispatch::handle(torch::jit::fuser::cuda::Expr const*) + 0x1d (0x7efd72e6627d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #32: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xcb (0x7efd730fc2cb in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #33: void torch::jit::fuser::cuda::Expr::constDispatch<torch::jit::fuser::cuda::OptOutConstDispatch*>(torch::jit::fuser::cuda::OptOutConstDispatch*, torch::jit::fuser::cuda::Expr const*) + 0x577 (0x7efd72e6f487 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #34: torch::jit::fuser::cuda::OptOutConstDispatch::handle(torch::jit::fuser::cuda::Expr const*) + 0x1d (0x7efd72e6627d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #35: torch::jit::fuser::cuda::IndexLowering::generate(std::vector<torch::jit::fuser::cuda::Expr*, std::allocator<torch::jit::fuser::cuda::Expr*> > const&) + 0x7e (0x7efd7310179e in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #36: <unknown function> + 0x663266d (0x7efd7319066d in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #37: torch::jit::fuser::cuda::GpuLower::lower(torch::jit::fuser::cuda::Fusion*, torch::jit::fuser::cuda::DataType) + 0x8ef (0x7efd7318f2af in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #38: <unknown function> + 0x6326a94 (0x7efd72e84a94 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #39: <unknown function> + 0x6334220 (0x7efd72e92220 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #40: torch::jit::fuser::cuda::FusionExecutor::compileFusion(torch::jit::fuser::cuda::Fusion*, c10::ArrayRef<c10::IValue> const&, torch::jit::fuser::cuda::LaunchParams const&, torch::jit::fuser::cuda::CompileOptions) + 0x3e4 (0x7efd72e79ba4 in /home/gaoxiang/nvfuser3/build/lib/libtorch_cuda.so)
frame #41: torch::jit::NVFuserTest_FusionScheduleTransposeRepro1_CUDA_Test::TestBody() + 0x573 (0x562c6cc3fe53 in ./build/bin/test_jit)
frame #42: void testing::internal::HandleSehExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x7b (0x562c6cca2bab in ./build/bin/test_jit)
frame #43: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x7d (0x562c6cc8b8dd in ./build/bin/test_jit)
frame #44: testing::Test::Run() + 0xc3 (0x562c6cc63d83 in ./build/bin/test_jit)
frame #45: testing::TestInfo::Run() + 0x106 (0x562c6cc64c16 in ./build/bin/test_jit)
frame #46: testing::TestSuite::Run() + 0x111 (0x562c6cc654c1 in ./build/bin/test_jit)
frame #47: testing::internal::UnitTestImpl::RunAllTests() + 0x45b (0x562c6cc7768b in ./build/bin/test_jit)
frame #48: bool testing::internal::HandleSehExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x7b (0x562c6cca561b in ./build/bin/test_jit)
frame #49: bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x83 (0x562c6cc8e083 in ./build/bin/test_jit)
frame #50: testing::UnitTest::Run() + 0xd5 (0x562c6cc771c5 in ./build/bin/test_jit)
frame #51: <unknown function> + 0x3b9531 (0x562c6c43c531 in ./build/bin/test_jit)
frame #52: main + 0x226 (0x562c6c43c4d6 in ./build/bin/test_jit)
frame #53: <unknown function> + 0x232d0 (0x7efd5ca7a2d0 in /usr/lib/libc.so.6)
frame #54: __libc_start_main + 0x8a (0x7efd5ca7a38a in /usr/lib/libc.so.6)
frame #55: _start + 0x25 (0x562c6c43c105 in ./build/bin/test_jit)
" thrown in the test body.
[ FAILED ] NVFuserTest.FusionScheduleTransposeRepro1_CUDA (1107 ms)
Versions
TOT devel
Metadata
Metadata
Assignees
Labels
No labels