Skip to content

Commit a656b6d

Browse files
committed
Format and add comments to GEMM test.
1 parent b69006c commit a656b6d

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

test/cpp/jit/test_gpu.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,23 +2674,44 @@ void testGPU_FusionSimpleGemm() {
26742674
fusion.addInput(tv1);
26752675

26762676
TensorView* tv2 = broadcast(tv0, {false, false, true});
2677+
// tv2[I0, I1, B] = tv0[I0, I1]
2678+
26772679
TensorView* tv3 = broadcast(tv1, {true, false, false});
2680+
// tv3[B, I1, I2] = tv1[I1, I2]
26782681

2682+
// tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
26792683
TensorView* tv4 = mul(tv2, tv3);
2684+
// tv5[I0, R1, I2] = tv4[I0, I1, I2]
26802685
TensorView* tv5 = sum(tv4, {1});
26812686
fusion.addOutput(tv5);
26822687

26832688
tv5->split(1, 32);
2689+
// tv5[I0, R1o, R1i{32}, I2]
2690+
26842691
auto tv6 = tv5->rFactor({1});
2692+
// tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
2693+
// tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
26852694

26862695
tv5->split(0, 4);
26872696
tv5->split(-1, 4);
2697+
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
2698+
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
26882699

26892700
tv0->computeAt(tv5, -1);
26902701
tv1->computeAt(tv5, -1);
26912702

2703+
// tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
2704+
// tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
2705+
//--> (line symbolizes compute at location)
2706+
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
2707+
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
2708+
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
2709+
26922710
tv0->computeAt(tv6, -1);
26932711
tv1->computeAt(tv6, -1);
2712+
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
2713+
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
2714+
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
26942715

26952716
tv5->axis(0)->parallelize(ParallelType::BIDz);
26962717
tv5->axis(1)->parallelize(ParallelType::TIDz);

torch/csrc/jit/codegen/cuda/iter_visitor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void IterVisitor::traverseFrom(
5656
FusionGuard fg(fusion);
5757
std::unordered_set<Statement*> visited;
5858
stmt_stack.clear();
59-
if(!from.empty())
59+
if (!from.empty())
6060
stmt_stack.emplace_back(from.rbegin(), from.rend());
6161

6262
while (!stmt_stack.empty()) {
@@ -191,7 +191,7 @@ std::unordered_set<Val*> IterVisitor::getTerminatingOutputs(
191191
auto exprs = Exprs::getExprs(
192192
fusion,
193193
std::vector<Val*>(fusion->outputs().begin(), fusion->outputs().end()));
194-
194+
195195
for (auto expr : exprs) {
196196
for (auto inp : expr->inputs())
197197
used_vals.emplace(inp);

0 commit comments

Comments
 (0)