Skip to content

Commit ecc7a87

Browse files
shmsongcsarofeen
andauthored
Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761)
Co-authored-by: Christian Sarofeen <[email protected]>
1 parent a054b3e commit ecc7a87

File tree

4 files changed

+406
-136
lines changed

4 files changed

+406
-136
lines changed

torch/csrc/jit/codegen/cuda/mma_type.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) {
4040

4141
// TODO: validate op config
4242
MmaOptions MmaBuilder::build() const {
43+
TORCH_CHECK(
44+
option_.mma_op != nullptr,
45+
"Please configure accumulator tv before using swizzle options.")
4346
return option_;
4447
}
4548

@@ -53,6 +56,15 @@ void MmaBuilder::configureMma(TensorView* mma_output) const {
5356
mma->configureOptions(option_);
5457
}
5558

59+
void MmaBuilder::accumulatorTv(TensorView* tv) {
60+
TORCH_CHECK(
61+
tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register");
62+
TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv");
63+
auto mma = dynamic_cast<MmaOp*>(tv->definition());
64+
TORCH_CHECK(mma, "Requires mma op output for reduction tv");
65+
option_.mma_op = mma;
66+
}
67+
5668
namespace {
5769

5870
// Utility to get ldmatrix direction a mma layout and operand

torch/csrc/jit/codegen/cuda/mma_type.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ struct MmaOptions {
9494
operand == other.operand &&
9595
accumulator_stride == other.accumulator_stride;
9696
}
97+
98+
// To be inferred by mma builder interface.
99+
MmaOp* mma_op = nullptr;
97100
};
98101

99102
//! User interface for configuring the mma and mma related
@@ -127,6 +130,10 @@ class TORCH_CUDA_CU_API MmaBuilder {
127130
//! specified mma option.
128131
LoadStoreOpType ldMatrix() const;
129132

133+
//! Store the accumulator tv register reference in mma builder
134+
//! to avoid automatic matching of which mma ops.
135+
void accumulatorTv(TensorView* tv);
136+
130137
//! Fill in mma options in scheduling time.
131138
//! Each mma op in Fusion IR must be configured once before lowering.
132139
//! Mma options are configuration parameters used in lowering to mma

0 commit comments

Comments
 (0)