forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Labels
Description
🐛 Bug
I found a situation where a schedule was allowed to generate a kernel were TIDx
was allowed to be bound to two different values. An error should have been generated. The question is how to check for this?
The scenario where I made this mistake was when I did an rFactor
on a reduction dimension split and didn't apply the corresponding split to an operation found after the reduction.
To Reproduce
IR View of Operations
T5[ iblockIdx.x{gridDim.x}, rS{( ceilDiv(i3, 32) )}rf, ithreadIdx.x{32}rf ] = reduction( T0[ iS{i1}, iS{i3} ], op = add, initial value = float(0) )
T2[ iblockIdx.x{gridDim.x}, rthreadIdx.x{32} ] = reduction( T5[ iblockIdx.x{gridDim.x}, rS{( ceilDiv(i3, 32) )}rf, ithreadIdx.x{32}rf ], op = add, initial value = float(0) )
T3[ iblockIdx.x{gridDim.x}, bthreadIdx.x{1} ]
= T2[ iblockIdx.x{gridDim.x}, rthreadIdx.x{32} ];
T4[ iblockIdx.x{gridDim.x}, ithreadIdx.x{blockDim.x} ]
= T3[ iblockIdx.x{gridDim.x}, bthreadIdx.x{1} ]
+ T1[ iS{i5}, iS{i7} ];
Kernel Generated
__device__ void reduction_add_float(float& a, const float b) {
a = a + b;
}
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 2> T4){
__shared__ float shared_mem[1024];
float T3[1];
float T2[1];
T2[ 0 ]
= float(0);
float T5[1];
if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
T5[ 0 ]
= float(0);
}
for(size_t i30 = 0; i30 < ( ceilDiv(T0.size[1], 32) ); ++i30 ) {
if ( ( ( ( i30 * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
T5[ 0 ]
= T5[ 0 ]
+ T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( i30 * 32 ) + threadIdx.x ) * T0.stride[1] ) ];
}
}
blockReduce< true, false, false > ( T2[ 0 ], T5[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
if ( ( threadIdx.x == 0 ) ) {
T3[ 0 ]
= T2[ 0 ];
}
T4[ ( blockIdx.x * T4.stride[0] ) + ( threadIdx.x * T4.stride[1] ) ]
= T3[ 0 ]
+ T1[ ( blockIdx.x * T1.stride[0] ) + ( threadIdx.x * T1.stride[1] ) ];
}
Test
void testGPU_FusionThreadBindingError() {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
// TV0 original [ 32 , 128 ]
// TV5 [ 32, 32 ] <- rFactor TV2[ 32, 128] -> [32, 4, 32] (bound to tidx(-1))
// TV2 [ 32 ] Final Reduce (bound to tidx(-1))
// TV3 [ 32, 128 ] Broadcast (bound to tidx(-1))
// TV4 [ 32, 128 ] Add (bound to tidx(-1))
TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
TensorView* tv3 = broadcast(tv2, {false, true});
TensorView* tv4 = add(tv3, tv1);
tv2->split(-1, 32);
TensorView* tv5 = tv2->rFactor({-2});
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv4->axis(0)->parallelize(ParallelType::BIDx);
tv5->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
tv4->axis(-1)->parallelize(ParallelType::TIDx);
tv5->axis(-1)->parallelize(ParallelType::TIDx);
fusion.addOutput(tv4);
fusion.printMath();
GPULower gpulw(&fusion);
gpulw.printKernel(std::cout);
prog.device_ = 0;
prog.grid(32);
prog.block(32);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({32, 128}, options);
at::Tensor t1 = at::randn({32, 128}, options);
at::Tensor cg_output = at::empty({32, 128}, options);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0,t1}, {cg_output});
}