2
2
3
3
import torch
4
4
5
- from vllm .model_executor .layers .utils import (
6
- apply_penalties as _apply_penalties )
5
+ from vllm .model_executor .layers .utils import apply_penalties
7
6
from vllm .utils import is_pin_memory_available , make_tensor_with_pad
8
7
9
8
@@ -17,27 +16,30 @@ def apply_min_token_penalties(logits: torch.Tensor,
17
16
"""
18
17
min_tokens_logits_to_penalize : List [Tuple [int , int ]] = []
19
18
for index , min_token in enumerate (min_tokens ):
20
- if ( len (output_token_ids [index ]) < min_token ) :
19
+ if len (output_token_ids [index ]) < min_token :
21
20
for stop_token_id in stop_token_ids [index ]:
22
21
min_tokens_logits_to_penalize .append ((index , stop_token_id ))
23
22
if min_tokens_logits_to_penalize :
24
23
logits [tuple (zip (* min_tokens_logits_to_penalize ))] = - float ("inf" )
25
24
26
25
27
- def apply_penalties (logits : torch .Tensor , prompt_token_ids : torch .Tensor ,
28
- presence_penalties : torch .Tensor ,
29
- frequency_penalties : torch .Tensor ,
30
- repetition_penalties : torch .Tensor ,
31
- output_token_ids : List [List [int ]]) -> torch .Tensor :
26
+ def apply_all_penalties (
27
+ logits : torch .Tensor ,
28
+ prompt_token_ids : torch .Tensor ,
29
+ presence_penalties : torch .Tensor ,
30
+ frequency_penalties : torch .Tensor ,
31
+ repetition_penalties : torch .Tensor ,
32
+ output_token_ids : List [List [int ]],
33
+ ) -> torch .Tensor :
32
34
"""
33
35
Applies presence, frequency and repetition penalties to the logits.
34
36
"""
35
37
_ , vocab_size = logits .shape
36
38
output_tokens_t = _convert_to_tensors (output_token_ids , vocab_size ,
37
39
logits .device )
38
- return _apply_penalties (logits , prompt_token_ids , output_tokens_t ,
39
- presence_penalties , frequency_penalties ,
40
- repetition_penalties )
40
+ return apply_penalties (logits , prompt_token_ids , output_tokens_t ,
41
+ presence_penalties , frequency_penalties ,
42
+ repetition_penalties )
41
43
42
44
43
45
def _convert_to_tensors (output_token_ids : List [List [int ]], vocab_size : int ,
0 commit comments