Skip to content

Failing testGPU_FusionSmemBlockGemm #253

@jjsjann123

Description

@jjsjann123

🐛 Bug

CI test testGPU_FusionSmemBlockGemm is reporting wrong result from codegen.

I looked at the scheduling & launch param, things seems to be consistent.
Forwarding Naoya's comment from slack

For example, here's the final loop nest writing into the output tensor:

    for(size_t i81 = 0; i81 < 16; ++i81 ) {
      if ( ( ( ( ( blockIdx.x * 16 ) + i81 ) < T5.size[0] ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
        T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
           = T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
           + T6[ ( i81 * 16 ) + threadIdx.x ];
      }
    }

Each thread writes into T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]. Notice that threadIdx.y is not used for indexing, although there seem to be more than one threads along TIDy. The thread block size is 16x16x1.

To Reproduce

script copied out:.

void testGPU_FusionSmemBlockGemm() {                                                     
  Fusion fusion;                                                                         
  FusionGuard fg(&fusion);                                                               
                                                                                         
  // Algorithm                                                                           
  TensorView* tv0 = makeDummyTensor(2); // (M, K)                                        
  TensorView* tv1 = makeDummyTensor(2); // (K, N)                                        
  TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)                   
  TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)                   
  TensorView* tv4 = mul(tv2, tv3); // M, K, N                                            
  TensorView* tv5 = sum(tv4, {1}); // M, R, N                                        
  fusion.addInput(tv0);                                                              
  fusion.addInput(tv1);                                                              
  fusion.addOutput(tv5);                                                             
                                                                                     
  // Schedule                                                                        
  constexpr int BSX = 16;                                                            
  tv5->split(2, BSX);                                                                
  tv5->split(1, BSX);                                                                
  tv5->split(0, BSX);                                                                
  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX                                              
  tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}});                    
  // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX                                              
  TensorView* tv6 = tv5->rFactor({-1});                                              
                                                                                     
  tv2->setMemoryType(MemoryType::Shared);                                            
  tv3->setMemoryType(MemoryType::Shared);                                            
  tv4->setMemoryType(MemoryType::Shared);                                            
  tv6->setMemoryType(MemoryType::Shared);                                            
                                                                                     
  tv0->computeAt(tv5, 3);                                                            
  tv1->computeAt(tv5, 3);                                                            
                                                                                     
  tv0->computeAt(tv6, 3);                                                            
  tv1->computeAt(tv6, 3);                                                            
                                                                                         
  // Thread and Block binding                                                        
  tv5->axis(0)->parallelize(ParallelType::BIDx);                                     
  tv5->axis(1)->parallelize(ParallelType::BIDy);                                     
  tv5->axis(-1)->parallelize(ParallelType::TIDx);                                    
  // Manual Binding                                                                  
  tv2->axis(-1)->parallelize(ParallelType::TIDx);                                    
  tv3->axis(-1)->parallelize(ParallelType::TIDx);                                    
  tv4->axis(-1)->parallelize(ParallelType::TIDx);                                    
  tv6->axis(-3)->parallelize(ParallelType::TIDy);                                    
  tv6->axis(-2)->parallelize(ParallelType::TIDx);                                    
                                                                                     
  constexpr int M = 154, K = 45, N = 1524;                                           
                                                                                     
  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);         
  at::Tensor t0 = at::randn({M, K}, options);                                        
  at::Tensor t1 = at::randn({K, N}, options);                                        
                                                                                     
  torch::jit::fuser::cuda::FusionExecutor fe;                                        
  fe.compileFusion(&fusion);                                                         
  auto outputs = fe.runFusion({t0, t1});                                             
                                                                                     
  at::Tensor aten_output = matmul(t0, t1);                                           
  TORCH_CHECK(                                                                       
      aten_output.allclose(outputs[0], 1e-5, 1e-5),                                      
      "Error of: ",                                                                      
      aten_output.sub(outputs[0]).abs().max());                                          
} 

