Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions vllm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2218,6 +2218,59 @@ curl http://localhost:8000/v1/audio/transcriptions \
```
---

### 2.4.2 dots.ocr Support

Git clone the repo:

```bash
https://github.com/rednote-hilab/dots.ocr.git
cd dots.ocr
```

Then, we should comment out the following two items in `requirements.txt`:

> flash-attn==2.8.0.post2 and accelerate # because these two dependencies will require cuda

After commenting out these two elements, we can install the dependencies:

```bash
# Assuming you have installed torch/ipex etc.
pip install --no-deps accelerate
pip install -e .
```

To download model weights in `dots.ocr`:
```bash
# In dots.ocr
python3 tools/download_model.py

# with modelscope
python3 tools/download_model.py --type modelscope
```

In order to run dots.ocr, we will need to change codes in `./weights/DotsOCR`:

```bash
cd ./weights/DotsOCR
patch -p1 < YOUR_PATH/dots_ocr.patch
```

Then, you're ready to start:

```bash
export hf_model_path=./weights/DotsOCR # Path to your downloaded model weights, Please use a directory name without periods (e.g., `DotsOCR` instead of `dots.ocr`) for the model save path. This is a temporary workaround pending our integration with Transformers.
export PYTHONPATH=$(dirname "$hf_model_path"):$PYTHONPATH
sed -i '/^from vllm\.version import __version__ as VLLM_VERSION$/a\
from DotsOCR import modeling_dots_ocr_vllm' /usr/local/lib/python3.12/dist-packages/vllm-0.10.1.dev0+g6d8d0a24c.d20250825.xpu-py3.12-linux-x86_64.egg/vllm/entrypoints/openai/api_server.py
# If you downloaded model weights by yourself, please replace `DotsOCR` by your model saved directory name, and remember to use a directory name without periods (e.g., `DotsOCR` instead of `dots.ocr`)

# Start the service:
TORCH_LLM_ALLREDUCE=1 VLLM_USE_V1=1 CCL_ZE_IPC_EXCHANGE=pidfd VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 VLLM_WORKER_MULTIPROC_METHOD=spawn python3 -m vllm.entrypoints.openai.api_server --model YOUR_DOTSOCR_PATH --enforce-eager --host 0.0.0.0 --trust-remote-code --disable-sliding-window --gpu-memory-util=0.8 --no-enable-prefix-caching --max-num-batched-tokens=8192 --disable-log-requests --max-model-len=40000 --block-size 64 -tp=1 --port 8000 --served-model-name DotsOCR --chat-template-content-format string --dtype bfloat16
```


---

### 2.5 Omni Model Support

#### Install audio dependencies
Expand Down
115 changes: 115 additions & 0 deletions vllm/dots_ocr.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
diff --git a/modeling_dots_ocr_vllm.py b/modeling_dots_ocr_vllm.py
index a8ba8d0..eb84b0d 100644
--- a/modeling_dots_ocr_vllm.py
+++ b/modeling_dots_ocr_vllm.py
@@ -178,11 +178,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal):
)
_tp_plan = {}

- @classmethod
- def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
- if modality in ("image",):
- return "<|img|><|imgpad|><|endofimg|>"
-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

@@ -424,12 +419,20 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality.startswith("image"):
+ return "<|img|><|imgpad|><|endofimg|>"
+
+ raise ValueError("Only image modality is supported")
+

def patch_vllm_chat_placeholder():
import vllm
# return when vllm version > 0.9.1
- if not (vllm.__version_tuple__[0]==0 and vllm.__version_tuple__[1] <= 9 and vllm.__version_tuple__[2] <= 1):
- return
+ # our version is 0.9.0.dev, ignore the following version check.
+ # if not (vllm.__version_tuple__[0]==0 and vllm.__version_tuple__[1] <= 9 and vllm.__version_tuple__[2] <= 1):
+ # return
from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker

ori = BaseMultiModalItemTracker._placeholder_str
@@ -448,4 +451,4 @@ ModelRegistry.register_model(
)


-patch_vllm_chat_placeholder()
\ No newline at end of file
+# patch_vllm_chat_placeholder()
diff --git a/modeling_dots_vision.py b/modeling_dots_vision.py
index 1046513..56009a8 100644
--- a/modeling_dots_vision.py
+++ b/modeling_dots_vision.py
@@ -8,10 +8,16 @@ import torch.utils.checkpoint
flash_attn_available = True
npu_available = True

+# try:
+# from flash_attn import flash_attn_varlen_func
+# except ImportError:
+# flash_attn_available = False
+
try:
- from flash_attn import flash_attn_varlen_func
-except ImportError:
- flash_attn_available = False
+ import intel_extension_for_pytorch as ipex
+except ImportError as e:
+ raise ValueError("IPEX is not installed but required for XPU build")
+

from torch.nn import LayerNorm
from transformers.modeling_utils import PreTrainedModel
@@ -159,9 +165,41 @@ class VisionFlashAttention2(nn.Module):
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- attn_output = flash_attn_varlen_func(
- q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=self.is_causal
- ).reshape(seq_length, -1)
+
+ # Original code with flash_attn_varlen_func
+ # attn_output = flash_attn_varlen_func(
+ # q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=self.is_causal
+ # ).reshape(seq_length, -1)
+ # Original code ends
+
+ # Changes start for XPU
+ attn_output = torch.empty(
+ q.shape,
+ dtype=q.dtype,
+ device=q.device)
+ ipex.llm.functional.varlen_attention(
+ q.contiguous(), # query
+ k.contiguous(), # key
+ v.contiguous(), # value
+ attn_output, # out
+ cu_seqlens.int(), # seqlen_q
+ cu_seqlens.int(), # seqlen_k
+ None, # alibi_slopes
+ max_seqlen, # max_seqlen_q
+ max_seqlen, # max_seqlen_k
+ 0.0, # pdropout
+ 1.0 / (q.shape[-1] ** 0.5), # softmax_scale
+ False, # zero_tensors
+ self.is_causal, # is_causal
+ False, # return_softmax
+ None, # gen_
+ -1, # window_size_left
+ -1, # window_size_right
+ -1, # logits_soft_cap
+ )
+ attn_output = attn_output.reshape(seq_length, -1)
+ # Changes end for XPU
+
attn_output = self.proj(attn_output)

return attn_output