Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions third_party/nvfuser/csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2870,10 +2870,9 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
// parameter. Predicates involving vectorized loops are separately
// generated in lower_misaligned_vectorization.
//
// Second condition is simply to avoid predication on broadcasting axes as
// it's not required.
if (consumer_stop_indexing_it == consumer_stop_index_map.end() ||
consumer_stop_indexing_it->second->isZeroInt()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because the extent can be zero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we end up seeing more predicates like 0 < T.size[0]?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case, can you leave a comment and mention we can't ignore the consumer stop index even if it's zero as the extent can be zero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think other than loop rotation, it will be mostly 0 < 1? I can not think of a case where stop index is 0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would want to specialize a fusion for the case where there's no zero-dim tensor. Assuming that's the common case, having a compilation option flag to guarantee there's no zero-dim tensor seems reasonable. It's not clear to me how much perf overhead we would have due to the conservative assumption that there can be zero-dim tensors, but if that's something we want to optimize, it seems to make sense to have two compiled kernels, one for the common case and another as backup.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. Don't know how much effort would be to add that option. I would wait until some user request us to improve empty tensor support. :)

// Can not omit stop index even if it is zero. This is important for empty
// tensor support, because in empty tensor the extent of an ID can be zero
if (consumer_stop_indexing_it == consumer_stop_index_map.end()) {
continue;
}

Expand Down
80 changes: 39 additions & 41 deletions third_party/nvfuser/test/test_gpu_loop_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
i30 = T0.stride[0] * i21;
int64_t i44;
i44 = 3 * i21;
bool b76;
b76 = 0 < (T0.size[0] - i21);
bool b82;
b82 = 0 < (T0.size[0] - i21);
float T1[1];
float T2[1];
T1[0] = 0;
if (b76) {
if (b82) {
T1[0]
= T0[i30];
}
Expand All @@ -57,12 +57,12 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
float T3[1];
T3[0]
= T2[0];
if ((b76 && (i37 < 3))) {
if ((b82 && (i37 < 3))) {
T4[(i44 + i37)]
= T3[0];
}
T1[0] = 0;
if ((b76 && (i61 < 3))) {
if ((b82 && (i61 < 3))) {
T1[0]
= T0[(i30 + (T0.stride[1] * i61))];
}
Expand All @@ -75,7 +75,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
)";
assertCUDAKernel(&fusion, expected_kernel);

for (auto n : {1, 99}) {
for (auto n : {0, 1, 99}) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({n, 3}, options);
FusionExecutor fe;
Expand Down Expand Up @@ -103,17 +103,13 @@ TEST_F(NVFuserTest, FusionLoopRotation1Outer_CUDA) {
inlineAllAt(tv4, 1);
scheduler_utils::rotateLoop(tv4, 0, {tv1, tv2});

// TODO: the predicate 0 < T0.size[0] in
// T1[i21]
// = T0[(T0.stride[1] * i29)];
// is optimized to `true` by expr simplifier, due to
// https://github.com/csarofeen/pytorch/blob/167718b6d06558395f86b6d25a68352168b86da2/third_party/nvfuser/csrc/expr_simplifier.cpp#L1115-L1126
// This doesn't look very safe.
const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
NVFUSER_DEFINE_MAGIC_ZERO
int64_t i118;
i118 = -T0.size[0];
bool b79;
b79 = 0 < T0.size[0];
int64_t i128;
i128 = -T0.size[0];
float T1[3];
float T2[3];
#pragma unroll
Expand All @@ -125,7 +121,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) {
int64_t i29;
i29 = i21 + nvfuser_zero;
if ((i29 < 3)) {
if ((b79 && (i29 < 3))) {
T1[i21]
= T0[(T0.stride[1] * i29)];
}
Expand All @@ -143,10 +139,10 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
i48 = 3 * i24;
int64_t i69;
i69 = T0.stride[0] + (T0.stride[0] * i24);
bool b97;
b97 = 0 < (T0.size[0] - i24);
bool b126;
b126 = (i118 + i24) < -1;
bool b107;
b107 = 0 < (T0.size[0] - i24);
bool b136;
b136 = (i128 + i24) < -1;
// Alias Allocation - register
auto& T3 = T1;
#pragma unroll
Expand All @@ -159,7 +155,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i25 = 0; i25 < 3; ++i25) {
int64_t i41;
i41 = i25 + nvfuser_zero;
if ((b97 && (i41 < 3))) {
if ((b107 && (i41 < 3))) {
T4[(i48 + i41)]
= T3[i25];
}
Expand All @@ -174,7 +170,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) {
int64_t i52;
i52 = i21 + nvfuser_zero;
if ((b126 && (i52 < 3))) {
if ((b136 && (i52 < 3))) {
T1[i21]
= T0[(i69 + (T0.stride[1] * i52))];
}
Expand All @@ -191,7 +187,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
)";
assertCUDAKernel(&fusion, expected_kernel);

for (auto n : {1, 99}) {
for (auto n : {0, 1, 99}) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({n, 3}, options);
FusionExecutor fe;
Expand Down Expand Up @@ -314,7 +310,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
)";
assertCUDAKernel(&fusion, expected_kernel);

for (auto n : {1, 99}) {
for (auto n : {0, 1, 99}) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({n, 3}, options);
FusionExecutor fe;
Expand Down Expand Up @@ -438,7 +434,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
)";
assertCUDAKernel(&fusion, expected_kernel);

for (auto n : {1, 99}) {
for (auto n : {0, 1, 99}) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({n, 3}, options);
FusionExecutor fe;
Expand Down Expand Up @@ -474,10 +470,12 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
i119 = 4 * T0.stride[0];
int64_t i220;
i220 = T0.stride[0] * 5;
int64_t i347;
i347 = -T0.size[0];
bool b351;
b351 = i347 < -4;
bool b295;
b295 = 0 < T0.size[0];
int64_t i357;
i357 = -T0.size[0];
bool b361;
b361 = i357 < -4;
float T1[15];
#pragma unroll
for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) {
Expand All @@ -488,7 +486,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) {
int64_t i35;
i35 = i21 + nvfuser_zero;
if ((i35 < 3)) {
if ((b295 && (i35 < 3))) {
T1[i21]
= T0[(T0.stride[1] * i35)];
}
Expand All @@ -500,8 +498,8 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
i57 = 3 + (3 * i24);
int64_t i78;
i78 = T0.stride[0] + (T0.stride[0] * i24);
bool b323;
b323 = 0 < (T0.size[0] - ((1 + i24) + nvfuser_zero));
bool b333;
b333 = 0 < (T0.size[0] - ((1 + i24) + nvfuser_zero));
#pragma unroll
for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) {
T1[(i57 + i21)] = 0;
Expand All @@ -510,7 +508,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) {
int64_t i61;
i61 = i21 + nvfuser_zero;
if ((b323 && (i61 < 3))) {
if ((b333 && (i61 < 3))) {
T1[(i57 + i21)]
= T0[(i78 + (T0.stride[1] * i61))];
}
Expand All @@ -527,7 +525,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) {
int64_t i109;
i109 = i21 + nvfuser_zero;
if ((b351 && (i109 < 3))) {
if ((b361 && (i109 < 3))) {
T1[(12 + i21)]
= T0[(i119 + (T0.stride[1] * i109))];
}
Expand All @@ -549,10 +547,10 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
i222 = i220 + (T0.stride[0] * i25);
int64_t i288;
i288 = 3 * ((1 + i25) % 5);
bool b373;
b373 = 0 < (T0.size[0] - i25);
bool b416;
b416 = (i347 + i25) < -5;
bool b383;
b383 = 0 < (T0.size[0] - i25);
bool b426;
b426 = (i357 + i25) < -5;
float T3[3];
#pragma unroll
for(nvfuser_index_t i23 = 0; i23 < 3; ++i23) {
Expand All @@ -564,7 +562,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i27 = 0; i27 < 3; ++i27) {
int64_t i144;
i144 = i27 + nvfuser_zero;
if ((b373 && (i144 < 3))) {
if ((b383 && (i144 < 3))) {
T4[(i151 + i144)]
= T3[i27];
}
Expand All @@ -579,7 +577,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) {
int64_t i196;
i196 = i21 + nvfuser_zero;
if ((b416 && (i196 < 3))) {
if ((b426 && (i196 < 3))) {
T1[(i192 + i21)]
= T0[(i222 + (T0.stride[1] * i196))];
}
Expand All @@ -596,7 +594,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
)";
assertCUDAKernel(&fusion, expected_kernel);

for (auto n : {1, 99}) {
for (auto n : {0, 1, 99}) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({n, 3}, options);
FusionExecutor fe;
Expand Down Expand Up @@ -731,7 +729,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
)";
assertCUDAKernel(&fusion, expected_kernel);

for (auto n : {1, 99}) {
for (auto n : {0, 1, 99}) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({n, 3}, options);
FusionExecutor fe;
Expand Down