Skip to content

[Bug]: OpenGVLab/InternVL2-Llama3-76B: view size is not compatible with input tensor's size and stride #8630

@erkintelnyx

Description

@erkintelnyx

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.5.0.dev20240726+rocm6.1
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.1.40091-a8dbc0c19

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 17.0.0 (https://github.com/RadeonOpenCompute/llvm-project roc-6.1.2 24193 669db884972e769450470020c06a6f132a8a065b)
CMake version: version 3.26.4
Libc version: glibc-2.31

Python version: 3.9.19 (main, May  6 2024, 19:43:03)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI100 (gfx908:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.1.40093
MIOpen runtime version: 3.1.0
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Byte Order:                           Little Endian
Address sizes:                        48 bits physical, 48 bits virtual
CPU(s):                               16
On-line CPU(s) list:                  0-15
Thread(s) per core:                   1
Core(s) per socket:                   16
Socket(s):                            1
NUMA node(s):                         1
Vendor ID:                            AuthenticAMD
CPU family:                           25
Model:                                1
Model name:                           AMD EPYC 7713 64-Core Processor
Stepping:                             1
CPU MHz:                              2000.000
BogoMIPS:                             4000.00
Virtualization:                       AMD-V
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            1 MiB
L1i cache:                            1 MiB
L2 cache:                             8 MiB
L3 cache:                             16 MiB
NUMA node0 CPU(s):                    0-15
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; safe RET
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid fsrm arch_capabilities

Versions of relevant libraries:
[pip3] mypy==1.7.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.9.1
[pip3] pytorch-triton-rocm==3.0.0+21eae954ef
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.0.dev20240726+rocm6.1
[pip3] torchvision==0.20.0.dev20240726+rocm6.1
[pip3] transformers==4.43.2
[pip3] triton==3.0.0
[conda] No relevant packages
ROCM Version: 6.1.40093-bd86f1708
Neuron SDK Version: N/A
vLLM Version: 0.6.1.post2@a8c1d161a7d87dbc6c7cccfce303dcbe2e4ed6be
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

Model Input Dumps

err_execute_model_input_20240919-094504.pkl.zip

🐛 Describe the bug

When I start the model via:
vllm serve OpenGVLab/InternVL2-Llama3-76B --tensor-parallel-size 8 --max-model-len 8000

I get:

  File "/vllm-workspace/vllm/worker/model_runner_base.py", line 116, in _wrapper
    return func(*args, **kwargs)
  File "/vllm-workspace/vllm/worker/model_runner.py", line 1590, in execute_model
    hidden_or_intermediate_states = model_executable(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/vllm-workspace/vllm/model_executor/models/internvl.py", line 488, in forward
    vision_embeddings = self._process_image_input(image_input)
  File "/vllm-workspace/vllm/model_executor/models/internvl.py", line 471, in _process_image_input
    image_embeds = self.extract_feature(image_input["data"])
  File "/vllm-workspace/vllm/model_executor/models/internvl.py", line 395, in extract_feature
    vit_embeds = self.vision_model(pixel_values=pixel_values)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/vllm-workspace/vllm/model_executor/models/intern_vit.py", line 356, in forward
    encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/vllm-workspace/vllm/model_executor/models/intern_vit.py", line 298, in forward
    hidden_states = encoder_layer(hidden_states)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/vllm-workspace/vllm/model_executor/models/intern_vit.py", line 267, in forward
    hidden_states = hidden_states + self.attn(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/vllm-workspace/vllm/model_executor/models/intern_vit.py", line 203, in forward
    x = x.transpose(1, 2).view(B, N, -1)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/envs/py_3.9/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/opt/conda/envs/py_3.9/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/vllm-workspace/vllm/engine/multiprocessing/engine.py", line 318, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
  File "/vllm-workspace/vllm/engine/multiprocessing/engine.py", line 113, in from_engine_args
    return cls(
  File "/vllm-workspace/vllm/engine/multiprocessing/engine.py", line 69, in __init__
    self.engine = LLMEngine(*args, **kwargs)
  File "/vllm-workspace/vllm/engine/llm_engine.py", line 331, in __init__
    self._initialize_kv_caches()
  File "/vllm-workspace/vllm/engine/llm_engine.py", line 465, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
  File "/vllm-workspace/vllm/executor/distributed_gpu_executor.py", line 39, in determine_num_available_blocks
    num_blocks = self._run_workers("determine_num_available_blocks", )
  File "/vllm-workspace/vllm/executor/multiproc_gpu_executor.py", line 185, in _run_workers
    driver_worker_output = driver_worker_method(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/vllm-workspace/vllm/worker/worker.py", line 223, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/vllm-workspace/vllm/worker/model_runner.py", line 1236, in profile_run
    self.execute_model(model_input, kv_caches, intermediate_tensors)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/vllm-workspace/vllm/worker/model_runner_base.py", line 144, in _wrapper
    raise type(err)(
RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20240919-094504.pkl): view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.```

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingrocmRelated to AMD ROCmstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions