-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Support logit_bias in v1 Sampler #13079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7b17c04
c48b194
17b50c5
3b84436
d105842
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ def forward( | |
|
||
# Use float32 for the logits. | ||
logits = logits.to(torch.float32) | ||
# Apply logits bias. | ||
logits = self.apply_logits_bias(logits, sampling_metadata) | ||
# Apply penalties (e.g., min_tokens, freq_penalties). | ||
logits = self.apply_penalties(logits, sampling_metadata) | ||
# Apply temperature. | ||
|
@@ -166,3 +168,17 @@ def apply_penalties( | |
sampling_metadata.repetition_penalties, | ||
sampling_metadata.output_token_ids) | ||
return logits | ||
|
||
def apply_logits_bias( | ||
self, | ||
logits: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> torch.Tensor: | ||
# TODO(houseroad): this implementation is extremely inefficient. | ||
# One idea is implement this as a PyTorch C++ op, and we may | ||
# even optimize the logit_bias layout. | ||
for i, logit_bias in enumerate(sampling_metadata.logit_bias): | ||
if logit_bias: | ||
for token_id, bias in logit_bias.items(): | ||
logits[i, token_id] += bias | ||
|
||
return logits |
Uh oh!
There was an error while loading. Please reload this page.