Skip to content

Add grammar-based sampling #572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
import sys
import uuid
import time
Expand All @@ -23,6 +24,7 @@

from . import llama_cpp
from .llama_types import *
from .llama_grammar import LlamaGrammar

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -223,6 +225,7 @@ def __init__(
tensor_split: Optional[List[float]] = None,
rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0,
grammar: Optional[Union[str, Path]] = None,
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
verbose: bool = True,
Expand All @@ -248,6 +251,7 @@ def __init__(
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
rope_freq_base: Base frequency for rope sampling.
rope_freq_scale: Scale factor for rope sampling.
grammar: Path to a BNF grammar file to use for grammar based sampling.
verbose: Print verbose output to stderr.

Raises:
Expand Down Expand Up @@ -358,6 +362,12 @@ def __init__(
self.scores: npt.NDArray[np.single] = np.ndarray(
(n_ctx, self._n_vocab), dtype=np.single
)
if grammar is not None:
self.grammar = LlamaGrammar.from_file(
grammar, verbose=verbose
) # type: Optional[LlamaGrammar]
else:
self.grammar = None

@property
def _input_ids(self) -> npt.NDArray[np.intc]:
Expand Down Expand Up @@ -542,8 +552,16 @@ def _sample(
)
if not penalize_nl:
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)

if self.grammar is not None:
llama_cpp.llama_sample_grammar(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
grammar=self.grammar.grammar,
)

if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy(
id = llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
)
Expand All @@ -555,7 +573,7 @@ def _sample(
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token_mirostat(
id = llama_cpp.llama_sample_token_mirostat(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
Expand All @@ -570,7 +588,7 @@ def _sample(
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token_mirostat_v2(
id = llama_cpp.llama_sample_token_mirostat_v2(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
Expand Down Expand Up @@ -607,10 +625,17 @@ def _sample(
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token(
id = llama_cpp.llama_sample_token(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
)
if self.grammar is not None:
llama_cpp.llama_grammar_accept_token(
ctx=self.ctx,
grammar=self.grammar.grammar,
token=llama_cpp.ctypes.c_int(id),
)
return id

def sample(
self,
Expand Down Expand Up @@ -698,7 +723,6 @@ def generate(
The generated tokens.
"""
assert self.ctx is not None

if reset and len(self._input_ids) > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
Expand All @@ -716,6 +740,9 @@ def generate(
if reset:
self.reset()

if self.grammar is not None:
self.grammar.reset()

while True:
self.eval(tokens)
token = self.sample(
Expand Down
34 changes: 34 additions & 0 deletions llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,23 @@ def llama_sample_temperature(
_lib.llama_sample_temperature.restype = None


# LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
def llama_sample_grammar(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
grammar, # type: llama_grammar_p
):
return _lib.llama_sample_grammar(ctx, candidates, grammar)


_lib.llama_sample_grammar.argtypes = [
llama_context_p,
llama_token_data_array_p,
llama_grammar_p,
]
_lib.llama_sample_grammar.restype = None


# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
Expand Down Expand Up @@ -1244,6 +1261,23 @@ def llama_sample_token(
_lib.llama_sample_token.restype = llama_token


# /// @details Accepts the sampled token into the grammar
# LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
def llama_grammar_accept_token(
ctx: llama_context_p,
grammar: llama_grammar_p,
token: llama_token,
) -> None:
_lib.llama_grammar_accept_token(ctx, grammar, token)


_lib.llama_grammar_accept_token.argtypes = [
llama_context_p,
llama_grammar_p,
llama_token,
]
_lib.llama_grammar_accept_token.restype = None

# Performance information


Expand Down
Loading