Skip to content

Commit 9bb4cf7

Browse files
authored
fragment iteration to support fully unrolled mma ops (#1823)
1 parent a48270a commit 9bb4cf7

File tree

6 files changed

+645
-48
lines changed

6 files changed

+645
-48
lines changed

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,11 @@ class CudaKernelGenerator : private OptOutConstDispatch {
474474
}
475475
}
476476

477-
void handle(const kir::TensorIndex* ti) final {
477+
//! Returns the sum of all indices in a TensorIndex,
478+
//! or 0 if the indices vector is empty.
479+
//! Used lowering generic tensor index and lowering
480+
//! mma fragment indices.
481+
std::string genTensorIndex(const kir::TensorIndex* ti) {
478482
bool first = true;
479483
std::stringstream index;
480484
for (auto* ind : ti->indices()) {
@@ -490,12 +494,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
490494
if (first) {
491495
index << "0";
492496
}
497+
498+
return index.str();
499+
}
500+
501+
void handle(const kir::TensorIndex* ti) final {
493502
bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global &&
494503
kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID();
495504
if (is_volatile) {
496505
code_ << "*(volatile " << ti->getDataType().value() << "*)&";
497506
}
498-
code_ << varName(ti->view()) << "[" << index.str() << "]";
507+
code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]";
499508
}
500509

501510
void handle(const ViewAsScalar* sv) final {
@@ -1013,14 +1022,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
10131022
auto options = mma->options();
10141023
auto in_a = mma->inA()->as<kir::TensorIndex>()->view();
10151024
auto dtype = in_a->getDataType().value();
1016-
indent() << kTab << "reinterpret_cast<Array<" << dtype << ","
1025+
indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
10171026
<< getInputARegisterSize(options.macro) << ","
10181027
<< getInputARegisterSize(options.macro) << ">*>(&"
1019-
<< gen(mma->inA()) << "),\n";
1020-
indent() << kTab << "reinterpret_cast<Array<" << dtype << ","
1028+
<< varName(mma->inA()->as<kir::TensorIndex>()->view()) << ")["
1029+
<< genTensorIndex(mma->inA()->as<kir::TensorIndex>()) << "])"
1030+
<< ",\n";
1031+
indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
10211032
<< getInputBRegisterSize(options.macro) << ","
10221033
<< getInputBRegisterSize(options.macro) << ">*>(&"
1023-
<< gen(mma->inB()) << ")";
1034+
<< varName(mma->inB()->as<kir::TensorIndex>()->view()) << ")["
1035+
<< genTensorIndex(mma->inB()->as<kir::TensorIndex>()) << "])";
10241036
}
10251037

10261038
void genMmaInitialization(const MmaOp* mma, const UnaryOp* uop) {

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,16 @@ indexMapFromTV(
11251125
// Similarly for local memory tensors, zero replacement can be
11261126
// only done when there's a matching domain with the same
11271127
// parallel type
1128-
(loop->iter_domain()->isThread() && is_local && same_parallel_type)) {
1128+
(loop->iter_domain()->isThread() && is_local && same_parallel_type) ||
1129+
// MMA operands are currently indexed in units of "fragments",
1130+
// so each mma tensor domain would be zero-ed and the tensor index
1131+
// calculated here would be the fragment index.
1132+
// TODO: This is a quick WAR to enable iterating over a register array
1133+
// of MMA fragments, so we could generate unrolled mma loops.
1134+
// Eventually we still want IdGraph to be able to analyze the
1135+
// in-register layout of mma fragments for more unified indexing math
1136+
// as well as more flexibility in swizzling loops.
1137+
(loop->iter_domain()->isMma() && !as_consumer)) {
11291138
idx = GpuLower::current()->kernel()->zeroVal();
11301139
zero_loops.insert(loop);
11311140
} else {

torch/csrc/jit/codegen/cuda/lower_validation.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -899,22 +899,42 @@ void validateMmaTensors(MmaOp* mma) {
899899
}
900900

901901
// Note: this check will be relaxed in a follow up.
902-
auto validate_operand_ids = [](const TensorView* tv) {
902+
auto validate_operand = [](const TensorView* tv) {
903+
TORCH_INTERNAL_ASSERT(
904+
tv->getMemoryType() == MemoryType::Local,
905+
"Only supporting register input for mma ops, up to sm80 all mma ops have to take register inputs.");
906+
903907
TORCH_INTERNAL_ASSERT(
904908
std::all_of(
905909
tv->domain()->domain().begin() + tv->getComputeAtPosition(),
906910
tv->domain()->domain().end(),
907911
[](IterDomain* id) {
908912
return id->isMmaSwizzled() ||
909-
(id->isBroadcast() &&
913+
// MMA instructions can only take inputs from registers,
914+
// so we always assume mma op inputs are located on
915+
// registers.
916+
// Currently requiring that serial ids on the right of the
917+
// CA axis are constant sized to ensure early detection of
918+
// invalid mma schedules.
919+
((id->isBroadcast() || id->extent()->isConstInt()) &&
910920
id->getParallelType() == ParallelType::Serial);
911921
}),
912922
"All id's on the right of CA pos needs to be mma-swizzled by WarpMmaSwizzler\n",
913923
tv);
914924
};
915925

916-
validate_operand_ids(mma->inA()->as<TensorView>());
917-
validate_operand_ids(mma->inB()->as<TensorView>());
926+
validate_operand(mma->inA()->as<TensorView>());
927+
validate_operand(mma->inB()->as<TensorView>());
928+
929+
// Additionally validate that mma is not directly taking a double buffered
930+
// register input as the double buffer indexing is currently not compatible
931+
// with fragment iteration. Would need to require a cache stage in this case.
932+
TORCH_INTERNAL_ASSERT(
933+
!mma->inA()->as<TensorView>()->isDoubleBuffered(),
934+
"MMA op cannot directly take double buffered register input, put a set stage before.");
935+
TORCH_INTERNAL_ASSERT(
936+
!mma->inB()->as<TensorView>()->isDoubleBuffered(),
937+
"MMA op cannot directly take double buffered register input, put a set stage before.");
918938
}
919939

920940
//! Note and TODO:

torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,37 @@ void scheduleMatmul(
4646
TensorView* c,
4747
TensorView* a,
4848
TensorView* b,
49-
MmaBuilder& mma_builder,
50-
MatMulTileOptions& gemm_tile) {
49+
MatmulParam& params) {
50+
// Unpack from params.
51+
auto& mma_builder = params.mma_builder;
52+
auto& gemm_tile = params.tile_sizes;
53+
54+
// Including current tensor naming convention for reference,
55+
// this is very temporary and will change over time and
56+
// in fact the whole body of this function will
57+
// eventually be a set of utility functions for different
58+
// sections of matmul(fusion) kernels, with
59+
// each having its own build out to do.
60+
//
61+
// Current naming convention:
62+
//
63+
// operands assumed in global memory : a, b
64+
//
65+
// registers staging global load : ar, br (short for a/b read)
66+
//
67+
// shared mem cache of operands : acw_smem, bcw_smem (short for a/b
68+
// cache_write smem)
69+
//
70+
// registers at shared memory load output : acr, bcr (short for a/b cache
71+
// read)
72+
//
73+
// register tensor input to the actual mma op: ab, bb (short for a/b
74+
// broadcasted)
75+
//
76+
// accumulator register: cc (short for c cache)
77+
//
78+
// result in global memory: c
79+
5180
// Currently only support a, b, c as fusion inputs/outputs
5281
// aka. no prolog and epilog fusion yet.
5382
TORCH_CHECK(
@@ -112,6 +141,17 @@ void scheduleMatmul(
112141

113142
acr = acw_smem->cacheAfter();
114143
bcr = bcw_smem->cacheAfter();
144+
if (params.double_buffer_options.double_buffer_smem_read) {
145+
// Provide another copy op between the double buffered
146+
// smem load register and the actual mma ops to avoid
147+
// complication in double buffered fragment iteration.
148+
ab = acr->cacheAfter();
149+
bb = bcr->cacheAfter();
150+
} else {
151+
ab = acr;
152+
bb = bcr;
153+
}
154+
115155
} else {
116156
acw_smem = ar->cacheAfter();
117157
bcw_smem = br->cacheAfter();
@@ -182,8 +222,8 @@ void scheduleMatmul(
182222
b->computeAt(cc, 3);
183223

184224
// Main Loop:
185-
acr->computeAt(cc, -4);
186-
bcr->computeAt(cc, -4);
225+
acr->computeAt(cc, -6);
226+
bcr->computeAt(cc, -6);
187227

188228
// Add mma swizzle:
189229
// TODO: this section goes to a separate matmul util,
@@ -192,30 +232,26 @@ void scheduleMatmul(
192232
if (isTuring(mma_options.macro) || isAmpere(mma_options.macro)) {
193233
moveInnerBroadcastLeft(ab);
194234
moveInnerBroadcastLeft(bb);
195-
ab->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
196-
bb->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
197-
198-
// Propagate mma input swizzle up the DAG
199-
// to all the tensors before mma op and after shared mem read.
200-
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
201-
ab,
202-
-1,
203-
{acw_smem},
204-
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
205-
.propagateParallelType());
206-
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
207-
bb,
208-
-1,
209-
{bcw_smem},
210-
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
211-
.propagateParallelType());
212-
} else {
213-
// TODO:
214-
// Need to build out this to support balanced prolog fusion on Volta.
215-
acr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
216-
bcr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
217235
}
218236

237+
ab->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
238+
bb->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
239+
240+
// Propagate mma input swizzle up the DAG
241+
// to all the tensors before mma op and after shared mem read.
242+
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
243+
ab,
244+
-1,
245+
{acw_smem},
246+
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
247+
.propagateParallelType());
248+
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
249+
bb,
250+
-1,
251+
{bcw_smem},
252+
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
253+
.propagateParallelType());
254+
219255
cc->applyMmaSwizzle(
220256
mma_builder.operand(MmaOptions::Operand::Accumulator).build());
221257

@@ -243,6 +279,16 @@ void scheduleMatmul(
243279
cc->axis(4)->parallelize(ParallelType::TIDy);
244280

245281
// Propagate mma output swizzle and parallelization down the DAG
282+
if (params.double_buffer_options.double_buffer_smem_write) {
283+
acw_smem->doubleBuffer();
284+
bcw_smem->doubleBuffer();
285+
}
286+
287+
if (params.double_buffer_options.double_buffer_smem_read) {
288+
acr->doubleBuffer();
289+
bcr->doubleBuffer();
290+
}
291+
246292
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
247293
cc,
248294
-1,

torch/csrc/jit/codegen/cuda/scheduler/matmul.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,27 @@ namespace jit {
1010
namespace fuser {
1111
namespace cuda {
1212

13+
//! Starting point for a matmul scheduler parameters:
14+
class MatmulParam {
15+
public:
16+
MatmulParam(MmaBuilder builder) : mma_builder(builder) {}
17+
18+
struct DoubleBufferOptions {
19+
bool double_buffer_smem_write = false;
20+
bool double_buffer_smem_read = false;
21+
};
22+
23+
//! Specifies the tiling hierarchy on block,
24+
//! warp, and instruction levels.
25+
MatMulTileOptions tile_sizes;
26+
27+
//! Parameters for configuring mma ops.
28+
MmaBuilder mma_builder;
29+
30+
//! Specify which tensor we double buffer.
31+
DoubleBufferOptions double_buffer_options;
32+
};
33+
1334
//! Prototype auto scheduling function.
1435
//! Currently only support a pure matmul with no
1536
//! fused prolog or epilog.
@@ -22,8 +43,7 @@ TORCH_CUDA_CU_API void scheduleMatmul(
2243
TensorView* c_tv,
2344
TensorView* a_tv,
2445
TensorView* b_tv,
25-
MmaBuilder& mma_builder,
26-
MatMulTileOptions& gemm_tile);
46+
MatmulParam& params);
2747

2848
} // namespace cuda
2949
} // namespace fuser

0 commit comments

Comments
 (0)