Skip to content

Commit 1cd9451

Browse files
shmsongzasdfgbnm
andauthored
Simplify matmul scheduling with the new transform propagator. (#1817)
Co-authored-by: Gao, Xiang <[email protected]>
1 parent bbc1fb9 commit 1cd9451

File tree

13 files changed

+1136
-1135
lines changed

13 files changed

+1136
-1135
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ libtorch_cuda_core_sources = [
718718
"torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
719719
"torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
720720
"torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
721+
"torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp",
721722
"torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp",
722723
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
723724
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -971,13 +971,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
971971
}
972972
}
973973

974-
std::string genArchString(MmaOptions options) {
974+
std::string genArchString(MmaOptions::MacroType macro) {
975975
std::stringstream ss;
976-
if (isVolta(options.macro)) {
976+
if (isVolta(macro)) {
977977
ss << "Volta";
978-
} else if (isTuring(options.macro)) {
978+
} else if (isTuring(macro)) {
979979
ss << "Turing";
980-
} else if (isAmpere(options.macro)) {
980+
} else if (isAmpere(macro)) {
981981
ss << "Ampere";
982982
} else {
983983
TORCH_INTERNAL_ASSERT(false, "mma macro unknown arch");
@@ -988,7 +988,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
988988
std::string genMmaOp(const MmaOp* mma, bool init = false) {
989989
std::stringstream ss;
990990
auto options = mma->options();
991-
ss << genArchString(options) << "::";
991+
ss << genArchString(options.macro) << "::";
992992
if (init) {
993993
ss << "init";
994994
}

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,22 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr {
338338
//! Fused Matmul operation
339339
class TORCH_CUDA_CU_API MmaOp : public Expr {
340340
public:
341+
// This is a temporary data structure to for the
342+
// scheduling specific parameters that we still need
343+
// to store on an mma node. Eventually will only be
344+
// the mma macro type that will stay on the IR node
345+
// after additional cleaning ups.
346+
struct OptionsInMma {
347+
MmaOptions::MacroType macro = MmaOptions::MacroType::NoMMA;
348+
MmaOptions::MmaInputLayout operand_layout = MmaOptions::MmaInputLayout::TT;
349+
int accumulator_stride = 0;
350+
351+
bool operator==(const OptionsInMma& other) const {
352+
return macro == other.macro && operand_layout == other.operand_layout &&
353+
accumulator_stride == other.accumulator_stride;
354+
}
355+
};
356+
341357
MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);
342358

343359
MmaOp(
@@ -346,7 +362,7 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
346362
Val* in_a,
347363
Val* in_b,
348364
Val* init,
349-
MmaOptions options);
365+
OptionsInMma options);
350366

351367
MmaOp(const MmaOp* src, IrCloner* ir_cloner);
352368

@@ -379,15 +395,23 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
379395
}
380396

381397
void configureOptions(MmaOptions options) {
382-
options_ = options;
398+
options_ = OptionsInMma();
399+
TORCH_INTERNAL_ASSERT(
400+
options.macro != MmaOptions::MacroType::NoMMA,
401+
"Un-configured mma type from options.");
402+
TORCH_INTERNAL_ASSERT(
403+
options.accumulator_stride > 0, "Un-configured accumulator stride.");
404+
options_->accumulator_stride = options.accumulator_stride;
405+
options_->macro = options.macro;
406+
options_->operand_layout = options.operand_layout;
383407
}
384408

385409
private:
386410
Val* const out_ = nullptr;
387411
Val* const in_a_ = nullptr;
388412
Val* const in_b_ = nullptr;
389413
Val* const init_ = nullptr;
390-
c10::optional<MmaOptions> options_ = c10::nullopt;
414+
c10::optional<OptionsInMma> options_ = c10::nullopt;
391415
};
392416

393417
class TORCH_CUDA_CU_API TransposeOp : public Expr {

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ MmaOp::MmaOp(
630630
Val* in_a,
631631
Val* in_b,
632632
Val* init,
633-
MmaOptions options)
633+
OptionsInMma options)
634634
: MmaOp(passkey, out, in_a, in_b, init) {
635635
options_ = options;
636636
}

torch/csrc/jit/codegen/cuda/mma_type.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ namespace jit {
77
namespace fuser {
88
namespace cuda {
99

10+
MmaOp* MmaOptions::mmaOp() const {
11+
TORCH_INTERNAL_ASSERT(
12+
accumulator_tv != nullptr && accumulator_tv->definition() != nullptr,
13+
"Invalid accumulator_tv.");
14+
auto mma_op = dynamic_cast<MmaOp*>(accumulator_tv->definition());
15+
TORCH_INTERNAL_ASSERT(
16+
mma_op != nullptr, "accumulator tv not an output of mma op");
17+
return mma_op;
18+
}
19+
1020
MmaBuilder::MmaBuilder(
1121
MmaOptions::MacroType macro,
1222
MatMulTileOptions gemm_tile) {
@@ -41,7 +51,7 @@ MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) {
4151
// TODO: validate op config
4252
MmaOptions MmaBuilder::build() const {
4353
TORCH_CHECK(
44-
option_.mma_op != nullptr,
54+
option_.accumulator_tv != nullptr,
4555
"Please configure accumulator tv before using swizzle options.")
4656
return option_;
4757
}
@@ -60,9 +70,10 @@ void MmaBuilder::accumulatorTv(TensorView* tv) {
6070
TORCH_CHECK(
6171
tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register");
6272
TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv");
63-
auto mma = dynamic_cast<MmaOp*>(tv->definition());
64-
TORCH_CHECK(mma, "Requires mma op output for reduction tv");
65-
option_.mma_op = mma;
73+
TORCH_CHECK(
74+
tv->definition()->isA<MmaOp>(),
75+
"Requires mma op output for reduction tv");
76+
option_.accumulator_tv = tv;
6677
}
6778

6879
namespace {

torch/csrc/jit/codegen/cuda/mma_type.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ struct GemmTile {
1919
GemmTile operator/(const GemmTile& other) {
2020
return GemmTile(m / other.m, n / other.n, k / other.k);
2121
}
22+
23+
std::vector<int> toVector() {
24+
return {m, n, k};
25+
}
2226
};
2327

2428
//! Utility data structure for recording gemm tiles
@@ -95,8 +99,18 @@ struct MmaOptions {
9599
accumulator_stride == other.accumulator_stride;
96100
}
97101

98-
// To be inferred by mma builder interface.
99-
MmaOp* mma_op = nullptr;
102+
// The accumulator tensorview register supplied by the
103+
// scheduler interface. Each mma builder is responsible
104+
// for the parameters of one mma op, so the options struct
105+
// would need a pointer to keep track of which mma op it
106+
// is describing.
107+
// Tracking mma expressions would not be stable as the expression
108+
// can get deleted by mutate passes.
109+
TensorView* accumulator_tv = nullptr;
110+
111+
//! Returns the mma op that this options parameter list
112+
//! is describing. See comment on accumulator_tv.
113+
MmaOp* mmaOp() const;
100114
};
101115

102116
//! User interface for configuring the mma and mma related

0 commit comments

Comments
 (0)