Skip to content

GridReductions with a following op are failing due to Predication #244

@kevinstephano

Description

@kevinstephano

🐛 Bug

GridReductions with an operation following the reduction are failing because the thread flag to write the result is improperly scoped.

  if ( ( ( ( blockIdx.x * 32 ) + threadIdx.x ) < T3.size[0] ) ) {                                                      
    float block_result;                                                                                                
    blockReduce< false, true, false > ( block_result, T4[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
    // Allocate global tensor float T7[( ( ( ceilDiv(T3.size[0], 32) ) * 32 ) * 16 )];                                 
    // Allocate global tensor int64_t T8[( ceilDiv(T3.size[0], 32) )];                                                 
    bool T2pred = reduction::gridReduce< false, true, false, true, false, true > ( T2[ 0 ], block_result, reduction_add_float, &T7[0], T8, reinterpret_cast<float*>(shared_mem));                                                                                  
  }                                                                                                                    
  if ( ( ( ( ( blockIdx.x * 32 ) + threadIdx.x ) < T3.size[0] ) && ( T2pred && ( threadIdx.y == 0 ) ) ) ) {            
    T3[ ( ( blockIdx.x * 32 ) + threadIdx.x ) ]                                                                        
       = __float2half(T2[ 0 ]);
  }                               

Error Message :

CUDA NVRTC compile error: default_program(611): error: identifier "T2pred" is undefined

To Reproduce

void testGPU_FusionReductionSchedulerDimShmoo() {
  std::vector<bool> fp16_usage = {true};
  std::vector<int> red_axis = {0};        
  std::vector<int> output_dims = {320};
  std::vector<int> red_dims = {4096};             
                                                                                                                       
  //for(int i = 1; i <= 1024*1024; i <<= 1) {      
  //  red_dims.push_back(i);                               
  //}                                             
                                                            
  for(auto fp16 : fp16_usage) {                             
    for(auto &axis : red_axis) {                                            
      for(auto &odim : output_dims) {                                                                                  
        for(auto &rdim : red_dims) {                                                                                   
          std::cout << fp16 << " " << axis << " " << odim << " " << rdim << std::endl;                                 
          Fusion fusion;                     
          FusionGuard fg(&fusion);           
                                                                     
          TensorView* tv0 = makeDummyTensor(2, (fp16 ? DataType::Half : DataType::Float));                             
          fusion.addInput(tv0);                                                                                        
                                                                            
          torch::jit::fuser::Val* tv0_cast = nullptr;       
          if (fp16) {                                                                                                  
            tv0_cast = castOp(DataType::Float, tv0);        
          }                                                            
                                                         
          TensorView* tv1 = reductionOp(BinaryOpType::Add, {axis}, new Float(0), (fp16 ? tv0_cast->as<TensorView>() : tv0));
                                  
          TensorView* tv1_cast = nullptr;                                                                              
          if (fp16) {                                                         
            tv1_cast = castOp(DataType::Half, tv1);                                                                    
          }                                                                                                            
                                                         
          fusion.addOutput((fp16 ? tv1_cast : tv1));                                                                   
                                                                                                                       
          auto options = at::TensorOptions().dtype((fp16 ? at::kHalf : at::kFloat)).device(at::kCUDA, 0);
          at::Tensor input = (axis ? at::rand({odim, rdim}, options) : at::rand({rdim, odim}, options));
                                                                  
          const at::ArrayRef<c10::IValue> inputs({input});        
                                                              
          c10::optional<cuda::ReductionParams> rparams = cuda::scheduleReduction(&fusion, inputs, tv1);
          TORCH_CHECK(rparams != c10::nullopt, "Reduction is not found!");                                             
          if(fp16) {                                                        
            if (axis == 0 ) {                   
              int tidx = rparams.value().bdimx.value;                       
              tv1_cast->split(-1, tidx);                                                                               
              tv1_cast->axis(-1)->parallelize(ParallelType::TIDx);
              tv1_cast->axis(-2)->parallelize(ParallelType::BIDx);                                                     
            } else {                                                                                                   
              if (rparams.value().mul_reds_per_blk) { 
                int tidy = rparams.value().bdimy.value;                                                                
                tv1_cast->split(0, tidy);                                     
                tv1_cast->axis(-1)->parallelize(ParallelType::TIDy);                                                   
              }                                                             
              tv1_cast->axis(0)->parallelize(ParallelType::BIDx);
            }                                                 
          }                                                              
                                                                                                                       
          torch::jit::fuser::cuda::FusionExecutor fe;
          fe.compileFusion(&fusion);            
                                                   
          auto cg_output = fe.runFusion({input});   
          auto aten_output = input.sum({axis});       
                                                
          TORCH_CHECK(                                                      
              aten_output.allclose(cg_output[0]),          
              "Error of: ",                                
              aten_output.sub(cg_output[0]).abs().max());   
        }                                        
      }                                      
    }                                      
  }                                                                         
}                       

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions