Skip to content

Back To Back blockReduces produce wrong result #71

@kevinstephano

Description

@kevinstephano

🐛 Bug

The issue is back-to-back reduces write to an intermediate result that may not have the correct result.

To Reproduce

void reduction(int trials, int bidx, int bidy, int tidx, int tidy, int unroll, int elems) {
  torch::jit::fuser::cuda::CudaKernel prog;
  Fusion& fusion = *prog.fusion_;
  FusionGuard fg(&fusion);

  // Set up your input tensor views
  TensorView* input_tv0 = makeDummyTensor(3);
  fusion.addInput(input_tv0);

  TensorView* sum_val_tv1   = reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);

  sum_val_tv1->split(-1, tidy);
  sum_val_tv1->split(-2, tidx);
  sum_val_tv1->split(-3, unroll);

  TensorView* sum_val_rf_tv2 = sum_val_tv1->rFactor({-4});
  TensorView* sum_val_rf_tv3 = sum_val_tv1->rFactor({-3});
  TensorView* sum_val_rf_tv4 = sum_val_tv1->rFactor({-2});

  sum_val_rf_tv2->axis(0)->parallelize(ParallelType::BIDx);
  sum_val_rf_tv3->axis(0)->parallelize(ParallelType::BIDx);
  sum_val_rf_tv4->axis(0)->parallelize(ParallelType::BIDx);
  sum_val_tv1->axis(0)->parallelize(ParallelType::BIDx);

  sum_val_rf_tv2->axis(1)->parallelize(ParallelType::BIDy);
  sum_val_rf_tv3->axis(1)->parallelize(ParallelType::BIDy);
  sum_val_rf_tv4->axis(1)->parallelize(ParallelType::BIDy);
  sum_val_tv1->axis(1)->parallelize(ParallelType::BIDy);

  sum_val_rf_tv2->axis(-2)->parallelize(ParallelType::TIDx);
  sum_val_rf_tv3->axis(-2)->parallelize(ParallelType::TIDx);
  sum_val_rf_tv4->axis(-2)->parallelize(ParallelType::TIDx);

  sum_val_tv1->axis(-1)->parallelize(ParallelType::TIDy);
  sum_val_rf_tv2->axis(-1)->parallelize(ParallelType::TIDy);
  sum_val_rf_tv3->axis(-1)->parallelize(ParallelType::TIDy);
  sum_val_rf_tv4->axis(-1)->parallelize(ParallelType::TIDy);

  fusion.addOutput(sum_val_tv1);
__global__ void CUDAGeneratedKernel(Tensor<float, 3> T0, Tensor<float, 2> T1){
  T1[ ( blockIdx.x * T1.stride[0] ) + ( blockIdx.y * T1.stride[1] ) ]
     = float(0);
  float T4[1];
  T4[ 0 ]
     = float(0);
  float T3[1];
  T3[ 0 ]
     = float(0);
  float T2[4];
  for(size_t i37 = 0; i37 < 4; ++i37 ) {
    if ( ( ( ( ( ( ( ( 0 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) < T0.size[2] ) ) {
      T2[ i37 ]
         = float(0);
    }
  }
  for(size_t i38 = 0; i38 < ( ceilDiv(( ceilDiv(( ceilDiv(T0.size[2], 16) ), 32) ), 4) ); ++i38 ) {
    for(size_t i39 = 0; i39 < 4; ++i39 ) {
      if ( ( ( ( ( ( ( ( i38 * 4 ) + i39 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) < T0.size[2] ) ) {                                      
        T2[ i39 ]
           = T2[ i39 ]
           + T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( ( ( ( ( i38 * 4 ) + i39 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) * T0.stride[2] ) ];
      }
    }
  }
  for(size_t i41 = 0; i41 < 4; ++i41 ) {
    T3[ 0 ]
       = T3[ 0 ]
       + T2[ i41 ];
  }
  blockReduce< true, false, false > ( T4[ 0 ], T3[ 0 ], reduction_add_float);
  blockReduce< false, true, false > ( T1[ ( blockIdx.x * T1.stride[0] ) + ( blockIdx.y * T1.stride[1] ) ], T4[ 0 ], reduction_add_float);
}

Metadata

Metadata

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions