Skip to content

Commit eb1fdbe

Browse files
authored
Merge branch 'main' into user/dongfengy/fix_ref
2 parents aa26b7d + 7c73c2f commit eb1fdbe

39 files changed

+1085
-838
lines changed

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,10 +1074,10 @@ void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state
10741074
state.SkipWithMessage("Out of range tactic");
10751075
return;
10761076
}
1077+
auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1);
1078+
auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2);
10771079
if (LOG_LEVEL >= INFO)
10781080
{
1079-
auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1);
1080-
auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2);
10811081
std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics1.size() << "\n"
10821082
<< tactics1[tactic_idx1].toString() << std::endl;
10831083
std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics2.size() << "\n"
@@ -1086,6 +1086,20 @@ void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state
10861086
state.counters["tactic_idx1"] = tactic_idx1;
10871087
state.counters["tactic_idx2"] = tactic_idx2;
10881088

1089+
state.counters["t1_sm_version"] = tactics1[tactic_idx1].sm_version;
1090+
state.counters["t1_tile_shape"] = tactics1[tactic_idx1].getTileConfigAsInt();
1091+
state.counters["t1_cluster_shape"] = (int) tactics1[tactic_idx1].cluster_shape;
1092+
state.counters["t1_dynamic_cluster_shape"] = (int) tactics1[tactic_idx1].dynamic_cluster_shape;
1093+
state.counters["t1_fallback_cluster_shape"] = (int) tactics1[tactic_idx1].fallback_cluster_shape;
1094+
state.counters["t1_epilogue_schedule"] = (int) tactics1[tactic_idx1].epilogue_schedule;
1095+
1096+
state.counters["t2_sm_version"] = tactics2[tactic_idx2].sm_version;
1097+
state.counters["t2_tile_shape"] = tactics2[tactic_idx2].getTileConfigAsInt();
1098+
state.counters["t2_cluster_shape"] = (int) tactics2[tactic_idx2].cluster_shape;
1099+
state.counters["t2_dynamic_cluster_shape"] = (int) tactics2[tactic_idx2].dynamic_cluster_shape;
1100+
state.counters["t2_fallback_cluster_shape"] = (int) tactics2[tactic_idx2].fallback_cluster_shape;
1101+
state.counters["t2_epilogue_schedule"] = (int) tactics2[tactic_idx2].epilogue_schedule;
1102+
10891103
createGraph(parallelism_config, gemm_to_profile);
10901104

10911105
{

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
160160
*/
161161

162162
std::ifstream file{workloadFile};
163+
if (!file.is_open())
164+
{
165+
throw std::invalid_argument("Failed to open benchmark file: " + std::string(workloadFile));
166+
}
163167
std::stringstream buffer;
164168
buffer << file.rdbuf();
165169
auto file_contents = buffer.str();
@@ -294,7 +298,7 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
294298
int gemm_to_profile = get_or("gemm_to_profile", (int) GemmToProfile::LAYER);
295299
TLLM_CHECK_WITH_INFO(world_rank < tp_size * ep_size, "Rank is out of bounds of tp*ep");
296300

297-
if (gemm_to_profile != (int) GemmToProfile::LAYER && routing_config != UNIFORM_ROUTING_CONFIG)
301+
if (gemm_to_profile != (int) GemmToProfile::LAYER)
298302
{
299303
static bool info_printed = false;
300304
if (!info_printed && LOG_LEVEL >= INFO)
@@ -304,13 +308,14 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
304308
}
305309

306310
static bool printed = false;
307-
if (LOG_LEVEL >= ERROR && !printed)
311+
if (routing_config != UNIFORM_ROUTING_CONFIG && LOG_LEVEL >= ERROR && !printed)
308312
{
309313
std::cerr << "Warning: Profiling a specific GEMM will always use uniform random token distribution"
310314
<< std::endl;
311315
printed = true;
312316
}
313317
routing_config = UNIFORM_ROUTING_CONFIG;
318+
314319
if (gemm_to_profile == (int) GemmToProfile::GEMM_1)
315320
{
316321
tactic_ids2 = {-1};

0 commit comments

Comments
 (0)