Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ TensorView* makeDummyTensor(int nDims, DataType dtype = DataType::Float) {
return new TensorView(new TensorDomain(dom), dtype);
}

TensorView* makeConcreteTensor(
std::vector<int> sizes,
DataType dtype = DataType::Float) {
// We can uncomment the below statement to test all tests with contiguous
// tensors. return makeContigTensor(nDims, dtype);
std::vector<IterDomain*> dom;
for (int i = 0; i < sizes.size(); i++)
dom.push_back(new IterDomain(new Int(0), new Int(sizes[i])));
return new TensorView(new TensorDomain(dom), dtype);
}

TensorView* makeTensorWithContig(
int nDims,
std::vector<bool> contig_info,
Expand Down Expand Up @@ -3038,6 +3049,53 @@ void testGPU_FusionSimpleBCast() {
#endif
}

void testGPU_FusionComplexBCast() {
Fusion fusion;
FusionGuard fg(&fusion);

int x = 2, y = 3, z = 4;

auto tv0 = makeConcreteTensor({y});
auto tv1 = broadcast(tv0, {false, true});
auto tv2 = makeConcreteTensor({y, z});
auto tv3 = mul(tv1, tv2);
auto tv4 = broadcast(tv3, {true, false, false});
auto tv5 = makeConcreteTensor({x, y, z});
auto tv6 = add(tv4, tv5);

// tv0[ i1 ]
// tv1[ i1, b2]
// tv2[ i1, i2]
// tv3[ i1, i2]
// tv4[b0, i1, i2]
// tv5[i0, i1, i2]
// tv6[i0, i1, i2]

// tv3 = bcast(tv0) * tv2
// tv6 = bcast(tv3) + tv5

fusion.addInput(tv0);
fusion.addInput(tv2);
fusion.addInput(tv5);

fusion.addOutput(tv6);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

at::Tensor t0 = at::randn({y}, options);
at::Tensor t2 = at::randn({y, z}, options);
at::Tensor t5 = at::randn({x, y, z}, options);

auto t3 = t0.unsqueeze(-1).expand({y, z}) * t2;
auto t6 = t3.unsqueeze(0).expand({x, y, z}) + t5;

torch::jit::fuser::cuda::FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t0, t2, t5});

TORCH_CHECK(t6.allclose(outputs[0]));
}

// Test a simple Gemm but also play around with fusion executor features
void testGPU_FusionSimpleGemm() {
Fusion fusion;
Expand Down
1 change: 1 addition & 0 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ namespace jit {
_(GPU_FusionReduction5) \
_(GPU_FusionReductionTFT) \
_(GPU_FusionSimpleBCast) \
_(GPU_FusionComplexBCast) \
_(GPU_FusionSimpleGemm) \
_(GPU_FusionSoftmax1D) \
_(GPU_FusionSoftmax1DNormalized) \
Expand Down
Loading