Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
374 changes: 374 additions & 0 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4605,6 +4605,380 @@ void testGPU_FusionReductionScheduler() {
aten_output.sub(cg_output).abs().max());
}

void testGPU_FusionCacheBefore() {
// TVM Cache Write
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);

TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = add(tv0, new Float(1.0));
TensorView* tv2 = mul(tv1, new Float(3.0));
fusion->addInput(tv0);
fusion->addOutput(tv2);
// Before: TV2 = TV1 * 3
// After: TV3 = TV1 * 3;
// TV2 = TV3;
// Algorithm

constexpr int BSX = 32;
tv2->split(-1, BSX);
tv0->computeAt(tv2, -1);

// cache_before automatically applies ComputeAt to the cache TensorView
TensorView* tv3 = tv2->cache_before();
// Schedule
// fusion->printMath();

tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
// Thread and Block binding
// fusion->printKernel();

constexpr int M = 32, N = 750;
prog.setDevice(0);
setupLaunchConfig(
prog.fusion(),
BSX, // tid_x
1, // tid_y
1, // tid_z
M, // gid_x
1, // gid_y
1, // gid_z
0 // shared_memory size
);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand({M, N}, options);
at::Tensor cg_output = at::empty({M, N}, options);

torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}, c10::nullopt);

at::Tensor aten_output = (input + 1.0) * 3.0;
TORCH_CHECK(
aten_output.allclose(cg_output, 1e-5, 1e-5),
"Error of: ",
aten_output.sub(cg_output).abs().sum());
}

void testGPU_FusionCacheAfter() {
// TVM Cache Read
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);

TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = add(tv0, new Float(1.0));
TensorView* tv2 = mul(tv1, new Float(3.0));
fusion->addInput(tv0);
fusion->addOutput(tv2);
// Before: TV1 = TV0 + 1
// After: TV3 = TV0;
// TV1 = TV3 + 1
// Algorithm

constexpr int BSX = 32;
tv2->split(-1, BSX);
tv0->computeAt(tv2, -1);

// cache_after automatically applies ComputeAt to the cache TensorView
TensorView* tv3 = tv0->cache_after();
// Schedule
// fusion->printMath();

tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
// Thread and Block binding
// fusion->printKernel();

constexpr int M = 32, N = 457;
prog.setDevice(0);
setupLaunchConfig(
prog.fusion(),
BSX, // tid_x
1, // tid_y
1, // tid_z
M, // gid_x
1, // gid_y
1, // gid_z
0 // shared_memory size
);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand({M, N}, options);
at::Tensor cg_output = at::empty({M, N}, options);

torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}, c10::nullopt);

at::Tensor aten_output = (input + 1.0) * 3.0;
TORCH_CHECK(
aten_output.allclose(cg_output, 1e-5, 1e-5),
"Error of: ",
aten_output.sub(cg_output).abs().sum());
}

void testGPU_FusionCacheIndirect() {
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);

TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);
TensorView* tv2 = makeDummyTensor(2);
TensorView* tv3 = makeDummyTensor(2);
TensorView* tv4 = sub(tv2, tv3);
TensorView* tv5 = add(tv1, tv4);
TensorView* tv6 = sub(tv5, tv0);
fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);
fusion->addInput(tv3);
fusion->addOutput(tv6);
// t6 = ((t1 + (t2 - t3)) - t0)

// cache_after on inputs placed before schedule

constexpr int BSX = 32;
tv6->split(-1, BSX);
tv2->computeAt(tv6, -1);

TensorView* tv7 = tv5->cache_after();
TensorView* tv8 = tv5->cache_before();
// Schedule
// fusion->printMath();

tv6->axis(0)->parallelize(ParallelType::BIDx);
tv6->axis(-1)->parallelize(ParallelType::TIDx);
// Thread and Block binding
// fusion->printKernel();

constexpr int M = 32, N = 810;
prog.setDevice(0);
setupLaunchConfig(
prog.fusion(),
BSX, // tid_x
1, // tid_y
1, // tid_z
M, // gid_x
1, // gid_y
1, // gid_z
0 // shared_memory size
);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor in0 = at::rand({M, N}, options);
at::Tensor in1 = at::rand({M, N}, options);
at::Tensor in2 = at::rand({M, N}, options);
at::Tensor in3 = at::rand({M, N}, options);
at::Tensor cg_output = at::empty({M, N}, options);

torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runKernel(
&prog, {in0, in1, in2, in3}, {cg_output}, c10::nullopt);

at::Tensor aten_output = (in1 + (in2 - in3)) - in0;
TORCH_CHECK(
aten_output.allclose(cg_output, 1e-5, 1e-5),
"Error of: ",
aten_output.sub(cg_output).abs().sum());
}

void testGPU_FusionCacheBcast() {
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);

TensorView* tv0 = makeDummyTensor(1); // (M, 1)
TensorView* tv1 = broadcast(tv0, {false, true});
TensorView* tv2 = makeDummyTensor(1); // (1, N)
TensorView* tv3 = broadcast(tv2, {true, false});
TensorView* tv4 = mul(tv1, tv3);
fusion->addInput(tv0);
fusion->addInput(tv2);
fusion->addOutput(tv4);
// Algorithm

constexpr int BSX = 128;
tv4->split(0, BSX);
tv4->split(-1, BSX);
tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
// M/BSX, N/BSY, BSX, BSY
tv0->computeAt(tv4, 2);
tv2->computeAt(tv4, 2);
// 0, 1 | 2, 3, 4

// Case 1
TensorView* tv5 = tv0->cache_after();

