Skip to content

[broadcast] Using an explicit split vs a computeAt(-1) of split broadcast dim fails #115

@kevinstephano

Description

@kevinstephano

🐛 Bug

The scenario is as follows:

TV2[X, >>Y<<] = reductionOp
TV3[X, <<Y>>] = broadcast Op(TV2, {false, true})

TV4[X, Yout, Yin] = add(TV3, TV4) (this is generated) by a split.

When I use TV3->computeAt(TV4, -1) this works!

If I do TV3->split(-1, 32) and appropriately bind blocks and threads to TV3, I get a failure because the intermediate TV3 attempts to create a dynamic intermediary even though for broadcast I only need 1 element.

To Reproduce

Bad Schedule:

void testGPU_FusionBad() {
  torch::jit::fuser::cuda::CudaKernel prog;                                              
  Fusion& fusion = *prog.fusion_;
  FusionGuard fg(&fusion);

  // Set up your input tensor views
  TensorView* input_tv0 = makeDummyTensor(3);
  TensorView* input_tv1 = makeDummyTensor(3);                                            
  fusion.addInput(input_tv0);
  fusion.addInput(input_tv1);                                                            

  TensorView* sum_tv2 =
      reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);                      
  TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});                      
  TensorView* output_tv4 = div(input_tv1, bcast_tv3);

  sum_tv2->split(-1, 32);                                                                
  TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});

  bcast_tv3->split(-1, 32);                                                              
  output_tv4->split(-1, 32);

  sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);                                  
  sum_tv2->axis(0)->parallelize(ParallelType::BIDx);                                     
  bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);                                   
  output_tv4->axis(0)->parallelize(ParallelType::BIDx);                                  

  sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);                                  
  sum_tv2->axis(1)->parallelize(ParallelType::BIDy);                                     
  bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);                                   
  output_tv4->axis(1)->parallelize(ParallelType::BIDy);                                  

  sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);                                 
  sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);                                    
  bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);                                  
  output_tv4->axis(-1)->parallelize(ParallelType::TIDx);                                 
                                                                                         
  fusion.addOutput(output_tv4);

  fusion.printMath();

  prog.device_ = 0;                                                                      
  prog.grid(32, 32);
  prog.block(32);                                                                        
  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);             
  at::Tensor t0 = at::randn({32, 32, 128}, options);                                     
  at::Tensor t1 = at::randn({32, 32, 128}, options);                                     
  at::Tensor cg_output = at::empty({32, 32, 128}, options);                              
  torch::jit::fuser::cuda::compileKernel(&prog);
  torch::jit::fuser::cuda::runTestKernel(&prog, {t0,t1}, {cg_output});                   
} 

Error:

CUDA NVRTC compile error: default_program(505): error: function call must have a constant value in a constant expression

Algo Exprs:

T5[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, rS{( ceilDiv(i5, 32) )}rf, ithreadIdx.x{32}rf ] = reduction( T0[ iS{i1}, iS{i3}, iS{i5} ], op = add, initial value = float(0) )
T2[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, rthreadIdx.x{32} ] = reduction( T5[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, rS{( ceilDiv(i5, 32) )}rf, ithreadIdx.x{32}rf ], op = add, initial value = float(0) )
T3[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, bS{( ceilDiv(1, 32) )}, bthreadIdx.x{32} ]        
   = T2[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, rthreadIdx.x{32} ];                           
T4[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, iS{( ceilDiv(i11, 32) )}, ithreadIdx.x{32} ]
   = T1[ iS{i7}, iS{i9}, iS{i11} ]                                                                     
   / T3[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, bS{( ceilDiv(1, 32) )}, bthreadIdx.x{32} ];  

Kernel:

__global__ void kernel(Tensor<float, 3> T0, Tensor<float, 3> T1, Tensor<float, 3> T4){
  __shared__ float shared_mem[1024];
  float T3[( ceilDiv(1, 32) )];
  float T2[1];
  T2[ 0 ]
     = float(0);
  float T5[1];
  if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T0.size[2] ) ) {
    T5[ 0 ]
       = float(0);
  }
  for(size_t i48 = 0; i48 < ( ceilDiv(T0.size[2], 32) ); ++i48 ) {
    if ( ( ( ( i48 * 32 ) + threadIdx.x ) < T0.size[2] ) ) {
      T5[ 0 ]
         = T5[ 0 ]
         + T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( i48 * 32 ) + threadIdx.x ) * T0.stride[2] ) ];
    }
  }
  blockReduce< true, false, false > ( T2[ 0 ], T5[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
  if ( ( threadIdx.x == 0 ) ) {
    T3[ 0 ]
       = T2[ 0 ];
  }
  for(size_t i51 = 0; i51 < ( ceilDiv(T4.size[2], 32) ); ++i51 ) {
    if ( ( ( ( i51 * 32 ) + threadIdx.x ) < T4.size[2] ) ) {
      T4[ ( blockIdx.x * T4.stride[0] ) + ( blockIdx.y * T4.stride[1] ) + ( ( ( i51 * 32 ) + threadIdx.x ) * T4.stride[2] ) ]
         = T1[ ( blockIdx.x * T1.stride[0] ) + ( blockIdx.y * T1.stride[1] ) + ( ( ( i51 * 32 ) + threadIdx.x ) * T1.stride[2] ) ]
         / T3[ 0 ];
    }
  }
}

Good schedule for comparison:

void testGPU_FusionGood() {
  torch::jit::fuser::cuda::CudaKernel prog;
  Fusion& fusion = *prog.fusion_;
  FusionGuard fg(&fusion);

  // Set up your input tensor views
  TensorView* input_tv0 = makeDummyTensor(3);
  TensorView* input_tv1 = makeDummyTensor(3);
  fusion.addInput(input_tv0);
  fusion.addInput(input_tv1);

  TensorView* sum_tv2 =
      reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
  TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
  TensorView* output_tv4 = div(input_tv1, bcast_tv3);

  sum_tv2->split(-1, 32);
  TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});

  output_tv4->split(-1, 32);
  bcast_tv3->computeAt(output_tv4, {-1});

  sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
  sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
  output_tv4->axis(0)->parallelize(ParallelType::BIDx);

  sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
  sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
  output_tv4->axis(1)->parallelize(ParallelType::BIDy);

  sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
  sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
  output_tv4->axis(-1)->parallelize(ParallelType::TIDx);

  fusion.addOutput(output_tv4);

  fusion.printMath();
  
  prog.device_ = 0;
  prog.grid(32, 32);
  prog.block(32);
  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({32, 32, 128}, options);
  at::Tensor t1 = at::randn({32, 32, 128}, options);
  at::Tensor cg_output = at::empty({32, 32, 128}, options);
  torch::jit::fuser::cuda::compileKernel(&prog);
  torch::jit::fuser::cuda::runTestKernel(&prog, {t0,t1}, {cg_output});
} 

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions