Skip to content

[Bug]: AttributeError: 'tuple' object has no attribute 'seq_data' when loading Molmo7B on two Nvidia L4 #10042

@OsaCode

Description

@OsaCode

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.11.10 | packaged by conda-forge | (main, Oct 16 2024, 01:27:36) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.10.0-33-cloud-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.5.82
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA L4
GPU 1: NVIDIA L4

Nvidia driver version: 550.90.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Byte Order:                           Little Endian
Address sizes:                        46 bits physical, 48 bits virtual
CPU(s):                               24
On-line CPU(s) list:                  0-23
Thread(s) per core:                   2
Core(s) per socket:                   12
Socket(s):                            1
NUMA node(s):                         1
Vendor ID:                            GenuineIntel
CPU family:                           6
Model:                                85
Model name:                           Intel(R) Xeon(R) CPU @ 2.20GHz
Stepping:                             7
CPU MHz:                              2200.154
BogoMIPS:                             4400.30
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            384 KiB
L1i cache:                            384 KiB
L2 cache:                             12 MiB
L3 cache:                             38.5 MiB
NUMA node0 CPU(s):                    0-23
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==25.1.2
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.45.2
[pip3] triton==3.1.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-ml-py              12.560.30                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pyzmq                     25.1.2          py311h6a678d5_0    anaconda
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torchvision               0.20.1                   pypi_0    pypi
[conda] transformers              4.45.2                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.3.post2.dev235+g93dee88f
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PHB     0-23    0               N/A
GPU1    PHB      X      0-23    0               N/A

Model Input Dumps

No response

🐛 Describe the bug

Ran

model_name = "allenai/Molmo-7B-D-0924"

llm = LLM(
    model=model_name,
    trust_remote_code=True,
    dtype="bfloat16",
    tensor_parallel_size=2
)

Got

INFO 11-05 15:45:50 config.py:1764] Downcasting torch.float32 to torch.bfloat16.
INFO 11-05 15:45:55 config.py:991] Defaulting to use mp for distributed inference
INFO 11-05 15:45:55 llm_engine.py:247] Initializing an LLM engine (v0.6.3.post2.dev235+g93dee88f) with config: model='allenai/Molmo-7B-D-0924', speculative_config=None, tokenizer='allenai/Molmo-7B-D-0924', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=allenai/Molmo-7B-D-0924, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, chat_template_text_format=string, mm_processor_kwargs=None, pooler_config=None)
WARNING 11-05 15:45:56 multiproc_gpu_executor.py:53] Reducing Torch parallelism from 12 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 11-05 15:45:56 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
INFO 11-05 15:45:56 selector.py:110] Using Flash Attention backend.
(VllmWorkerProcess pid=11861) INFO 11-05 15:45:56 selector.py:110] Using Flash Attention backend.
(VllmWorkerProcess pid=11861) INFO 11-05 15:45:56 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
INFO 11-05 15:45:57 utils.py:960] Found nccl from library libnccl.so.2
INFO 11-05 15:45:57 pynccl.py:63] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=11861) INFO 11-05 15:45:57 utils.py:960] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=11861) INFO 11-05 15:45:57 pynccl.py:63] vLLM is using nccl==2.21.5
INFO 11-05 15:45:57 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /home/jupyter/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
(VllmWorkerProcess pid=11861) INFO 11-05 15:45:57 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /home/jupyter/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 11-05 15:45:57 shm_broadcast.py:236] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7f2cc2d34fd0>, local_subscribe_port=36563, remote_subscribe_port=None)
INFO 11-05 15:45:57 model_runner.py:1052] Starting to load model allenai/Molmo-7B-D-0924...
(VllmWorkerProcess pid=11861) INFO 11-05 15:45:57 model_runner.py:1052] Starting to load model allenai/Molmo-7B-D-0924...
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
(VllmWorkerProcess pid=11861) WARNING 11-05 15:45:57 utils.py:622] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
INFO 11-05 15:45:57 selector.py:110] Using Flash Attention backend.
(VllmWorkerProcess pid=11861) INFO 11-05 15:45:57 selector.py:110] Using Flash Attention backend.
INFO 11-05 15:45:57 weight_utils.py:243] Using model weights format ['*.safetensors']
(VllmWorkerProcess pid=11861) INFO 11-05 15:45:57 weight_utils.py:243] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:14<00:00,  1.92s/it]
INFO 11-05 15:46:14 model_runner.py:1057] Loading model weights took 7.4949 GB
(VllmWorkerProcess pid=11861) INFO 11-05 15:46:14 model_runner.py:1057] Loading model weights took 7.4949 GB
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229] Traceback (most recent call last):
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]   File "/opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]              ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]   File "/opt/conda/envs/molmo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]     return func(*args, **kwargs)
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]   File "/opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/worker/worker.py", line 198, in determine_num_available_blocks
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]     self.model_runner.profile_run()
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]   File "/opt/conda/envs/molmo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]     return func(*args, **kwargs)
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]   File "/opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1249, in profile_run
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]     .dummy_data_for_profiling(self.model_config,
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]   File "/opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/inputs/registry.py", line 234, in dummy_data_for_profiling
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]     num_tokens = dummy_data.seq_data.prompt_token_ids
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229]                  ^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=11861) ERROR 11-05 15:46:16 multiproc_worker_utils.py:229] AttributeError: 'tuple' object has no attribute 'seq_data'
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], line 3
      1 model_name = "allenai/Molmo-7B-D-0924"
----> 3 llm = LLM(
      4     model=model_name,
      5     trust_remote_code=True,
      6     dtype="bfloat16",
      7     tensor_parallel_size=2
      8 )

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/utils.py:1025, in deprecate_args.<locals>.wrapper.<locals>.inner(*args, **kwargs)
   1018             msg += f" {additional_message}"
   1020         warnings.warn(
   1021             DeprecationWarning(msg),
   1022             stacklevel=3,  # The inner function takes up one level
   1023         )
-> 1025 return fn(*args, **kwargs)

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/entrypoints/llm.py:209, in LLM.__init__(self, model, tokenizer, tokenizer_mode, skip_tokenizer_init, trust_remote_code, allowed_local_media_path, tensor_parallel_size, dtype, quantization, revision, tokenizer_revision, seed, gpu_memory_utilization, swap_space, cpu_offload_gb, enforce_eager, max_seq_len_to_capture, disable_custom_all_reduce, disable_async_output_proc, mm_processor_kwargs, task, pooling_type, pooling_norm, pooling_softmax, pooling_step_tag_id, pooling_returned_token_ids, **kwargs)
    178     kwargs["disable_log_stats"] = True
    180 engine_args = EngineArgs(
    181     model=model,
    182     task=task,
   (...)
    207     **kwargs,
    208 )
--> 209 self.llm_engine = LLMEngine.from_engine_args(
    210     engine_args, usage_context=UsageContext.LLM_CLASS)
    211 self.request_counter = Counter()

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/engine/llm_engine.py:577, in LLMEngine.from_engine_args(cls, engine_args, usage_context, stat_loggers)
    575 executor_class = cls._get_executor_cls(engine_config)
    576 # Create the LLM engine.
--> 577 engine = cls(
    578     vllm_config=engine_config,
    579     executor_class=executor_class,
    580     log_stats=not engine_args.disable_log_stats,
    581     usage_context=usage_context,
    582     stat_loggers=stat_loggers,
    583 )
    585 return engine

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/engine/llm_engine.py:350, in LLMEngine.__init__(self, vllm_config, executor_class, log_stats, usage_context, stat_loggers, input_registry, use_cached_outputs)
    347 self.model_executor = executor_class(vllm_config=vllm_config, )
    349 if self.model_config.task != "embedding":
--> 350     self._initialize_kv_caches()
    352 # If usage stat is enabled, collect relevant info.
    353 if is_usage_stats_enabled():

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/engine/llm_engine.py:487, in LLMEngine._initialize_kv_caches(self)
    480 def _initialize_kv_caches(self) -> None:
    481     """Initialize the KV cache in the worker(s).
    482 
    483     The workers will determine the number of blocks in both the GPU cache
    484     and the swap CPU cache.
    485     """
    486     num_gpu_blocks, num_cpu_blocks = (
--> 487         self.model_executor.determine_num_available_blocks())
    489     if self.cache_config.num_gpu_blocks_override is not None:
    490         num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/executor/distributed_gpu_executor.py:39, in DistributedGPUExecutor.determine_num_available_blocks(self)
     29 """Determine the number of available KV blocks.
     30 
     31 This invokes `determine_num_available_blocks` on each worker and takes
   (...)
     36     - tuple[num_gpu_blocks, num_cpu_blocks]
     37 """
     38 # Get the maximum number of blocks that can be allocated on GPU and CPU.
---> 39 num_blocks = self._run_workers("determine_num_available_blocks", )
     41 # Since we use a shared centralized controller, we take the minimum
     42 # number of blocks across all workers to make sure all the memory
     43 # operators can be applied to all workers.
     44 num_gpu_blocks = min(b[0] for b in num_blocks)

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/executor/multiproc_gpu_executor.py:192, in MultiprocessingGPUExecutor._run_workers(self, method, async_run_tensor_parallel_workers_only, max_concurrent_workers, *args, **kwargs)
    186 worker_outputs = [
    187     worker.execute_method(method, *args, **kwargs)
    188     for worker in self.workers
    189 ]
    191 driver_worker_method = getattr(self.driver_worker, method)
--> 192 driver_worker_output = driver_worker_method(*args, **kwargs)
    194 # Get the results of the workers.
    195 return [driver_worker_output
    196         ] + [output.get() for output in worker_outputs]

File /opt/conda/envs/molmo/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/worker/worker.py:198, in Worker.determine_num_available_blocks(self)
    194 free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
    196 # Execute a forward pass with dummy inputs to profile the memory usage
    197 # of the model.
--> 198 self.model_runner.profile_run()
    199 torch.cuda.synchronize()
    201 self._assert_memory_footprint_increased_during_profiling()

File /opt/conda/envs/molmo/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/worker/model_runner.py:1249, in GPUModelRunnerBase.profile_run(self)
   1244 seq_len = (max_num_batched_tokens // max_num_seqs +
   1245            (group_id < max_num_batched_tokens % max_num_seqs))
   1246 batch_size += seq_len
   1248 dummy_data = self.input_registry \
-> 1249     .dummy_data_for_profiling(self.model_config,
   1250                               seq_len,
   1251                               self.mm_registry)
   1253 seq = SequenceGroupMetadata(
   1254     request_id=str(group_id),
   1255     is_prompt=True,
   (...)
   1262     multi_modal_placeholders=dummy_data.multi_modal_placeholders,
   1263 )
   1264 seqs.append(seq)

File /opt/conda/envs/molmo/lib/python3.11/site-packages/vllm/inputs/registry.py:234, in InputRegistry.dummy_data_for_profiling(self, model_config, seq_len, mm_registry, is_encoder_data)
    229 dummy_data = dummy_factory(InputContext(model_config), seq_len,
    230                            _MultiModalCounts(mm_counts),
    231                            **mm_processor_kwargs)
    233 # Having more tokens is over-conservative but otherwise fine
--> 234 num_tokens = dummy_data.seq_data.prompt_token_ids
    235 if len(num_tokens) < seq_len:
    236     if is_encoder_data:

AttributeError: 'tuple' object has no attribute 'seq_data'

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions