|
2 | 2 | import os
|
3 | 3 | from abc import ABC, abstractmethod
|
4 | 4 | from dataclasses import dataclass, field, fields
|
5 |
| -from typing import List, NamedTuple, Optional, Tuple, Union |
| 5 | +from typing import Dict, List, NamedTuple, Optional, Tuple, Union |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | from pydantic import BaseModel
|
@@ -108,6 +108,55 @@ def __call__(
|
108 | 108 | pass # noqa
|
109 | 109 |
|
110 | 110 |
|
| 111 | +class LogitBiasLogitsProcessor(LogitsProcessor): |
| 112 | + def __init__(self, logit_bias: Dict[str, float]) -> None: |
| 113 | + super().__init__() |
| 114 | + self.logit_bias = logit_bias |
| 115 | + self.tokens_to_adjust = self.process_logit_bias(logit_bias) |
| 116 | + if not self.tokens_to_adjust: |
| 117 | + raise ValueError("Empty logit_bias provided - no tokens to adjust") |
| 118 | + |
| 119 | + def process_logit_bias(self, logit_bias: Dict[str, float]) -> Dict[int, float]: |
| 120 | + valid = {} |
| 121 | + invalid = {} |
| 122 | + |
| 123 | + for k, v in logit_bias.items(): |
| 124 | + try: |
| 125 | + token_id = int(k) |
| 126 | + valid[token_id] = v |
| 127 | + except (ValueError, TypeError): |
| 128 | + invalid[k] = v |
| 129 | + |
| 130 | + if invalid: |
| 131 | + raise ValueError( |
| 132 | + f"Invalid token_ids in logit_bias: {list(invalid.keys())}. " |
| 133 | + f"All keys must be integers." |
| 134 | + ) |
| 135 | + return valid |
| 136 | + |
| 137 | + def __call__( |
| 138 | + self, |
| 139 | + req_id: int, |
| 140 | + logits: torch.Tensor, |
| 141 | + token_ids: List[List[int]], |
| 142 | + stream_ptr: Optional[int], |
| 143 | + client_id: Optional[int], |
| 144 | + ) -> None: |
| 145 | + vocab_size = logits.size(-1) |
| 146 | + token_ids_list = list(self.tokens_to_adjust.keys()) |
| 147 | + bias_values = torch.tensor(list(self.tokens_to_adjust.values()), device=logits.device) |
| 148 | + |
| 149 | + invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size] |
| 150 | + if invalid_token_ids: |
| 151 | + raise ValueError( |
| 152 | + f"Token ID(s) {invalid_token_ids} exceed vocabulary size (vocab_size={vocab_size})" |
| 153 | + ) |
| 154 | + |
| 155 | + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) |
| 156 | + with torch.cuda.stream(stream): |
| 157 | + logits[:, :, token_ids_list] += bias_values |
| 158 | + |
| 159 | + |
111 | 160 | @dataclass(slots=True, kw_only=True)
|
112 | 161 | class AdditionalModelOutput:
|
113 | 162 | """An additional output to gather from the model.
|
|
0 commit comments