Skip to content

Commit 09e7d2d

Browse files
committed
add comment and fix lint
1 parent 6779f40 commit 09e7d2d

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

vllm/model_executor/models/qwen2_rm.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,20 @@
1414
from vllm.config import CacheConfig, LoRAConfig
1515
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1616
RowParallelLinear)
17+
from vllm.model_executor.layers.pooler import Pooler, PoolingType
1718
from vllm.model_executor.layers.quantization.base_config import (
1819
QuantizationConfig)
1920
from vllm.model_executor.model_loader.weight_utils import (
2021
default_weight_loader, maybe_remap_kv_scale_name)
21-
from vllm.model_executor.layers.pooler import Pooler, PoolingType
22+
from vllm.model_executor.models.qwen2 import Qwen2Model
2223
from vllm.model_executor.pooling_metadata import PoolingMetadata
23-
from vllm.sequence import IntermediateTensors
24-
from vllm.sequence import PoolerOutput
25-
24+
from vllm.sequence import IntermediateTensors, PoolerOutput
2625

2726
from .utils import is_pp_missing_parameter
28-
from vllm.model_executor.models.qwen2 import Qwen2Model
2927

3028

3129
class ReLU(nn.Module):
30+
3231
def __init__(self):
3332
super().__init__()
3433
self.activation = nn.ReLU()
@@ -89,9 +88,12 @@ def __init__(
8988
self.model = Qwen2Model(config, cache_config, quant_config)
9089

9190
self.score = nn.Sequential(
92-
ColumnParallelLinear(config.hidden_size, config.hidden_size, quant_config=quant_config),
91+
ColumnParallelLinear(config.hidden_size,
92+
config.hidden_size,
93+
quant_config=quant_config),
9394
ReLU(),
94-
RowParallelLinear(config.hidden_size, 1, quant_config=quant_config),
95+
RowParallelLinear(config.hidden_size, 1,
96+
quant_config=quant_config),
9597
)
9698
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False)
9799

@@ -126,6 +128,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
126128
]
127129
params_dict = dict(self.named_parameters(remove_duplicate=False))
128130
for name, loaded_weight in weights:
131+
# Skip loading lm_head for embedding model
129132
if name == "lm_head.weight":
130133
continue
131134
if "rotary_emb.inv_freq" in name:

0 commit comments

Comments
 (0)