From a2509ede59bf403ef39c682722bd56cd430d602a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 13 Nov 2024 22:43:41 +0800 Subject: [PATCH 1/3] fix qwen2cls tp Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/qwen2_cls.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 020af88aadd9..b65d17d400bd 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -69,9 +69,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # hidden_states from Qwen2Model has been reduced, + # the input of score layer is not parallelized. self.score = RowParallelLinear(config.hidden_size, config.num_labels, - quant_config=quant_config) + quant_config=quant_config, + input_is_parallel=False) self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST, From f443280286691790ffe7f10c5293b5c7a5592fb0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 13 Nov 2024 23:40:28 +0800 Subject: [PATCH 2/3] address comment Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/embedding/language/test_cls_models.py | 6 +++--- vllm/model_executor/models/qwen2_cls.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/models/embedding/language/test_cls_models.py b/tests/models/embedding/language/test_cls_models.py index d8ca6d361f0e..40ee49cf6074 100644 --- a/tests/models/embedding/language/test_cls_models.py +++ b/tests/models/embedding/language/test_cls_models.py @@ -21,14 +21,14 @@ def test_classification_models( model: str, dtype: str, ) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSequenceClassification) as hf_model: hf_outputs = hf_model.classify(example_prompts) - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.classify(example_prompts) - print(hf_outputs, vllm_outputs) # check logits difference diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index b65d17d400bd..926691033089 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -74,7 +74,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.score = RowParallelLinear(config.hidden_size, config.num_labels, quant_config=quant_config, - input_is_parallel=False) + input_is_parallel=False, + prefix=maybe_prefix(prefix, "score")) self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST, From 5cb22f52d65565c5a78ab3e200e7962dfd3cbad7 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 14 Nov 2024 02:38:37 +0800 Subject: [PATCH 3/3] score bias=False Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/qwen2_cls.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 926691033089..27eb7e8a9397 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -75,6 +75,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.num_labels, quant_config=quant_config, input_is_parallel=False, + bias=False, prefix=maybe_prefix(prefix, "score")) self._pooler = Pooler.from_config_with_defaults( pooler_config,