Skip to content

Commit 7093e39

Browse files
authored
Mma op integration on ampere (#1440)
1 parent fade8da commit 7093e39

37 files changed

+2503
-103
lines changed

tools/build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ libtorch_nvfuser_runtime_sources = [
3232
"torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
3333
"torch/csrc/jit/codegen/cuda/runtime/index_utils.cu",
3434
"torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu",
35+
"torch/csrc/jit/codegen/cuda/runtime/memory.cu",
3536
"torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
3637
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
3738
"torch/csrc/jit/codegen/cuda/runtime/tuple.cu",

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,33 @@ class CudaKernelGenerator : private OptOutConstDispatch {
477477
TORCH_INTERNAL_ASSERT(false, "Unreachable");
478478
}
479479

480+
//! Utility for generating vectorized pointer access in ldsm and
481+
//! cpasync.
482+
//! TODO: this access pattern as is could be merged with exisiting
483+
//! vectorization handling logic but this path will be updated in
484+
//! follow ups to optimize the generated assembly so keeping them
485+
//! separate path for now.
486+
std::string genVectorPointer(Val* val, DataType dtype, int vec_size) {
487+
std::stringstream ss;
488+
489+
ss << "reinterpret_cast<Array<" << dtype << "," << vec_size << ","
490+
<< vec_size << ">*>(&" << gen(val) << ")";
491+
492+
return ss.str();
493+
}
494+
495+
void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) {
496+
auto dtype = ldst->in()->getDataType().value();
497+
indent() << "Turing::ldMatrix";
498+
if (ldst->opType() == LoadStoreOpType::LdMatrixTranspose) {
499+
code_ << "T";
500+
}
501+
code_ << " (";
502+
code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size)
503+
<< ","
504+
<< "&" << gen(ldst->in()) << ");\n";
505+
}
506+
480507
void handle(const UnaryOp* uop) final {
481508
bool is_vector_op = false;
482509
size_t vector_word_size = 1;
@@ -918,7 +945,15 @@ class CudaKernelGenerator : private OptOutConstDispatch {
918945
if (init) {
919946
ss << "init";
920947
}
921-
ss << toString(options.macro) << toString(options.operand_layout);
948+
ss << toString(options.macro);
949+
950+
if (isVolta(options.macro)) {
951+
ss << toString(options.operand_layout);
952+
} else if (isTuring(options.macro) || isAmpere(options.macro)) {
953+
// mma's in turing and ampere TN only, transpose is handled either
954+
// via ldmatrix for fp16 or explicitly for other types.
955+
ss << "TN";
956+
}
922957
// TODO: additional parameter could be removed by swizzling iterdomain
923958
auto acc_stride = mma->accStride();
924959
TORCH_INTERNAL_ASSERT(acc_stride > 0);
@@ -1123,6 +1158,49 @@ class CudaKernelGenerator : private OptOutConstDispatch {
11231158
}
11241159
}
11251160

1161+
void handle(const LoadStoreOp* ldst) {
1162+
// TODO:
1163+
// Need to gradually merge the code path of this
1164+
// with UnaryOp::Set for vectorization.
1165+
// There is quite a bit of possible clean up.
1166+
bool vectorize_op = false;
1167+
size_t vector_word_size = 1;
1168+
auto ti = ldst->out()->as<kir::TensorIndex>();
1169+
1170+
// Check vectorization and set vector word size
1171+
for (auto id : ti->view()->domain()->domain()) {
1172+
if (!isParallelTypeVectorize(id->getParallelType())) {
1173+
continue;
1174+
}
1175+
1176+
ExpressionEvaluator expr_eval(id->fusion());
1177+
auto vector_size_optional = expr_eval.evaluate(id->extent());
1178+
1179+
TORCH_INTERNAL_ASSERT(
1180+
vector_size_optional.has_value(),
1181+
"Could not evaluate constant value bound to vectorized dim.");
1182+
1183+
TORCH_INTERNAL_ASSERT(
1184+
id->getParallelType() != ParallelType::MisalignedVectorize,
1185+
"LoadStoreOp: no support yet for mis-aligned vectorization");
1186+
vector_word_size = vector_size_optional.value();
1187+
vectorize_op = true;
1188+
break;
1189+
}
1190+
1191+
// Dispatch instruction generation:
1192+
switch (ldst->opType()) {
1193+
case LoadStoreOpType::LdMatrix:
1194+
case LoadStoreOpType::LdMatrixTranspose:
1195+
TORCH_INTERNAL_ASSERT(
1196+
vectorize_op, "LdMatrix: Vectorization required: ", ldst);
1197+
genLdMatrix(ldst, vector_word_size);
1198+
break;
1199+
default:
1200+
TORCH_INTERNAL_ASSERT(false, "LoadStoreOp: Unknown op type");
1201+
}
1202+
}
1203+
11261204
void handle(const WelfordOp* wop) final {
11271205
TORCH_INTERNAL_ASSERT(wop->out()->isA<kir::TensorIndex>());
11281206

@@ -2033,7 +2111,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
20332111
}
20342112
}
20352113

