Skip to content

Commit b97f555

Browse files
committed
perf: Add MOE support for dynamic cluster shapes and custom epilogue schedules
Signed-off-by: Daniel Stokes <[email protected]>
1 parent 2d2b8ba commit b97f555

File tree

7 files changed

+749
-647
lines changed

7 files changed

+749
-647
lines changed

3rdparty/cutlass

Submodule cutlass updated 168 files

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h

Lines changed: 93 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <iostream>
2020
#include <sstream>
2121
#include <string>
22+
#include <tuple>
2223

2324
#include "cute/tensor.hpp"
2425

@@ -80,7 +81,21 @@ enum class SplitKStyle
8081
// SPLIT_K_PARALLEL // Not supported yet
8182
};
8283

83-
enum class CutlassTileConfigSM90
84+
constexpr static int shape_tuple_to_enum(int m, int n, int k)
85+
{
86+
return m * 1000000 + n * 1000 + k;
87+
}
88+
89+
template <typename TEnum>
90+
constexpr static std::tuple<int, int, int> enum_to_shape_tuple(TEnum shape_id_enum)
91+
{
92+
static_assert(std::is_enum_v<TEnum> && std::is_same_v<std::underlying_type_t<TEnum>, int>,
93+
"TEnum must be an enum with underlying type int");
94+
auto shape_id = static_cast<int>(shape_id_enum);
95+
return std::make_tuple(shape_id / 1000000, (shape_id % 1000000) / 1000, shape_id % 1000);
96+
}
97+
98+
enum class CutlassTileConfigSM90 : int
8499
{
85100
// Signals that we should run heuristics do choose a config
86101
Undefined,
@@ -89,25 +104,25 @@ enum class CutlassTileConfigSM90
89104
ChooseWithHeuristic,
90105

91106
// CTA configs for M=64
92-
CtaShape64x16x128B,
93-
CtaShape64x32x128B,
94-
CtaShape64x64x128B,
95-
CtaShape64x128x128B,
96-
CtaShape64x256x128B,
107+
CtaShape64x16x128B = shape_tuple_to_enum(64, 16, 128),
108+
CtaShape64x32x128B = shape_tuple_to_enum(64, 32, 128),
109+
CtaShape64x64x128B = shape_tuple_to_enum(64, 64, 128),
110+
CtaShape64x128x128B = shape_tuple_to_enum(64, 128, 128),
111+
CtaShape64x256x128B = shape_tuple_to_enum(64, 256, 128),
97112

98113
// CTA configs for M=128
99-
CtaShape128x16x128B,
100-
CtaShape128x32x128B,
101-
CtaShape128x64x128B,
102-
CtaShape128x128x128B,
103-
CtaShape128x256x128B,
114+
CtaShape128x16x128B = shape_tuple_to_enum(128, 16, 128),
115+
CtaShape128x32x128B = shape_tuple_to_enum(128, 32, 128),
116+
CtaShape128x64x128B = shape_tuple_to_enum(128, 64, 128),
117+
CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128),
118+
CtaShape128x256x128B = shape_tuple_to_enum(128, 256, 128),
104119

105120
// CTA configs for M=256
106-
CtaShape256x128x128B,
107-
CtaShape256x256x128B,
121+
CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128),
122+
CtaShape256x256x128B = shape_tuple_to_enum(256, 256, 128),
108123
};
109124

110-
enum class CutlassTileConfigSM100
125+
enum class CutlassTileConfigSM100 : int
111126
{
112127
// Signals that we should run heuristics do choose a config
113128
Undefined,
@@ -119,41 +134,41 @@ enum class CutlassTileConfigSM100
119134
* Grouped GEMM
120135
*/
121136
// M=64
122-
CtaShape64x32x128B,
123-
CtaShape64x64x128B,
124-
CtaShape64x128x128B,
125-
CtaShape64x256x128B,
137+
CtaShape64x32x128B = shape_tuple_to_enum(64, 32, 128),
138+
CtaShape64x64x128B = shape_tuple_to_enum(64, 64, 128),
139+
CtaShape64x128x128B = shape_tuple_to_enum(64, 128, 128),
140+
CtaShape64x256x128B = shape_tuple_to_enum(64, 256, 128),
126141

127142
// M=128
128-
CtaShape128x8x256B,
129-
CtaShape128x16x128B,
130-
CtaShape128x32x128B,
131-
CtaShape128x64x128B,
132-
CtaShape128x128x128B,
133-
CtaShape128x256x128B,
134-
CtaShape128x128x256B,
135-
CtaShape128x256x256B,
143+
CtaShape128x8x256B = shape_tuple_to_enum(128, 8, 256),
144+
CtaShape128x16x128B = shape_tuple_to_enum(128, 16, 128),
145+
CtaShape128x32x128B = shape_tuple_to_enum(128, 32, 128),
146+
CtaShape128x64x128B = shape_tuple_to_enum(128, 64, 128),
147+
CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128),
148+
CtaShape128x256x128B = shape_tuple_to_enum(128, 256, 128),
149+
CtaShape128x128x256B = shape_tuple_to_enum(128, 128, 256),
150+
CtaShape128x256x256B = shape_tuple_to_enum(128, 256, 256),
136151

137152
// M=256
138-
CtaShape256x64x128B,
139-
CtaShape256x128x128B,
140-
CtaShape256x256x128B,
153+
CtaShape256x64x128B = shape_tuple_to_enum(256, 64, 128),
154+
CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128),
155+
CtaShape256x256x128B = shape_tuple_to_enum(256, 256, 128),
141156
};
142157

143-
enum class CutlassTileConfigSM120
158+
enum class CutlassTileConfigSM120 : int
144159
{
145160
// Signals that we should run heuristics do choose a config
146161
Undefined,
147162

148163
// Signals that we should run heuristics do choose a config
149164
ChooseWithHeuristic,
150165

151-
CtaShape128x128x128B,
152-
CtaShape128x128x64B,
153-
CtaShape256x128x64B,
154-
CtaShape128x256x64B,
155-
CtaShape128x128x256B,
156-
CtaShape256x128x128B,
166+
CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128),
167+
CtaShape128x128x64B = shape_tuple_to_enum(128, 128, 64),
168+
CtaShape256x128x64B = shape_tuple_to_enum(256, 128, 64),
169+
CtaShape128x256x64B = shape_tuple_to_enum(128, 256, 64),
170+
CtaShape128x128x256B = shape_tuple_to_enum(128, 128, 256),
171+
CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128),
157172
};
158173

159174
enum class MainloopScheduleType
@@ -191,23 +206,25 @@ enum class EpilogueScheduleType
191206
AUTO, // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
192207
// architectures older than hopper, the epilogue is always performed by the same thread block as the main
193208
// loop.
209+
NO_SMEM,
210+
TMA
194211
};
195212

196-
enum class TileShape
213+
enum class TileShape : int
197214
{
198-
TileShape_64x16x128,
199-
TileShape_64x32x128,
200-
TileShape_64x64x128,
201-
TileShape_64x128x128,
202-
TileShape_64x256x128,
203-
TileShape_64x512x128,
204-
TileShape_128x16x128,
205-
TileShape_128x32x128,
206-
TileShape_128x64x128,
207-
TileShape_128x128x128,
208-
TileShape_128x256x128,
209-
TileShape_256x128x128,
210-
TileShape_256x256x128
215+
TileShape_64x16x128 = shape_tuple_to_enum(64, 16, 128),
216+
TileShape_64x32x128 = shape_tuple_to_enum(64, 32, 128),
217+
TileShape_64x64x128 = shape_tuple_to_enum(64, 64, 128),
218+
TileShape_64x128x128 = shape_tuple_to_enum(64, 128, 128),
219+
TileShape_64x256x128 = shape_tuple_to_enum(64, 256, 128),
220+
TileShape_64x512x128 = shape_tuple_to_enum(64, 512, 128),
221+
TileShape_128x16x128 = shape_tuple_to_enum(128, 16, 128),
222+
TileShape_128x32x128 = shape_tuple_to_enum(128, 32, 128),
223+
TileShape_128x64x128 = shape_tuple_to_enum(128, 64, 128),
224+
TileShape_128x128x128 = shape_tuple_to_enum(128, 128, 128),
225+
TileShape_128x256x128 = shape_tuple_to_enum(128, 256, 128),
226+
TileShape_256x128x128 = shape_tuple_to_enum(256, 128, 128),
227+
TileShape_256x256x128 = shape_tuple_to_enum(256, 256, 128)
211228
};
212229