// Case 2
TensorView* tv6 = tv1->cache_before();

// Case 3
TensorView* tv7 = tv1->cache_after();

// Case 4
TensorView* tv8 = tv4->cache_before();
// Schedule
// fusion->printMath();

tv4->axis(0)->parallelize(ParallelType::BIDx);
tv4->axis(1)->parallelize(ParallelType::BIDy);
tv4->axis(-1)->parallelize(ParallelType::TIDx);
// Manual Replay on TV3
tv3->axis(-1)->parallelize(ParallelType::TIDx);
tv8->axis(-1)->parallelize(ParallelType::TIDx);
// Thread and Block binding
// fusion->printKernel();

constexpr int M = 92, N = 500;
const int Mr = ceilDiv_(M, BSX);
const int Nr = ceilDiv_(N, BSX);
prog.setDevice(0);
setupLaunchConfig(
prog.fusion(),
BSX, // tid_x
1, // tid_y
1, // tid_z
Mr, // gid_x
Nr, // gid_y
1, // gid_z
0 // shared_memory size
);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({M}, options);
at::Tensor t1 = at::randn({N}, options);
at::Tensor cg_output = at::empty({M, N}, options);

torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runKernel(
&prog, {t0, t1}, {cg_output}, c10::nullopt);

at::Tensor aten_output = t0.unsqueeze(1).matmul(t1.unsqueeze(0));
TORCH_CHECK(
aten_output.allclose(cg_output, 1e-5, 1e-5),
"Error of: ",
aten_output.sub(cg_output).abs().max());
}

void testGPU_FusionCacheComplex() {
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);

TensorView* tv0 = makeDummyTensor(2); // (N, N)
TensorView* tv1 = makeDummyTensor(1); // (N)
TensorView* tv2 = sum(tv0, {1}); // (N)
TensorView* tv3 = broadcast(tv2, {false, true}); // (N, 1)
TensorView* tv4 = broadcast(tv1, {true, false}); // (1, N)
TensorView* tv5 = mul(tv3, tv4); // (N, N)
fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addOutput(tv5);
// Algorithm

// Exception: Cache-Before on reduction Op
// TensorView* tv9 = tv2->cache_before();

constexpr int BSX = 128;
tv5->split(0, BSX);
tv5->split(-1, BSX);
// M/BSX, BSX, N/BSX, BSX
tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
// M/BSX, N/BSY, BSX, BSY
tv0->computeAt(tv5, 2);
tv1->computeAt(tv5, 2);
// 0, 1 | 2, 3, 4

TensorView* tv6 = tv2->cache_after();
TensorView* tv7 = tv5->cache_before();
// Schedule
// fusion->printMath();

tv5->axis(0)->parallelize(ParallelType::BIDx);
tv5->axis(1)->parallelize(ParallelType::BIDy);
tv5->axis(-1)->parallelize(ParallelType::TIDx);

tv4->axis(-1)->parallelize(ParallelType::TIDx);
tv7->axis(-1)->parallelize(ParallelType::TIDx);
// Thread and Block binding
// fusion->printKernel();

constexpr int N = 800;
const int Nr = ceilDiv_(N, BSX);
prog.setDevice(0);
setupLaunchConfig(
prog.fusion(),
BSX, // tid_x
1, // tid_y
1, // tid_z
Nr, // gid_x
Nr, // gid_y
1, // gid_z
0 // shared_memory size
);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input1 = at::rand({N, N}, options);
at::Tensor input2 = at::rand({N}, options);
at::Tensor cg_output = at::empty({N, N}, options);

torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runKernel(
&prog, {input1, input2}, {cg_output}, c10::nullopt);

at::Tensor aten_output =
matmul(sum(input1, 1).unsqueeze(1), input2.unsqueeze(0));
TORCH_CHECK(
aten_output.allclose(cg_output, 1e-5, 1e-5),
"Error of: ",
aten_output.sub(cg_output).abs().sum());
}

void testGPU_FusionCacheMultiConsumer() {
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);

TensorView* tv0 = makeDummyTensor(1);
TensorView* tv1 = add(tv0, new Float(1));
TensorView* tv2 = add(tv1, new Float(2));
TensorView* tv3 = add(tv0, new Float(1));
TensorView* tv4 = add(tv3, new Float(2));

fusion->addInput(tv0);
fusion->addOutput(tv2);
fusion->addOutput(tv4);

tv1->computeAt(tv2, -1);
tv3->computeAt(tv4, -1);

// std::cout << "Before caching\n";
// fusion->printKernel();

// Passes
auto tv5 = tv1->cache_before();
auto tv6 = tv3->cache_before();

// Fails because tensor must be recomputed twice
// auto tv7 = tv0->cache_after();

// std::cout << "After caching\n";
// fusion->printKernel();

prog.setDevice(0);
torch::jit::fuser::cuda::compileKernel(&prog);
return;
}

} // namespace jit
} // namespace torch

Expand Down
8 changes: 7 additions & 1 deletion test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ namespace jit {
_(GPU_FusionZeroDimReduction) \
_(GPU_FusionReductionMultiConsumer) \
_(GPU_FusionBCastAfterReduce) \
_(GPU_FusionReductionScheduler)
_(GPU_FusionReductionScheduler) \
_(GPU_FusionCacheBefore) \
_(GPU_FusionCacheAfter) \
_(GPU_FusionCacheIndirect) \
_(GPU_FusionCacheBcast) \
_(GPU_FusionCacheComplex) \
_(GPU_FusionCacheMultiConsumer)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
Loading