Skip to content

[Schedule Validation] Schedule is allowed to bind a thread to two different dim values #108

@kevinstephano

Description

@kevinstephano

🐛 Bug

I found a situation where a schedule was allowed to generate a kernel were TIDx was allowed to be bound to two different values. An error should have been generated. The question is how to check for this?

The scenario where I made this mistake was when I did an rFactor on a reduction dimension split and didn't apply the corresponding split to an operation found after the reduction.

To Reproduce

IR View of Operations

T5[ iblockIdx.x{gridDim.x}, rS{( ceilDiv(i3, 32) )}rf, ithreadIdx.x{32}rf ] = reduction( T0[ iS{i1}, iS{i3} ], op = add, initial value = float(0) )                                                                                                                  
T2[ iblockIdx.x{gridDim.x}, rthreadIdx.x{32} ] = reduction( T5[ iblockIdx.x{gridDim.x}, rS{( ceilDiv(i3, 32) )}rf, ithreadIdx.x{32}rf ], op = add, initial value = float(0) )                                                                                        
T3[ iblockIdx.x{gridDim.x}, bthreadIdx.x{1} ]
   = T2[ iblockIdx.x{gridDim.x}, rthreadIdx.x{32} ];
T4[ iblockIdx.x{gridDim.x}, ithreadIdx.x{blockDim.x} ]                                                           
   = T3[ iblockIdx.x{gridDim.x}, bthreadIdx.x{1} ]                                                               
   + T1[ iS{i5}, iS{i7} ];

Kernel Generated

__device__ void reduction_add_float(float& a, const float b) {                                                   
  a = a + b;                                                                                                     
}
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 2> T4){              
  __shared__ float shared_mem[1024];                                                                             
  float T3[1];                                                                                                   
  float T2[1];                                                                                                   
  T2[ 0 ]
     = float(0);                                                                                                 
  float T5[1];
  if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T0.size[1] ) ) {                                                         
    T5[ 0 ]
       = float(0);                                                                                               
  }
  for(size_t i30 = 0; i30 < ( ceilDiv(T0.size[1], 32) ); ++i30 ) {                                               
    if ( ( ( ( i30 * 32 ) + threadIdx.x ) < T0.size[1] ) ) {                                                     
      T5[ 0 ]
         = T5[ 0 ]
         + T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( i30 * 32 ) + threadIdx.x ) * T0.stride[1] ) ];              
    }                                                                                                            
  }
  blockReduce< true, false, false > ( T2[ 0 ], T5[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
  if ( ( threadIdx.x == 0 ) ) {                                                                                  
    T3[ 0 ]                                                                                                      
       = T2[ 0 ];                                                                                                
  }                                                                                                              
  T4[ ( blockIdx.x * T4.stride[0] ) + ( threadIdx.x * T4.stride[1] ) ]                                           
     = T3[ 0 ]                                                                                                   
     + T1[ ( blockIdx.x * T1.stride[0] ) + ( threadIdx.x * T1.stride[1] ) ];                                     
}                                  

Test

void testGPU_FusionThreadBindingError() {                                 
  torch::jit::fuser::cuda::CudaKernel prog;
  Fusion& fusion = *prog.fusion_;                                           
  FusionGuard fg(&fusion);                                 
                                                       
  TensorView* tv0 = makeDummyTensor(2);           
  TensorView* tv1 = makeDummyTensor(2);          
  fusion.addInput(tv0);                                                          
  fusion.addInput(tv1);                          
                                                 
  // TV0 original [ 32 , 128 ]                   
  // TV5 [ 32, 32 ] <- rFactor TV2[ 32, 128] -> [32, 4, 32] (bound to tidx(-1)) 
  // TV2 [ 32 ] Final Reduce (bound to tidx(-1))                           
  // TV3 [ 32, 128 ] Broadcast (bound to tidx(-1))                
  // TV4 [ 32, 128 ] Add (bound to tidx(-1))
                                                                              
  TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); 
  TensorView* tv3 = broadcast(tv2, {false, true});                                             
  TensorView* tv4 = add(tv3, tv1);                                          
                                                                 
  tv2->split(-1, 32);                                        
  TensorView* tv5 = tv2->rFactor({-2});                               
                                                           
  tv2->axis(0)->parallelize(ParallelType::BIDx);                            
  tv3->axis(0)->parallelize(ParallelType::BIDx);                              
  tv4->axis(0)->parallelize(ParallelType::BIDx);       
  tv5->axis(0)->parallelize(ParallelType::BIDx);       
                                                              
  tv2->axis(-1)->parallelize(ParallelType::TIDx);                     
  tv3->axis(-1)->parallelize(ParallelType::TIDx);                          
  tv4->axis(-1)->parallelize(ParallelType::TIDx);
  tv5->axis(-1)->parallelize(ParallelType::TIDx);     
                                                                              
  fusion.addOutput(tv4);                        
                                                                  
  fusion.printMath();                           
                                           
  GPULower gpulw(&fusion);                        
  gpulw.printKernel(std::cout);                  
                                                                            
  prog.device_ = 0;                                                        
  prog.grid(32);                                                          
  prog.block(32);                      
  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);  
  at::Tensor t0 = at::randn({32, 128}, options);                      
  at::Tensor t1 = at::randn({32, 128}, options);                                 
  at::Tensor cg_output = at::empty({32, 128}, options);          
  torch::jit::fuser::cuda::compileKernel(&prog);
  torch::jit::fuser::cuda::runTestKernel(&prog, {t0,t1}, {cg_output});
}

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions