From 8c7b4c14925b1f9ef4928220c36abfa9c86a847d Mon Sep 17 00:00:00 2001 From: "antoine.lizee" Date: Mon, 30 Oct 2023 10:54:18 +0000 Subject: [PATCH] fix: tokenization of special characters: It should behave like llama.cpp, where most out of the box usages treat special characters accordingly --- llama_cpp/llama.py | 6 +++--- llama_cpp/server/app.py | 2 +- test.py | 0 tests/test_llama.py | 9 +++++++++ 4 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 test.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e53c9c8ae..9efccbe62 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -856,7 +856,7 @@ def create_embedding( data: List[EmbeddingData] = [] total_tokens = 0 for index, input in enumerate(inputs): - tokens = self.tokenize(input.encode("utf-8")) + tokens = self.tokenize(input.encode("utf-8"), special=True) self.reset() self.eval(tokens) n_tokens = len(tokens) @@ -927,7 +927,7 @@ def _create_completion( completion_tokens: List[int] = [] # Add blank space to start of prompt to match OG llama tokenizer prompt_tokens: List[int] = ( - self.tokenize(prompt.encode("utf-8")) + self.tokenize(prompt.encode("utf-8"), special=True) if prompt != "" else [self.token_bos()] ) @@ -1823,7 +1823,7 @@ def __init__(self, llama: Llama): def encode(self, text: str, add_bos: bool = True) -> List[int]: return self.llama.tokenize( - text.encode("utf-8", errors="ignore"), add_bos=add_bos + text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True ) def decode(self, tokens: List[int]) -> str: diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 3dd0a38fe..1547f7b86 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -589,7 +589,7 @@ def make_logit_bias_processor( elif logit_bias_type == "tokens": for token, score in logit_bias.items(): token = token.encode("utf-8") - for input_id in llama.tokenize(token, add_bos=False): + for input_id in llama.tokenize(token, add_bos=False, special=True): to_bias[input_id] = score def logit_bias_processor( diff --git a/test.py b/test.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_llama.py b/tests/test_llama.py index 76291fbca..330b69b9c 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -25,6 +25,15 @@ def test_llama_cpp_tokenization(): detokenized = llama.detokenize(tokens) assert detokenized != text + text = b"Hello World" + tokens = llama.tokenize(text) + assert tokens[-1] != llama.token_eos() + assert tokens == [1, 15043, 2787, 829, 29879, 29958] + + tokens = llama.tokenize(text, special=True) + assert tokens[-1] == llama.token_eos() + assert tokens == [1, 10994, 2787, 2] + def test_llama_patch(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)