213230
template <TileShape Shape_MNK>
@@ -325,19 +342,20 @@ static auto get_tile_shape_name(TileShape Shape_MNK)
325342
return "Unknown shape";
326343
}
327344

328-
enum class ClusterShape
345+
enum class ClusterShape : int
329346
{
330-
ClusterShape_1x1x1,
331-
ClusterShape_2x1x1,
332-
ClusterShape_1x2x1,
333-
ClusterShape_2x2x1,
334-
ClusterShape_1x4x1,
335-
ClusterShape_4x1x1,
336-
ClusterShape_4x2x1,
337-
ClusterShape_2x4x1,
338-
ClusterShape_4x4x1,
339-
ClusterShape_1x8x1,
340-
ClusterShape_8x1x1
347+
Undefined,
348+
ClusterShape_1x1x1 = shape_tuple_to_enum(1, 1, 1),
349+
ClusterShape_2x1x1 = shape_tuple_to_enum(2, 1, 1),
350+
ClusterShape_1x2x1 = shape_tuple_to_enum(1, 2, 1),
351+
ClusterShape_2x2x1 = shape_tuple_to_enum(2, 2, 1),
352+
ClusterShape_1x4x1 = shape_tuple_to_enum(1, 4, 1),
353+
ClusterShape_4x1x1 = shape_tuple_to_enum(4, 1, 1),
354+
ClusterShape_4x2x1 = shape_tuple_to_enum(4, 2, 1),
355+
ClusterShape_2x4x1 = shape_tuple_to_enum(2, 4, 1),
356+
ClusterShape_4x4x1 = shape_tuple_to_enum(4, 4, 1),
357+
ClusterShape_1x8x1 = shape_tuple_to_enum(1, 8, 1),
358+
ClusterShape_8x1x1 = shape_tuple_to_enum(8, 1, 1)
341359
};
342360

343361
static auto get_cluster_shape_name(ClusterShape Shape_MNK)
@@ -434,6 +452,8 @@ struct CutlassGemmConfig
434452
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
435453
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
436454
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
455+
ClusterShape dynamic_cluster_shape = ClusterShape::Undefined;
456+
ClusterShape fallback_cluster_shape = ClusterShape::Undefined;
437457
bool enableCudaKernel = false;
438458
int sm_version = 80; // Use 80 as a catch all for <90
439459
bool is_tma_warp_specialized = false;
@@ -460,12 +480,18 @@ struct CutlassGemmConfig
460480
{
461481
}
462482

483+
// If dynamic_cluster_shape is provided, dynamic CGA will be enabled and cluster_shape will be interpreted as
484+
// whether to use 1 or 2 SM mode, otherwise static cluster shape is used.
463485
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule,
464-
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
486+
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape,
487+
ClusterShape dynamic_cluster_shape = ClusterShape::Undefined,
488+
ClusterShape fallback_cluster_shape = ClusterShape::Undefined)
465489
: tile_config_sm100(tile_config_sm100)
466490
, mainloop_schedule(mainloop_schedule)
467491
, epilogue_schedule(epilogue_schedule)
468492
, cluster_shape(cluster_shape)
493+
, dynamic_cluster_shape(dynamic_cluster_shape)
494+
, fallback_cluster_shape(fallback_cluster_shape)
469495
, sm_version(100)
470496
, is_tma_warp_specialized(true)
471497
{
@@ -506,6 +532,8 @@ struct CutlassGemmConfig
506532
tactic << "\n\tstyle=TMA Warp Specialized"
507533
<< "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt()
508534
<< "\n\tcluster shape ID: " << (int) cluster_shape
535+
<< "\n\tdynamic cluster shape ID: " << (int) dynamic_cluster_shape
536+
<< "\n\tfallback cluster shape ID: " << (int) fallback_cluster_shape
509537
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule
510538
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
511539
}
@@ -539,6 +567,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf
539567
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
540568
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
541569
<< ", cluster_shape_enum: " << int(config.cluster_shape)
570+
<< ", dynamic_cluster_shape_enum: " << int(config.dynamic_cluster_shape)
571+
<< ", fallback_cluster_shape_enum: " << int(config.fallback_cluster_shape)
542572
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
543573
}
544574
else

0 commit comments

Comments
 (0)