Skip to content

Commit 6650e6a

Browse files
[Model] Add classification Task with Qwen2ForSequenceClassification (#9704)
Signed-off-by: Kevin-Yang <[email protected]> Co-authored-by: Kevin-Yang <[email protected]>
1 parent 07e981f commit 6650e6a

File tree

6 files changed

+211
-1
lines changed

6 files changed

+211
-1
lines changed

docs/source/models/supported_models.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,28 @@ Reward Modeling
361361
.. note::
362362
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
363363

364+
Classification
365+
---------------
366+
367+
.. list-table::
368+
:widths: 25 25 50 5 5
369+
:header-rows: 1
370+
371+
* - Architecture
372+
- Models
373+
- Example HF Models
374+
- :ref:`LoRA <lora>`
375+
- :ref:`PP <distributed_serving>`
376+
* - :code:`Qwen2ForSequenceClassification`
377+
- Qwen2-based
378+
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
379+
-
380+
- ✅︎
381+
382+
.. note::
383+
As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now).
384+
385+
364386
Multimodal Language Models
365387
^^^^^^^^^^^^^^^^^^^^^^^^^^
366388

tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,17 @@ def get_inputs(
343343

344344
return all_inputs
345345

346+
def classify(self, prompts: List[str]) -> List[str]:
347+
# output is final logits
348+
all_inputs = self.get_inputs(prompts)
349+
outputs = []
350+
for inputs in all_inputs:
351+
output = self.model(**self.wrap_device(inputs))
352+
logits = output.logits.softmax(dim=-1)[0].tolist()
353+
outputs.append(logits)
354+
355+
return outputs
356+
346357
def generate(
347358
self,
348359
prompts: List[str],
@@ -688,6 +699,14 @@ def get_inputs(
688699

689700
return inputs
690701

702+
def classify(self, prompts: List[str]) -> List[str]:
703+
req_outputs = self.model.encode(prompts)
704+
outputs = []
705+
for req_output in req_outputs:
706+
embedding = req_output.outputs.embedding
707+
outputs.append(embedding)
708+
return outputs
709+
691710
def generate(
692711
self,
693712
prompts: List[str],
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Compare the outputs of HF and vLLM when using greedy sampling.
2+
3+
This test only tests small models. Big models such as 7B should be tested from
4+
test_big_models.py because it could use a larger instance to run tests.
5+
6+
Run `pytest tests/models/test_cls_models.py`.
7+
"""
8+
import pytest
9+
import torch
10+
from transformers import AutoModelForSequenceClassification
11+
12+
CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"]
13+
14+
15+
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
16+
@pytest.mark.parametrize("dtype", ["float"])
17+
def test_classification_models(
18+
hf_runner,
19+
vllm_runner,
20+
example_prompts,
21+
model: str,
22+
dtype: str,
23+
) -> None:
24+
with hf_runner(model,
25+
dtype=dtype,
26+
auto_cls=AutoModelForSequenceClassification) as hf_model:
27+
hf_outputs = hf_model.classify(example_prompts)
28+
29+
with vllm_runner(model, dtype=dtype) as vllm_model:
30+
vllm_outputs = vllm_model.classify(example_prompts)
31+
32+
print(hf_outputs, vllm_outputs)
33+
34+
# check logits difference
35+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
36+
hf_output = torch.tensor(hf_output)
37+
vllm_output = torch.tensor(vllm_output)
38+
39+
assert torch.allclose(hf_output, vllm_output, 1e-3)
40+
41+
42+
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
43+
@pytest.mark.parametrize("dtype", ["float"])
44+
def test_classification_model_print(
45+
vllm_runner,
46+
model: str,
47+
dtype: str,
48+
) -> None:
49+
with vllm_runner(model, dtype=dtype) as vllm_model:
50+
# This test is for verifying whether the model's extra_repr
51+
# can be printed correctly.
52+
print(vllm_model.model.llm_engine.model_executor.driver_worker.
53+
model_runner.model)

vllm/model_executor/layers/pooler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ class Pooler(nn.Module):
2828
normalize: Whether to normalize the pooled data.
2929
"""
3030

31-
def __init__(self, pooling_type: PoolingType, normalize: bool):
31+
def __init__(self,
32+
pooling_type: PoolingType,
33+
normalize: bool,
34+
softmax: bool = False):
3235
super().__init__()
3336

3437
self.pooling_type = pooling_type
3538
self.normalize = normalize
39+
self.softmax = softmax
3640

3741
def forward(
3842
self,
@@ -64,6 +68,9 @@ def forward(
6468
if self.normalize:
6569
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
6670

71+
if self.softmax:
72+
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
73+
6774
pooled_outputs = [
6875
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
6976
]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# coding=utf-8
2+
# Adapted from
3+
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
4+
# Copyright 2024 Kakao Corp. (Kanana-X Team)
5+
# Copyright 2024 The Qwen team.
6+
# Copyright 2023 The vLLM team.
7+
"""Inference-only Qwen2-Classification model compatible with HF weights."""
8+
from typing import Iterable, List, Optional, Tuple
9+
10+
import torch
11+
from torch import nn
12+
from transformers import Qwen2Config
13+
14+
from vllm.attention import AttentionMetadata
15+
from vllm.config import CacheConfig, LoRAConfig
16+
from vllm.model_executor.layers.linear import RowParallelLinear
17+
from vllm.model_executor.layers.pooler import Pooler, PoolingType
18+
from vllm.model_executor.layers.quantization.base_config import (
19+
QuantizationConfig)
20+
from vllm.model_executor.models.qwen2 import Qwen2Model
21+
from vllm.model_executor.pooling_metadata import PoolingMetadata
22+
from vllm.sequence import IntermediateTensors, PoolerOutput
23+
24+
from .utils import AutoWeightsLoader
25+
26+
27+
class Qwen2ForSequenceClassification(nn.Module):
28+
packed_modules_mapping = {
29+
"qkv_proj": [
30+
"q_proj",
31+
"k_proj",
32+
"v_proj",
33+
],
34+
"gate_up_proj": [
35+
"gate_proj",
36+
"up_proj",
37+
],
38+
}
39+
40+
# LoRA specific attributes
41+
supported_lora_modules = [
42+
"qkv_proj",
43+
"o_proj",
44+
"gate_up_proj",
45+
"down_proj",
46+
]
47+
embedding_modules = {}
48+
embedding_padding_modules = []
49+
50+
def __init__(
51+
self,
52+
config: Qwen2Config,
53+
cache_config: Optional[CacheConfig] = None,
54+
quant_config: Optional[QuantizationConfig] = None,
55+
lora_config: Optional[LoRAConfig] = None,
56+
) -> None:
57+
# TODO (@robertgshaw2): see if this can be moved out
58+
if (cache_config.sliding_window is not None
59+
and hasattr(config, "max_window_layers")):
60+
raise ValueError("Sliding window for some but all layers is not "
61+
"supported. This model uses sliding window "
62+
"but `max_window_layers` = %s is less than "
63+
"`num_hidden_layers` = %s. Please open an issue "
64+
"to discuss this feature." % (
65+
config.max_window_layers,
66+
config.num_hidden_layers,
67+
))
68+
69+
super().__init__()
70+
71+
self.config = config
72+
self.lora_config = lora_config
73+
74+
self.quant_config = quant_config
75+
self.model = Qwen2Model(config, cache_config, quant_config)
76+
77+
self.score = RowParallelLinear(config.hidden_size,
78+
config.num_labels,
79+
quant_config=quant_config)
80+
self._pooler = Pooler(pooling_type=PoolingType.LAST,
81+
normalize=False,
82+
softmax=True)
83+
84+
def forward(
85+
self,
86+
input_ids: torch.Tensor,
87+
positions: torch.Tensor,
88+
kv_caches: List[torch.Tensor],
89+
attn_metadata: AttentionMetadata,
90+
intermediate_tensors: Optional[IntermediateTensors] = None,
91+
) -> torch.Tensor:
92+
hidden_states = self.model(input_ids, positions, kv_caches,
93+
attn_metadata, intermediate_tensors)
94+
logits, _ = self.score(hidden_states)
95+
return logits
96+
97+
def pooler(
98+
self,
99+
hidden_states: torch.Tensor,
100+
pooling_metadata: PoolingMetadata,
101+
) -> Optional[PoolerOutput]:
102+
return self._pooler(hidden_states, pooling_metadata)
103+
104+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
105+
loader = AutoWeightsLoader(self,
106+
ignore_unexpected_prefixes=["lm_head."])
107+
loader.load_weights(weights)

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@
9696
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
9797
"MistralModel": ("llama", "LlamaEmbeddingModel"),
9898
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
99+
"Qwen2ForSequenceClassification": (
100+
"qwen2_cls", "Qwen2ForSequenceClassification"),
99101
# [Multimodal]
100102
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
101103
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),

0 commit comments

Comments
 (0)