forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Description
🐛 Bug
CI test testGPU_FusionSmemBlockGemm
is reporting wrong result from codegen.
I looked at the scheduling & launch param, things seems to be consistent.
Forwarding Naoya's comment from slack
For example, here's the final loop nest writing into the output tensor:
for(size_t i81 = 0; i81 < 16; ++i81 ) {
if ( ( ( ( ( blockIdx.x * 16 ) + i81 ) < T5.size[0] ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
= T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
+ T6[ ( i81 * 16 ) + threadIdx.x ];
}
}
Each thread writes into T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]. Notice that threadIdx.y is not used for indexing, although there seem to be more than one threads along TIDy. The thread block size is 16x16x1.
To Reproduce
script copied out:.
void testGPU_FusionSmemBlockGemm() {
Fusion fusion;
FusionGuard fg(&fusion);
// Algorithm
TensorView* tv0 = makeDummyTensor(2); // (M, K)
TensorView* tv1 = makeDummyTensor(2); // (K, N)
TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
TensorView* tv4 = mul(tv2, tv3); // M, K, N
TensorView* tv5 = sum(tv4, {1}); // M, R, N
fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addOutput(tv5);
// Schedule
constexpr int BSX = 16;
tv5->split(2, BSX);
tv5->split(1, BSX);
tv5->split(0, BSX);
// M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}});
// M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
TensorView* tv6 = tv5->rFactor({-1});
tv2->setMemoryType(MemoryType::Shared);
tv3->setMemoryType(MemoryType::Shared);
tv4->setMemoryType(MemoryType::Shared);
tv6->setMemoryType(MemoryType::Shared);
tv0->computeAt(tv5, 3);
tv1->computeAt(tv5, 3);
tv0->computeAt(tv6, 3);
tv1->computeAt(tv6, 3);
// Thread and Block binding
tv5->axis(0)->parallelize(ParallelType::BIDx);
tv5->axis(1)->parallelize(ParallelType::BIDy);
tv5->axis(-1)->parallelize(ParallelType::TIDx);
// Manual Binding
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
tv4->axis(-1)->parallelize(ParallelType::TIDx);
tv6->axis(-3)->parallelize(ParallelType::TIDy);
tv6->axis(-2)->parallelize(ParallelType::TIDx);
constexpr int M = 154, K = 45, N = 1524;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({M, K}, options);
at::Tensor t1 = at::randn({K, N}, options);
torch::jit::fuser::cuda::FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t0, t1});
at::Tensor aten_output = matmul(t0, t1);
TORCH_CHECK(
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
}
output kernel looks like
__device__ void reduction_add_float(float& a, const float b) {
a = a + b;
}
__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 2> T5){
for(size_t i80 = 0; i80 < 16; ++i80 ) {
if ( ( ( ( ( blockIdx.x * 16 ) + i80 ) < T5.size[0] ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
T5[ ( ( ( blockIdx.x * 16 ) + i80 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
= float(0);
}
}
for(size_t i67 = 0; i67 < ( ceilDiv(T0.size[1], 16) ); ++i67 ) {
__shared__ float T2[( 16 * 16 )];
for(size_t i69 = 0; i69 < 16; ++i69 ) {
if ( ( ( ( ( blockIdx.x * 16 ) + i69 ) < T5.size[0] ) && ( ( ( i67 * 16 ) + threadIdx.x ) < T0.size[1] ) ) ) {
T2[ ( i69 * 16 ) + threadIdx.x ]
= T0[ ( ( ( blockIdx.x * 16 ) + i69 ) * T0.stride[0] ) + ( ( ( i67 * 16 ) + threadIdx.x ) * T0.stride[1] ) ];
}
}
__shared__ float T3[( 16 * 16 )];
for(size_t i73 = 0; i73 < 16; ++i73 ) {
if ( ( ( ( ( i67 * 16 ) + threadIdx.x ) < T1.size[0] ) && ( ( ( blockIdx.y * 16 ) + i73 ) < T5.size[1] ) ) ) {
T3[ ( i73 * 16 ) + threadIdx.x ]
= T1[ ( ( ( i67 * 16 ) + threadIdx.x ) * T1.stride[0] ) + ( ( ( blockIdx.y * 16 ) + i73 ) * T1.stride[1] ) ];
}
}
__syncthreads();
__shared__ float T4[( ( 16 * 16 ) * 16 )];
for(size_t i76 = 0; i76 < 16; ++i76 ) {
for(size_t i77 = 0; i77 < 16; ++i77 ) {
if ( ( ( ( ( ( blockIdx.x * 16 ) + i76 ) < T5.size[0] ) && ( ( ( i67 * 16 ) + threadIdx.x ) < T0.size[1] ) ) && ( ( ( blockIdx.y * 16 ) + i77 ) < T5.size[1] ) ) ) {
T4[ ( ( i76 * 16 ) * 16 ) + ( i77 * 16 ) + threadIdx.x ]
= T2[ ( i76 * 16 ) + threadIdx.x ]
* T3[ ( i77 * 16 ) + threadIdx.x ];
}
}
}
__syncthreads();
__shared__ float T6[( 16 * 16 )];
if ( ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T5.size[0] ) && ( ( ( i67 * 16 ) + 0 ) < T0.size[1] ) ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
T6[ ( threadIdx.y * 16 ) + threadIdx.x ]
= float(0);
}
for(size_t i79 = 0; i79 < 16; ++i79 ) {
if ( ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T5.size[0] ) && ( ( ( i67 * 16 ) + i79 ) < T0.size[1] ) ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
T6[ ( threadIdx.y * 16 ) + threadIdx.x ]
= T6[ ( threadIdx.y * 16 ) + threadIdx.x ]
+ T4[ ( ( threadIdx.y * 16 ) * 16 ) + ( threadIdx.x * 16 ) + i79 ];
}
}
__syncthreads();
for(size_t i81 = 0; i81 < 16; ++i81 ) {
if ( ( ( ( ( blockIdx.x * 16 ) + i81 ) < T5.size[0] ) && ( ( ( blockIdx.y * 16 ) + threadIdx.x ) < T5.size[1] ) ) ) {
T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
= T5[ ( ( ( blockIdx.x * 16 ) + i81 ) * T5.stride[0] ) + ( ( ( blockIdx.y * 16 ) + threadIdx.x ) * T5.stride[1] ) ]
+ T6[ ( i81 * 16 ) + threadIdx.x ];
}
}
}
}
it launches on grid (10 96 1) block (16 16 1)
Additional context
commit that introduces this test & breakage: cbd6f67
Test is disabled in PR #252 , please re-enable it after fix.
Metadata
Metadata
Assignees
Labels
No labels