Skip to content

Commit c37a029

Browse files
committed
add bindings for llama_grammar_parse / llama_grammar_from_state
1 parent 5ff88a3 commit c37a029

File tree

4 files changed

+27
-19
lines changed

4 files changed

+27
-19
lines changed

grammar_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from llama_cpp import Llama
22

3-
grammar = """
4-
root ::= nav eol (commands eol)*
3+
grammar = """root ::= nav eol (commands eol)*
54
commands ::= t | info
65
nav ::= "nav(\\"admin/" [a-z/]* "\\")"
76
info ::= "info(" setting ")"
@@ -17,15 +16,13 @@
1716

1817
llm = Llama(
1918
model_path="/Users/alex/llama-7b.ggmlv3.q8_0.bin",
20-
lora_base="/Users/alex/llama-7b.ggml.f16.bin",
19+
# lora_base="/Users/alex/llama-7b.ggml.f16.bin",
2120
# python ~/llama.cpp/convert-lora-to-ggml.py .
22-
lora_path="/Users/alex/src/github.com/Shopify/sidekick-data/src/webapp/models/ggml-adapter-model.bin",
21+
# lora_path="/Users/alex/src/github.com/Shopify/sidekick-data/src/webapp/models/ggml-adapter-model.bin",
2322
# n_gpu_layers=1000,
2423
n_ctx=2048,
2524
grammar=grammar,
2625
)
2726

28-
# response = llm("make my theme orange")
29-
3027
import code
3128
code.interact(local=globals())

llama_cpp/llama.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,6 @@ def __init__(
273273

274274
self.lora_base = lora_base
275275
self.lora_path = lora_path
276-
self.grammar = grammar
277-
278-
if grammar:
279-
self.grammar = llama_cpp.llama_parse_grammar(
280-
llama_cpp.c_char_p(self.grammar.encode("utf-8"))
281-
)
282276

283277
### DEPRECATED ###
284278
self.n_parts = n_parts
@@ -306,6 +300,12 @@ def __init__(
306300
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
307301
)
308302

303+
if grammar:
304+
self.parse_state = llama_cpp.llama_grammar_parse(
305+
llama_cpp.c_char_p(grammar.encode("utf-8"))
306+
)
307+
self.grammar = llama_cpp.llama_grammar_from_state(self.parse_state)
308+
309309
if self.verbose:
310310
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
311311

@@ -582,7 +582,6 @@ def _sample(
582582
)
583583

584584
if self.grammar:
585-
breakpoint()
586585
id = llama_cpp.llama_grammar_accept_token(
587586
self.ctx,
588587
self.grammar,
@@ -890,7 +889,8 @@ def _create_completion(
890889
stopping_criteria=stopping_criteria,
891890
logits_processor=logits_processor,
892891
):
893-
if token == self._token_eos:
892+
893+
if token == self._token_eos: #or token == self._token_nl:
894894
text = self.detokenize(completion_tokens)
895895
finish_reason = "stop"
896896
break

llama_cpp/llama_cpp.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def _load_shared_library(lib_base_name: str):
114114
llama_token = c_int
115115
llama_token_p = POINTER(llama_token)
116116

117+
# struct llama_grammar
118+
parse_state_p = c_void_p
117119
llama_grammar_p = c_void_p
118120

119121

@@ -796,13 +798,22 @@ def llama_sample_temperature(
796798
_lib.llama_sample_temperature.restype = None
797799

798800

799-
def llama_parse_grammar(grammar: str):
800-
return _lib.llama_parse_grammar(grammar)
801+
def llama_grammar_parse(grammar: str):
802+
return _lib.llama_grammar_parse(grammar)
801803

802-
_lib.llama_parse_grammar.argtypes = [
804+
_lib.llama_grammar_parse.argtypes = [
803805
c_char_p,
804806
]
805-
_lib.llama_parse_grammar.restype = llama_grammar_p
807+
_lib.llama_grammar_parse.restype = parse_state_p
808+
809+
810+
def llama_grammar_from_state(parse_state: parse_state_p):
811+
return _lib.llama_grammar_from_state(parse_state)
812+
813+
_lib.llama_grammar_from_state.argtypes = [
814+
parse_state_p
815+
]
816+
_lib.llama_grammar_from_state.restype = llama_grammar_p
806817

807818

808819
def llama_sample_grammar(

0 commit comments

Comments
 (0)