forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Description
🐛 Bug
I am doing a reduction on not the fastest changing axis: [>>X<<, Y]. I am trying to split Y into two parts [>>X<<, Ya, Yb]. I am binding blocks to Ya and threads to Yb. When I do this the first output gets calculated correctly but all the rest are zero. The kernel looks incorrect in that has a thread predicate where it uses T1 index1 size even though T1 is only 1D.
Tensor<float, 1> T1
if ( ( ( ( blockIdx.x * 2 ) + threadIdx.x ) < T1.size[1] ) ) {
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 1> T1){
if ( ( ( ( blockIdx.x * 2 ) + threadIdx.x ) < T1.size[1] ) ) {
T1[ ( ( ( blockIdx.x * 2 ) + threadIdx.x ) * T1.stride[0] ) ]
= float(0);
}
for(size_t i48 = 0; i48 < T0.size[0]; ++i48 ) {
if ( ( ( ( blockIdx.x * 2 ) + threadIdx.x ) < T1.size[1] ) ) {
T1[ ( ( ( blockIdx.x * 2 ) + threadIdx.x ) * T1.stride[0] ) ]
= T1[ ( ( ( blockIdx.x * 2 ) + threadIdx.x ) * T1.stride[0] ) ]
+ T0[ ( i48 * T0.stride[0] ) + ( ( ( blockIdx.x * 2 ) + threadIdx.x ) * T0.stride[1] ) ];
}
}
}
For Tensor dimension [16, 3, 2], this yields 6 outputs.
ATEN Output:
5.5575
9.6090
7.6456
9.6341
9.0592
7.0243
Codegen Output:
5.5575
0.0000
0.0000
0.0000
0.0000
0.0000
To Reproduce
I am using 20_6_11_devel
branch.
Here is the test code:
void testGPU_FusionNonRedAxisBind() {
int bid_x = 3;
int tid_x = 2;
int red_dim = 0;
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
TensorView* tv1 = reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
fusion.addOutput(tv1);
tv1->split(-1, tid_x);
tv1->axis(-2)->parallelize(ParallelType::BIDx);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
prog.device_ = 0;
prog.grid(bid_x);
prog.block(tid_x);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand({16, bid_x * tid_x}, options);
at::Tensor cg_output = at::empty({bid_x * tid_x}, options);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
GPULower gpulw(&fusion);
gpulw.printKernel(std::cout);
auto aten_output = input.sum({red_dim});
std::cout << aten_output << std::endl;
std::cout << cg_output << std::endl;
TORCH_CHECK(aten_output.allclose(cg_output),
"Error of: ",
aten_output.sub(cg_output).abs().max());
}
Metadata
Metadata
Assignees
Labels
No labels