forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Labels
Description
🐛 Bug
The issue is back-to-back reduces write to an intermediate result that may not have the correct result.
To Reproduce
void reduction(int trials, int bidx, int bidy, int tidx, int tidy, int unroll, int elems) {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* input_tv0 = makeDummyTensor(3);
fusion.addInput(input_tv0);
TensorView* sum_val_tv1 = reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
sum_val_tv1->split(-1, tidy);
sum_val_tv1->split(-2, tidx);
sum_val_tv1->split(-3, unroll);
TensorView* sum_val_rf_tv2 = sum_val_tv1->rFactor({-4});
TensorView* sum_val_rf_tv3 = sum_val_tv1->rFactor({-3});
TensorView* sum_val_rf_tv4 = sum_val_tv1->rFactor({-2});
sum_val_rf_tv2->axis(0)->parallelize(ParallelType::BIDx);
sum_val_rf_tv3->axis(0)->parallelize(ParallelType::BIDx);
sum_val_rf_tv4->axis(0)->parallelize(ParallelType::BIDx);
sum_val_tv1->axis(0)->parallelize(ParallelType::BIDx);
sum_val_rf_tv2->axis(1)->parallelize(ParallelType::BIDy);
sum_val_rf_tv3->axis(1)->parallelize(ParallelType::BIDy);
sum_val_rf_tv4->axis(1)->parallelize(ParallelType::BIDy);
sum_val_tv1->axis(1)->parallelize(ParallelType::BIDy);
sum_val_rf_tv2->axis(-2)->parallelize(ParallelType::TIDx);
sum_val_rf_tv3->axis(-2)->parallelize(ParallelType::TIDx);
sum_val_rf_tv4->axis(-2)->parallelize(ParallelType::TIDx);
sum_val_tv1->axis(-1)->parallelize(ParallelType::TIDy);
sum_val_rf_tv2->axis(-1)->parallelize(ParallelType::TIDy);
sum_val_rf_tv3->axis(-1)->parallelize(ParallelType::TIDy);
sum_val_rf_tv4->axis(-1)->parallelize(ParallelType::TIDy);
fusion.addOutput(sum_val_tv1);
__global__ void CUDAGeneratedKernel(Tensor<float, 3> T0, Tensor<float, 2> T1){
T1[ ( blockIdx.x * T1.stride[0] ) + ( blockIdx.y * T1.stride[1] ) ]
= float(0);
float T4[1];
T4[ 0 ]
= float(0);
float T3[1];
T3[ 0 ]
= float(0);
float T2[4];
for(size_t i37 = 0; i37 < 4; ++i37 ) {
if ( ( ( ( ( ( ( ( 0 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) < T0.size[2] ) ) {
T2[ i37 ]
= float(0);
}
}
for(size_t i38 = 0; i38 < ( ceilDiv(( ceilDiv(( ceilDiv(T0.size[2], 16) ), 32) ), 4) ); ++i38 ) {
for(size_t i39 = 0; i39 < 4; ++i39 ) {
if ( ( ( ( ( ( ( ( i38 * 4 ) + i39 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) < T0.size[2] ) ) {
T2[ i39 ]
= T2[ i39 ]
+ T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( ( ( ( ( i38 * 4 ) + i39 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) * T0.stride[2] ) ];
}
}
}
for(size_t i41 = 0; i41 < 4; ++i41 ) {
T3[ 0 ]
= T3[ 0 ]
+ T2[ i41 ];
}
blockReduce< true, false, false > ( T4[ 0 ], T3[ 0 ], reduction_add_float);
blockReduce< false, true, false > ( T1[ ( blockIdx.x * T1.stride[0] ) + ( blockIdx.y * T1.stride[1] ) ], T4[ 0 ], reduction_add_float);
}