output kernel looks like

__device__ void reduction_add_float(float& a, const float b) {
  a = a + b;
}
__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 2> T5){
  for(size_t i80 = 0; i80 < 16; ++i80 ) {
    if ( ( ( ( ( blockIdx.x * 16 ) + i80 ) < T5.size[0] ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
      T5[ ( ( ( blockIdx.x * 16 ) + i80 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
         = float(0);
    }
  }
  for(size_t i67 = 0; i67 < ( ceilDiv(T0.size[1], 16) ); ++i67 ) {
    __shared__ float T2[( 16 * 16 )];
    for(size_t i69 = 0; i69 < 16; ++i69 ) {
      if ( ( ( ( ( blockIdx.x * 16 ) + i69 ) < T5.size[0] ) && ( ( ( i67 * 16 ) + threadIdx.x ) < T0.size[1] ) ) ) {
        T2[ ( i69 * 16 ) + threadIdx.x ]
           = T0[ ( ( ( blockIdx.x * 16 ) + i69 ) * T0.stride[0] ) + ( ( ( i67 * 16 ) + threadIdx.x ) * T0.stride[1] ) ];
      }
    }
    __shared__ float T3[( 16 * 16 )];
    for(size_t i73 = 0; i73 < 16; ++i73 ) {
      if ( ( ( ( ( i67 * 16 ) + threadIdx.x ) < T1.size[0] ) && ( ( ( blockIdx.y * 16 ) + i73 ) < T5.size[1] ) ) ) {
        T3[ ( i73 * 16 ) + threadIdx.x ]
           = T1[ ( ( ( i67 * 16 ) + threadIdx.x ) * T1.stride[0] ) + ( ( ( blockIdx.y * 16 ) + i73 ) * T1.stride[1] ) ];
      }
    }
    __syncthreads();
    __shared__ float T4[( ( 16 * 16 ) * 16 )];
    for(size_t i76 = 0; i76 < 16; ++i76 ) {
      for(size_t i77 = 0; i77 < 16; ++i77 ) {
        if ( ( ( ( ( ( blockIdx.x * 16 ) + i76 ) < T5.size[0] ) && ( ( ( i67 * 16 ) + threadIdx.x ) < T0.size[1] ) ) && ( ( ( blockIdx.y * 16 ) + i77 ) < T5.size[1] ) ) ) {
          T4[ ( ( i76 * 16 ) * 16 ) + ( i77 * 16 ) + threadIdx.x ]
             = T2[ ( i76 * 16 ) + threadIdx.x ]
             * T3[ ( i77 * 16 ) + threadIdx.x ];
        }
      }
    }
    __syncthreads();
    __shared__ float T6[( 16 * 16 )];
    if ( ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T5.size[0] ) && ( ( ( i67 * 16 ) + 0 ) < T0.size[1] ) ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
      T6[ ( threadIdx.y * 16 ) + threadIdx.x ]
         = float(0);
    }
    for(size_t i79 = 0; i79 < 16; ++i79 ) {
      if ( ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T5.size[0] ) && ( ( ( i67 * 16 ) + i79 ) < T0.size[1] ) ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
        T6[ ( threadIdx.y * 16 ) + threadIdx.x ]
           = T6[ ( threadIdx.y * 16 ) + threadIdx.x ]
           + T4[ ( ( threadIdx.y * 16 ) * 16 ) + ( threadIdx.x * 16 ) + i79 ];
      }
    }
    __syncthreads();
    for(size_t i81 = 0; i81 < 16; ++i81 ) {
      if ( ( ( ( ( blockIdx.x * 16 ) + i81 ) < T5.size[0] ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
        T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
           = T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
           + T6[ ( i81 * 16 ) + threadIdx.x ];
      }
    }
  }
}

it launches on grid (10 96 1) block (16 16 1)

Additional context

commit that introduces this test & breakage: cbd6f67
Test is disabled in PR #252 , please re-enable it after fix.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions