File tree 4 files changed +406
-136
lines changed
torch/csrc/jit/codegen/cuda 4 files changed +406
-136
lines changed Original file line number Diff line number Diff line change @@ -40,6 +40,9 @@ MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) {
40
40
41
41
// TODO: validate op config
42
42
MmaOptions MmaBuilder::build () const {
43
+ TORCH_CHECK (
44
+ option_.mma_op != nullptr ,
45
+ " Please configure accumulator tv before using swizzle options." )
43
46
return option_;
44
47
}
45
48
@@ -53,6 +56,15 @@ void MmaBuilder::configureMma(TensorView* mma_output) const {
53
56
mma->configureOptions (option_);
54
57
}
55
58
59
+ void MmaBuilder::accumulatorTv (TensorView* tv) {
60
+ TORCH_CHECK (
61
+ tv->getMemoryType () == MemoryType::Local, " Mma only outputs to register" );
62
+ TORCH_CHECK (tv->definition (), " Input cannot be accumulator tv" );
63
+ auto mma = dynamic_cast <MmaOp*>(tv->definition ());
64
+ TORCH_CHECK (mma, " Requires mma op output for reduction tv" );
65
+ option_.mma_op = mma;
66
+ }
67
+
56
68
namespace {
57
69
58
70
// Utility to get ldmatrix direction a mma layout and operand
Original file line number Diff line number Diff line change @@ -94,6 +94,9 @@ struct MmaOptions {
94
94
operand == other.operand &&
95
95
accumulator_stride == other.accumulator_stride ;
96
96
}
97
+
98
+ // To be inferred by mma builder interface.
99
+ MmaOp* mma_op = nullptr ;
97
100
};
98
101
99
102
// ! User interface for configuring the mma and mma related
@@ -127,6 +130,10 @@ class TORCH_CUDA_CU_API MmaBuilder {
127
130
// ! specified mma option.
128
131
LoadStoreOpType ldMatrix () const ;
129
132
133
+ // ! Store the accumulator tv register reference in mma builder
134
+ // ! to avoid automatic matching of which mma ops.
135
+ void accumulatorTv (TensorView* tv);
136
+
130
137
// ! Fill in mma options in scheduling time.
131
138
// ! Each mma op in Fusion IR must be configured once before lowering.
132
139
// ! Mma options are configuration parameters used in lowering to mma
You can’t perform that action at this time.
0 commit comments