Skip to content

[Bug]: Try-catch conditions are incorrect to import correct ROCm Flash Attention Backend in Draft Model #9100

@tjtanaa

Description

@tjtanaa

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

Model Input Dumps

No response

🐛 Describe the bug

I found an issue running draft model speculative decoding on AMD platform, the issue arised from vllm/spec_decode/draft_model_runner.py

try:
    from vllm.attention.backends.flash_attn import FlashAttentionMetadata # this is throwing ImportError rather than ModuleNotFoundError
except ModuleNotFoundError:
    # vllm_flash_attn is not installed, use the identical ROCm FA metadata
    from vllm.attention.backends.rocm_flash_attn import (
        ROCmFlashAttentionMetadata as FlashAttentionMetadata)

Within the try-catch block ImportError is thrown rather than ModuleNotFoundError

  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/engine/multiprocessing/engine.py", line 78, in __init__                          
    self.engine = LLMEngine(*args,                                                                                                          
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/engine/llm_engine.py", line 335, in __init__                                     
    self.model_executor = executor_class(                                                                                                   
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/distributed_gpu_executor.py", line 26, in __init__                      
    super().__init__(*args, **kwargs)                                                                                                       
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/executor_base.py", line 47, in __init__                                 
    self._init_executor()                                                                                                                   
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/multiproc_gpu_executor.py", line 108, in _init_executor                 
    self.driver_worker = self._create_worker(                                                                                               
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/gpu_executor.py", line 105, in _create_worker                           
    return create_worker(**self._get_create_worker_kwargs(                                                                                  
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/gpu_executor.py", line 24, in create_worker                             
    wrapper.init_worker(**kwargs)                                                                                                           
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/worker/worker_base.py", line 446, in init_worker                                 
    mod = importlib.import_module(self.worker_module_name)                                                                                  
  File "/home/aac/anaconda3/envs/rocm611-0929/lib/python3.9/importlib/__init__.py", line 127, in import_module                              
    return _bootstrap._gcd_import(name[level:], package, level)                                                                             
  File "<frozen importlib._bootstrap>", line 1030, in _gcd_import                                                                           
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load                                                                        
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked                                                                
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked                                                                         
  File "<frozen importlib._bootstrap_external>", line 850, in exec_module                                                                   
  File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed                                                              
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/spec_decode/spec_decode_worker.py", line 21, in <module>                         
    from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/spec_decode/draft_model_runner.py", line 9, in <module>
    from vllm.attention.backends.flash_attn import FlashAttentionMetadata
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/attention/backends/flash_attn.py", line 23, in <module>
    from vllm.vllm_flash_attn import (flash_attn_varlen_func,                                                                               
ImportError: cannot import name 'flash_attn_varlen_func' from 'vllm.vllm_flash_attn' (unknown location)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions