From f5d3b260843635a6d395e8e14e42c386bb7ea98d Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 4 Jul 2025 16:32:35 +0000 Subject: [PATCH 1/3] enable MTP V1 ROCm Signed-off-by: tjtanaa --- vllm/platforms/rocm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4550ef570684..ff95d57a7dce 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -319,9 +319,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: if envs.VLLM_USE_V1: - raise NotImplementedError( - "Speculative decoding is not yet supported on vLLM V1." - ) + parallel_config.worker_cls = \ + "vllm.v1.worker.gpu_worker.Worker" else: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" From 161e90978bf89f68ef9131a5d6f6a70dac00f19f Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sun, 13 Jul 2025 16:38:53 +0000 Subject: [PATCH 2/3] fix mamba_ssm compilation bug on ROCm Signed-off-by: tjtanaa --- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 5f9209979341..5766fbab4e87 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -7,7 +7,11 @@ #include #include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#ifdef USE_ROCM + #include // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK +#else + #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#endif #ifndef USE_ROCM #include @@ -320,8 +324,13 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; if (kSmemSize >= 48 * 1024) { +#ifdef USE_ROCM + C10_HIP_CHECK(hipFuncSetAttribute( + reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); +#else C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); +#endif } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); From fd70bc99f937bf02956213e3bc05c8459c623fa8 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sun, 13 Jul 2025 17:45:18 +0000 Subject: [PATCH 3/3] revert changes to platforms/rocm.py Signed-off-by: tjtanaa --- vllm/platforms/rocm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 6ee878146034..04637f5c7aa6 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -327,8 +327,9 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" + raise NotImplementedError( + "Speculative decoding is not yet supported on vLLM V1." + ) else: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker"