Skip to content

Commit b86aede

Browse files
committed
Cleanup gemm_configs.h changes
Signed-off-by: djns99 <[email protected]>
1 parent a793106 commit b86aede

File tree

1 file changed

+8
-0
lines changed
  • cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions

1 file changed

+8
-0
lines changed

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ enum class SplitKStyle
8484

8585
constexpr static int shape_tuple_to_enum(int m, int n, int k)
8686
{
87+
assert(m >= 0 && n >= 0 && k >= 0);
88+
assert(m < 1000 && n < 1000 && k < 1000);
8789
return m * 1000000 + n * 1000 + k;
8890
}
8991

@@ -93,6 +95,8 @@ constexpr static std::tuple<int, int, int> enum_to_shape_tuple(TEnum shape_id_en
9395
static_assert(std::is_enum_v<TEnum> && std::is_same_v<std::underlying_type_t<TEnum>, int>,
9496
"TEnum must be an enum with underlying type int");
9597
auto shape_id = static_cast<int>(shape_id_enum);
98+
assert(shape_id >= 0);
99+
assert(shape_id < (int) 1e9);
96100
return std::make_tuple(shape_id / 1000000, (shape_id % 1000000) / 1000, shape_id % 1000);
97101
}
98102

@@ -300,6 +304,10 @@ static std::string get_tile_shape_name(TEnum Shape_MNK)
300304
{
301305
return "undefined";
302306
}
307+
else if ((int) Shape_MNK == 1)
308+
{
309+
return "heuristic";
310+
}
303311
else
304312
{
305313
auto [m, n, k] = enum_to_shape_tuple(Shape_MNK);

0 commit comments

Comments
 (0)