Skip to content

Invalid code generation #112

@naoyam

Description

@naoyam

🐛 Bug

The testGPU_FusionSoftmax test in this branch generates invalid code:

https://github.com/naoyam/pytorch/tree/codegen-bug

Here's the generated kernel:

__global__ void CUDAGeneratedKernel(Tensor<float, 3> T0, Tensor<float, 3> T9){
  __shared__ float shared_mem[1024];
  float T6[1];
  float T5[1];
  T5[ 0 ]
     = float(0);
  float T11[1];
  if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
    T11[ 0 ]
       = float(0);
  }
  float T2[1];
  float T1[1];
  T1[ 0 ]
     = float(0);
  float T10[1];
  if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
    T10[ 0 ]
       = float(0);
  }
  for(size_t i84 = 0; i84 < ( ceilDiv(T9.size[2], 32) ); ++i84 ) {
    if ( ( ( ( i84 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
      T10[ 0 ]
         = fmaxf(T10[ 0 ]
        , T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( i84 * 32 ) + threadIdx.x ) * T0.stride[2] ) ]);
    }
  }
  blockReduce< true, false, false > ( T1[ 0 ], T10[ 0 ], reduction_fmaxf_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
  if ( ( threadIdx.x == 0 ) ) {
    T2[ 0 ]
       = T1[ 0 ];
  }
  for(size_t i87 = 0; i87 < ( ceilDiv(T9.size[2], 32) ); ++i87 ) {
    float T7[1];
    if ( ( ( ( i87 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
      T7[ 0 ]
         = T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( i87 * 32 ) + threadIdx.x ) * T0.stride[2] ) ]
         - T2[ 0 ];
    }
    float T8[1];
    if ( ( ( ( i87 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
      T8[ 0 ]
         = expf(T7[ 0 ]);
    }
    float T3[1];
    if ( ( ( ( i87 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
      T3[ 0 ]
         = T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( i87 * 32 ) + threadIdx.x ) * T0.stride[2] ) ]
         - T2[ 0 ];
    }
    float T4[1];
    if ( ( ( ( i87 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
      T4[ 0 ]
         = expf(T3[ 0 ]);
    }
  }
  for(size_t i93 = 0; i93 < ( ceilDiv(T9.size[2], 32) ); ++i93 ) {
    if ( ( ( ( i93 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
      T11[ 0 ]
         = T11[ 0 ]
         + T4[ 0 ];
    }
  }
  blockReduce< true, false, false > ( T5[ 0 ], T11[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
  if ( ( threadIdx.x == 0 ) ) {
    T6[ 0 ]
       = T5[ 0 ];
  }
  for(size_t i96 = 0; i96 < ( ceilDiv(T9.size[2], 32) ); ++i96 ) {
    if ( ( ( ( i96 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
      T9[ ( blockIdx.x * T9.stride[0] ) + ( blockIdx.y * T9.stride[1] ) + ( ( ( i96 * 32 ) + threadIdx.x ) * T9.stride[2] ) ]
         = T8[ 0 ]
         / T6[ 0 ];
    }
  }
}

T8 is used in the final loop nest but is defined and computed in a different loop nest.

The final loop nest corresponds to this expression:

TensorView* output_tv6 = div(exp_tv4_2, bcast_sum_tv6);

Interestingly, swapping the two operands, i.e., div(bcast_sum_tv6, exp_tv4_2), seems to be fine. See https://github.com/naoyam/pytorch/blob/d2913c9fb2c7e94c786ce3b73a6b5ab87b4eee3b/test/cpp/jit/test_gpu.cpp#L2798. The fuser generates code as expected, although they are not the same computation anymore.

Note that this issue was encountered in the block broadcast PR (#100) but shows up with the dev branch as well.

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions