Skip to content

Bug with rFactored reductions, an op fused to the rFactor, and an Unroll of part of the rFactor #176

@kevinstephano

Description

@kevinstephano

🐛 Bug

In the presence of a Reduction and an Operation like an Add that is fused in an rFactor loop of the reduction in conjunction with an unroll. The Tensor from the Operation that is supposed to be fused is not allocated the size of the unroll.

I might suspect that when the Fused Op computeAt() is applied, it is not picking up the Unroll.

I will note I also show a similar situation where the Unroll is properly applied between two operations where the second operation is not a Reduction.

This is the important part:

  for(size_t i36 = 0; i36 < ( ceilDiv(( ceilDiv(T0.size[1], 32) ), 4) ); ++i36 ) {
    float T1[1];
    if ( ( ( ( ( ( i36 * 4 ) + ( 4 - 1 ) ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
      for(size_t i37 = 0; i37 < 4; ++i37 ) {     
        T1[ i37 ]
           = T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]
           + float(0);                           
        T3[ 0 ]
           = T3[ 0 ]                             
           + T1[ i37 ];
      }

Full kernel generated:

__global__ void kernel(Tensor<float, 2> T0, Tensor<float, 1> T2){
  __shared__ float shared_mem[1024];
  T2[ ( blockIdx.x * T2.stride[0] ) ]            
     = float(0);
  float T3[1];
  if ( ( ( ( ( ( 0 * 4 ) + 0 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
    T3[ 0 ]
       = float(0);                               
  }
  for(size_t i36 = 0; i36 < ( ceilDiv(( ceilDiv(T0.size[1], 32) ), 4) ); ++i36 ) {
    float T1[1];
    if ( ( ( ( ( ( i36 * 4 ) + ( 4 - 1 ) ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
      for(size_t i37 = 0; i37 < 4; ++i37 ) {     
        T1[ i37 ]
           = T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]
           + float(0);                           
        T3[ 0 ]
           = T3[ 0 ]                             
           + T1[ i37 ];
      }
    } else {
      for(size_t i37 = 0; i37 < 4; ++i37 ) {     
        if ( ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
          T1[ i37 ]                              
             = T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]
             + float(0);                         
        }                                        
        if ( ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
          T3[ 0 ]
             = T3[ 0 ]
             + T1[ i37 ];                        
        }                                        
      }                                          
    }
  }
  blockReduce< true, false, false > ( T2[ ( blockIdx.x * T2.stride[0] ) ], T3[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
}

To Reproduce

void testGPU_FusionUnrollBug2() {                
  const std::vector<int64_t> tensor_dims_in = {128, 128};
  torch::jit::fuser::cuda::CudaKernel prog;      
  prog.setFusionPtr(std::make_unique<Fusion>()); 
  Fusion* fusion = prog.fusion();                                        
  FusionGuard fg(fusion);                       
                                                 
                                
  // Set up your input tensor views              
  TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
  fusion->addInput(tv0);                                                    
                                             
  TensorView* tv1 = add(tv0, new Float(0));  
  TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv1);
  fusion->addOutput(tv2);                    
                                                    
  const auto options =                                                           
      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input = at::rand(tensor_dims_in, options);
  at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options);
  
  //const at::ArrayRef<c10::IValue> inputs({input});
  
  // Schedule
  tv2->split(1, 32);
  tv2->split(1, 4); // unroll
  
  auto tv2_rf = tv2->rFactor({-3, -2});
  
  tv2->axis(0)->parallelize(ParallelType::BIDx);
  tv2->axis(-1)->parallelize(ParallelType::TIDx);
  
  tv2_rf->axis(0)->parallelize(ParallelType::BIDx);
  tv2_rf->axis(-1)->parallelize(ParallelType::TIDx);
  tv2_rf->axis(-2)->parallelize(ParallelType::Unroll);
  
  tv1->computeAt(tv2_rf, -1);
  
  prog.setDevice(0);
  fusion->setLaunchConfig(LaunchConfigType::TIDx, new Int(tensor_dims_in[0]));
  fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(1));
  fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1));
  fusion->setLaunchConfig(LaunchConfigType::BIDx, new Int(32));
  fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1));
  fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1));
  fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0));
  fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1));
  
  torch::jit::fuser::cuda::compileKernel(&prog);
  torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}, c10::nullopt);
} 

Expected behavior

void testGPU_FusionUnrollBug() {                                                                                                                                                                                                                                                                              [406/1910]
  const std::vector<int64_t> tensor_dims_in = {128, 128};
  torch::jit::fuser::cuda::CudaKernel prog;         
  prog.setFusionPtr(std::make_unique<Fusion>());
  Fusion* fusion = prog.fusion();                                           
  FusionGuard fg(fusion);                       
                                                    
                                                    
  // Set up your input tensor views                                
  TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());                     
  fusion->addInput(tv0);                            
                                                          
  TensorView* tv1 = add(tv0, new Float(0));       
  TensorView* tv2 = add(tv1, new Float(0));       
  fusion->addOutput(tv2);                                      
                                              
  const auto options =                                                        
      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input = at::rand(tensor_dims_in, options);       
  at::Tensor cg_output = at::empty(tensor_dims_in, options);   
                                                              
  //const at::ArrayRef<c10::IValue> inputs({input});          
                                                                      
  // Schedule                                                       
  tv2->split(1, 32);                                        
  tv2->split(1, 4); // unroll                               
                                                                                
  tv2->axis(0)->parallelize(ParallelType::BIDx); 
  tv2->axis(-1)->parallelize(ParallelType::TIDx);
  tv2->axis(-2)->parallelize(ParallelType::Unroll);
                                           
  tv1->computeAt(tv2, -1);                 
                                                
  prog.setDevice(0);                       
  fusion->setLaunchConfig(LaunchConfigType::TIDx, new Int(tensor_dims_in[0]));
  fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(1));
  fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1));
  fusion->setLaunchConfig(LaunchConfigType::BIDx, new Int(32));
  fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1));
  fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1));          
  fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0));
  fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1));
                           
  torch::jit::fuser::cuda::compileKernel(&prog);
  torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}, c10::nullopt);
}  

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions