Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

Conversation

seungwoos
Copy link
Contributor

@seungwoos seungwoos commented Feb 6, 2025

  1. Add Qwen2.5-VL model with updated util functions.

  2. Add position_embedding on module_kwargs since the latest huggingface version requires pre-computed positional embeddings as a forward process argument. (see the difference between huggingface<4.48.0 and huggingface>=4.48.0)

@seungwoos seungwoos changed the title Add qwen2 5 vl Add qwen2.5-vl Feb 6, 2025
@seungwoos seungwoos changed the title Add qwen2.5-vl Add Qwen2.5-VL Feb 6, 2025
@BenasdTW
Copy link

BenasdTW commented Feb 7, 2025

@seungwoos Is this branch usable? Can you provide some instructions on how to get it to work?
I can't get it to work, and it also causes issues with older models like Qwen2-VL.
This is how I install it:

pip install git+https://github.com/seungwoos/AutoAWQ.git@add-qwen2_5_vl --no-deps
pip install git+https://github.com/huggingface/transformers

Code:

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
quant_path = "test_awq"
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')

Got the following error:

root@0455e7995f18:/workspaces/SpecsML# /opt/conda/bin/python /workspaces/SpecsML/quant.py
Fetching 14 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 190033.19it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 10.05it/s]
Repo card metadata block was not found. Setting CardData to empty.
AWQ:   0%|                                                                                                                         | 0/36 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/workspaces/SpecsML/quant.py", line 13, in <module>
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/awq/models/base.py", line 242, in quantize
    self.quantizer.quantize()
  File "/opt/conda/lib/python3.11/site-packages/awq/quantize/quantizer.py", line 172, in quantize
    input_feat = self._get_input_feat(self.modules[i], named_linears)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/awq/quantize/quantizer.py", line 648, in _get_input_feat
    self.inps = self._module_forward(self.inps, layer, module_kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/awq/quantize/quantizer.py", line 260, in _module_forward
    module_output = module(x, **module_kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1017, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 910, in forward
    cos, sin = position_embeddings
    ^^^^^^^^
TypeError: cannot unpack non-iterable NoneType object

@BenasdTW
Copy link

BenasdTW commented Feb 7, 2025

@seungwoos Is this branch usable? Can you provide some instructions on how to get it to work? I can't get it to work, and it also causes issues with older models like Qwen2-VL.

The error is believed to be caused by dependency issue. This Qwen2.5-VL depends on the main branch of transformers, but AutoAWQ depends on transformers<=4.47.1,>=4.45.0

autoawq 0.2.8 requires transformers<=4.47.1,>=4.45.0, but you have transformers 4.49.0.dev0 which is incompatible.

@seungwoos
Copy link
Contributor Author

Hi, @BenasdTW

You should compute positional embeddings beforehand. I actually made another PR to handle this issue. There's room for enhancement since rotary embedding only requires the input device.

The current AutoAWQ doesn't include the latest transformer version. Installing the latest transformer after installing AutoAWQ's required packages worked for me.

@seungwoos
Copy link
Contributor Author

seungwoos commented Feb 7, 2025

I guess we shouldn't use pile-val-backup as a calibration dataset, but the Qwen2-VL example code seems not working properly. I'm currently working on fixing this issue.

you should add padding_side=left on preprocessor.

@BenasdTW
Copy link

BenasdTW commented Feb 7, 2025

You should compute positional embeddings beforehand. I actually made another PR to handle this issue. There's room for enhancement since rotary embedding only requires the input device.

Thanks for the clarification! After manually applying the patch from #705, it works as expected.

I think it would be useful to mention that this PR depends on #705.

@BenasdTW
Copy link

BenasdTW commented Feb 7, 2025

@seungwoos Would you mind creating a branch that merges add-computed-position-embedding and add-qwen2_5_vl in your fork? This would make it easier for people to install and use.

@seungwoos
Copy link
Contributor Author

Thanks for your comment @BenasdTW !
I just merged the previous PR into this one.

@jlia0
Copy link

jlia0 commented Feb 9, 2025

The following config works for me.

Image.debian_slim(python_version="3.12")
    .apt_install("git")
    .pip_install("torch")
    .pip_install(
        "git+https://github.com/seungwoos/AutoAWQ.git@add-qwen2_5_vl"
    )
    .pip_install(
        "git+https://github.com/huggingface/transformers",
        "accelerate",
    )
    .pip_install(
        "pillow"
    )

@BenasdTW
Copy link

BenasdTW commented Feb 9, 2025

@jlia0 I saw your comment on Hugging Face. Would you mind sharing the 72B model on Hugging Face if you manage to quantize it? I don't have a PC powerful enough to quantize the 72B model.

Here are the 3B and 7B AWQ quantized version in case someone needs it.
https://huggingface.co/Benasd/Qwen2.5-VL-7B-Instruct-AWQ
https://huggingface.co/Benasd/Qwen2.5-VL-3B-Instruct-AWQ

@jlia0
Copy link

jlia0 commented Feb 9, 2025

@jlia0 I saw your comment on Hugging Face. Would you mind sharing the 72B model on Hugging Face if you manage to quantize it? I don't have a PC powerful enough to quantize the 72B model.

Here are the 3B and 7B AWQ quantized version in case someone needs it. https://huggingface.co/Benasd/Qwen2.5-VL-7B-Instruct-AWQ https://huggingface.co/Benasd/Qwen2.5-VL-3B-Instruct-AWQ

sure - there you go

https://huggingface.co/PointerHQ/Qwen2.5-VL-72B-Instruct-Pointer-AWQ

@jlia0
Copy link

jlia0 commented Feb 10, 2025

@jlia0 I saw your comment on Hugging Face. Would you mind sharing the 72B model on Hugging Face if you manage to quantize it? I don't have a PC powerful enough to quantize the 72B model.

Here are the 3B and 7B AWQ quantized version in case someone needs it. https://huggingface.co/Benasd/Qwen2.5-VL-7B-Instruct-AWQ https://huggingface.co/Benasd/Qwen2.5-VL-3B-Instruct-AWQ

Hi could you please share your AutoAWQ quantization code for Qwen2.5-VL?

There's something wrong with my 72B-AWQ model when serving it using vLLM with --tensor-parallel-size=2.

@BenasdTW
Copy link

Hi could you please share your AutoAWQ quantization code for Qwen2.5-VL?
Sure.

from AutoAWQ.awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = "Qwen/Qwen2.5-VL-7B-Instruct"
quant_path = "Qwen2.5-VL-7B-Instruct-AWQ"
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')

I haven't tried --tensor-parallel-size=2, so I'm not sure if it will work.

@jlia0
Copy link

jlia0 commented Feb 10, 2025

Hi could you please share your AutoAWQ quantization code for Qwen2.5-VL?

Sure.

from AutoAWQ.awq import AutoAWQForCausalLM

from transformers import AutoTokenizer



model_path = "Qwen/Qwen2.5-VL-7B-Instruct"

quant_path = "Qwen2.5-VL-7B-Instruct-AWQ"

quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }



# Load model

model = AutoAWQForCausalLM.from_pretrained(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)



# Quantize

model.quantize(tokenizer, quant_config=quant_config)



# Save quantized model

model.save_quantized(quant_path)

tokenizer.save_pretrained(quant_path)



print(f'Model is quantized and saved at "{quant_path}"')

I haven't tried --tensor-parallel-size=2, so I'm not sure if it will work.

What's your setup/environment?

I have tried TP=2 with your 7B-AWQ model and they work.

However, the 72B didn't work, with the following error.

ValueError( ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size.

@BenasdTW
Copy link

BenasdTW commented Feb 10, 2025

What's your setup/environment?

I ran it in a vscode devcontainer with this docker file:

FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
# Install Python and other necessary packages
RUN apt-get update && \
    apt-get install -y git libgl1-mesa-glx libglib2.0-0 && \
    rm -rf /var/lib/apt/lists/*

# Upgrade pip
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install torch torchvision torchaudio
RUN python3 -m pip install git+https://github.com/huggingface/transformers
RUN python3 -m pip install git+https://github.com/huggingface/accelerate
RUN python3 -m pip install git+https://github.com/huggingface/peft
RUN python3 -m pip install git+https://github.com/huggingface/trl
RUN python3 -m pip install flash-attn --no-build-isolation
RUN python3 -m pip install datasets numpy sentencepiece gguf protobuf matplotlib
RUN python3 -m pip install bitsandbytes
RUN python3 -m pip install tensorboard
RUN python3 -m pip install qwen-vl-utils[decord]
RUN python3 -m pip install git+https://github.com/seungwoos/AutoAWQ.git@add-qwen2_5_vl --no-deps
RUN python3 -m pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly

Hardware: i9-12900K, RTX 3080 Ti
Host OS: Windows 10

I'm not sure, but I think it could be because TP=2 doesn't split the 7B-AWQ model, instead, it just duplicates the small model.

@BenasdTW
Copy link

BenasdTW commented Feb 13, 2025

However, the 72B didn't work, with the following error.

ValueError( ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size.

@jlia0 Have you found a solution to this problem?

I was able to run a quantized model with -tp 2 using a workaround: setting q_group_size to 64 during quantization. As shown here.
However, I'm not sure if it's working correctly, because my quantized 72B model just outputs gibberish.

Would you mind sharing your quantization code?

@casper-hansen
Copy link
Owner

@seungwoos Thanks for this PR. I hope to review it soon and merge it!

@BenasdTW There is a bug in vLLM. Try inference in AutoAWQ first to see if it works. vllm-project/vllm#13227

@BenasdTW
Copy link

BenasdTW commented Feb 13, 2025

Nevermind. Everything worked again after reboot.

@Cescfangs
Copy link

However, the 72B didn't work, with the following error.

ValueError( ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size.

@jlia0 Have you found a solution to this problem?

I was able to run a quantized model with -tp 2 using a workaround: setting q_group_size to 64 during quantization. As shown here. However, I'm not sure if it's working correctly, because my quantized 72B model just outputs gibberish.

Would you mind sharing your quantization code?

Hey @BenasdTW , I encountered the same issue, quantized 72B model outputs gibberish, how did you solve this?

@BenasdTW
Copy link

Hey @BenasdTW , I encountered the same issue, quantized 72B model outputs gibberish, how did you solve this?

I actually just restarted the server, rebuilt the container and re-ran the exact same code. Make sure no other program is using the GPUs.

@Cescfangs
Copy link

Hey @BenasdTW , I encountered the same issue, quantized 72B model outputs gibberish, how did you solve this?

I actually just restarted the server, rebuilt the container and re-ran the exact same code. Make sure no other program is using the GPUs.

Actually, the quantized model is ok under autoawq, inference result was completely different with vllm server, I was using vllm 0.7.2, any further advices?

@BenasdTW
Copy link

Actually, the quantized model is ok under autoawq, inference result was completely different with vllm server, I was using vllm 0.7.2, any further advices?

Are you using vLLM v1? I think v1 is bugged, the inference result is different to v0.

@seungwoos
Copy link
Contributor Author

seungwoos commented Feb 20, 2025

If you want to use a vision and text dataset as a calibration set, you should use processor = Qwen2_5_VLProcessor.from_pretrained(model_path, padding_side='left') instead of model.processor in this example.

@BenasdTW
Copy link

If you want to use a vision and text dataset as a calibration set, you should change to processor = Qwen2_5_VLProcessor.from_pretrained(model_path, padding_side='left') in this example.

There is no processor in the example.
Did you mean replacing model.processor with Qwen2_5_VLProcessor.from_pretrained(model_path, padding_side='left')?

@BenasdTW
Copy link

BenasdTW commented Feb 20, 2025

The Qwen team just released their official version of AWQ quantized model.
Qwen/Qwen2.5-VL-72B-Instruct-AWQ

BTW, the official quantized version doesn't work with -tp 2 for now.

@seungwoos
Copy link
Contributor Author

There is no processor in the example. Did you mean replacing model.processor with Qwen2_5_VLProcessor.from_pretrained(model_path, padding_side='left')?

Oh yes, we should import Qwen2_5_VLProcessor first, then set processor with padding_side=left.
Or we can just use AutoProcessor. The key point is using padding_side=left; otherwise, it does not work.

@jlia0
Copy link

jlia0 commented Feb 21, 2025

@BenasdTW @seungwoos

I have updated the previous uploaded weights.

Try PointerHQ/Qwen2.5-VL-72B-Instruct-Pointer-AWQ which supports --tensor-parallel on 2, 4 and 8 GPUs.

@BenasdTW
Copy link

@BenasdTW @seungwoos

I have updated the previous uploaded weights.

Try PointerHQ/Qwen2.5-VL-72B-Instruct-Pointer-AWQ which supports --tensor-parallel on 2, 4 and 8 GPUs.

Thanks! Good Work! This is definitely better than changing the group_size!
I've noticed that you padded the intermediate_size of the model. Would you mind telling me how to pad the model? Is fine-tuning required? I would also like to know which calibration dataset you used for AWQ.

@casper-hansen casper-hansen merged commit b6719dc into casper-hansen:main Mar 6, 2025
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants