Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4325,6 +4325,91 @@ void testGPU_FusionReductionSchedulerMultiDimFastest() {
aten_output.sub(outputs[0]).abs().max());
}

void testGPU_FusionReductionSchedulerDimShmoo() {
std::vector<bool> fp16_usage = {false};
std::vector<int> red_axis = {1, 0};
std::vector<int> output_dims = {320, 640};
std::vector<int> red_dims;

// Tried to cut down the number iterations with just
// doing every other power of 2.
for (int i = 1; i <= 1024 * 1024; i <<= 2) {
red_dims.push_back(i);
}

for (auto fp16 : fp16_usage) {
for (auto& axis : red_axis) {
for (auto& odim : output_dims) {
for (auto& rdim : red_dims) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* tv0 =
makeDummyTensor(2, (fp16 ? DataType::Half : DataType::Float));
fusion.addInput(tv0);

torch::jit::fuser::Val* tv0_cast = nullptr;
if (fp16) {
tv0_cast = castOp(DataType::Float, tv0);
}

TensorView* tv1 = reductionOp(
BinaryOpType::Add,
{axis},
new Float(0),
(fp16 ? tv0_cast->as<TensorView>() : tv0));

TensorView* tv1_cast = nullptr;
if (fp16) {
tv1_cast = castOp(DataType::Half, tv1);
}

fusion.addOutput((fp16 ? tv1_cast : tv1));

auto options = at::TensorOptions()
.dtype((fp16 ? at::kHalf : at::kFloat))
.device(at::kCUDA, 0);
at::Tensor input =
(axis ? at::rand({odim, rdim}, options)
: at::rand({rdim, odim}, options));

const at::ArrayRef<c10::IValue> inputs({input});

c10::optional<cuda::ReductionParams> rparams =
cuda::scheduleReduction(&fusion, inputs, tv1);
TORCH_CHECK(rparams != c10::nullopt, "Reduction is not found!");
if (fp16) {
if (axis == 0) {
int tidx = rparams.value().bdimx.value;
tv1_cast->split(-1, tidx);
tv1_cast->axis(-1)->parallelize(ParallelType::TIDx);
tv1_cast->axis(-2)->parallelize(ParallelType::BIDx);
} else {
if (rparams.value().mul_reds_per_blk) {
int tidy = rparams.value().bdimy.value;
tv1_cast->split(0, tidy);
tv1_cast->axis(-1)->parallelize(ParallelType::TIDy);
}
tv1_cast->axis(0)->parallelize(ParallelType::BIDx);
}
}

torch::jit::fuser::cuda::FusionExecutor fe;
fe.compileFusion(&fusion);

auto cg_output = fe.runFusion({input});
auto aten_output = input.sum({axis});

TORCH_CHECK(
aten_output.allclose(cg_output[0]),
"Error of: ",
aten_output.sub(cg_output[0]).abs().max());
}
}
}
}
}

void testGPU_FusionCacheBefore() {
// TVM Cache Write
Fusion fusion;
Expand Down
1 change: 1 addition & 0 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ namespace jit {
_(GPU_FusionReductionScheduler) \
_(GPU_FusionReductionSchedulerMultiDimNonFastest) \
_(GPU_FusionReductionSchedulerMultiDimFastest) \
_(GPU_FusionReductionSchedulerDimShmoo) \
_(GPU_FusionCacheBefore) \
_(GPU_FusionCacheAfter) \
_(GPU_FusionCacheIndirect) \
Expand Down