From 3c13436e5ec2e41cb6e7a574e64e14fb61fc5050 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 28 Nov 2023 04:09:13 -0500 Subject: [PATCH 1/9] Clean up llama.py into seperate modules. --- llama_cpp/_internals.py | 528 ++++++++++++++++++++++++++++ llama_cpp/llama.py | 738 +++------------------------------------ llama_cpp/llama_cache.py | 154 ++++++++ 3 files changed, 732 insertions(+), 688 deletions(-) create mode 100644 llama_cpp/_internals.py create mode 100644 llama_cpp/llama_cache.py diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py new file mode 100644 index 000000000..90e4a0605 --- /dev/null +++ b/llama_cpp/_internals.py @@ -0,0 +1,528 @@ +from __future__ import annotations + +import os + +from typing import ( + List, + Optional, + Sequence, +) + +import ctypes + +import numpy as np +import numpy.typing as npt + +from .llama_types import * +from .llama_grammar import LlamaGrammar + +import llama_cpp.llama_cpp as llama_cpp + +from ._utils import suppress_stdout_stderr + + + +class _LlamaModel: + """Intermediate Python wrapper for a llama.cpp llama_model. + + NOTE: For stability it's recommended you use the Llama class instead.""" + + _llama_free_model = None + + def __init__( + self, + *, + path_model: str, + params: llama_cpp.llama_model_params, + verbose: bool = True, + ): + self.path_model = path_model + self.params = params + self.verbose = verbose + + self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore + + if not os.path.exists(path_model): + raise ValueError(f"Model path does not exist: {path_model}") + + with suppress_stdout_stderr(disable=self.verbose): + self.model = llama_cpp.llama_load_model_from_file( + self.path_model.encode("utf-8"), self.params + ) + + def __del__(self): + with suppress_stdout_stderr(disable=self.verbose): + if self.model is not None and self._llama_free_model is not None: + self._llama_free_model(self.model) + self.model = None + + def vocab_type(self) -> int: + assert self.model is not None + return llama_cpp.llama_vocab_type(self.model) + + def n_vocab(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_vocab(self.model) + + def n_ctx_train(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_ctx_train(self.model) + + def n_embd(self) -> int: + assert self.model is not None + return llama_cpp.llama_n_embd(self.model) + + def rope_freq_scale_train(self) -> float: + assert self.model is not None + return llama_cpp.llama_rope_freq_scale_train(self.model) + + def desc(self) -> str: + assert self.model is not None + buf = ctypes.create_string_buffer(1024) + llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore + return buf.value.decode("utf-8") + + def size(self) -> int: + assert self.model is not None + return llama_cpp.llama_model_size(self.model) + + def n_params(self) -> int: + assert self.model is not None + return llama_cpp.llama_model_n_params(self.model) + + def get_tensor(self, name: str) -> ctypes.c_void_p: + assert self.model is not None + return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) + + def apply_lora_from_file( + self, + lora_path: str, + scale: float, + path_base_model: Optional[str], + n_threads: int, + ): + assert self.model is not None + return llama_cpp.llama_model_apply_lora_from_file( + self.model, + lora_path.encode("utf-8"), + scale, + path_base_model.encode("utf-8") + if path_base_model is not None + else llama_cpp.c_char_p(0), + n_threads, + ) + + # Vocab + + def token_get_text(self, token: int) -> str: + # TODO: Fix + assert self.model is not None + return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") + + def token_get_score(self, token: int) -> float: + assert self.model is not None + return llama_cpp.llama_token_get_score(self.model, token) + + def token_get_type(self, token: int) -> int: + assert self.model is not None + return llama_cpp.llama_token_get_type(self.model, token) + + # Special tokens + + def token_bos(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_bos(self.model) + + def token_eos(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_eos(self.model) + + def token_nl(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_nl(self.model) + + def token_prefix(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_prefix(self.model) + + def token_middle(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_middle(self.model) + + def token_suffix(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_suffix(self.model) + + def token_eot(self) -> int: + assert self.model is not None + return llama_cpp.llama_token_eot(self.model) + + # Tokenization + + def tokenize(self, text: bytes, add_bos: bool, special: bool): + assert self.model is not None + n_ctx = self.n_ctx_train() + tokens = (llama_cpp.llama_token * n_ctx)() + n_tokens = llama_cpp.llama_tokenize( + self.model, text, len(text), tokens, n_ctx, add_bos, special + ) + if n_tokens < 0: + n_tokens = abs(n_tokens) + tokens = (llama_cpp.llama_token * n_tokens)() + n_tokens = llama_cpp.llama_tokenize( + self.model, text, len(text), tokens, n_tokens, add_bos, special + ) + if n_tokens < 0: + raise RuntimeError( + f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' + ) + return list(tokens[:n_tokens]) + + def token_to_piece(self, token: int) -> bytes: + assert self.model is not None + buf = ctypes.create_string_buffer(32) + llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore + return bytes(buf) + + def detokenize(self, tokens: List[int]) -> bytes: + assert self.model is not None + output = b"" + size = 32 + buffer = (ctypes.c_char * size)() + for token in tokens: + n = llama_cpp.llama_token_to_piece( + self.model, llama_cpp.llama_token(token), buffer, size + ) + assert n <= size + output += bytes(buffer[:n]) + # NOTE: Llama1 models automatically added a space at the start of the prompt + # this line removes a leading space if the first token is a beginning of sentence token + return ( + output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output + ) + + @staticmethod + def default_params(): + """Get the default llama_model_params.""" + return llama_cpp.llama_model_default_params() + + +class _LlamaContext: + """Intermediate Python wrapper for a llama.cpp llama_context. + + NOTE: For stability it's recommended you use the Llama class instead.""" + + _llama_free = None + + def __init__( + self, + *, + model: _LlamaModel, + params: llama_cpp.llama_context_params, + verbose: bool = True, + ): + self.model = model + self.params = params + self.verbose = verbose + + self._llama_free = llama_cpp._lib.llama_free # type: ignore + + with suppress_stdout_stderr(disable=self.verbose): + self.ctx = llama_cpp.llama_new_context_with_model( + self.model.model, self.params + ) + + def __del__(self): + with suppress_stdout_stderr(disable=self.verbose): + if self.ctx is not None and self._llama_free is not None: + self._llama_free(self.ctx) + self.ctx = None + + def n_ctx(self) -> int: + assert self.ctx is not None + return llama_cpp.llama_n_ctx(self.ctx) + + def kv_cache_clear(self): + assert self.ctx is not None + llama_cpp.llama_kv_cache_clear(self.ctx) + + def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) + + def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) + + def kv_cache_seq_keep(self, seq_id: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) + + def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): + assert self.ctx is not None + llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift) + + def get_state_size(self) -> int: + assert self.ctx is not None + return llama_cpp.llama_get_state_size(self.ctx) + + # TODO: copy_state_data + + # TODO: set_state_data + + # TODO: llama_load_session_file + + # TODO: llama_save_session_file + + def decode(self, batch: "_LlamaBatch"): + assert self.ctx is not None + assert batch.batch is not None + return_code = llama_cpp.llama_decode( + ctx=self.ctx, + batch=batch.batch, + ) + if return_code != 0: + raise RuntimeError(f"llama_decode returned {return_code}") + + def set_n_threads(self, n_threads: int, n_threads_batch: int): + assert self.ctx is not None + llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) + + def get_logits(self): + assert self.ctx is not None + return llama_cpp.llama_get_logits(self.ctx) + + def get_logits_ith(self, i: int): + assert self.ctx is not None + return llama_cpp.llama_get_logits_ith(self.ctx, i) + + def get_embeddings(self): + assert self.ctx is not None + return llama_cpp.llama_get_embeddings(self.ctx) + + # Sampling functions + + def set_rng_seed(self, seed: int): + assert self.ctx is not None + llama_cpp.llama_set_rng_seed(self.ctx, seed) + + def sample_repetition_penalties( + self, + candidates: "_LlamaTokenDataArray", + last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", + penalty_last_n: int, + penalty_repeat: float, + penalty_freq: float, + penalty_present: float, + ): + assert self.ctx is not None + llama_cpp.llama_sample_repetition_penalties( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + last_tokens_data, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + ) + + def sample_classifier_free_guidance( + self, + candidates: "_LlamaTokenDataArray", + guidance_ctx: "_LlamaContext", + scale: float, + ): + assert self.ctx is not None + assert guidance_ctx.ctx is not None + llama_cpp.llama_sample_classifier_free_guidance( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + guidance_ctx.ctx, + scale, + ) + + def sample_softmax(self, candidates: "_LlamaTokenDataArray"): + assert self.ctx is not None + llama_cpp.llama_sample_softmax( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_top_k( + self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore + ) + + def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_top_p( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): + assert self.ctx is not None + llama_cpp.llama_sample_min_p( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_tail_free( + self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int + ): + assert self.ctx is not None + llama_cpp.llama_sample_tail_free( + self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore + ) + + def sample_typical( + self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int + ): + assert self.ctx is not None + llama_cpp.llama_sample_typical( + self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore + ) + + def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): + assert self.ctx is not None + llama_cpp.llama_sample_temp( + self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore + ) + + def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): + assert self.ctx is not None + assert grammar.grammar is not None + llama_cpp.llama_sample_grammar( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + grammar.grammar, + ) + + def sample_token_mirostat( + self, + candidates: "_LlamaTokenDataArray", + tau: float, + eta: float, + m: int, + mu: float, + ) -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_mirostat( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + tau, + eta, + m, + ctypes.pointer(ctypes.c_float(mu)), + ) + + def sample_token_mirostat_v2( + self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: float + ) -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_mirostat_v2( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + tau, + eta, + ctypes.pointer(ctypes.c_float(mu)), + ) + + def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token_greedy( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: + assert self.ctx is not None + return llama_cpp.llama_sample_token( + self.ctx, + ctypes.byref(candidates.candidates), # type: ignore + ) + + # Grammar + def grammar_accept_token(self, grammar: LlamaGrammar, token: int): + assert self.ctx is not None + assert grammar.grammar is not None + llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token) + + def reset_timings(self): + assert self.ctx is not None + llama_cpp.llama_reset_timings(self.ctx) + + def print_timings(self): + assert self.ctx is not None + llama_cpp.llama_print_timings(self.ctx) + + # Utility functions + @staticmethod + def default_params(): + """Get the default llama_context_params.""" + return llama_cpp.llama_context_default_params() + + +class _LlamaBatch: + _llama_batch_free = None + + def __init__( + self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True + ): + self.n_tokens = n_tokens + self.embd = embd + self.n_seq_max = n_seq_max + self.verbose = verbose + + self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore + + with suppress_stdout_stderr(disable=self.verbose): + self.batch = llama_cpp.llama_batch_init( + self.n_tokens, self.embd, self.n_seq_max + ) + + def __del__(self): + with suppress_stdout_stderr(disable=self.verbose): + if self.batch is not None and self._llama_batch_free is not None: + self._llama_batch_free(self.batch) + self.batch = None + + def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): + assert self.batch is not None + n_tokens = len(batch) + self.batch.n_tokens = n_tokens + for i in range(n_tokens): + self.batch.token[i] = batch[i] + self.batch.pos[i] = n_past + i + self.batch.seq_id[i][0] = 0 + self.batch.n_seq_id[i] = 1 + self.batch.logits[i] = logits_all + self.batch.logits[n_tokens - 1] = True + + +class _LlamaTokenDataArray: + def __init__(self, *, n_vocab: int): + self.n_vocab = n_vocab + self.candidates_data = np.array( + [], + dtype=np.dtype( + [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True + ), + ) + self.candidates_data.resize(3, self.n_vocab, refcheck=False) + self.candidates = llama_cpp.llama_token_data_array( + data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), + size=self.n_vocab, + sorted=False, + ) + self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) + self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) + + def copy_logits(self, logits: npt.NDArray[np.single]): + self.candidates_data["id"][:] = self.default_candidates_data_id + self.candidates_data["logit"][:] = logits + self.candidates_data["p"][:] = self.default_candidates_data_p + self.candidates.data = self.candidates_data.ctypes.data_as( + llama_cpp.llama_token_data_p + ) + self.candidates.sorted = llama_cpp.c_bool(False) + self.candidates.size = llama_cpp.c_size_t(self.n_vocab) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3cb07e524..d2ba8695c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import os import sys import uuid import time import math import multiprocessing -from abc import ABC, abstractmethod + from typing import ( List, Optional, @@ -13,706 +15,26 @@ Sequence, Iterator, Deque, - Tuple, Callable, ) -from collections import deque, OrderedDict +from collections import deque -import diskcache import ctypes +import numpy as np +import numpy.typing as npt + from .llama_types import * from .llama_grammar import LlamaGrammar +from .llama_cache import BaseLlamaCache + import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format -import numpy as np -import numpy.typing as npt - +from ._internals import _LlamaModel, _LlamaContext, _LlamaBatch, _LlamaTokenDataArray # type: ignore from ._utils import suppress_stdout_stderr -class BaseLlamaCache(ABC): - """Base cache class for a llama.cpp model.""" - - def __init__(self, capacity_bytes: int = (2 << 30)): - self.capacity_bytes = capacity_bytes - - @property - @abstractmethod - def cache_size(self) -> int: - raise NotImplementedError - - def _find_longest_prefix_key( - self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: - pass - - @abstractmethod - def __getitem__(self, key: Sequence[int]) -> "LlamaState": - raise NotImplementedError - - @abstractmethod - def __contains__(self, key: Sequence[int]) -> bool: - raise NotImplementedError - - @abstractmethod - def __setitem__(self, key: Sequence[int], value: "LlamaState") -> None: - raise NotImplementedError - - -class LlamaRAMCache(BaseLlamaCache): - """Cache for a llama.cpp model using RAM.""" - - def __init__(self, capacity_bytes: int = (2 << 30)): - super().__init__(capacity_bytes) - self.capacity_bytes = capacity_bytes - self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict() - - @property - def cache_size(self): - return sum([state.llama_state_size for state in self.cache_state.values()]) - - def _find_longest_prefix_key( - self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: - min_len = 0 - min_key = None - keys = ( - (k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() - ) - for k, prefix_len in keys: - if prefix_len > min_len: - min_len = prefix_len - min_key = k - return min_key - - def __getitem__(self, key: Sequence[int]) -> "LlamaState": - key = tuple(key) - _key = self._find_longest_prefix_key(key) - if _key is None: - raise KeyError("Key not found") - value = self.cache_state[_key] - self.cache_state.move_to_end(_key) - return value - - def __contains__(self, key: Sequence[int]) -> bool: - return self._find_longest_prefix_key(tuple(key)) is not None - - def __setitem__(self, key: Sequence[int], value: "LlamaState"): - key = tuple(key) - if key in self.cache_state: - del self.cache_state[key] - self.cache_state[key] = value - while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: - self.cache_state.popitem(last=False) - - -# Alias for backwards compatibility -LlamaCache = LlamaRAMCache - - -class LlamaDiskCache(BaseLlamaCache): - """Cache for a llama.cpp model using disk.""" - - def __init__( - self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) - ): - super().__init__(capacity_bytes) - self.cache = diskcache.Cache(cache_dir) - - @property - def cache_size(self): - return int(self.cache.volume()) # type: ignore - - def _find_longest_prefix_key( - self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: - min_len = 0 - min_key: Optional[Tuple[int, ...]] = None - for k in self.cache.iterkeys(): # type: ignore - prefix_len = Llama.longest_token_prefix(k, key) - if prefix_len > min_len: - min_len = prefix_len - min_key = k # type: ignore - return min_key - - def __getitem__(self, key: Sequence[int]) -> "LlamaState": - key = tuple(key) - _key = self._find_longest_prefix_key(key) - if _key is None: - raise KeyError("Key not found") - value: "LlamaState" = self.cache.pop(_key) # type: ignore - # NOTE: This puts an integer as key in cache, which breaks, - # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens - # self.cache.push(_key, side="front") # type: ignore - return value - - def __contains__(self, key: Sequence[int]) -> bool: - return self._find_longest_prefix_key(tuple(key)) is not None - - def __setitem__(self, key: Sequence[int], value: "LlamaState"): - print("LlamaDiskCache.__setitem__: called", file=sys.stderr) - key = tuple(key) - if key in self.cache: - print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) - del self.cache[key] - self.cache[key] = value - print("LlamaDiskCache.__setitem__: set", file=sys.stderr) - while self.cache_size > self.capacity_bytes and len(self.cache) > 0: - key_to_remove = next(iter(self.cache)) - del self.cache[key_to_remove] - print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) - - -class LlamaState: - def __init__( - self, - input_ids: npt.NDArray[np.intc], - scores: npt.NDArray[np.single], - n_tokens: int, - llama_state: bytes, - llama_state_size: int, - ): - self.input_ids = input_ids - self.scores = scores - self.n_tokens = n_tokens - self.llama_state = llama_state - self.llama_state_size = llama_state_size - - -LogitsProcessor = Callable[ - [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] -] - - -class LogitsProcessorList(List[LogitsProcessor]): - def __call__( - self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] - ) -> npt.NDArray[np.single]: - for processor in self: - scores = processor(input_ids, scores) - return scores - - -StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] - - -class StoppingCriteriaList(List[StoppingCriteria]): - def __call__( - self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] - ) -> bool: - return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) - - -class _LlamaModel: - """Intermediate Python wrapper for a llama.cpp llama_model. - - NOTE: For stability it's recommended you use the Llama class instead.""" - - _llama_free_model = None - - def __init__( - self, - *, - path_model: str, - params: llama_cpp.llama_model_params, - verbose: bool = True, - ): - self.path_model = path_model - self.params = params - self.verbose = verbose - - self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore - - if not os.path.exists(path_model): - raise ValueError(f"Model path does not exist: {path_model}") - - with suppress_stdout_stderr(disable=self.verbose): - self.model = llama_cpp.llama_load_model_from_file( - self.path_model.encode("utf-8"), self.params - ) - - def __del__(self): - with suppress_stdout_stderr(disable=self.verbose): - if self.model is not None and self._llama_free_model is not None: - self._llama_free_model(self.model) - self.model = None - - def vocab_type(self) -> int: - assert self.model is not None - return llama_cpp.llama_vocab_type(self.model) - - def n_vocab(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_vocab(self.model) - - def n_ctx_train(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_ctx_train(self.model) - - def n_embd(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_embd(self.model) - - def rope_freq_scale_train(self) -> float: - assert self.model is not None - return llama_cpp.llama_rope_freq_scale_train(self.model) - - def desc(self) -> str: - assert self.model is not None - buf = ctypes.create_string_buffer(1024) - llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore - return buf.value.decode("utf-8") - - def size(self) -> int: - assert self.model is not None - return llama_cpp.llama_model_size(self.model) - - def n_params(self) -> int: - assert self.model is not None - return llama_cpp.llama_model_n_params(self.model) - - def get_tensor(self, name: str) -> ctypes.c_void_p: - assert self.model is not None - return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) - - def apply_lora_from_file( - self, - lora_path: str, - scale: float, - path_base_model: Optional[str], - n_threads: int, - ): - assert self.model is not None - return llama_cpp.llama_model_apply_lora_from_file( - self.model, - lora_path.encode("utf-8"), - scale, - path_base_model.encode("utf-8") - if path_base_model is not None - else llama_cpp.c_char_p(0), - n_threads, - ) - - # Vocab - - def token_get_text(self, token: int) -> str: - # TODO: Fix - assert self.model is not None - return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") - - def token_get_score(self, token: int) -> float: - assert self.model is not None - return llama_cpp.llama_token_get_score(self.model, token) - - def token_get_type(self, token: int) -> int: - assert self.model is not None - return llama_cpp.llama_token_get_type(self.model, token) - - # Special tokens - - def token_bos(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_bos(self.model) - - def token_eos(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_eos(self.model) - - def token_nl(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_nl(self.model) - - def token_prefix(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_prefix(self.model) - - def token_middle(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_middle(self.model) - - def token_suffix(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_suffix(self.model) - - def token_eot(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_eot(self.model) - - # Tokenization - - def tokenize(self, text: bytes, add_bos: bool, special: bool): - assert self.model is not None - n_ctx = self.n_ctx_train() - tokens = (llama_cpp.llama_token * n_ctx)() - n_tokens = llama_cpp.llama_tokenize( - self.model, text, len(text), tokens, n_ctx, add_bos, special - ) - if n_tokens < 0: - n_tokens = abs(n_tokens) - tokens = (llama_cpp.llama_token * n_tokens)() - n_tokens = llama_cpp.llama_tokenize( - self.model, text, len(text), tokens, n_tokens, add_bos, special - ) - if n_tokens < 0: - raise RuntimeError( - f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' - ) - return list(tokens[:n_tokens]) - - def token_to_piece(self, token: int) -> bytes: - assert self.model is not None - buf = ctypes.create_string_buffer(32) - llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore - return bytes(buf) - - def detokenize(self, tokens: List[int]) -> bytes: - assert self.model is not None - output = b"" - size = 32 - buffer = (ctypes.c_char * size)() - for token in tokens: - n = llama_cpp.llama_token_to_piece( - self.model, llama_cpp.llama_token(token), buffer, size - ) - assert n <= size - output += bytes(buffer[:n]) - # NOTE: Llama1 models automatically added a space at the start of the prompt - # this line removes a leading space if the first token is a beginning of sentence token - return ( - output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output - ) - - @staticmethod - def default_params(): - """Get the default llama_model_params.""" - return llama_cpp.llama_model_default_params() - - -class _LlamaContext: - """Intermediate Python wrapper for a llama.cpp llama_context. - - NOTE: For stability it's recommended you use the Llama class instead.""" - - _llama_free = None - - def __init__( - self, - *, - model: _LlamaModel, - params: llama_cpp.llama_context_params, - verbose: bool = True, - ): - self.model = model - self.params = params - self.verbose = verbose - - self._llama_free = llama_cpp._lib.llama_free # type: ignore - - with suppress_stdout_stderr(disable=self.verbose): - self.ctx = llama_cpp.llama_new_context_with_model( - self.model.model, self.params - ) - - def __del__(self): - with suppress_stdout_stderr(disable=self.verbose): - if self.ctx is not None and self._llama_free is not None: - self._llama_free(self.ctx) - self.ctx = None - - def n_ctx(self) -> int: - assert self.ctx is not None - return llama_cpp.llama_n_ctx(self.ctx) - - def kv_cache_clear(self): - assert self.ctx is not None - llama_cpp.llama_kv_cache_clear(self.ctx) - - def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) - - def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) - - def kv_cache_seq_keep(self, seq_id: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) - - def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_shift(self.ctx, seq_id, p0, p1, shift) - - def get_state_size(self) -> int: - assert self.ctx is not None - return llama_cpp.llama_get_state_size(self.ctx) - - # TODO: copy_state_data - - # TODO: set_state_data - - # TODO: llama_load_session_file - - # TODO: llama_save_session_file - - def decode(self, batch: "_LlamaBatch"): - assert self.ctx is not None - assert batch.batch is not None - return_code = llama_cpp.llama_decode( - ctx=self.ctx, - batch=batch.batch, - ) - if return_code != 0: - raise RuntimeError(f"llama_decode returned {return_code}") - - def set_n_threads(self, n_threads: int, n_threads_batch: int): - assert self.ctx is not None - llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) - - def get_logits(self): - assert self.ctx is not None - return llama_cpp.llama_get_logits(self.ctx) - - def get_logits_ith(self, i: int): - assert self.ctx is not None - return llama_cpp.llama_get_logits_ith(self.ctx, i) - - def get_embeddings(self): - assert self.ctx is not None - return llama_cpp.llama_get_embeddings(self.ctx) - - # Sampling functions - - def set_rng_seed(self, seed: int): - assert self.ctx is not None - llama_cpp.llama_set_rng_seed(self.ctx, seed) - - def sample_repetition_penalties( - self, - candidates: "_LlamaTokenDataArray", - last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", - penalty_last_n: int, - penalty_repeat: float, - penalty_freq: float, - penalty_present: float, - ): - assert self.ctx is not None - llama_cpp.llama_sample_repetition_penalties( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - last_tokens_data, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ) - - def sample_classifier_free_guidance( - self, - candidates: "_LlamaTokenDataArray", - guidance_ctx: "_LlamaContext", - scale: float, - ): - assert self.ctx is not None - assert guidance_ctx.ctx is not None - llama_cpp.llama_sample_classifier_free_guidance( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - guidance_ctx.ctx, - scale, - ) - - def sample_softmax(self, candidates: "_LlamaTokenDataArray"): - assert self.ctx is not None - llama_cpp.llama_sample_softmax( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - ) - - def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_top_k( - self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore - ) - - def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_top_p( - self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore - ) - - def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_min_p( - self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore - ) - - def sample_tail_free( - self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int - ): - assert self.ctx is not None - llama_cpp.llama_sample_tail_free( - self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore - ) - - def sample_typical( - self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int - ): - assert self.ctx is not None - llama_cpp.llama_sample_typical( - self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore - ) - - def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): - assert self.ctx is not None - llama_cpp.llama_sample_temp( - self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore - ) - - def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): - assert self.ctx is not None - assert grammar.grammar is not None - llama_cpp.llama_sample_grammar( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - grammar.grammar, - ) - - def sample_token_mirostat( - self, - candidates: "_LlamaTokenDataArray", - tau: float, - eta: float, - m: int, - mu: float, - ) -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_mirostat( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - tau, - eta, - m, - ctypes.pointer(ctypes.c_float(mu)), - ) - - def sample_token_mirostat_v2( - self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: float - ) -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_mirostat_v2( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - tau, - eta, - ctypes.pointer(ctypes.c_float(mu)), - ) - - def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_greedy( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - ) - - def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token( - self.ctx, - ctypes.byref(candidates.candidates), # type: ignore - ) - - # Grammar - def grammar_accept_token(self, grammar: LlamaGrammar, token: int): - assert self.ctx is not None - assert grammar.grammar is not None - llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token) - - def reset_timings(self): - assert self.ctx is not None - llama_cpp.llama_reset_timings(self.ctx) - - def print_timings(self): - assert self.ctx is not None - llama_cpp.llama_print_timings(self.ctx) - - # Utility functions - @staticmethod - def default_params(): - """Get the default llama_context_params.""" - return llama_cpp.llama_context_default_params() - - -class _LlamaBatch: - _llama_batch_free = None - - def __init__( - self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True - ): - self.n_tokens = n_tokens - self.embd = embd - self.n_seq_max = n_seq_max - self.verbose = verbose - - self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore - - with suppress_stdout_stderr(disable=self.verbose): - self.batch = llama_cpp.llama_batch_init( - self.n_tokens, self.embd, self.n_seq_max - ) - - def __del__(self): - with suppress_stdout_stderr(disable=self.verbose): - if self.batch is not None and self._llama_batch_free is not None: - self._llama_batch_free(self.batch) - self.batch = None - - def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): - assert self.batch is not None - n_tokens = len(batch) - self.batch.n_tokens = n_tokens - for i in range(n_tokens): - self.batch.token[i] = batch[i] - self.batch.pos[i] = n_past + i - self.batch.seq_id[i][0] = 0 - self.batch.n_seq_id[i] = 1 - self.batch.logits[i] = logits_all - self.batch.logits[n_tokens - 1] = True - - -class _LlamaTokenDataArray: - def __init__(self, *, n_vocab: int): - self.n_vocab = n_vocab - self.candidates_data = np.array( - [], - dtype=np.dtype( - [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True - ), - ) - self.candidates_data.resize(3, self.n_vocab, refcheck=False) - self.candidates = llama_cpp.llama_token_data_array( - data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), - size=self.n_vocab, - sorted=False, - ) - self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) - self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) - - def copy_logits(self, logits: npt.NDArray[np.single]): - self.candidates_data["id"][:] = self.default_candidates_data_id - self.candidates_data["logit"][:] = logits - self.candidates_data["p"][:] = self.default_candidates_data_p - self.candidates.data = self.candidates_data.ctypes.data_as( - llama_cpp.llama_token_data_p - ) - self.candidates.sorted = llama_cpp.c_bool(False) - self.candidates.size = llama_cpp.c_size_t(self.n_vocab) - class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -2305,3 +1627,43 @@ def decode(self, tokens: List[int]) -> str: @classmethod def from_ggml_file(cls, path: str) -> "LlamaTokenizer": return cls(Llama(model_path=path, vocab_only=True)) + + +class LlamaState: + def __init__( + self, + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], + n_tokens: int, + llama_state: bytes, + llama_state_size: int, + ): + self.input_ids = input_ids + self.scores = scores + self.n_tokens = n_tokens + self.llama_state = llama_state + self.llama_state_size = llama_state_size + + +LogitsProcessor = Callable[ + [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] +] + + +class LogitsProcessorList(List[LogitsProcessor]): + def __call__( + self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] + ) -> npt.NDArray[np.single]: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] + + +class StoppingCriteriaList(List[StoppingCriteria]): + def __call__( + self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] + ) -> bool: + return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py new file mode 100644 index 000000000..96c682186 --- /dev/null +++ b/llama_cpp/llama_cache.py @@ -0,0 +1,154 @@ +import sys + +from abc import ABC, abstractmethod +from typing import ( + Optional, + Sequence, + Tuple, +) +from collections import OrderedDict + +import llama_cpp.llama + + +class BaseLlamaCache(ABC): + """Base cache class for a llama.cpp model.""" + + def __init__(self, capacity_bytes: int = (2 << 30)): + self.capacity_bytes = capacity_bytes + + @property + @abstractmethod + def cache_size(self) -> int: + raise NotImplementedError + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + pass + + @abstractmethod + def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + raise NotImplementedError + + @abstractmethod + def __contains__(self, key: Sequence[int]) -> bool: + raise NotImplementedError + + @abstractmethod + def __setitem__( + self, key: Sequence[int], value: "llama_cpp.llama.LlamaState" + ) -> None: + raise NotImplementedError + + +class LlamaRAMCache(BaseLlamaCache): + """Cache for a llama.cpp model using RAM.""" + + def __init__(self, capacity_bytes: int = (2 << 30)): + super().__init__(capacity_bytes) + self.capacity_bytes = capacity_bytes + self.cache_state: OrderedDict[ + Tuple[int, ...], "llama_cpp.llama.LlamaState" + ] = OrderedDict() + + @property + def cache_size(self): + return sum([state.llama_state_size for state in self.cache_state.values()]) + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + min_len = 0 + min_key = None + keys = ( + (k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) + for k in self.cache_state.keys() + ) + for k, prefix_len in keys: + if prefix_len > min_len: + min_len = prefix_len + min_key = k + return min_key + + def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + key = tuple(key) + _key = self._find_longest_prefix_key(key) + if _key is None: + raise KeyError("Key not found") + value = self.cache_state[_key] + self.cache_state.move_to_end(_key) + return value + + def __contains__(self, key: Sequence[int]) -> bool: + return self._find_longest_prefix_key(tuple(key)) is not None + + def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): + key = tuple(key) + if key in self.cache_state: + del self.cache_state[key] + self.cache_state[key] = value + while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: + self.cache_state.popitem(last=False) + + +# Alias for backwards compatibility +LlamaCache = LlamaRAMCache + + +class LlamaDiskCache(BaseLlamaCache): + """Cache for a llama.cpp model using disk.""" + + def __init__( + self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) + ): + import diskcache + + super().__init__(capacity_bytes) + self.cache = diskcache.Cache(cache_dir) + + @property + def cache_size(self): + return int(self.cache.volume()) # type: ignore + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + min_len = 0 + min_key: Optional[Tuple[int, ...]] = None + for k in self.cache.iterkeys(): # type: ignore + prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key) + if prefix_len > min_len: + min_len = prefix_len + min_key = k # type: ignore + return min_key + + def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + key = tuple(key) + _key = self._find_longest_prefix_key(key) + if _key is None: + raise KeyError("Key not found") + value: "LlamaState" = self.cache.pop(_key) # type: ignore + # NOTE: This puts an integer as key in cache, which breaks, + # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens + # self.cache.push(_key, side="front") # type: ignore + return value + + def __contains__(self, key: Sequence[int]) -> bool: + return self._find_longest_prefix_key(tuple(key)) is not None + + def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): + print("LlamaDiskCache.__setitem__: called", file=sys.stderr) + key = tuple(key) + if key in self.cache: + print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) + del self.cache[key] + self.cache[key] = value + print("LlamaDiskCache.__setitem__: set", file=sys.stderr) + while self.cache_size > self.capacity_bytes and len(self.cache) > 0: + key_to_remove = next(iter(self.cache)) + del self.cache[key_to_remove] + print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) From 135b300b85ab09bb3aeea60a441dbcdc65751dfd Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 28 Nov 2023 04:12:32 -0500 Subject: [PATCH 2/9] docs: re-order api reference --- docs/api-reference.md | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/docs/api-reference.md b/docs/api-reference.md index c6d76be60..f763a77d6 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -26,16 +26,6 @@ title: API Reference - token_eos show_root_heading: true -::: llama_cpp.LlamaGrammar - options: - members: - - from_string - - from_json_schema - -::: llama_cpp.LlamaCache - options: - show_root_heading: true - ::: llama_cpp.LlamaState options: show_root_heading: true @@ -56,6 +46,13 @@ title: API Reference options: show_root_heading: true +::: llama_cpp.LlamaGrammar + options: + members: + - from_string + - from_json_schema + + ## Low Level API ::: llama_cpp.llama_cpp From 3a1ba774627c97e718b92e18fdf1398ad41ae25d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 28 Nov 2023 06:30:03 -0500 Subject: [PATCH 3/9] Add sampling context --- llama_cpp/_internals.py | 173 ++++++++++++++++++++++++++++++++++++++-- llama_cpp/llama.py | 6 +- 2 files changed, 171 insertions(+), 8 deletions(-) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 90e4a0605..18075c71f 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -1,14 +1,14 @@ from __future__ import annotations import os +import ctypes from typing import ( List, Optional, Sequence, ) - -import ctypes +from dataclasses import dataclass, field import numpy as np import numpy.typing as npt @@ -21,6 +21,8 @@ from ._utils import suppress_stdout_stderr +# Python wrappers over llama.h structs + class _LlamaModel: """Intermediate Python wrapper for a llama.cpp llama_model. @@ -403,7 +405,7 @@ def sample_token_mirostat( tau: float, eta: float, m: int, - mu: float, + mu: ctypes._Pointer[ctypes.c_float], # type: ignore ) -> int: assert self.ctx is not None return llama_cpp.llama_sample_token_mirostat( @@ -412,11 +414,11 @@ def sample_token_mirostat( tau, eta, m, - ctypes.pointer(ctypes.c_float(mu)), + mu, ) def sample_token_mirostat_v2( - self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: float + self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore ) -> int: assert self.ctx is not None return llama_cpp.llama_sample_token_mirostat_v2( @@ -424,7 +426,7 @@ def sample_token_mirostat_v2( ctypes.byref(candidates.candidates), # type: ignore tau, eta, - ctypes.pointer(ctypes.c_float(mu)), + mu, ) def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: @@ -526,3 +528,162 @@ def copy_logits(self, logits: npt.NDArray[np.single]): ) self.candidates.sorted = llama_cpp.c_bool(False) self.candidates.size = llama_cpp.c_size_t(self.n_vocab) + + +# Python wrappers over common/sampling structs + + +@dataclass +class _LlamaSamplingParams: + n_prev: int = 64 + n_probs: int = 0 + top_k: int = 40 + top_p: float = 0.95 + min_p: float = 0.05 + tfs_z: float = 1.00 + typical_p: float = 1.00 + temp: float = 0.80 + penalty_last_n: int = 64 + penalty_repeat: float = 1.10 + penalty_freq: float = 0.00 + penalty_present: float = 0.00 + mirostat: int = 0 + mirostat_tau: float = 5.00 + mirostat_eta: float = 0.10 + penalize_nl: bool = True + + grammar: str = "" + + cfg_negative_prompt: str = "" + cfg_scale: float = 1.00 + + logit_bias: dict[int, float] = field(default_factory=dict) + + +@dataclass +class _LlamaSamplingContext: + params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams) + mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float) + grammar: Optional[LlamaGrammar] = None + # NOTE: Missing parsed_grammar + prev: list[int] = field(default_factory=list) + cur: list[llama_cpp.llama_token_data] = field(default_factory=list) + + def reset(self): + self.prev = [] + self.cur = [] + if self.grammar is not None: + self.grammar.reset() + + def cp(self): + return _LlamaSamplingContext( + params=self.params, + mirostat_mu=self.mirostat_mu, + grammar=self.grammar, + prev=self.prev.copy(), + cur=self.cur.copy(), + ) + + def last(self) -> Optional[int]: + if len(self.prev) > 0: + return self.prev[-1] + else: + return None + + def prev_str(self, ctx_main: _LlamaContext, n: int) -> str: + return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8") + + def sample( + self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext], idx: int = 0 + ): + n_vocab = ctx_main.model.n_vocab() + id = 0 + logits = ctx_main.get_logits_ith(idx) + + # apply logit_bias + for token, logit_bias in self.params.logit_bias.items(): + logits[token] += logit_bias + + logits_array = np.array( + ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents, + dtype=np.single, + ) + token_data_array = _LlamaTokenDataArray( + n_vocab=n_vocab + ) # TODO: Only create this once + token_data_array.copy_logits(logits_array) + + if ctx_cfg is not None: + ctx_main.sample_classifier_free_guidance( + token_data_array, ctx_cfg, self.params.cfg_scale + ) + + # apply penalties + if len(self.prev) > 0: + nl_token = ctx_main.model.token_nl() + nl_logit = logits[nl_token] + if self.params.penalty_last_n > 0: + ctx_main.sample_repetition_penalties( + token_data_array, + # TODO: Only create this once + (llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore + self.params.penalty_last_n, + self.params.penalty_repeat, + self.params.penalty_freq, + self.params.penalty_present, + ) + if not self.params.penalize_nl: + token_data_array.candidates_data["logit"][nl_token] = nl_logit + + if self.grammar is not None: + ctx_main.sample_grammar(token_data_array, self.grammar) + + if self.params.temp < 0: + ctx_main.sample_softmax(token_data_array) + id = token_data_array.candidates_data["id"][0] + elif self.params.temp == 0: + id = ctx_main.sample_token_greedy(token_data_array) + else: + if self.params.mirostat == 1: + mirostat_m = 100 + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token_mirostat( + token_data_array, + self.params.mirostat_tau, + self.params.mirostat_eta, + mirostat_m, + ctypes.pointer(self.mirostat_mu), + ) + elif self.params.mirostat == 2: + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token_mirostat_v2( + token_data_array, + self.params.mirostat_tau, + self.params.mirostat_eta, + ctypes.pointer(self.mirostat_mu), + ) + else: + min_keep = max(1, self.params.n_probs) + ctx_main.sample_top_k( + token_data_array, self.params.top_k, min_keep=min_keep + ) + ctx_main.sample_tail_free( + token_data_array, self.params.tfs_z, min_keep=min_keep + ) + ctx_main.sample_typical( + token_data_array, self.params.typical_p, min_keep=min_keep + ) + ctx_main.sample_top_p( + token_data_array, self.params.top_p, min_keep=min_keep + ) + ctx_main.sample_min_p( + token_data_array, self.params.min_p, min_keep=min_keep + ) + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token(token_data_array) + return id + + def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool): + if apply_grammar and self.grammar is not None: + ctx_main.grammar_accept_token(self.grammar, id) + self.prev.append(id) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d2ba8695c..ed407c58d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -466,21 +466,23 @@ def sample( elif temp == 0.0: id = self._ctx.sample_token_greedy(candidates=self._candidates) elif mirostat_mode == 1: + mu = 2.0 * mirostat_tau self._ctx.sample_temp(candidates=self._candidates, temp=temp) id = self._ctx.sample_token_mirostat( candidates=self._candidates, tau=mirostat_tau, eta=mirostat_eta, - mu=2.0 * mirostat_tau, + mu=ctypes.pointer(ctypes.c_float(mu)), m=100, ) elif mirostat_mode == 2: + mu = 2.0 * mirostat_tau self._ctx.sample_temp(candidates=self._candidates, temp=temp) id = self._ctx.sample_token_mirostat_v2( candidates=self._candidates, tau=mirostat_tau, eta=mirostat_eta, - mu=2.0 * mirostat_tau, + mu=ctypes.pointer(ctypes.c_float(mu)), ) else: self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) From 7ae9a3e614b043ca01a739d0baaab558fb201728 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 29 Nov 2023 18:55:48 -0500 Subject: [PATCH 4/9] Add vocab utils --- llama_cpp/_internals.py | 75 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 6a3605839..b49cf3c22 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -536,6 +536,81 @@ def copy_logits(self, logits: npt.NDArray[np.single]): self.candidates.size = llama_cpp.c_size_t(self.n_vocab) +# Python wrappers over common/common +def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]: + n_tokens = len(text) + 1 if add_bos else len(text) + result = (llama_cpp.llama_token * n_tokens)() + n_tokens = llama_cpp.llama_tokenize( + model.model, + text.encode("utf-8"), + len(text), + result, + n_tokens, + add_bos, + special, + ) + if n_tokens < 0: + result = (llama_cpp.llama_token * -n_tokens)() + check = llama_cpp.llama_tokenize( + model.model, + text.encode("utf-8"), + len(text), + result, + len(result), + add_bos, + special, + ) + if check != -n_tokens: + raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') + else: + result = result[:n_tokens] + return list(result) + + +def _token_to_piece(model: _LlamaModel, token: int) -> str: + assert model.model is not None + result = (ctypes.c_char * 8)(0) + n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) + if n_tokens < 0: + result = (ctypes.c_char * -n_tokens)(0) + check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) + if check != -n_tokens: + raise RuntimeError(f"Failed to get piece: token={token}") + else: + result = result[:n_tokens] + return bytes(result).decode("utf-8") + + +def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str: + bos_id = model.token_bos() + result = "" + for i, token in enumerate(tokens): + piece = _token_to_piece(model, token) + if ( + (tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0) + ) and piece[0] == " ": + piece = piece[1:] + result += piece + return result + + +def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str: + result = "" + for token in tokens: + piece = _token_to_piece(model, token) + result += piece + return result + + +def _should_add_bos(model: _LlamaModel) -> bool: + assert model.model is not None + add_bos = llama_cpp.llama_add_bos_token(model.model) + if add_bos != -1: + return add_bos != 0 + else: + return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM + + # Python wrappers over common/sampling structs From 40f2293ff036b5a4f360afa57e3436ca66a8439c Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 30 Nov 2023 12:24:54 -0500 Subject: [PATCH 5/9] Refactor _create_completion --- llama_cpp/llama.py | 251 +++++++++++++++++++-------------------------- 1 file changed, 105 insertions(+), 146 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ed407c58d..c830f7f44 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -780,6 +780,79 @@ def logit_bias_processor( if seed is not None: self._ctx.set_rng_seed(seed) + + def _completion_stream_response(text: str, logprobs_or_none: Optional[CompletionLogprobs] = None, finish_reason: Optional[Literal["stop", "length"]] = None) -> CreateCompletionStreamResponse: + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": 0, + "logprobs": logprobs_or_none, + "finish_reason": finish_reason, + } + ], + } + + def _completion_response(text: str, finish_reason: Literal["stop", "length"], logprobs_or_none: Optional[CompletionLogprobs] = None) -> CreateCompletionResponse: + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": 0, + "logprobs": logprobs_or_none, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": len(prompt_tokens), + "completion_tokens": len(completion_tokens), + "total_tokens": len(prompt_tokens) + len(completion_tokens), + }, + } + + def _logprobs_or_none(all_tokens: List[int], all_token_strs: List[str], all_logprobs: List[List[float]], text_offset: int) -> CompletionLogprobs: + tokens: List[str] = [] + text_offsets: List[int] = [] + token_logprobs: List[Optional[float]] = [] + top_logprobs: List[Optional[Dict[str, float]]] = [] + + for token, token_str, token_logprob in zip( + all_tokens, all_token_strs, all_logprobs + ): + if token == self.token_bos(): + continue + + text_offset += len(token_str) + sorted_logprobs = list( + sorted( + zip(token_logprob, range(len(token_logprob))), reverse=True + ) + ) + top_logprob = { + self.detokenize([i]).decode("utf-8", errors="ignore"): logprob + for logprob, i in sorted_logprobs[:logprobs] + } + top_logprob.update({token_str: token_logprob[int(token)]}) + + tokens.append(token_str) + text_offsets.append(text_offset) + token_logprobs.append(token_logprob[int(token)]) + top_logprobs.append(top_logprob) + + return { + "tokens": tokens, + "text_offset": text_offsets, + "token_logprobs": token_logprobs, + "top_logprobs": top_logprobs, + } finish_reason = "length" multibyte_fix = 0 @@ -868,10 +941,10 @@ def logit_bias_processor( ) token_offset = len(prompt_tokens) + returned_tokens logits = self._scores[token_offset - 1, :].tolist() - current_logprobs = Llama.logits_to_logprobs(logits) + token_logprob = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( - zip(current_logprobs, range(len(current_logprobs))), + zip(token_logprob, range(len(token_logprob))), reverse=True, ) ) @@ -881,7 +954,7 @@ def logit_bias_processor( ): logprob for logprob, i in sorted_logprobs[:logprobs] } - top_logprob.update({token_str: current_logprobs[int(token)]}) + top_logprob.update({token_str: token_logprob[int(token)]}) logprobs_or_none = { "tokens": [ self.detokenize([token]).decode( @@ -889,26 +962,14 @@ def logit_bias_processor( ) ], "text_offset": [text_offset], - "token_logprobs": [current_logprobs[int(token)]], + "token_logprobs": [token_logprob[int(token)]], "top_logprobs": [top_logprob], } returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize([token]).decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } + yield _completion_stream_response( + self.detokenize([token]).decode("utf-8", errors="ignore"), + logprobs_or_none, + ) else: while len(remaining_tokens) > 0: decode_success = False @@ -933,20 +994,7 @@ def logit_bias_processor( remaining_tokens = remaining_tokens[i:] returned_tokens += i - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": ts, - "index": 0, - "logprobs": None, - "finish_reason": None, - } - ], - } + yield _completion_stream_response(text=ts) if len(completion_tokens) >= max_tokens: text = self.detokenize(completion_tokens) @@ -986,11 +1034,10 @@ def logit_bias_processor( self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - 1 - logits = self._scores[token_offset, :].tolist() - current_logprobs = Llama.logits_to_logprobs(logits) + token_logprob = Llama.logits_to_logprobs(self._scores[token_offset, :].tolist()) sorted_logprobs = list( sorted( - zip(current_logprobs, range(len(current_logprobs))), + zip(token_logprob, range(len(token_logprob))), reverse=True, ) ) @@ -998,13 +1045,11 @@ def logit_bias_processor( self.detokenize([i]).decode("utf-8", errors="ignore"): logprob for logprob, i in sorted_logprobs[:logprobs] } - top_logprob.update({token_str: current_logprobs[int(token)]}) + top_logprob.update({token_str: token_logprob[int(token)]}) logprobs_or_none = { - "tokens": [ - self.detokenize([token]).decode("utf-8", errors="ignore") - ], + "tokens": [token_str], "text_offset": [text_offset], - "token_logprobs": [current_logprobs[int(token)]], + "token_logprobs": [token_logprob[int(token)]], "top_logprobs": [top_logprob], } @@ -1013,54 +1058,17 @@ def logit_bias_processor( if token_end_position == end - 1: break returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": last_text[ - : len(last_text) - (token_end_position - end) - ].decode("utf-8", errors="ignore"), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } + yield _completion_stream_response( + text=last_text[: len(last_text) - (token_end_position - end)].decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none + ) break returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize([token]).decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": "", - "index": 0, - "logprobs": None, - "finish_reason": finish_reason, - } - ], - } + yield _completion_stream_response( + text=self.detokenize([token]).decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none + ) + yield _completion_stream_response( + text=self.detokenize(completion_tokens[returned_tokens:]).decode("utf-8", errors="ignore"), finish_reason=finish_reason + ) if self.cache: if self.verbose: print("Llama._create_completion: cache save", file=sys.stderr) @@ -1076,6 +1084,7 @@ def logit_bias_processor( text_str = text.decode("utf-8", errors="ignore") if echo: + assert isinstance(prompt, str) text_str = prompt + text_str if suffix is not None: @@ -1083,19 +1092,10 @@ def logit_bias_processor( logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: + # Remove leading BOS token + all_tokens = prompt_tokens[1:] + completion_tokens if echo else completion_tokens text_offset = 0 if echo else len(prompt) token_offset = 0 if echo else len(prompt_tokens[1:]) - text_offsets: List[int] = [] - token_logprobs: List[Optional[float]] = [] - tokens: List[str] = [] - top_logprobs: List[Optional[Dict[str, float]]] = [] - - if echo: - # Remove leading BOS token - all_tokens = prompt_tokens[1:] + completion_tokens - else: - all_tokens = completion_tokens - all_token_strs = [ self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens @@ -1103,58 +1103,17 @@ def logit_bias_processor( all_logprobs = [ Llama.logits_to_logprobs(row.tolist()) for row in self._scores ][token_offset:] - for token, token_str, logprobs_token in zip( - all_tokens, all_token_strs, all_logprobs - ): - if token == self.token_bos(): - continue - text_offsets.append(text_offset) - text_offset += len(token_str) - tokens.append(token_str) - sorted_logprobs = list( - sorted( - zip(logprobs_token, range(len(logprobs_token))), reverse=True - ) - ) - token_logprobs.append(logprobs_token[int(token)]) - top_logprob: Optional[Dict[str, float]] = { - self.detokenize([i]).decode("utf-8", errors="ignore"): logprob - for logprob, i in sorted_logprobs[:logprobs] - } - top_logprob.update({token_str: logprobs_token[int(token)]}) - top_logprobs.append(top_logprob) + logprobs_or_none = _logprobs_or_none( + all_tokens, all_token_strs, all_logprobs, text_offset + ) # Weird idosincracy of the OpenAI API where # token_logprobs and top_logprobs are null for # the first token. if echo and len(all_tokens) > 0: - token_logprobs[0] = None - top_logprobs[0] = None - logprobs_or_none = { - "tokens": tokens, - "text_offset": text_offsets, - "token_logprobs": token_logprobs, - "top_logprobs": top_logprobs, - } + logprobs_or_none["token_logprobs"][0] = None + logprobs_or_none["top_logprobs"][0] = None - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": text_str, - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": finish_reason, - } - ], - "usage": { - "prompt_tokens": len(prompt_tokens), - "completion_tokens": len(completion_tokens), - "total_tokens": len(prompt_tokens) + len(completion_tokens), - }, - } + yield _completion_response(text_str, finish_reason, logprobs_or_none) def create_completion( self, From fcbd177c95c77b49e8b11e4658ef7a6813aee6d0 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 18 Dec 2023 18:38:04 -0500 Subject: [PATCH 6/9] Fix logits are not json serializable --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f5dfa22a7..aa010793d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -946,7 +946,7 @@ def _logprobs_or_none(all_tokens: List[int], all_token_strs: List[str], all_logp ) token_offset = len(prompt_tokens) + returned_tokens logits = self._scores[token_offset - 1, :] - token_logprob = Llama.logits_to_logprobs(logits) + token_logprob = Llama.logits_to_logprobs(logits).tolist() sorted_logprobs = list( sorted( zip(token_logprob, range(len(token_logprob))), @@ -1040,7 +1040,7 @@ def _logprobs_or_none(all_tokens: List[int], all_token_strs: List[str], all_logp ) token_offset = len(prompt_tokens) + returned_tokens - 1 logits = self._scores[token_offset, :] - token_logprob = Llama.logits_to_logprobs(logits) + token_logprob = Llama.logits_to_logprobs(logits).tolist() sorted_logprobs = list( sorted( zip(token_logprob, range(len(token_logprob))), From e1cd61ed91fa56c923130f5078d6d9675c4cba59 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 5 Jan 2024 04:57:57 -0500 Subject: [PATCH 7/9] Fix #1038 --- llama_cpp/llama.py | 11 ++++++++--- llama_cpp/llama_chat_format.py | 20 ++++++++++++++++---- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3ff631d96..d6172edb6 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -872,7 +872,7 @@ def _completion_response(text: str, finish_reason: Literal["stop", "length"], lo break if stream: - remaining_tokens = completion_tokens[returned_tokens:] + remaining_tokens = completion_tokens[returned_tokens:-1] remaining_text = self.detokenize(remaining_tokens) remaining_length = len(remaining_text) @@ -1030,9 +1030,14 @@ def _completion_response(text: str, finish_reason: Literal["stop", "length"], lo break returned_tokens += 1 yield _completion_stream_response( - text=last_text[: len(last_text) - (token_end_position - end)].decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none + text=last_text[: len(last_text) - (token_end_position - end)].decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none, finish_reason=finish_reason ) - break + if self.cache: + if self.verbose: + print("Llama._create_completion: cache save", file=sys.stderr) + self.cache[prompt_tokens + completion_tokens] = self.save_state() + print("Llama._create_completion: cache saved", file=sys.stderr) + return returned_tokens += 1 yield _completion_stream_response( text=self.detokenize([token]).decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 0ef7bd4a8..b5d490950 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -260,13 +260,25 @@ def _convert_text_completion_chunks_to_chat( "index": 0, "delta": { "content": chunk["choices"][0]["text"], - } - if chunk["choices"][0]["finish_reason"] is None - else {}, - "finish_reason": chunk["choices"][0]["finish_reason"], + }, + "finish_reason": None, } ], } + if chunk["choices"][0]["finish_reason"] is not None: + yield { + "id": "chat" + chunk["id"], + "model": chunk["model"], + "created": chunk["created"], + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": chunk["choices"][0]["finish_reason"], + } + ], + } def _convert_completion_to_chat( From 7f4ba48ada4cdd09611c66ec2614528514b46b9b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 10 Jan 2024 08:29:54 -0500 Subject: [PATCH 8/9] Use sampling context --- llama_cpp/_internals.py | 20 +++--- llama_cpp/llama.py | 137 ++++++++++++++++++---------------------- 2 files changed, 74 insertions(+), 83 deletions(-) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index b49cf3c22..e993f7d84 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -675,20 +675,22 @@ def prev_str(self, ctx_main: _LlamaContext, n: int) -> str: return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8") def sample( - self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext], idx: int = 0 + self, ctx_main: _LlamaContext, ctx_cfg: Optional[_LlamaContext] = None, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None ): n_vocab = ctx_main.model.n_vocab() - id = 0 - logits = ctx_main.get_logits_ith(idx) + id: int = 0 + + if logits_array is None: + logits = ctx_main.get_logits_ith(idx) + logits_array = np.array( + ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents, + dtype=np.single, + ) # apply logit_bias for token, logit_bias in self.params.logit_bias.items(): - logits[token] += logit_bias + logits_array[token] += logit_bias - logits_array = np.array( - ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents, - dtype=np.single, - ) token_data_array = _LlamaTokenDataArray( n_vocab=n_vocab ) # TODO: Only create this once @@ -702,7 +704,7 @@ def sample( # apply penalties if len(self.prev) > 0: nl_token = ctx_main.model.token_nl() - nl_logit = logits[nl_token] + nl_logit = logits_array[nl_token] if self.params.penalty_last_n > 0: ctx_main.sample_repetition_penalties( token_data_array, diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3205b8c46..21292040a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -30,7 +30,14 @@ import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format -from ._internals import _LlamaModel, _LlamaContext, _LlamaBatch, _LlamaTokenDataArray # type: ignore +from ._internals import ( + _LlamaModel, + _LlamaContext, + _LlamaBatch, + _LlamaTokenDataArray, # type: ignore + _LlamaSamplingParams, + _LlamaSamplingContext, +) from ._utils import suppress_stdout_stderr @@ -427,79 +434,39 @@ def sample( """ assert self._ctx is not None assert self.n_tokens > 0 - last_n_tokens_data = [llama_cpp.llama_token(0)] * max( - 0, self.last_n_tokens_size - self.n_tokens - ) + self._input_ids[-self.last_n_tokens_size :].tolist() - last_n_tokens_size = len(last_n_tokens_data) - n_vocab = self._n_vocab - n_ctx = self._n_ctx - top_k = n_vocab if top_k <= 0 else top_k - last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size - last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)( - *last_n_tokens_data - ) + logits: npt.NDArray[np.single] = self._scores[-1, :] if logits_processor is not None: logits[:] = logits_processor(self._input_ids, logits) - nl_logit = logits[self._token_nl] - self._candidates.copy_logits(logits) - self._ctx.sample_repetition_penalties( - candidates=self._candidates, - last_tokens_data=last_n_tokens_data_c, - penalty_last_n=last_n_tokens_size, + sampling_params = _LlamaSamplingParams( + top_k=top_k, + top_p=top_p, + min_p=min_p, + tfs_z=tfs_z, + typical_p=typical_p, + temp=temp, + penalty_last_n=self.last_n_tokens_size, penalty_repeat=repeat_penalty, penalty_freq=frequency_penalty, penalty_present=presence_penalty, + mirostat=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + penalize_nl=penalize_nl, + ) + sampling_context = _LlamaSamplingContext( + params=sampling_params, + grammar=grammar, + ) + sampling_context.prev = list(self.eval_tokens) + id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits) + sampling_context.accept( + ctx_main=self._ctx, + id=id, + apply_grammar=grammar is not None, ) - if not penalize_nl: - self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float( - nl_logit - ) - - if grammar is not None: - self._ctx.sample_grammar( - candidates=self._candidates, - grammar=grammar, - ) - - if temp < 0.0: - self._ctx.sample_softmax(candidates=self._candidates) - id = self._candidates.candidates.data[0].id - elif temp == 0.0: - id = self._ctx.sample_token_greedy(candidates=self._candidates) - elif mirostat_mode == 1: - mu = 2.0 * mirostat_tau - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token_mirostat( - candidates=self._candidates, - tau=mirostat_tau, - eta=mirostat_eta, - mu=ctypes.pointer(ctypes.c_float(mu)), - m=100, - ) - elif mirostat_mode == 2: - mu = 2.0 * mirostat_tau - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token_mirostat_v2( - candidates=self._candidates, - tau=mirostat_tau, - eta=mirostat_eta, - mu=ctypes.pointer(ctypes.c_float(mu)), - ) - else: - self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) - self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1) - self._ctx.sample_typical( - candidates=self._candidates, p=typical_p, min_keep=1 - ) - self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1) - self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1) - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token(candidates=self._candidates) - if grammar is not None: - self._ctx.grammar_accept_token(grammar=grammar, token=id) return id def generate( @@ -786,8 +753,12 @@ def logit_bias_processor( if seed is not None: self._ctx.set_rng_seed(seed) - - def _completion_stream_response(text: str, logprobs_or_none: Optional[CompletionLogprobs] = None, finish_reason: Optional[Literal["stop", "length"]] = None) -> CreateCompletionStreamResponse: + + def _completion_stream_response( + text: str, + logprobs_or_none: Optional[CompletionLogprobs] = None, + finish_reason: Optional[Literal["stop", "length"]] = None, + ) -> CreateCompletionStreamResponse: return { "id": completion_id, "object": "text_completion", @@ -803,7 +774,11 @@ def _completion_stream_response(text: str, logprobs_or_none: Optional[Completion ], } - def _completion_response(text: str, finish_reason: Literal["stop", "length"], logprobs_or_none: Optional[CompletionLogprobs] = None) -> CreateCompletionResponse: + def _completion_response( + text: str, + finish_reason: Literal["stop", "length"], + logprobs_or_none: Optional[CompletionLogprobs] = None, + ) -> CreateCompletionResponse: return { "id": completion_id, "object": "text_completion", @@ -1032,20 +1007,32 @@ def _completion_response(text: str, finish_reason: Literal["stop", "length"], lo break returned_tokens += 1 yield _completion_stream_response( - text=last_text[: len(last_text) - (token_end_position - end)].decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none, finish_reason=finish_reason + text=last_text[ + : len(last_text) - (token_end_position - end) + ].decode("utf-8", errors="ignore"), + logprobs_or_none=logprobs_or_none, + finish_reason=finish_reason, ) if self.cache: if self.verbose: - print("Llama._create_completion: cache save", file=sys.stderr) - self.cache[prompt_tokens + completion_tokens] = self.save_state() + print( + "Llama._create_completion: cache save", file=sys.stderr + ) + self.cache[ + prompt_tokens + completion_tokens + ] = self.save_state() print("Llama._create_completion: cache saved", file=sys.stderr) return returned_tokens += 1 yield _completion_stream_response( - text=self.detokenize([token]).decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none + text=self.detokenize([token]).decode("utf-8", errors="ignore"), + logprobs_or_none=logprobs_or_none, ) yield _completion_stream_response( - text=self.detokenize(completion_tokens[returned_tokens:]).decode("utf-8", errors="ignore"), finish_reason=finish_reason + text=self.detokenize(completion_tokens[returned_tokens:]).decode( + "utf-8", errors="ignore" + ), + finish_reason=finish_reason, ) if self.cache: if self.verbose: @@ -1071,7 +1058,9 @@ def _completion_response(text: str, finish_reason: Literal["stop", "length"], lo logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: # Remove leading BOS token - all_tokens = prompt_tokens[1:] + completion_tokens if echo else completion_tokens + all_tokens = ( + prompt_tokens[1:] + completion_tokens if echo else completion_tokens + ) text_offset = 0 if echo else len(prompt) token_offset = 0 if echo else len(prompt_tokens[1:]) text_offsets: List[int] = [] From 6f080212803e989d0afa957ccf8aa95ea77e437d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 17 Jan 2024 09:48:46 -0500 Subject: [PATCH 9/9] Cleanup pyproject --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 806127d89..413097201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,6 @@ license = { text = "MIT" } authors = [ { name = "Andrei Betlen", email = "abetlen@gmail.com" }, ] -# mkdocs-martiral requires "jinja2~=3.0" -# transformers requires "jinja2>=2.11.3" dependencies = [ "typing-extensions>=4.5.0", "numpy>=1.20.0",