2036-
void handle(const kir::BlockSync*) final {
2114+
void handle(const kir::BlockSync* sync) final {
20372115
// Use a custom synchronization method if enabled
20382116
if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
20392117
indent() << "block_sync::sync();\n";

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ void Expr::dispatch(T handler, Expr* expr) {
110110
case ExprType::WelfordOp:
111111
ptr(handler)->handle(expr->as<WelfordOp>());
112112
return;
113+
case ExprType::LoadStoreOp:
114+
ptr(handler)->handle(expr->as<LoadStoreOp>());
115+
return;
113116
case ExprType::MmaOp:
114117
ptr(handler)->handle(expr->as<MmaOp>());
115118
return;
@@ -260,6 +263,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
260263
case ExprType::WelfordOp:
261264
ptr(handler)->handle(expr->as<WelfordOp>());
262265
return;
266+
case ExprType::LoadStoreOp:
267+
ptr(handler)->handle(expr->as<LoadStoreOp>());
268+
return;
263269
case ExprType::MmaOp:
264270
ptr(handler)->handle(expr->as<MmaOp>());
265271
return;
@@ -418,6 +424,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
418424
case ExprType::WelfordOp:
419425
ptr(mutator)->mutate(expr->as<WelfordOp>());
420426
return;
427+
case ExprType::LoadStoreOp:
428+
ptr(mutator)->mutate(expr->as<LoadStoreOp>());
429+
return;
421430
case ExprType::MmaOp:
422431
ptr(mutator)->mutate(expr->as<MmaOp>());
423432
return;
@@ -641,6 +650,9 @@ void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) {
641650
void OptOutConstDispatch::handle(const WelfordOp* stmt) {
642651
unhandled(stmt);
643652
}
653+
void OptOutConstDispatch::handle(const LoadStoreOp* stmt) {
654+
unhandled(stmt);
655+
}
644656
void OptOutConstDispatch::handle(const MmaOp* stmt) {
645657
unhandled(stmt);
646658
}
@@ -761,6 +773,9 @@ void OptOutDispatch::handle(GroupedReductionOp* stmt) {
761773
void OptOutDispatch::handle(WelfordOp* stmt) {
762774
unhandled(stmt);
763775
}
776+
void OptOutDispatch::handle(LoadStoreOp* stmt) {
777+
unhandled(stmt);
778+
}
764779
void OptOutDispatch::handle(MmaOp* stmt) {
765780
unhandled(stmt);
766781
}

torch/csrc/jit/codegen/cuda/dispatch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class TernaryOp;
7474
class ReductionOp;
7575
class GroupedReductionOp;
7676
class WelfordOp;
77+
class LoadStoreOp;
7778
class MmaOp;
7879
class BroadcastOp;
7980
class TransposeOp;
@@ -136,6 +137,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
136137
virtual void handle(const ReductionOp* stmt);
137138
virtual void handle(const GroupedReductionOp* stmt);
138139
virtual void handle(const WelfordOp* stmt);
140+
virtual void handle(const LoadStoreOp* stmt);
139141
virtual void handle(const MmaOp* stmt);
140142
virtual void handle(const BroadcastOp* stmt);
141143

@@ -191,6 +193,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
191193
virtual void handle(ReductionOp* stmt);
192194
virtual void handle(GroupedReductionOp* stmt);
193195
virtual void handle(WelfordOp* stmt);
196+
virtual void handle(LoadStoreOp* stmt);
194197
virtual void handle(MmaOp* stmt);
195198
virtual void handle(BroadcastOp* stmt);
196199

@@ -287,6 +290,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
287290
virtual void mutate(ReductionOp*);
288291
virtual void mutate(GroupedReductionOp*);
289292
virtual void mutate(WelfordOp*);
293+
virtual void mutate(LoadStoreOp*);
290294
virtual void mutate(MmaOp*);
291295
virtual void mutate(BroadcastOp*);
292296

torch/csrc/jit/codegen/cuda/executor_utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <nvfuser_resources/grid_sync.h>
2929
#include <nvfuser_resources/helpers.h>
3030
#include <nvfuser_resources/index_utils.h>
31+
#include <nvfuser_resources/memory.h>
3132
#include <nvfuser_resources/random_numbers.h>
3233
#include <nvfuser_resources/tensor.h>
3334
#include <nvfuser_resources/tensorcore.h>
@@ -98,6 +99,7 @@ std::string kernelPreamble() {
9899
ss << nvfuser_resources::welford_cu;
99100
ss << nvfuser_resources::warp_cu;
100101
ss << nvfuser_resources::tensorcore_cu;
102+
ss << nvfuser_resources::memory_cu;
101103
ss << nvfuser_resources::fused_reduction_cu;
102104

103105
// Random utilities

torch/csrc/jit/codegen/cuda/ir_builder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ IR_BUILDER_INSTANTIATE(TernaryOp)
6464
IR_BUILDER_INSTANTIATE(ReductionOp)
6565
IR_BUILDER_INSTANTIATE(GroupedReductionOp)
6666
IR_BUILDER_INSTANTIATE(WelfordOp)
67+
IR_BUILDER_INSTANTIATE(LoadStoreOp)
6768
IR_BUILDER_INSTANTIATE(MmaOp)
6869
IR_BUILDER_INSTANTIATE(BroadcastOp)
6970

torch/csrc/jit/codegen/cuda/ir_cloner.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ void IrCloner::handle(const WelfordOp* op) {
116116
clone_ = IrBuilder::clone(op, this);
117117
}
118118

119+
void IrCloner::handle(const LoadStoreOp* op) {
120+
clone_ = IrBuilder::clone(op, this);
121+
}
122+
119123
void IrCloner::handle(const MmaOp* op) {
120124
clone_ = IrBuilder::clone(op, this);
121125
}

torch/csrc/jit/codegen/cuda/ir_cloner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
7575
void handle(const ReductionOp*) override;
7676
void handle(const GroupedReductionOp*) override;
7777
void handle(const WelfordOp*) override;
78+
void handle(const LoadStoreOp*) override;
7879
void handle(const MmaOp*) override;
7980
void handle(const TransposeOp*) override;
8081
void handle(const ShiftOp*) override;

torch/csrc/jit/codegen/cuda/ir_interface_nodes.h

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -402,14 +402,22 @@ class TORCH_CUDA_CU_API TensorView : public Val {
402402
const std::vector<int>& axes,
403403
const std::vector<TensorView*>& tvs);
404404

405-
// Create a TensorView before the original tensor. A common use case is to
406-
// write results into shared memory or registers before moving to global
407-
// memory. Analogous to TVM Cache_Write
408-
TensorView* cacheBefore();
405+
//! Create a TensorView before the original tensor. A common use case is to
406+
//! write results into shared memory or registers before moving to global
407+
//! memory. Analogous to TVM Cache_Write
408+
//!
409+
//! @param cache_op: memory operator to use for the inserted op between
410+
//! the the data tensor and the cache tensor
411+
TensorView* cacheBefore(
412+
c10::optional<LoadStoreOpType> cache_op = c10::nullopt);
409413

410-
// Create a TensorView after the original tensor. A common use case is to
411-
// read tensor into shared memory or registers. Analogous to TVM Cache_Read
412-
TensorView* cacheAfter();
414+
//! Create a TensorView after the original tensor. A common use case is to
415+
//! read tensor into shared memory or registers. Analogous to TVM Cache_Read
416+
//!
417+
//! @param cache_op: memory operator to use for the inserted op between
418+
//! the the data tensor and the cache tensor
419+
TensorView* cacheAfter(
420+
c10::optional<LoadStoreOpType> cache_op = c10::nullopt);
413421

414422
// For a fusion output with other uses, we want to avoid writing to global
415423
// memory and then reading the output again. We write to global memory
@@ -438,17 +446,6 @@ class TORCH_CUDA_CU_API TensorView : public Val {
438446
return is_double_buffered_;
439447
}
440448

441-
//! Fill in mma options in scheduling time.
442-
//! Each mma op in Fusion IR must be configured once before lowering.
443-
//! Mma options are configuration parameters used in lowering to mma
444-
//! instrinsics, mainly the type of mma macro to use and input data layout
445-
//! etc.
446-
//!
447-
//! TODO: This step will very likely be removed in a follow up PR. All of
448-
//! the options configured here could actually be inferred from fusion IR
449-
//! once we are feature complete.
450-
void configureMma(MmaOptions options);
451-
452449
//! Transforms the innermost iterdomains according to the given mma swizzle,
453450
//! this should be used on the tvs that are either inputs/outputs of an
454451
//! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,36 @@ class TORCH_CUDA_CU_API ViewOp : public Expr {
611611
TensorView* const in_ = nullptr;
612612
};
613613

614+
//! This operator explicitly models data movement between
615+
//! state spaces on GPU. Currently the modeled state spaces include
616+
//! global memory, shared memory and register.
617+
//!
618+
//! The main usage of this op is to facilitate generation of hardware
619+
//! accelerated memory ops, i.e. ldmatrix, cp.async and more to come.
620+
class TORCH_CUDA_CU_API LoadStoreOp : public Expr {
621+
public:
622+
LoadStoreOp(IrBuilderPasskey, LoadStoreOpType op_type, Val* out, Val* in);
623+
624+
LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner);
625+
626+
Val* out() const {
627+
return out_;
628+
}
629+
630+
Val* in() const {
631+
return in_;
632+
}
633+
634+
LoadStoreOpType opType() const {
635+
return load_store_type_;
636+
}
637+
638+
private:
639+
LoadStoreOpType load_store_type_ = LoadStoreOpType::LdMatrix;
640+
Val* const out_ = nullptr;
641+
Val* const in_ = nullptr;
642+
};
643+
614644
// Friends for direct access to split
615645
class TensorDomain;
616646
class ReplayTransformations;

0 commit comments

Comments
 (0)