Skip to content

Added support for min_p #921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,8 @@ def sample(
self,
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.1,
frequency_penalty: float = 0.0,
Expand Down Expand Up @@ -1108,7 +1110,10 @@ def sample(
grammar=grammar,
)

if temp == 0.0:
if temp < 0.0:
self._ctx.sample_softmax(candidates=self._candidates)
id = self._candidates.candidates.data[0].id
elif temp == 0.0:
id = self._ctx.sample_token_greedy(candidates=self._candidates)
elif mirostat_mode == 1:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
Expand All @@ -1130,8 +1135,9 @@ def sample(
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(candidates=self._candidates, p=1.0, min_keep=1)
self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token(candidates=self._candidates)
if grammar is not None:
Expand All @@ -1143,6 +1149,8 @@ def generate(
tokens: Sequence[int],
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.1,
reset: bool = True,
Expand Down Expand Up @@ -1200,6 +1208,8 @@ def generate(
token = self.sample(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
Expand Down Expand Up @@ -1298,6 +1308,8 @@ def _create_completion(
max_tokens: Optional[int] = 16,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
Expand Down Expand Up @@ -1396,6 +1408,8 @@ def _create_completion(
prompt_tokens,
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temperature,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
Expand Down Expand Up @@ -1764,6 +1778,8 @@ def create_completion(
max_tokens: Optional[int] = 16,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
Expand Down Expand Up @@ -1810,6 +1826,8 @@ def create_completion(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
logprobs=logprobs,
echo=echo,
stop=stop,
Expand Down Expand Up @@ -1841,6 +1859,8 @@ def __call__(
max_tokens: int = 128,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
Expand Down Expand Up @@ -1887,6 +1907,8 @@ def __call__(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
logprobs=logprobs,
echo=echo,
stop=stop,
Expand Down Expand Up @@ -1916,6 +1938,8 @@ def create_chat_completion(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
Expand Down Expand Up @@ -1962,6 +1986,8 @@ def create_chat_completion(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
seed=seed,
Expand Down
16 changes: 16 additions & 0 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __call__(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
Expand Down Expand Up @@ -287,6 +289,8 @@ def basic_create_chat_completion(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
Expand Down Expand Up @@ -330,6 +334,8 @@ def basic_create_chat_completion(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
seed=seed,
Expand Down Expand Up @@ -579,6 +585,8 @@ def functionary_chat_handler(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
Expand Down Expand Up @@ -761,6 +769,8 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=["user:", "</s>"],
max_tokens=max_tokens,
Expand Down Expand Up @@ -831,6 +841,8 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
Expand Down Expand Up @@ -929,6 +941,8 @@ def __call__(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[
Expand Down Expand Up @@ -1045,6 +1059,8 @@ def __call__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
max_tokens=max_tokens,
Expand Down
10 changes: 10 additions & 0 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,14 @@ async def get_event_publisher(
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
)

min_p_field = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Sets a minimum base probability threshold for token selection.\n\n"
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
)

stop_field = Field(
default=None,
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
Expand Down Expand Up @@ -593,6 +601,7 @@ class CreateCompletionRequest(BaseModel):
max_tokens: int = max_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
echo: bool = Field(
default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
Expand Down Expand Up @@ -788,6 +797,7 @@ class CreateChatCompletionRequest(BaseModel):
)
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
stop: Optional[List[str]] = stop_field
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
Expand Down