Skip to content

implicit thread binding propagation is inconsistent between kernel generation and intermediate allocation.  #69

@jjsjann123

Description

@jjsjann123

🐛 Bug

This might as well not be a bug but just my misunderstanding is totally off. (This has happened already too many times). But discussion might still be educational (to me at least).

I found that computeAt propagates bindings from layers past the nested source to input.

This is the simplified description to illustrate the idea.

  • Data flow T0 -> T2 -> T1,
  • T2 is the intermediate generated from T1->rFactor
  • I specify T0->computeAt(T2) and T1->axis(0)->parallelize(ParallelType::BIDx);

The strange behavior here is:
codegen seems to be able to implicitly propagate the BIDx binding to axis(0) all the way back to T0, even though it is not explicitly specified via computeAt;
However, BIDx binding to axis(0) is not considered for T2 during allocation. It complains about non-constant allocation.

  • Note that, here we are not specifying T2->computeAt(T1), which if added would have explicit thread binding propagated from T1 back to T2 and T0, and it would resolve the issue.

I think from the observed behavior, we probably have an assert that could be relaxed on conditions.

To Reproduce

void testGPU_FusionReductionJ() {                                                                                       
  torch::jit::fuser::cuda::CudaKernel prog;                                                                             
  Fusion& fusion = *prog.fusion_;                                                                                       
  FusionGuard fg(&fusion);                                                                                              
                                                                                                                        
  // Set up your input tensor views                                                                                     
  TensorView* tv0 = makeDummyTensor(2);                                                                                 
  fusion.addInput(tv0);                                                                                                 
                                                                                                                        
  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);                                             
  fusion.addOutput(tv1);                                                                                                
                                                                                                                        
  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");                                          
  tv1->split(1, 128);                                                                                                   
  auto tv2 = tv1->rFactor({1});                                                                                         
  std::cout << "fusion 1: \n" << fusion << std::endl;                                                                   
                                                                                                                        
  tv0->computeAt(tv2, 2);                                                                                               
  std::cout << "fusion 2: \n" << fusion << std::endl;                                                                   
  tv1->axis(0)->parallelize(ParallelType::BIDx);                                                                        
  tv1->axis(-1)->parallelize(ParallelType::TIDx);                                                                       
                                                                                                                        
  tv0->axis(-1)->parallelize(ParallelType::TIDx);                                                                       
  tv2->axis(-1)->parallelize(ParallelType::TIDx);                                                                       

  // If we can figure out the thread binding for tv0 in the generated kernel without specifying it here, we should be able to do the same thing for the allocation of tv2 as well. 
  //tv0->axis(0)->parallelize(ParallelType::BIDx);                                                                                                                                                                   
  tv2->axis(0)->parallelize(ParallelType::BIDx);                                                                        
  std::cout << "fusion 3: \n" << fusion << std::endl;                                                                   
                                                                                                                        
  GPULower gpulw(&fusion);                                                                                              
  std::stringstream cdg;                                                                                                
  gpulw.printKernel(cdg);                                                                                               
  std::cout << cdg.str() << std::endl;                                                                                  
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions