forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Labels
Description
🐛 Bug
This test generates invalid code:
(https://github.com/naoyam/pytorch/blob/compute-at-bug/test/cpp/jit/test_gpu.cpp)
void testGPU_FusionComputeAtBug() {
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion->addInput(tv0);
TensorView* tv1 = mul(tv0, new Float(1));
TensorView* tv2 = add(tv0, new Float(2));
TensorView* tv3 = add(tv1, new Float(3));
TensorView* tv4 = add(tv1, new Float(4));
fusion->addOutput(tv2);
fusion->addOutput(tv3);
fusion->addOutput(tv4);
std::cout << "tv1->computeAt(tv3, 1)\n";
tv1->computeAt(tv3, -1);
fusion->printMath();
fusion->printKernel();
fusion->printMath();
prog.setDevice(0);
prog.grid(1);
prog.block(1);
torch::jit::fuser::cuda::compileKernel(&prog);
}
Here's the generated kernel:
__global__ void kernel(Tensor<float, 2> T0, Tensor<float, 2> T2, Tensor<float, 2> T3, Tensor<float, 2> T4){
for(size_t i36 = 0; i36 < T4.size[0]; ++i36 ) {
for(size_t i37 = 0; i37 < T4.size[1]; ++i37 ) {
float T1[1];
T1[ 0 ]
= T0[ ( i36 * T0.stride[0] ) + ( i37 * T0.stride[1] ) ]
* float(1);
T4[ ( i36 * T4.stride[0] ) + ( i37 * T4.stride[1] ) ]
= T1[ 0 ]
+ float(4);
}
}
for(size_t i39 = 0; i39 < T4.size[0]; ++i39 ) {
for(size_t i40 = 0; i40 < T4.size[1]; ++i40 ) {
T2[ ( i39 * T2.stride[0] ) + ( i40 * T2.stride[1] ) ]
= T0[ ( i39 * T0.stride[0] ) + ( i40 * T0.stride[1] ) ]
+ float(2);
}
}
for(size_t i41 = 0; i41 < T4.size[0]; ++i41 ) {
for(size_t i42 = 0; i42 < T4.size[1]; ++i42 ) {
T3[ ( i41 * T3.stride[0] ) + ( i42 * T3.stride[1] ) ]
= T1[ 0 ]
+ float(3);
}
}
}
Notice that the loop for T3
references T1
, but that's only defined in the first loop nest. The T3
loop should be actually placed in the same loop nest as T4
, but it's "blocked" by the T2
loop. We had a similar issue recently, which we fixed by sorting inputs of expressions when traversing fusion expressions (see #112). This issue is similar, but happens with output expressions.