19
19
#include < iostream>
20
20
#include < sstream>
21
21
#include < string>
22
+ #include < tuple>
22
23
23
24
#include " cute/tensor.hpp"
24
25
@@ -80,7 +81,21 @@ enum class SplitKStyle
80
81
// SPLIT_K_PARALLEL // Not supported yet
81
82
};
82
83
83
- enum class CutlassTileConfigSM90
84
+ constexpr static int shape_tuple_to_enum (int m, int n, int k)
85
+ {
86
+ return m * 1000000 + n * 1000 + k;
87
+ }
88
+
89
+ template <typename TEnum>
90
+ constexpr static std::tuple<int , int , int > enum_to_shape_tuple (TEnum shape_id_enum)
91
+ {
92
+ static_assert (std::is_enum_v<TEnum> && std::is_same_v<std::underlying_type_t <TEnum>, int >,
93
+ " TEnum must be an enum with underlying type int" );
94
+ auto shape_id = static_cast <int >(shape_id_enum);
95
+ return std::make_tuple (shape_id / 1000000 , (shape_id % 1000000 ) / 1000 , shape_id % 1000 );
96
+ }
97
+
98
+ enum class CutlassTileConfigSM90 : int
84
99
{
85
100
// Signals that we should run heuristics do choose a config
86
101
Undefined,
@@ -89,25 +104,25 @@ enum class CutlassTileConfigSM90
89
104
ChooseWithHeuristic,
90
105
91
106
// CTA configs for M=64
92
- CtaShape64x16x128B,
93
- CtaShape64x32x128B,
94
- CtaShape64x64x128B,
95
- CtaShape64x128x128B,
96
- CtaShape64x256x128B,
107
+ CtaShape64x16x128B = shape_tuple_to_enum( 64 , 16 , 128 ) ,
108
+ CtaShape64x32x128B = shape_tuple_to_enum( 64 , 32 , 128 ) ,
109
+ CtaShape64x64x128B = shape_tuple_to_enum( 64 , 64 , 128 ) ,
110
+ CtaShape64x128x128B = shape_tuple_to_enum( 64 , 128 , 128 ) ,
111
+ CtaShape64x256x128B = shape_tuple_to_enum( 64 , 256 , 128 ) ,
97
112
98
113
// CTA configs for M=128
99
- CtaShape128x16x128B,
100
- CtaShape128x32x128B,
101
- CtaShape128x64x128B,
102
- CtaShape128x128x128B,
103
- CtaShape128x256x128B,
114
+ CtaShape128x16x128B = shape_tuple_to_enum( 128 , 16 , 128 ) ,
115
+ CtaShape128x32x128B = shape_tuple_to_enum( 128 , 32 , 128 ) ,
116
+ CtaShape128x64x128B = shape_tuple_to_enum( 128 , 64 , 128 ) ,
117
+ CtaShape128x128x128B = shape_tuple_to_enum( 128 , 128 , 128 ) ,
118
+ CtaShape128x256x128B = shape_tuple_to_enum( 128 , 256 , 128 ) ,
104
119
105
120
// CTA configs for M=256
106
- CtaShape256x128x128B,
107
- CtaShape256x256x128B,
121
+ CtaShape256x128x128B = shape_tuple_to_enum( 256 , 128 , 128 ) ,
122
+ CtaShape256x256x128B = shape_tuple_to_enum( 256 , 256 , 128 ) ,
108
123
};
109
124
110
- enum class CutlassTileConfigSM100
125
+ enum class CutlassTileConfigSM100 : int
111
126
{
112
127
// Signals that we should run heuristics do choose a config
113
128
Undefined,
@@ -119,41 +134,41 @@ enum class CutlassTileConfigSM100
119
134
* Grouped GEMM
120
135
*/
121
136
// M=64
122
- CtaShape64x32x128B,
123
- CtaShape64x64x128B,
124
- CtaShape64x128x128B,
125
- CtaShape64x256x128B,
137
+ CtaShape64x32x128B = shape_tuple_to_enum( 64 , 32 , 128 ) ,
138
+ CtaShape64x64x128B = shape_tuple_to_enum( 64 , 64 , 128 ) ,
139
+ CtaShape64x128x128B = shape_tuple_to_enum( 64 , 128 , 128 ) ,
140
+ CtaShape64x256x128B = shape_tuple_to_enum( 64 , 256 , 128 ) ,
126
141
127
142
// M=128
128
- CtaShape128x8x256B,
129
- CtaShape128x16x128B,
130
- CtaShape128x32x128B,
131
- CtaShape128x64x128B,
132
- CtaShape128x128x128B,
133
- CtaShape128x256x128B,
134
- CtaShape128x128x256B,
135
- CtaShape128x256x256B,
143
+ CtaShape128x8x256B = shape_tuple_to_enum( 128 , 8 , 256 ) ,
144
+ CtaShape128x16x128B = shape_tuple_to_enum( 128 , 16 , 128 ) ,
145
+ CtaShape128x32x128B = shape_tuple_to_enum( 128 , 32 , 128 ) ,
146
+ CtaShape128x64x128B = shape_tuple_to_enum( 128 , 64 , 128 ) ,
147
+ CtaShape128x128x128B = shape_tuple_to_enum( 128 , 128 , 128 ) ,
148
+ CtaShape128x256x128B = shape_tuple_to_enum( 128 , 256 , 128 ) ,
149
+ CtaShape128x128x256B = shape_tuple_to_enum( 128 , 128 , 256 ) ,
150
+ CtaShape128x256x256B = shape_tuple_to_enum( 128 , 256 , 256 ) ,
136
151
137
152
// M=256
138
- CtaShape256x64x128B,
139
- CtaShape256x128x128B,
140
- CtaShape256x256x128B,
153
+ CtaShape256x64x128B = shape_tuple_to_enum( 256 , 64 , 128 ) ,
154
+ CtaShape256x128x128B = shape_tuple_to_enum( 256 , 128 , 128 ) ,
155
+ CtaShape256x256x128B = shape_tuple_to_enum( 256 , 256 , 128 ) ,
141
156
};
142
157
143
- enum class CutlassTileConfigSM120
158
+ enum class CutlassTileConfigSM120 : int
144
159
{
145
160
// Signals that we should run heuristics do choose a config
146
161
Undefined,
147
162
148
163
// Signals that we should run heuristics do choose a config
149
164
ChooseWithHeuristic,
150
165
151
- CtaShape128x128x128B,
152
- CtaShape128x128x64B,
153
- CtaShape256x128x64B,
154
- CtaShape128x256x64B,
155
- CtaShape128x128x256B,
156
- CtaShape256x128x128B,
166
+ CtaShape128x128x128B = shape_tuple_to_enum( 128 , 128 , 128 ) ,
167
+ CtaShape128x128x64B = shape_tuple_to_enum( 128 , 128 , 64 ) ,
168
+ CtaShape256x128x64B = shape_tuple_to_enum( 256 , 128 , 64 ) ,
169
+ CtaShape128x256x64B = shape_tuple_to_enum( 128 , 256 , 64 ) ,
170
+ CtaShape128x128x256B = shape_tuple_to_enum( 128 , 128 , 256 ) ,
171
+ CtaShape256x128x128B = shape_tuple_to_enum( 256 , 128 , 128 ) ,
157
172
};
158
173
159
174
enum class MainloopScheduleType
@@ -191,23 +206,25 @@ enum class EpilogueScheduleType
191
206
AUTO, // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
192
207
// architectures older than hopper, the epilogue is always performed by the same thread block as the main
193
208
// loop.
209
+ NO_SMEM,
210
+ TMA
194
211
};
195
212
196
- enum class TileShape
213
+ enum class TileShape : int
197
214
{
198
- TileShape_64x16x128,
199
- TileShape_64x32x128,
200
- TileShape_64x64x128,
201
- TileShape_64x128x128,
202
- TileShape_64x256x128,
203
- TileShape_64x512x128,
204
- TileShape_128x16x128,
205
- TileShape_128x32x128,
206
- TileShape_128x64x128,
207
- TileShape_128x128x128,
208
- TileShape_128x256x128,
209
- TileShape_256x128x128,
210
- TileShape_256x256x128
215
+ TileShape_64x16x128 = shape_tuple_to_enum( 64 , 16 , 128 ) ,
216
+ TileShape_64x32x128 = shape_tuple_to_enum( 64 , 32 , 128 ) ,
217
+ TileShape_64x64x128 = shape_tuple_to_enum( 64 , 64 , 128 ) ,
218
+ TileShape_64x128x128 = shape_tuple_to_enum( 64 , 128 , 128 ) ,
219
+ TileShape_64x256x128 = shape_tuple_to_enum( 64 , 256 , 128 ) ,
220
+ TileShape_64x512x128 = shape_tuple_to_enum( 64 , 512 , 128 ) ,
221
+ TileShape_128x16x128 = shape_tuple_to_enum( 128 , 16 , 128 ) ,
222
+ TileShape_128x32x128 = shape_tuple_to_enum( 128 , 32 , 128 ) ,
223
+ TileShape_128x64x128 = shape_tuple_to_enum( 128 , 64 , 128 ) ,
224
+ TileShape_128x128x128 = shape_tuple_to_enum( 128 , 128 , 128 ) ,
225
+ TileShape_128x256x128 = shape_tuple_to_enum( 128 , 256 , 128 ) ,
226
+ TileShape_256x128x128 = shape_tuple_to_enum( 256 , 128 , 128 ) ,
227
+ TileShape_256x256x128 = shape_tuple_to_enum( 256 , 256 , 128 )
211
228
};
212
229
213
230
template <TileShape Shape_MNK>
@@ -325,19 +342,20 @@ static auto get_tile_shape_name(TileShape Shape_MNK)
325
342
return " Unknown shape" ;
326
343
}
327
344
328
- enum class ClusterShape
345
+ enum class ClusterShape : int
329
346
{
330
- ClusterShape_1x1x1,
331
- ClusterShape_2x1x1,
332
- ClusterShape_1x2x1,
333
- ClusterShape_2x2x1,
334
- ClusterShape_1x4x1,
335
- ClusterShape_4x1x1,
336
- ClusterShape_4x2x1,
337
- ClusterShape_2x4x1,
338
- ClusterShape_4x4x1,
339
- ClusterShape_1x8x1,
340
- ClusterShape_8x1x1
347
+ Undefined,
348
+ ClusterShape_1x1x1 = shape_tuple_to_enum(1 , 1 , 1 ),
349
+ ClusterShape_2x1x1 = shape_tuple_to_enum(2 , 1 , 1 ),
350
+ ClusterShape_1x2x1 = shape_tuple_to_enum(1 , 2 , 1 ),
351
+ ClusterShape_2x2x1 = shape_tuple_to_enum(2 , 2 , 1 ),
352
+ ClusterShape_1x4x1 = shape_tuple_to_enum(1 , 4 , 1 ),
353
+ ClusterShape_4x1x1 = shape_tuple_to_enum(4 , 1 , 1 ),
354
+ ClusterShape_4x2x1 = shape_tuple_to_enum(4 , 2 , 1 ),
355
+ ClusterShape_2x4x1 = shape_tuple_to_enum(2 , 4 , 1 ),
356
+ ClusterShape_4x4x1 = shape_tuple_to_enum(4 , 4 , 1 ),
357
+ ClusterShape_1x8x1 = shape_tuple_to_enum(1 , 8 , 1 ),
358
+ ClusterShape_8x1x1 = shape_tuple_to_enum(8 , 1 , 1 )
341
359
};
342
360
343
361
static auto get_cluster_shape_name (ClusterShape Shape_MNK)
@@ -434,6 +452,8 @@ struct CutlassGemmConfig
434
452
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
435
453
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
436
454
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
455
+ ClusterShape dynamic_cluster_shape = ClusterShape::Undefined;
456
+ ClusterShape fallback_cluster_shape = ClusterShape::Undefined;
437
457
bool enableCudaKernel = false ;
438
458
int sm_version = 80 ; // Use 80 as a catch all for <90
439
459
bool is_tma_warp_specialized = false ;
@@ -460,12 +480,18 @@ struct CutlassGemmConfig
460
480
{
461
481
}
462
482
483
+ // If dynamic_cluster_shape is provided, dynamic CGA will be enabled and cluster_shape will be interpreted as
484
+ // whether to use 1 or 2 SM mode, otherwise static cluster shape is used.
463
485
CutlassGemmConfig (CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule,
464
- EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
486
+ EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape,
487
+ ClusterShape dynamic_cluster_shape = ClusterShape::Undefined,
488
+ ClusterShape fallback_cluster_shape = ClusterShape::Undefined)
465
489
: tile_config_sm100(tile_config_sm100)
466
490
, mainloop_schedule(mainloop_schedule)
467
491
, epilogue_schedule(epilogue_schedule)
468
492
, cluster_shape(cluster_shape)
493
+ , dynamic_cluster_shape(dynamic_cluster_shape)
494
+ , fallback_cluster_shape(fallback_cluster_shape)
469
495
, sm_version(100 )
470
496
, is_tma_warp_specialized(true )
471
497
{
@@ -506,6 +532,8 @@ struct CutlassGemmConfig
506
532
tactic << " \n\t style=TMA Warp Specialized"
507
533
<< " \n\t sm: " << sm_version << " \n\t tile shape ID: " << getTileConfigAsInt ()
508
534
<< " \n\t cluster shape ID: " << (int ) cluster_shape
535
+ << " \n\t dynamic cluster shape ID: " << (int ) dynamic_cluster_shape
536
+ << " \n\t fallback cluster shape ID: " << (int ) fallback_cluster_shape
509
537
<< " \n\t mainloop sched: " << (int ) mainloop_schedule << " \n\t epi sched: " << (int ) epilogue_schedule
510
538
<< " \n\t enable cuda kernel: " << (enableCudaKernel ? " true" : " false" );
511
539
}
@@ -539,6 +567,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf
539
567
<< " , mainloop_schedule_enum: " << int (config.mainloop_schedule )
540
568
<< " , epilogue_schedule_enum: " << int (config.epilogue_schedule )
541
569
<< " , cluster_shape_enum: " << int (config.cluster_shape )
570
+ << " , dynamic_cluster_shape_enum: " << int (config.dynamic_cluster_shape )
571
+ << " , fallback_cluster_shape_enum: " << int (config.fallback_cluster_shape )
542
572
<< " , enable_cuda_kernel: " << (config.enableCudaKernel ? " true" : " false" );
543
573
}
544
574
else
0 commit comments