From 2e60932dbd3e193f84b683cfb287becc3c10ad7b Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 4 Mar 2025 00:57:59 +0000 Subject: [PATCH] Fix benchmark_moe.py tuning for CUDA devices Signed-off-by: mgoin --- benchmarks/kernels/benchmark_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c862dec81fcc..3a5a2134a9a4 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -2,6 +2,7 @@ import argparse import time +from contextlib import nullcontext from datetime import datetime from itertools import product from typing import Any, TypedDict @@ -382,7 +383,8 @@ def tune( hidden_size, search_space, is_fp16) - with torch.cuda.device(self.device_id): + with torch.cuda.device(self.device_id) if current_platform.is_rocm( + ) else nullcontext(): for config in tqdm(search_space): try: kernel_time = benchmark_config(config,