Skip to content

Invalid code generation with computeAt #164

@naoyam

Description

@naoyam

🐛 Bug

This test generates invalid code:
(https://github.com/naoyam/pytorch/blob/compute-at-bug/test/cpp/jit/test_gpu.cpp)

void testGPU_FusionComputeAtBug() {
  torch::jit::fuser::cuda::CudaKernel prog;
  prog.setFusionPtr(std::make_unique<Fusion>());
  Fusion* fusion = prog.fusion();
  FusionGuard fg(fusion);

  // Set up your input tensor views
  TensorView* tv0 = makeDummyTensor(2);
  fusion->addInput(tv0);

  TensorView* tv1 = mul(tv0, new Float(1));
  TensorView* tv2 = add(tv0, new Float(2));
  TensorView* tv3 = add(tv1, new Float(3));
  TensorView* tv4 = add(tv1, new Float(4));

  fusion->addOutput(tv2);
  fusion->addOutput(tv3);
  fusion->addOutput(tv4);

  std::cout << "tv1->computeAt(tv3, 1)\n";
  tv1->computeAt(tv3, -1);
  fusion->printMath();
  fusion->printKernel();
  fusion->printMath();

  prog.setDevice(0);
  prog.grid(1);
  prog.block(1);
  torch::jit::fuser::cuda::compileKernel(&prog);
}

Here's the generated kernel:

__global__ void kernel(Tensor<float, 2> T0, Tensor<float, 2> T2, Tensor<float, 2> T3, Tensor<float, 2> T4){
  for(size_t i36 = 0; i36 < T4.size[0]; ++i36 ) {
    for(size_t i37 = 0; i37 < T4.size[1]; ++i37 ) {
      float T1[1];
      T1[ 0 ]
         = T0[ ( i36 * T0.stride[0] ) + ( i37 * T0.stride[1] ) ]
         * float(1);
      T4[ ( i36 * T4.stride[0] ) + ( i37 * T4.stride[1] ) ]
         = T1[ 0 ]
         + float(4);
    }
  }
  for(size_t i39 = 0; i39 < T4.size[0]; ++i39 ) {
    for(size_t i40 = 0; i40 < T4.size[1]; ++i40 ) {
      T2[ ( i39 * T2.stride[0] ) + ( i40 * T2.stride[1] ) ]
         = T0[ ( i39 * T0.stride[0] ) + ( i40 * T0.stride[1] ) ]
         + float(2);
    }
  }
  for(size_t i41 = 0; i41 < T4.size[0]; ++i41 ) {
    for(size_t i42 = 0; i42 < T4.size[1]; ++i42 ) {
      T3[ ( i41 * T3.stride[0] ) + ( i42 * T3.stride[1] ) ]
         = T1[ 0 ]
         + float(3);
    }
  }
}

Notice that the loop for T3 references T1, but that's only defined in the first loop nest. The T3 loop should be actually placed in the same loop nest as T4, but it's "blocked" by the T2 loop. We had a similar issue recently, which we fixed by sorting inputs of expressions when traversing fusion expressions (see #112). This issue is similar, but happens with output expressions.

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions