From ac188a21f3a2e9530b36e6103de36f5ba655376e Mon Sep 17 00:00:00 2001 From: c0sogi Date: Sat, 5 Aug 2023 14:43:35 +0900 Subject: [PATCH 1/4] Added low level grammar API --- llama_cpp/llama_cpp.py | 34 + llama_cpp/llama_grammar.py | 1331 ++++++++++++++++++++++++++++++++++++ 2 files changed, 1365 insertions(+) create mode 100644 llama_cpp/llama_grammar.py diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 423a4a043..d9a68a96c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1157,6 +1157,23 @@ def llama_sample_temperature( _lib.llama_sample_temperature.restype = None +# LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); +def llama_sample_grammar( + ctx: llama_context_p, + candidates, # type: _Pointer[llama_token_data_array] + grammar, # type: llama_grammar_p +): + return _lib.llama_sample_grammar(ctx, candidates, grammar) + + +_lib.llama_sample_grammar.argtypes = [ + llama_context_p, + llama_token_data_array_p, + llama_grammar_p, +] +_lib.llama_sample_grammar.restype = None + + # @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. # @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. # @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -1244,6 +1261,23 @@ def llama_sample_token( _lib.llama_sample_token.restype = llama_token +# /// @details Accepts the sampled token into the grammar +# LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); +def llama_grammar_accept_token( + ctx: llama_context_p, + grammar: llama_grammar_p, + token: llama_token, +) -> None: + _lib.llama_grammar_accept_token(ctx, grammar, token) + + +_lib.llama_grammar_accept_token.argtypes = [ + llama_context_p, + llama_grammar_p, + llama_token, +] +_lib.llama_grammar_accept_token.restype = None + # Performance information diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py new file mode 100644 index 000000000..07a120f2c --- /dev/null +++ b/llama_cpp/llama_grammar.py @@ -0,0 +1,1331 @@ +"""C++ implementation of the llama grammar parser.""" +# flake8: noqa +import argparse +from pathlib import Path +import sys +from ctypes import Array, c_int, c_size_t, c_uint32, cast +from enum import Enum +from itertools import islice +from typing import ( + Callable, + Generic, + List, + Optional, + OrderedDict, + TextIO, + Tuple, + TypeVar, + Union, +) + +import llama_cpp + +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") +W = TypeVar("W") +size_t = uint8_t = uint32_t = int +static_cast_uint8_t = ord + + +class Sentinel: + pass + + +class const_char_p: + """C++ implementation of const char*.""" + + def __init__(self, value: Union[str, "const_char_p"]): + if isinstance(value, const_char_p): + # We're copying an existing const_char_p + self.value = value.value + self.pos = value.pos + return + + # We're creating a new const_char_p + self.value = value + self.pos = 0 + + def __str__(self) -> str: + return self.value[self.pos :] + + def __add__(self, increment: int) -> "const_char_p": + # To avoid side effects, we create a new const_char_p object + new = self.__class__(self.value) + new.pos = self.pos + increment + return new + + def __sub__(self, decrement: int) -> "const_char_p": + # To avoid side effects, we create a new const_char_p object + new = self.__class__(self.value) + new.pos = self.pos - decrement + return new + + def __lt__(self, other: "const_char_p") -> bool: + return self.pos < other.pos and self.value == other.value + + def __gt__(self, other: "const_char_p") -> bool: + return self.pos > other.pos and self.value == other.value + + def __eq__(self, other: "const_char_p") -> bool: + return self.pos == other.pos and self.value == other.value + + def add(self, other: "const_char_p") -> int: + if self.value != other.value: + raise ValueError("Can't add pointers to different strings") + return self.pos + other.pos + + def sub(self, other: "const_char_p") -> int: + if self.value != other.value: + raise ValueError("Can't subtract pointers to different strings") + return self.pos - other.pos + + def plus_plus(self) -> None: + self.pos += 1 + + def minus_minus(self) -> None: + self.pos -= 1 + + @property + def derefer(self) -> Optional[str]: + if self.pos >= len(self.value): + # We've reached the end of the string + return None + + return self.value[self.pos] + + +class std__vector(Generic[T], List[T]): + """C++ implementation of std::vector.""" + + class iterator: + def __init__(self, vector: "std__vector[T]", index: int): + self._vector = vector + self._index = index + self._version = vector._version + + def _check_version(self): + if self._version != self._vector._version: + raise RuntimeError("Iterator used after vector was modified.") + + def __iter__(self): + return self + + def __next__(self) -> T: + self._check_version() + if self._index >= self._vector.size(): + raise StopIteration + value = self._vector[self._index] + self._index += 1 + return value + + def __add__(self, value: int) -> "std__vector[T].iterator": + return self.__class__(self._vector, self._index + value) + + def __sub__(self, value: int) -> "std__vector[T].iterator": + return self.__class__(self._vector, self._index - value) + + def __init__(self): + self._version = 0 + + def modify(self): + # This is a bit of a hack to make sure iterators are invalidated + self._version += 1 + + def push_back(self, value: T) -> None: + self.modify() + self.append(value) + + def pop_back(self) -> None: + self.modify() + if not self.empty(): + self.pop() + + def back(self) -> T: + return self[-1] + + def size(self) -> int: + return len(self) + + # def clear(self) -> None: + # super().clear() + + def empty(self) -> bool: + return self.size() == 0 + + def data(self) -> "std__vector[T]": + return self + + def resize( + self, + new_size: int, + fill_value_factory: Optional[Callable[[], T]] = None, + ) -> None: + if new_size > self.size(): + if fill_value_factory is None: + raise ValueError( + "A fill value factory function must be provided." + ) + self.reserve(new_size, fill_value_factory) + elif new_size < self.size(): + self[:] = self[:new_size] + + def reserve( + self, capacity: int, fill_value_factory: Callable[[], T] + ) -> None: + if capacity > self.size(): + fill_value = fill_value_factory() + self.extend([fill_value] * (capacity - self.size())) + + def front(self) -> T: + if not self.empty(): + return self[0] + else: + raise IndexError("Vector is empty.") + + def assign(self, count: int, value: T) -> None: + self.clear() + self.extend([value] * count) + + def insert( + self, + pos: "std__vector[T].iterator", + first: "std__vector[T].iterator", + last: "std__vector[T].iterator", + ) -> None: + self[pos._index : pos._index] = list( + islice(first._vector, first._index, last._index) + ) + + def begin(self) -> "std__vector[T].iterator": + return self.iterator(self, 0) + + def end(self) -> "std__vector[T].iterator": + return self.iterator(self, self.size()) + + +class std__map(Generic[T, U], OrderedDict[T, U]): + """C++ implementation of std::map.""" + + class iterator(Generic[V, W]): + def __init__(self, _map: "std__map[T, U]", key: Union[T, Sentinel]): + self._map = _map + self.iter = iter(_map) + self.key = key + self._advance() + + def _sanitize_key(self) -> T: + if isinstance(self.key, Sentinel): + raise StopIteration + return self.key + + def _advance(self) -> None: + try: + while next(self.iter) != self.key: + pass + except StopIteration: + self.key = Sentinel() + + def __next__(self) -> Tuple[T, U]: + key = self._sanitize_key() + if key in self._map: + value = self._map[key] + self._advance() + return key, value + else: + raise StopIteration + + def get(self) -> Tuple[T, U]: + key = self._sanitize_key() + return key, self._map[key] + + @property + def first(self) -> T: + return self._sanitize_key() + + @property + def second(self) -> U: + return self._map[self._sanitize_key()] + + def insert( + self, key: T, value: U + ) -> Tuple["std__map[T, U].iterator[T, U]", bool]: + if key in self: + return self.iterator(self, key), False + else: + self[key] = value + return self.iterator(self, key), True + + def find(self, key: T) -> "std__map[T, U].iterator[T, U]": + if key in self: + return self.iterator(self, key) + else: + return self.end() + + def at(self, key: T) -> U: + if key in self: + return self[key] + else: + raise KeyError("The provided key is not found in the map.") + + def erase(self, iterator: "std__map[T, U].iterator[T, U]") -> None: + key = iterator.first + if key in self: + del self[key] + + def size(self) -> int: + return len(self) + + def empty(self) -> bool: + return self.size() == 0 + + def lower_bound(self, key: T) -> "std__map[T, U].iterator[T, U]": + try: + keys = sorted(list(self.keys())) # type: ignore + for k in keys: + if k >= key: + return self.iterator(self, k) + raise ValueError( + "No key found that is not less than the input key" + ) + except TypeError: + raise TypeError("Keys of type T cannot be sorted.") + + def begin(self) -> "std__map[T, U].iterator[T, U]": + return self.iterator(self, next(iter(self))) + + def end(self) -> "std__map[T, U].iterator[T, U]": + return self.iterator(self, Sentinel()) + + +class std__string(str): + def __new__(cls, ptr: const_char_p, length: Optional[int] = None): + if length is not None: + return super().__new__(cls, str(ptr)[:length]) + return super().__new__(cls, str(ptr)) + + +# // grammar element type +# enum llama_gretype { +# // end of rule definition +# LLAMA_GRETYPE_END = 0, + +# // start of alternate definition for rule +# LLAMA_GRETYPE_ALT = 1, + +# // non-terminal element: reference to rule +# LLAMA_GRETYPE_RULE_REF = 2, + +# // terminal element: character (code point) +# LLAMA_GRETYPE_CHAR = 3, + +# // inverse char(s) ([^a], [^a-b] [^abc]) +# LLAMA_GRETYPE_CHAR_NOT = 4, + +# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to +# // be an inclusive range ([a-z]) +# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + +# // modifies a preceding LLAMA_GRETYPE_CHAR or +# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) +# LLAMA_GRETYPE_CHAR_ALT = 6, +# }; +class llama_gretype(Enum): + """grammar element type""" + + LLAMA_GRETYPE_END = 0 # end of rule definition + LLAMA_GRETYPE_ALT = 1 # start of alternate definition for rule + LLAMA_GRETYPE_RULE_REF = 2 # non-terminal element: reference to rule + LLAMA_GRETYPE_CHAR = 3 # terminal element: character (code point) + LLAMA_GRETYPE_CHAR_NOT = 4 # inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + + +# typedef struct llama_grammar_element { +# enum llama_gretype type; +# uint32_t value; // Unicode code point or rule ID +# } llama_grammar_element; + + +# class llama_grammar_element(Structure): +# _fields_ = [ +# ("type", c_int), +# ("value", c_uint32), +# ] + + +class llama_grammar_element: + def __init__(self, type: llama_gretype, value: uint32_t): + self.type = type + self.value = value # Unicode code point or rule ID + + def __repr__(self): # debug + return f"llama_grammar_element({self.type}, {self.value})" + + +# struct parse_state { +# std::map symbol_ids; +# std::vector> rules; +# std::vector c_rules(); +# }; +class parse_state: + def __init__(self): + self.symbol_ids: std__map[str, uint32_t] = std__map() + self.rules: std__vector[ + std__vector[llama_grammar_element] + ] = std__vector() + + # std::vector parse_state::c_rules() { + # std::vector ret; + # for (const auto & rule : rules) { + # ret.push_back(rule.data()); + # } + # return ret; + # } + def c_rules(self) -> std__vector[std__vector[llama_grammar_element]]: + ret = ( + std__vector() + ) # type: std__vector[std__vector[llama_grammar_element]] + for rule in self.rules: + ret.push_back(rule.data()) + return ret + + +# struct llama_grammar { +# const std::vector> rules; +# std::vector> stacks; +# }; +class llama_grammar: + def __init__( + self, + rules: std__vector[std__vector[llama_grammar_element]], + stacks: std__vector[std__vector[llama_grammar_element]], + ): + self.rules = rules + self.stacks = stacks + + +# uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { +# uint32_t next_id = static_cast(state.symbol_ids.size()); +# auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); +# return result.first->second; +# } +def get_symbol_id(state: parse_state, src: const_char_p, len: size_t) -> int: + next_id = uint32_t(state.symbol_ids.size()) # type: uint32_t + result = state.symbol_ids.insert(str(std__string(src, len)), next_id) + return result[0].second # type: ignore + + +# uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { +# uint32_t next_id = static_cast(state.symbol_ids.size()); +# state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; +# return next_id; +# } +def generate_symbol_id(state: parse_state, base_name: str) -> uint32_t: + next_id = state.symbol_ids.size() # type: uint32_t + state.symbol_ids[base_name + "_" + str(next_id)] = next_id + return next_id + + +# void add_rule( +# parse_state & state, +# uint32_t rule_id, +# const std::vector & rule) { +# if (state.rules.size() <= rule_id) { +# state.rules.resize(rule_id + 1); +# } +# state.rules[rule_id] = rule; +# } +def add_rule( + state: parse_state, + rule_id: uint32_t, + rule: std__vector[llama_grammar_element], +) -> None: + if state.rules.size() <= rule_id: + state.rules.resize( + rule_id + 1, fill_value_factory=std__vector[llama_grammar_element] + ) + state.rules[rule_id] = rule + + +# std::pair decode_utf8(const char * src) { +# static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; +# uint8_t first_byte = static_cast(*src); +# uint8_t highbits = first_byte >> 4; +# int len = lookup[highbits]; +# uint8_t mask = (1 << (8 - len)) - 1; +# uint32_t value = first_byte & mask; +# const char * end = src + len; // may overrun! +# const char * pos = src + 1; +# for ( ; pos < end && *pos; pos++) { +# value = (value << 6) + (static_cast(*pos) & 0x3F); +# } +# return std::make_pair(value, pos); +# } +def decode_utf8(src: const_char_p) -> Tuple[uint32_t, const_char_p]: + """Decodes a UTF-8 character from the source string.""" + lookup = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4) + first_byte = static_cast_uint8_t(src.derefer or "") # type: uint8_t + highbits = first_byte >> 4 # type: uint8_t + len = lookup[highbits] # type: int + mask = (1 << (8 - len)) - 1 # type: uint8_t + value = first_byte & mask # type: uint32_t + end = src + len # type: const_char_p # may overrun! + pos = src + 1 # type: const_char_p + while pos < end and pos.derefer: + value = (value << 6) + (static_cast_uint8_t(src.derefer or "") & 0x3F) + pos.plus_plus() + return value, pos + + +# bool is_word_char(char c) { +# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); +# } +def is_word_char(c: str) -> bool: + return ( + ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") + ) + + +# std::pair parse_hex(const char * src, int size) { +# const char * pos = src; +# const char * end = src + size; +# uint32_t value = 0; +# for ( ; pos < end && *pos; pos++) { +# value <<= 4; +# char c = *pos; +# if ('a' <= c && c <= 'f') { +# value += c - 'a' + 10; +# } else if ('A' <= c && c <= 'F') { +# value += c - 'A' + 10; +# } else if ('0' <= c && c <= '9') { +# value += c - '0'; +# } else { +# break; +# } +# } +# if (pos != end) { +# throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); +# } +# return std::make_pair(value, pos); +# } +def parse_hex(src: const_char_p, size: int) -> Tuple[uint32_t, const_char_p]: + pos = const_char_p(src) # type: const_char_p + end = src + size # type: const_char_p + value = 0 # type: uint32_t + while pos < end and pos.derefer: + value <<= 4 + c = pos.derefer # type: str + if "a" <= c <= "f": + value += static_cast_uint8_t(c) - static_cast_uint8_t("a") + 10 + elif "A" <= c <= "F": + value += static_cast_uint8_t(c) - static_cast_uint8_t("A") + 10 + elif "0" <= c <= "9": + value += static_cast_uint8_t(c) - static_cast_uint8_t("0") + else: + break + pos.plus_plus() + if pos != end: + raise RuntimeError( + "expecting " + str(size) + " hex chars at " + str(src) + ) + return (value, pos) + + +# std::pair parse_char(const char * src) { +# if (*src == '\\') { +# switch (src[1]) { +# case 'x': return parse_hex(src + 2, 2); +# case 'u': return parse_hex(src + 2, 4); +# case 'U': return parse_hex(src + 2, 8); +# case 't': return std::make_pair('\t', src + 2); +# case 'r': return std::make_pair('\r', src + 2); +# case 'n': return std::make_pair('\n', src + 2); +# case '\\': +# case '"': +# case '[': +# case ']': +# return std::make_pair(src[1], src + 2); +# default: +# throw std::runtime_error(std::string("unknown escape at ") + src); +# } +# } else if (*src) { +# return decode_utf8(src); +# } +# throw std::runtime_error("unexpected end of input"); +# } +def parse_char(src: const_char_p) -> Tuple[uint32_t, const_char_p]: + if src.derefer == "\\": + switch = (src + 1).derefer # type: Optional[str] + if switch == "x": + return parse_hex(src + 2, 2) + elif switch == "u": + return parse_hex(src + 2, 4) + elif switch == "U": + return parse_hex(src + 2, 8) + elif switch == "t": + return (static_cast_uint8_t("\t"), src + 2) # implicit cast + elif switch == "r": + return (static_cast_uint8_t("\r"), src + 2) # implicit cast + elif switch == "n": + return (static_cast_uint8_t("\n"), src + 2) # implicit cast + elif switch in ("\\", '"', "[", "]"): + return (static_cast_uint8_t(switch), src + 2) # implicit cast + else: + raise RuntimeError("unknown escape at " + str(src)) + elif src.derefer: + return decode_utf8(src) + else: + raise RuntimeError("unexpected end of input") + + +# const char * parse_name(const char * src) { +# const char * pos = src; +# while (is_word_char(*pos)) { +# pos++; +# } +# if (pos == src) { +# throw std::runtime_error(std::string("expecting name at ") + src); +# } +# return pos; +# } +def parse_name(src: const_char_p) -> const_char_p: + pos = const_char_p(src) # type: const_char_p + while is_word_char(pos.derefer or ""): + pos.plus_plus() + if pos == src: + raise RuntimeError("expecting name at " + str(src)) + return pos + + +# const char * parse_space(const char * src, bool newline_ok) { +# const char * pos = src; +# while (*pos == ' ' || *pos == '\t' || *pos == '#' || +# (newline_ok && (*pos == '\r' || *pos == '\n'))) { +# if (*pos == '#') { +# while (*pos && *pos != '\r' && *pos != '\n') { +# pos++; +# } +# } else { +# pos++; +# } +# } +# return pos; +# } +def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: + # Using a copy of `src` to avoid side effects + pos = const_char_p(src) + + while pos.derefer in (" ", "\t", "#") or ( + newline_ok and pos.derefer in ("\r", "\n") + ): + if pos.derefer == "#": + while pos.derefer is not None and pos.derefer not in ("\r", "\n"): + pos.plus_plus() + else: + pos.plus_plus() + + return pos + + +# const char * parse_sequence( +# parse_state & state, +# const char * src, +# const std::string & rule_name, +# std::vector & out_elements, +# bool is_nested) { +def parse_sequence( + state: parse_state, + src: const_char_p, + rule_name: str, + out_elements: std__vector[llama_grammar_element], + is_nested: bool, +) -> const_char_p: + # size_t last_sym_start = out_elements.size(); + # const char * pos = src; + last_sym_start = out_elements.size() # type: size_t + pos = const_char_p(src) # type: const_char_p + # while (*pos) { + while pos.derefer: + # if (*pos == '"') { // literal string + # pos++; + # last_sym_start = out_elements.size(); + # while (*pos != '"') { + # auto char_pair = parse_char(pos); + # pos = char_pair.second; + # out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + # } + # pos = parse_space(pos + 1, is_nested); + if pos.derefer == '"': # literal string + pos.plus_plus() + last_sym_start = out_elements.size() + while pos.derefer != '"': + char_pair = parse_char( + pos + ) # type: Tuple[uint32_t, const_char_p] + pos = char_pair[1] + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0] + ) + ) + pos = parse_space(pos + 1, is_nested) + # } else if (*pos == '[') { // char range(s) + # pos++; + # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + elif pos.derefer == "[": # char range(s) + pos.plus_plus() + start_type = ( + llama_gretype.LLAMA_GRETYPE_CHAR + ) # type: llama_gretype + # if (*pos == '^') { + # pos++; + # start_type = LLAMA_GRETYPE_CHAR_NOT; + # } + # last_sym_start = out_elements.size(); + if pos.derefer == "^": + pos.plus_plus() + start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT + last_sym_start = out_elements.size() + # while (*pos != ']') { + # auto char_pair = parse_char(pos); + # pos = char_pair.second; + # enum llama_gretype type = last_sym_start < out_elements.size() + # ? LLAMA_GRETYPE_CHAR_ALT + # : start_type; + # out_elements.push_back({type, char_pair.first}); + while pos.derefer != "]": + char_pair = parse_char( + pos + ) # type: Tuple[uint32_t, const_char_p] + pos = char_pair[1] + type = ( + llama_gretype.LLAMA_GRETYPE_CHAR_ALT + if last_sym_start < out_elements.size() + else start_type + ) # type: llama_gretype + out_elements.push_back( + llama_grammar_element(type, char_pair[0]) + ) + # if (pos[0] == '-' && pos[1] != ']') { + # auto endchar_pair = parse_char(pos + 1); + # pos = endchar_pair.second; + # out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + # } + # } + if pos.derefer == "-" and (pos + 1).derefer != "]": + endchar_pair = parse_char( + pos + 1 + ) # type: Tuple[uint32_t, const_char_p] + pos = endchar_pair[1] + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + endchar_pair[0], + ) + ) + # pos = parse_space(pos + 1, is_nested); + pos = parse_space(pos + 1, is_nested) + # } else if (is_word_char(*pos)) { // rule reference + # const char * name_end = parse_name(pos); + # uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + # pos = parse_space(name_end, is_nested); + # last_sym_start = out_elements.size(); + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + elif is_word_char(pos.derefer): # rule reference + name_end = parse_name(pos) # type: const_char_p + ref_rule_id = get_symbol_id( + state, pos, name_end.sub(pos) + ) # type: uint32_t + pos = parse_space(name_end, is_nested) + last_sym_start = out_elements.size() + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id + ) + ) + # } else if (*pos == '(') { // grouping + # // parse nested alternates into synthesized rule + # pos = parse_space(pos + 1, true); + # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + # pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); + # last_sym_start = out_elements.size(); + # // output reference to synthesized rule + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # if (*pos != ')') { + # throw std::runtime_error(std::string("expecting ')' at ") + pos); + # } + # pos = parse_space(pos + 1, is_nested); + elif pos.derefer == "(": # grouping + pos = parse_space(pos + 1, True) + sub_rule_id = generate_symbol_id( + state, rule_name + ) # type: uint32_t + pos = parse_alternates(state, pos, rule_name, sub_rule_id, True) + last_sym_start = out_elements.size() + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + ) + ) + if pos.derefer != ")": + raise RuntimeError("expecting ')' at " + str(pos)) + pos = parse_space(pos + 1, is_nested) + # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + # if (last_sym_start == out_elements.size()) { + # throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); + # } + elif pos.derefer in ("*", "+", "?"): # repetition operator + if last_sym_start == out_elements.size(): + raise RuntimeError( + "expecting preceding item to */+/? at " + str(pos) + ) + # // apply transformation to previous symbol (last_sym_start to end) according to + # // rewrite rules: + # // S* --> S' ::= S S' | + # // S+ --> S' ::= S S' | S + # // S? --> S' ::= S | + # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + # std::vector sub_rule; + # // add preceding symbol to generated rule + # sub_rule.insert( + # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + sub_rule_id = generate_symbol_id( + state, rule_name + ) # type: uint32_t + sub_rule = std__vector[llama_grammar_element]() + sub_rule.insert( + sub_rule.end(), + out_elements.begin() + last_sym_start, + out_elements.end(), + ) + # if (*pos == '*' || *pos == '+') { + # // cause generated rule to recurse + # sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # } + # // mark start of alternate def + # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + if pos.derefer in ("*", "+"): + sub_rule.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + ) + ) + sub_rule.push_back( + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT, 0) + ) + # if (*pos == '+') { + # // add preceding symbol as alternate only for '+' (otherwise empty) + # sub_rule.insert( + # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + # } + # sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + # add_rule(state, sub_rule_id, sub_rule); + # // in original rule, replace previous symbol with reference to generated rule + # out_elements.resize(last_sym_start); + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + # pos = parse_space(pos + 1, is_nested); + if pos.derefer == "+": + sub_rule.insert( + sub_rule.end(), + out_elements.begin() + last_sym_start, + out_elements.end(), + ) + sub_rule.push_back( + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END, 0) + ) + add_rule(state, sub_rule_id, sub_rule) + out_elements.resize(last_sym_start) + out_elements.push_back( + llama_grammar_element( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + ) + ) + pos = parse_space(pos + 1, is_nested) + # } else { + # break; + # } + else: + break + # } + # return pos; + # } + return pos + + +# const char * parse_alternates( +# parse_state & state, +# const char * src, +# const std::string & rule_name, +# uint32_t rule_id, +# bool is_nested) { +# std::vector rule; +# const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); +# while (*pos == '|') { +# rule.push_back({LLAMA_GRETYPE_ALT, 0}); +# pos = parse_space(pos + 1, true); +# pos = parse_sequence(state, pos, rule_name, rule, is_nested); +# } +# rule.push_back({LLAMA_GRETYPE_END, 0}); +# add_rule(state, rule_id, rule); +# return pos; +# } +def parse_alternates( + state: parse_state, + src: const_char_p, + rule_name: str, + rule_id: uint32_t, + is_nested: bool, +) -> const_char_p: + rule = std__vector() # type: std__vector[llama_grammar_element] + pos = parse_sequence( + state, src, rule_name, rule, is_nested + ) # type: const_char_p + while pos.derefer == "|": + rule.push_back( + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT, 0) + ) + pos = parse_space(pos + 1, True) + pos = parse_sequence(state, pos, rule_name, rule, is_nested) + rule.push_back(llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END, 0)) + add_rule(state, rule_id, rule) + return pos + + +# const char * parse_rule(parse_state & state, const char * src) { +# const char * name_end = parse_name(src); +# const char * pos = parse_space(name_end, false); +# size_t name_len = name_end - src; +# uint32_t rule_id = get_symbol_id(state, src, name_len); +# const std::string name(src, name_len); + +# if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { +# throw std::runtime_error(std::string("expecting ::= at ") + pos); +# } +# pos = parse_space(pos + 3, true); + +# pos = parse_alternates(state, pos, name, rule_id, false); + + +# if (*pos == '\r') { +# pos += pos[1] == '\n' ? 2 : 1; +# } else if (*pos == '\n') { +# pos++; +# } else if (*pos) { +# throw std::runtime_error(std::string("expecting newline or end at ") + pos); +# } +# return parse_space(pos, true); +# } +def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: + name_end = parse_name(src) # type: const_char_p + pos = parse_space(name_end, False) # type: const_char_p + name_len = name_end.sub(src) # type: size_t + rule_id = get_symbol_id(state, src, name_len) # type: uint32_t + name = std__string(src, name_len) # type: std__string + + if not ( + pos.derefer == ":" + and (pos + 1).derefer == ":" + and (pos + 2).derefer == "=" + ): + raise RuntimeError("expecting ::= at " + str(pos)) + + pos = parse_space(pos + 3, True) # type: const_char_p + pos = parse_alternates( + state, pos, name, rule_id, False + ) # type: const_char_p + + if pos.derefer == "\r": + pos += 2 if (pos + 1).derefer == "\n" else 1 + elif pos.derefer == "\n": + pos.plus_plus() + elif pos.derefer: + raise RuntimeError("expecting newline or end at " + str(pos)) + return parse_space(pos, True) + + +# parse_state parse(const char * src) { +# try { +# parse_state state; +# const char * pos = parse_space(src, true); +# while (*pos) { +# pos = parse_rule(state, pos); +# } +# return state; +# } catch (const std::exception & err) { +# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); +# return parse_state(); +# } +# } +def parse(src: const_char_p) -> parse_state: + try: + state = parse_state() # type: parse_state + pos = parse_space(src, True) # type: const_char_p + while pos.derefer: + pos = parse_rule(state, pos) + return state + except Exception as err: + print(f"{parse.__name__}: error parsing grammar: {err}") + return parse_state() + + +# void print_grammar_char(FILE * file, uint32_t c) { +# if (0x20 <= c && c <= 0x7f) { +# fprintf(file, "%c", static_cast(c)); +# } else { +# // cop out of encoding UTF-8 +# fprintf(file, "", c); +# } +# } +def print_grammar_char(file: TextIO, c: uint32_t) -> None: + if 0x20 <= c and c <= 0x7F: + file.write(chr(c)) + else: + # cop out of encoding UTF-8 + file.write(f"") + + +# bool is_char_element(llama_grammar_element elem) { +# switch (elem.type) { +# case LLAMA_GRETYPE_CHAR: return true; +# case LLAMA_GRETYPE_CHAR_NOT: return true; +# case LLAMA_GRETYPE_CHAR_ALT: return true; +# case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; +# default: return false; +# } +# } +def is_char_element(elem: llama_grammar_element) -> bool: + return elem.type in ( + llama_gretype.LLAMA_GRETYPE_CHAR, + llama_gretype.LLAMA_GRETYPE_CHAR_NOT, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + ) + + +# void print_rule( +# FILE * file, +# uint32_t rule_id, +# const std::vector & rule, +# const std::map & symbol_id_names) { +def print_rule( + file: TextIO, + rule_id: uint32_t, + rule: std__vector[llama_grammar_element], + symbol_id_names: std__map[uint32_t, str], +) -> None: + # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + # throw std::runtime_error( + # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + # } + # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: + raise RuntimeError( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + + str(rule_id) + ) + print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") + # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + # llama_grammar_element elem = rule[i]; + # switch (elem.type) { + # case LLAMA_GRETYPE_END: + # throw std::runtime_error( + # "unexpected end of rule: " + std::to_string(rule_id) + "," + + # std::to_string(i)); + # case LLAMA_GRETYPE_ALT: + # fprintf(file, "| "); + # break; + # case LLAMA_GRETYPE_RULE_REF: + # fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + # break; + # case LLAMA_GRETYPE_CHAR: + # fprintf(file, "["); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_NOT: + # fprintf(file, "[^"); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_RNG_UPPER: + # if (i == 0 || !is_char_element(rule[i - 1])) { + # throw std::runtime_error( + # "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + # std::to_string(rule_id) + "," + std::to_string(i)); + # } + # fprintf(file, "-"); + # print_grammar_char(file, elem.value); + # break; + # case LLAMA_GRETYPE_CHAR_ALT: + # if (i == 0 || !is_char_element(rule[i - 1])) { + # throw std::runtime_error( + # "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + # std::to_string(rule_id) + "," + std::to_string(i)); + # } + # print_grammar_char(file, elem.value); + # break; + # } + for i, elem in enumerate(rule[:-1]): + switch = elem.type # type: llama_gretype + if switch == llama_gretype.LLAMA_GRETYPE_END: + raise RuntimeError( + "unexpected end of rule: " + str(rule_id) + "," + str(i) + ) + elif switch == llama_gretype.LLAMA_GRETYPE_ALT: + print("| ", file=file, end="") + elif switch == llama_gretype.LLAMA_GRETYPE_RULE_REF: + print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") + elif switch == llama_gretype.LLAMA_GRETYPE_CHAR: + print("[", file=file, end="") + print_grammar_char(file, elem.value) + elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_NOT: + print("[^", file=file, end="") + print_grammar_char(file, elem.value) + elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: + if i == 0 or not is_char_element(rule[i - 1]): + raise RuntimeError( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + str(rule_id) + + "," + + str(i) + ) + print("-", file=file, end="") + print_grammar_char(file, elem.value) + elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_ALT: + if i == 0 or not is_char_element(rule[i - 1]): + raise RuntimeError( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + str(rule_id) + + "," + + str(i) + ) + print_grammar_char(file, elem.value) + # if (is_char_element(elem)) { + # switch (rule[i + 1].type) { + # case LLAMA_GRETYPE_CHAR_ALT: + # case LLAMA_GRETYPE_CHAR_RNG_UPPER: + # break; + # default: + # fprintf(file, "] "); + if is_char_element(elem): + if rule[i + 1].type in ( + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + ): + pass + else: + print("] ", file=file, end="") + # } + # } + # } + # fprintf(file, "\n"); + # } + print(file=file) + + +# void print_grammar(FILE * file, const parse_state & state) { +# try { +# std::map symbol_id_names; +# for (auto kv : state.symbol_ids) { +# symbol_id_names[kv.second] = kv.first; +# } +# for (size_t i = 0, end = state.rules.size(); i < end; i++) { +# // fprintf(file, "%zu: ", i); +# // print_rule_binary(file, state.rules[i]); +# print_rule(file, i, state.rules[i], symbol_id_names); +# // fprintf(file, "\n"); +# } +# } catch (const std::exception & err) { +# fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); +# } +# } +def print_grammar(file: TextIO, state: parse_state) -> None: + try: + symbol_id_names = std__map() # type: std__map[uint32_t, str] + for kv in state.symbol_ids.items(): + symbol_id_names[kv[1]] = kv[0] + + for i, rule in enumerate(state.rules): + print_rule(file, i, rule, symbol_id_names) + except Exception as err: + print( + f"{print_grammar.__name__}: error printing grammar: {err}", + file=sys.stderr, + ) + + +def convert_to_rules( + llama_grammar_elements: std__vector[std__vector[llama_grammar_element]], +) -> Array[llama_cpp.llama_grammar_element_p]: + """Make an Array object that is used for `llama_grammer_init`""" + + # Step 1: Convert to c_llama_grammar_element + llama_grammar_element_p_p = ( + [] + ) # type: List[List[llama_cpp.llama_grammar_element]] + for subvector in llama_grammar_elements: + llama_grammar_element_p_p.append([]) + for elem in subvector: + c_llama_grammar_element = llama_cpp.llama_grammar_element() + c_llama_grammar_element.type = c_int(elem.type.value) + c_llama_grammar_element.value = c_uint32(elem.value) + llama_grammar_element_p_p[-1].append(c_llama_grammar_element) + + # Step 2: Convert each list to llama_grammar_element array and get pointer + element_arrays = [ + (llama_cpp.llama_grammar_element * len(sublist))(*sublist) + for sublist in llama_grammar_element_p_p + ] # type: List[Array[llama_cpp.llama_grammar_element]] + + # Step 3: Get pointer of each array + element_array_pointers = [ + cast(sublist, llama_cpp.llama_grammar_element_p) + for sublist in element_arrays + ] # type: List[llama_cpp.llama_grammar_element_p] + + # Step 4: Make array of these pointers and get its pointer + return (llama_cpp.llama_grammar_element_p * len(element_array_pointers))( + *element_array_pointers + ) + + +def parse_grammar_init_args( + bnf: str, +) -> Tuple[Array[llama_cpp.llama_grammar_element_p], c_size_t, c_size_t]: + """Parse a GBNF string and return tuple of `grammar rules` and `root symbol id`""" + parsed_grammar = parse(const_char_p(bnf)) # type: parse_state + if parsed_grammar.rules.empty(): + raise Exception( + f"{parse_grammar_init_args.__name__}: error parsing grammar file: parsed_grammar.rules is empty" + ) + print(f"{parse_grammar_init_args.__name__} grammar:", file=sys.stderr) + print_grammar(sys.stdout, parsed_grammar) + print(file=sys.stderr) + grammar_rules = ( + parsed_grammar.c_rules() + ) # type: std__vector[std__vector[llama_grammar_element]] + return ( + convert_to_rules(grammar_rules), + c_size_t(grammar_rules.size()), + c_size_t(parsed_grammar.symbol_ids.at("root")), + ) + + +def parse_grammar_init_args_from_file( + bnf_path: Union[str, Path] +) -> Tuple[Array[llama_cpp.llama_grammar_element_p], c_size_t, c_size_t]: + """Parse a GBNF file and return tuple of `grammar rules` and `root symbol id`""" + try: + with open(bnf_path) as f: + params_grammer = f.read() + except Exception as err: + raise Exception( + f"{parse_grammar_init_args_from_file.__name__}: error reading grammar file: {err}" + ) + + if params_grammer: + return parse_grammar_init_args(params_grammer) + + raise Exception( + f"{parse_grammar_init_args_from_file.__name__}: error parsing grammar file: params_grammer is empty" + ) + + +# def get_grammar_p(bnf: str) -> llama_cpp.llama_grammar_p: +# """Parse a GBNF string and return pointer to `llama_grammar`""" + +# grammar_rules, root_symbol_id = parse_rules(bnf) + +# grammar_element_p_p = convert_to_double_ptr( +# grammar_rules +# ) # type: llama_cpp.llama_grammar_element_p_p + +# c_llama_grammar_p = llama_cpp.llama_grammar_init( +# grammar_element_p_p, +# c_size_t(grammar_rules.size()), +# c_size_t(root_symbol_id), +# ) # type: llama_cpp.llama_grammar_p +# return c_llama_grammar_p + + +# def get_grammar_p_from_file( +# bnf_path: Union[str, Path] +# ) -> llama_cpp.llama_grammar_p: +# """Parse a GBNF file and return pointer to `llama_grammar`""" +# try: +# with open(bnf_path) as f: +# params_grammer = f.read() +# except Exception as err: +# raise Exception( +# f"{get_grammar_p_from_file.__name__}: error reading grammar file: {err}" +# ) + +# if params_grammer: +# return get_grammar_p(params_grammer) + +# raise Exception( +# f"{get_grammar_p_from_file.__name__}: error parsing grammar file: params_grammer is empty" +# ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate C++ parser from GBNF grammar" + ) + parser.add_argument( + "-g", + "--grammar", + type=str, + default="./vendor/llama.cpp/grammars/json.gbnf", + help="path to GBNF grammar file", + ) + + args = parser.parse_args() + rules, n_rules, start_rule_index = parse_grammar_init_args_from_file( + args.grammar + ) + llama_grammar_p = llama_cpp.llama_grammar_init( + rules, + n_rules, + start_rule_index, + ) # type: llama_cpp.llama_grammar_p + + # ----- USAGE: + # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p) + # llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...) + + # ----- SAMPLE OUTPUT: + # main grammar: + # root ::= object + # object ::= [{] ws object_11 [}] ws + # value ::= object | array | string | number | value_6 ws + # array ::= [[] ws array_15 []] ws + # string ::= ["] string_18 ["] ws + # number ::= number_19 number_25 number_29 ws + # value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l] + # ws ::= ws_31 + # object_8 ::= string [:] ws value object_10 + # object_9 ::= [,] ws string [:] ws value + # object_10 ::= object_9 object_10 | + # object_11 ::= object_8 | + # array_12 ::= value array_14 + # array_13 ::= [,] ws value + # array_14 ::= array_13 array_14 | + # array_15 ::= array_12 | + # string_16 ::= [^"\] | [\] string_17 + # string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] + # string_18 ::= string_16 string_18 | + # number_19 ::= number_20 number_21 + # number_20 ::= [-] | + # number_21 ::= [0-9] | [1-9] number_22 + # number_22 ::= [0-9] number_22 | + # number_23 ::= [.] number_24 + # number_24 ::= [0-9] number_24 | [0-9] + # number_25 ::= number_23 | + # number_26 ::= [eE] number_27 number_28 + # number_27 ::= [-+] | + # number_28 ::= [0-9] number_28 | [0-9] + # number_29 ::= number_26 | + # ws_30 ::= [ ] ws + # ws_31 ::= ws_30 | From 418aa83b01bc7ea9fdd26546efb9d9899061cc4a Mon Sep 17 00:00:00 2001 From: c0sogi Date: Mon, 7 Aug 2023 02:21:37 +0900 Subject: [PATCH 2/4] Added grammar based sampling --- llama_cpp/llama.py | 36 +- llama_cpp/llama_grammar.py | 994 ++++++++++++++++++------------------- 2 files changed, 512 insertions(+), 518 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 66c76c9bf..ab99ee5bd 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import sys import uuid import time @@ -23,6 +24,7 @@ from . import llama_cpp from .llama_types import * +from .llama_grammar import LlamaGrammar import numpy as np import numpy.typing as npt @@ -223,6 +225,7 @@ def __init__( tensor_split: Optional[List[float]] = None, rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, + grammar: Optional[Union[str, Path]] = None, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b rms_norm_eps: Optional[float] = None, # (TEMPORARY) verbose: bool = True, @@ -248,6 +251,7 @@ def __init__( tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split. rope_freq_base: Base frequency for rope sampling. rope_freq_scale: Scale factor for rope sampling. + grammar: Path to a BNF grammar file to use for grammar based sampling. verbose: Print verbose output to stderr. Raises: @@ -358,6 +362,12 @@ def __init__( self.scores: npt.NDArray[np.single] = np.ndarray( (n_ctx, self._n_vocab), dtype=np.single ) + if grammar is not None: + self.grammar = LlamaGrammar.from_file( + grammar + ) # type: Optional[LlamaGrammar] + else: + self.grammar = None @property def _input_ids(self) -> npt.NDArray[np.intc]: @@ -542,8 +552,16 @@ def _sample( ) if not penalize_nl: candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) + + if self.grammar is not None: + llama_cpp.llama_sample_grammar( + ctx=self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), # type: ignore + grammar=self.grammar.grammar, + ) + if temp.value == 0.0: - return llama_cpp.llama_sample_token_greedy( + id = llama_cpp.llama_sample_token_greedy( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) @@ -555,7 +573,7 @@ def _sample( candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token_mirostat( + id = llama_cpp.llama_sample_token_mirostat( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -570,7 +588,7 @@ def _sample( candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token_mirostat_v2( + id = llama_cpp.llama_sample_token_mirostat_v2( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -607,10 +625,17 @@ def _sample( candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token( + id = llama_cpp.llama_sample_token( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) + if self.grammar is not None: + llama_cpp.llama_grammar_accept_token( + ctx=self.ctx, + grammar=self.grammar.grammar, + token=llama_cpp.ctypes.c_int(id), + ) + return id def sample( self, @@ -1509,6 +1534,9 @@ def __del__(self): if self.ctx is not None: llama_cpp.llama_free(self.ctx) self.ctx = None + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar.grammar) + self.grammar = None def __getstate__(self): return dict( diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 07a120f2c..06b2b7ff2 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -3,7 +3,7 @@ import argparse from pathlib import Path import sys -from ctypes import Array, c_int, c_size_t, c_uint32, cast +from ctypes import * # type: ignore from enum import Enum from itertools import islice from typing import ( @@ -16,293 +16,379 @@ Tuple, TypeVar, Union, + overload, ) import llama_cpp +# Type aliases +llama_grammar_element = llama_cpp.llama_grammar_element +llama_grammar_element_p = llama_cpp.llama_grammar_element_p +llama_grammar_p = llama_cpp.llama_grammar_p + +# Type variables +Ptr = TypeVar("Ptr", bound="const_char_p") T = TypeVar("T") U = TypeVar("U") V = TypeVar("V") W = TypeVar("W") -size_t = uint8_t = uint32_t = int -static_cast_uint8_t = ord class Sentinel: - pass + """Used to mark the end of a iterator of std::vector & std::map.""" + + +class LlamaGrammar: + """Keeps reference counts of all the arguments, so that they are not + garbage collected by Python.""" + + def __init__( + self, + parsed_grammar: "parse_state", + ) -> None: + grammar_rules = ( + parsed_grammar.c_rules() + ) # type: std.vector[std.vector[llama_grammar_element]] + + # Step 1: Convert each list to llama_grammar_element array and get pointer + self.element_arrays = [ + (llama_grammar_element * len(sublist))(*sublist) + for sublist in grammar_rules + ] # type: List[Array[llama_grammar_element]] + + # Step 2: Get pointer of each array + self.element_array_pointers = [ + cast(subarray, llama_grammar_element_p) + for subarray in self.element_arrays + ] # type: List[llama_grammar_element_p] + + # Step 3: Make array of these pointers and get its pointer + self.rules = ( + llama_grammar_element_p * len(self.element_array_pointers) + )(*self.element_array_pointers) + + self.n_rules = c_size_t(grammar_rules.size()) + self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) + self.grammar = self.init_grammar() + + @classmethod + def from_string(cls, grammar: str) -> "LlamaGrammar": + parsed_grammar = parse(const_char_p(grammar)) # type: parse_state + if parsed_grammar.rules.empty(): + raise ValueError( + f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" + ) + print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) + print_grammar(sys.stdout, parsed_grammar) + print(file=sys.stderr) + return cls(parsed_grammar) + + @classmethod + def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar": + try: + with open(file) as f: + grammar = f.read() + except Exception as err: + raise Exception( + f"{cls.from_file.__name__}: error reading grammar file: {err}" + ) + + if grammar: + return cls.from_string(grammar) + + raise ValueError( + f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" + ) + + def init_grammar(self) -> llama_grammar_p: + return llama_cpp.llama_grammar_init( + self.rules, self.n_rules, self.start_rule_index + ) class const_char_p: - """C++ implementation of const char*.""" + """C++ implementation of const char *.""" - def __init__(self, value: Union[str, "const_char_p"]): + def __init__(self, value: Union[str, Ptr], move: Optional[int] = None): if isinstance(value, const_char_p): # We're copying an existing const_char_p self.value = value.value - self.pos = value.pos + self.pos = value.pos + (move or 0) return # We're creating a new const_char_p self.value = value - self.pos = 0 + self.pos = move or 0 def __str__(self) -> str: + assert self.value is not None, "null pointer" return self.value[self.pos :] - def __add__(self, increment: int) -> "const_char_p": - # To avoid side effects, we create a new const_char_p object - new = self.__class__(self.value) - new.pos = self.pos + increment - return new - - def __sub__(self, decrement: int) -> "const_char_p": - # To avoid side effects, we create a new const_char_p object - new = self.__class__(self.value) - new.pos = self.pos - decrement - return new - - def __lt__(self, other: "const_char_p") -> bool: - return self.pos < other.pos and self.value == other.value - - def __gt__(self, other: "const_char_p") -> bool: - return self.pos > other.pos and self.value == other.value - - def __eq__(self, other: "const_char_p") -> bool: - return self.pos == other.pos and self.value == other.value - - def add(self, other: "const_char_p") -> int: - if self.value != other.value: - raise ValueError("Can't add pointers to different strings") - return self.pos + other.pos - - def sub(self, other: "const_char_p") -> int: - if self.value != other.value: - raise ValueError("Can't subtract pointers to different strings") - return self.pos - other.pos - - def plus_plus(self) -> None: - self.pos += 1 - - def minus_minus(self) -> None: - self.pos -= 1 - - @property - def derefer(self) -> Optional[str]: - if self.pos >= len(self.value): - # We've reached the end of the string - return None - - return self.value[self.pos] - + def __getitem__(self, index: int) -> str: + value = str(self) + return value[index] if index < len(value) else "" -class std__vector(Generic[T], List[T]): - """C++ implementation of std::vector.""" + @overload + def __add__(self: Ptr, other: int) -> Ptr: + ... - class iterator: - def __init__(self, vector: "std__vector[T]", index: int): - self._vector = vector - self._index = index - self._version = vector._version + @overload + def __add__(self: Ptr, other: Ptr) -> int: + ... - def _check_version(self): - if self._version != self._vector._version: - raise RuntimeError("Iterator used after vector was modified.") - - def __iter__(self): - return self + def __add__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: + return ( + self.__class__(self.value, self.pos + other) + if isinstance(other, int) + else self.pos + other.pos + ) - def __next__(self) -> T: - self._check_version() - if self._index >= self._vector.size(): - raise StopIteration - value = self._vector[self._index] - self._index += 1 - return value + @overload + def __sub__(self: Ptr, other: int) -> Ptr: + ... - def __add__(self, value: int) -> "std__vector[T].iterator": - return self.__class__(self._vector, self._index + value) + @overload + def __sub__(self: Ptr, other: Ptr) -> int: + ... - def __sub__(self, value: int) -> "std__vector[T].iterator": - return self.__class__(self._vector, self._index - value) + def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: + return ( + self.__class__(self.value, self.pos - other) + if isinstance(other, int) + else self.pos - other.pos + ) - def __init__(self): - self._version = 0 + def __eq__(self: Ptr, other: Ptr) -> bool: + assert ( + self.value == other.value + ), "comparing pointers from different strings" + return self.pos == other.pos + + def __lt__(self: Ptr, other: Ptr) -> bool: + assert ( + self.value == other.value + ), "comparing pointers from different strings" + return self.pos < other.pos + + def __gt__(self: Ptr, other: Ptr) -> bool: + assert ( + self.value == other.value + ), "comparing pointers from different strings" + return self.pos > other.pos + + +class std: + @staticmethod + def string(ptr: const_char_p, length: Optional[int] = None) -> str: + """C++ implementation of std::string constructor.""" + value = str(ptr) + if length is not None: + value = value[:length] + return value + + class vector(Generic[T], List[T]): + """C++ implementation of std::vector.""" + + class iterator: + def __init__(self, vector: "std.vector[T]", index: int): + self._vector = vector + self._index = index + self._version = vector._version + + def _check_version(self): + if self._version != self._vector._version: + raise RuntimeError( + "Iterator used after vector was modified." + ) - def modify(self): - # This is a bit of a hack to make sure iterators are invalidated - self._version += 1 + def __iter__(self): + return self - def push_back(self, value: T) -> None: - self.modify() - self.append(value) + def __next__(self) -> T: + self._check_version() + if self._index >= self._vector.size(): + raise StopIteration + value = self._vector[self._index] + self._index += 1 + return value - def pop_back(self) -> None: - self.modify() - if not self.empty(): - self.pop() + def __add__(self, value: int) -> "std.vector[T].iterator": + return self.__class__(self._vector, self._index + value) - def back(self) -> T: - return self[-1] + def __sub__(self, value: int) -> "std.vector[T].iterator": + return self.__class__(self._vector, self._index - value) - def size(self) -> int: - return len(self) + def __init__(self): + self._version = 0 - # def clear(self) -> None: - # super().clear() + def modify(self): + # This is a bit of a hack to make sure iterators are invalidated + self._version += 1 - def empty(self) -> bool: - return self.size() == 0 + def push_back(self, value: T) -> None: + self.modify() + self.append(value) - def data(self) -> "std__vector[T]": - return self + def pop_back(self) -> None: + self.modify() + if not self.empty(): + self.pop() - def resize( - self, - new_size: int, - fill_value_factory: Optional[Callable[[], T]] = None, - ) -> None: - if new_size > self.size(): - if fill_value_factory is None: - raise ValueError( - "A fill value factory function must be provided." - ) - self.reserve(new_size, fill_value_factory) - elif new_size < self.size(): - self[:] = self[:new_size] + def back(self) -> T: + return self[-1] - def reserve( - self, capacity: int, fill_value_factory: Callable[[], T] - ) -> None: - if capacity > self.size(): - fill_value = fill_value_factory() - self.extend([fill_value] * (capacity - self.size())) + def size(self) -> int: + return len(self) - def front(self) -> T: - if not self.empty(): - return self[0] - else: - raise IndexError("Vector is empty.") + def clear(self) -> None: + self.modify() + super().clear() - def assign(self, count: int, value: T) -> None: - self.clear() - self.extend([value] * count) + def empty(self) -> bool: + return self.size() == 0 - def insert( - self, - pos: "std__vector[T].iterator", - first: "std__vector[T].iterator", - last: "std__vector[T].iterator", - ) -> None: - self[pos._index : pos._index] = list( - islice(first._vector, first._index, last._index) - ) - - def begin(self) -> "std__vector[T].iterator": - return self.iterator(self, 0) - - def end(self) -> "std__vector[T].iterator": - return self.iterator(self, self.size()) - - -class std__map(Generic[T, U], OrderedDict[T, U]): - """C++ implementation of std::map.""" - - class iterator(Generic[V, W]): - def __init__(self, _map: "std__map[T, U]", key: Union[T, Sentinel]): - self._map = _map - self.iter = iter(_map) - self.key = key - self._advance() - - def _sanitize_key(self) -> T: - if isinstance(self.key, Sentinel): - raise StopIteration - return self.key + def data(self) -> "std.vector[T]": + return self - def _advance(self) -> None: - try: - while next(self.iter) != self.key: - pass - except StopIteration: - self.key = Sentinel() - - def __next__(self) -> Tuple[T, U]: - key = self._sanitize_key() - if key in self._map: - value = self._map[key] - self._advance() - return key, value + def resize( + self, + new_size: int, + fill_value_factory: Optional[Callable[[], T]] = None, + ) -> None: + if new_size > self.size(): + if fill_value_factory is None: + raise ValueError( + "A fill value factory function must be provided." + ) + self.reserve(new_size, fill_value_factory) + elif new_size < self.size(): + self[:] = self[:new_size] + + def reserve( + self, capacity: int, fill_value_factory: Callable[[], T] + ) -> None: + if capacity > self.size(): + fill_value = fill_value_factory() + self.extend([fill_value] * (capacity - self.size())) + + def front(self) -> T: + if not self.empty(): + return self[0] else: - raise StopIteration - - def get(self) -> Tuple[T, U]: - key = self._sanitize_key() - return key, self._map[key] + raise IndexError("Vector is empty.") + + def assign(self, count: int, value: T) -> None: + self.clear() + self.extend([value] * count) + + def insert( + self, + pos: "std.vector[T].iterator", + first: "std.vector[T].iterator", + last: "std.vector[T].iterator", + ) -> None: + self[pos._index : pos._index] = list( + islice(first._vector, first._index, last._index) + ) - @property - def first(self) -> T: - return self._sanitize_key() + def begin(self) -> "std.vector[T].iterator": + return self.iterator(self, 0) - @property - def second(self) -> U: - return self._map[self._sanitize_key()] + def end(self) -> "std.vector[T].iterator": + return self.iterator(self, self.size()) - def insert( - self, key: T, value: U - ) -> Tuple["std__map[T, U].iterator[T, U]", bool]: - if key in self: - return self.iterator(self, key), False - else: - self[key] = value - return self.iterator(self, key), True + class map(Generic[T, U], OrderedDict[T, U]): + """C++ implementation of std::map.""" - def find(self, key: T) -> "std__map[T, U].iterator[T, U]": - if key in self: - return self.iterator(self, key) - else: - return self.end() + class iterator(Generic[V, W]): + def __init__(self, _map: "std.map[T, U]", key: Union[T, Sentinel]): + self._map = _map + self.iter = iter(_map) + self.key = key + self._advance() - def at(self, key: T) -> U: - if key in self: - return self[key] - else: - raise KeyError("The provided key is not found in the map.") + def _sanitize_key(self) -> T: + if isinstance(self.key, Sentinel): + raise StopIteration + return self.key + + def _advance(self) -> None: + try: + while next(self.iter) != self.key: + pass + except StopIteration: + self.key = Sentinel() + + def __next__(self) -> Tuple[T, U]: + key = self._sanitize_key() + if key in self._map: + value = self._map[key] + self._advance() + return key, value + else: + raise StopIteration + + def get(self) -> Tuple[T, U]: + key = self._sanitize_key() + return key, self._map[key] + + @property + def first(self) -> T: + return self._sanitize_key() + + @property + def second(self) -> U: + return self._map[self._sanitize_key()] + + def insert( + self, key: T, value: U + ) -> Tuple["std.map[T, U].iterator[T, U]", bool]: + if key in self: + return self.iterator(self, key), False + else: + self[key] = value + return self.iterator(self, key), True - def erase(self, iterator: "std__map[T, U].iterator[T, U]") -> None: - key = iterator.first - if key in self: - del self[key] + def find(self, key: T) -> "std.map[T, U].iterator[T, U]": + if key in self: + return self.iterator(self, key) + else: + return self.end() - def size(self) -> int: - return len(self) + def at(self, key: T) -> U: + if key in self: + return self[key] + else: + raise KeyError("The provided key is not found in the map.") - def empty(self) -> bool: - return self.size() == 0 + def erase(self, iterator: "std.map[T, U].iterator[T, U]") -> None: + key = iterator.first + if key in self: + del self[key] - def lower_bound(self, key: T) -> "std__map[T, U].iterator[T, U]": - try: - keys = sorted(list(self.keys())) # type: ignore - for k in keys: - if k >= key: - return self.iterator(self, k) - raise ValueError( - "No key found that is not less than the input key" - ) - except TypeError: - raise TypeError("Keys of type T cannot be sorted.") + def size(self) -> int: + return len(self) - def begin(self) -> "std__map[T, U].iterator[T, U]": - return self.iterator(self, next(iter(self))) + def empty(self) -> bool: + return self.size() == 0 - def end(self) -> "std__map[T, U].iterator[T, U]": - return self.iterator(self, Sentinel()) + def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]": + try: + keys = sorted(list(self.keys())) # type: ignore + for k in keys: + if k >= key: + return self.iterator(self, k) + raise ValueError( + "No key found that is not less than the input key" + ) + except TypeError: + raise TypeError("Keys of type T cannot be sorted.") + def begin(self) -> "std.map[T, U].iterator[T, U]": + return self.iterator(self, next(iter(self))) -class std__string(str): - def __new__(cls, ptr: const_char_p, length: Optional[int] = None): - if length is not None: - return super().__new__(cls, str(ptr)[:length]) - return super().__new__(cls, str(ptr)) + def end(self) -> "std.map[T, U].iterator[T, U]": + return self.iterator(self, Sentinel()) # // grammar element type @@ -343,28 +429,6 @@ class llama_gretype(Enum): LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) -# typedef struct llama_grammar_element { -# enum llama_gretype type; -# uint32_t value; // Unicode code point or rule ID -# } llama_grammar_element; - - -# class llama_grammar_element(Structure): -# _fields_ = [ -# ("type", c_int), -# ("value", c_uint32), -# ] - - -class llama_grammar_element: - def __init__(self, type: llama_gretype, value: uint32_t): - self.type = type - self.value = value # Unicode code point or rule ID - - def __repr__(self): # debug - return f"llama_grammar_element({self.type}, {self.value})" - - # struct parse_state { # std::map symbol_ids; # std::vector> rules; @@ -372,10 +436,10 @@ def __repr__(self): # debug # }; class parse_state: def __init__(self): - self.symbol_ids: std__map[str, uint32_t] = std__map() - self.rules: std__vector[ - std__vector[llama_grammar_element] - ] = std__vector() + self.symbol_ids: std.map[str, int] = std.map() + self.rules: std.vector[ + std.vector[llama_grammar_element] + ] = std.vector() # std::vector parse_state::c_rules() { # std::vector ret; @@ -384,27 +448,30 @@ def __init__(self): # } # return ret; # } - def c_rules(self) -> std__vector[std__vector[llama_grammar_element]]: + def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]: ret = ( - std__vector() - ) # type: std__vector[std__vector[llama_grammar_element]] + std.vector() + ) # type: std.vector[std.vector[llama_grammar_element]] for rule in self.rules: ret.push_back(rule.data()) return ret + def __repr__(self) -> str: + return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + # struct llama_grammar { # const std::vector> rules; # std::vector> stacks; # }; -class llama_grammar: - def __init__( - self, - rules: std__vector[std__vector[llama_grammar_element]], - stacks: std__vector[std__vector[llama_grammar_element]], - ): - self.rules = rules - self.stacks = stacks +# class llama_grammar: +# def __init__( +# self, +# rules: std.vector[std.vector[llama_grammar_element]], +# stacks: std.vector[std.vector[llama_grammar_element]], +# ): +# self.rules = rules +# self.stacks = stacks # uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { @@ -412,9 +479,9 @@ def __init__( # auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); # return result.first->second; # } -def get_symbol_id(state: parse_state, src: const_char_p, len: size_t) -> int: - next_id = uint32_t(state.symbol_ids.size()) # type: uint32_t - result = state.symbol_ids.insert(str(std__string(src, len)), next_id) +def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int: + next_id = state.symbol_ids.size() # type: int + result = state.symbol_ids.insert(std.string(src, len), next_id) return result[0].second # type: ignore @@ -423,8 +490,8 @@ def get_symbol_id(state: parse_state, src: const_char_p, len: size_t) -> int: # state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; # return next_id; # } -def generate_symbol_id(state: parse_state, base_name: str) -> uint32_t: - next_id = state.symbol_ids.size() # type: uint32_t +def generate_symbol_id(state: parse_state, base_name: str) -> int: + next_id = state.symbol_ids.size() # type: int state.symbol_ids[base_name + "_" + str(next_id)] = next_id return next_id @@ -440,12 +507,13 @@ def generate_symbol_id(state: parse_state, base_name: str) -> uint32_t: # } def add_rule( state: parse_state, - rule_id: uint32_t, - rule: std__vector[llama_grammar_element], + rule_id: int, + rule: std.vector[llama_grammar_element], ) -> None: if state.rules.size() <= rule_id: state.rules.resize( - rule_id + 1, fill_value_factory=std__vector[llama_grammar_element] + rule_id + 1, + fill_value_factory=std.vector[llama_grammar_element], ) state.rules[rule_id] = rule @@ -464,19 +532,19 @@ def add_rule( # } # return std::make_pair(value, pos); # } -def decode_utf8(src: const_char_p) -> Tuple[uint32_t, const_char_p]: +def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: """Decodes a UTF-8 character from the source string.""" lookup = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4) - first_byte = static_cast_uint8_t(src.derefer or "") # type: uint8_t - highbits = first_byte >> 4 # type: uint8_t + first_byte = ord(src[0]) # type: int + highbits = first_byte >> 4 # type: int len = lookup[highbits] # type: int - mask = (1 << (8 - len)) - 1 # type: uint8_t - value = first_byte & mask # type: uint32_t + mask = (1 << (8 - len)) - 1 # type: int + value = first_byte & mask # type: int end = src + len # type: const_char_p # may overrun! pos = src + 1 # type: const_char_p - while pos < end and pos.derefer: - value = (value << 6) + (static_cast_uint8_t(src.derefer or "") & 0x3F) - pos.plus_plus() + while pos < end and pos[0]: + value = (value << 6) + (ord(pos[0]) & 0x3F) + pos += 1 return value, pos @@ -511,22 +579,22 @@ def is_word_char(c: str) -> bool: # } # return std::make_pair(value, pos); # } -def parse_hex(src: const_char_p, size: int) -> Tuple[uint32_t, const_char_p]: +def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: pos = const_char_p(src) # type: const_char_p end = src + size # type: const_char_p - value = 0 # type: uint32_t - while pos < end and pos.derefer: + value = 0 # type: int + while pos < end and pos[0]: value <<= 4 - c = pos.derefer # type: str + c = pos[0] # type: str if "a" <= c <= "f": - value += static_cast_uint8_t(c) - static_cast_uint8_t("a") + 10 + value += ord(c) - ord("a") + 10 elif "A" <= c <= "F": - value += static_cast_uint8_t(c) - static_cast_uint8_t("A") + 10 + value += ord(c) - ord("A") + 10 elif "0" <= c <= "9": - value += static_cast_uint8_t(c) - static_cast_uint8_t("0") + value += ord(c) - ord("0") else: break - pos.plus_plus() + pos += 1 if pos != end: raise RuntimeError( "expecting " + str(size) + " hex chars at " + str(src) @@ -556,26 +624,26 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[uint32_t, const_char_p]: # } # throw std::runtime_error("unexpected end of input"); # } -def parse_char(src: const_char_p) -> Tuple[uint32_t, const_char_p]: - if src.derefer == "\\": - switch = (src + 1).derefer # type: Optional[str] - if switch == "x": +def parse_char(src: const_char_p) -> Tuple[int, const_char_p]: + if src[0] == "\\": + case = src[1] # type: str + if case == "x": return parse_hex(src + 2, 2) - elif switch == "u": + elif case == "u": return parse_hex(src + 2, 4) - elif switch == "U": + elif case == "U": return parse_hex(src + 2, 8) - elif switch == "t": - return (static_cast_uint8_t("\t"), src + 2) # implicit cast - elif switch == "r": - return (static_cast_uint8_t("\r"), src + 2) # implicit cast - elif switch == "n": - return (static_cast_uint8_t("\n"), src + 2) # implicit cast - elif switch in ("\\", '"', "[", "]"): - return (static_cast_uint8_t(switch), src + 2) # implicit cast + elif case == "t": + return (ord("\t"), src + 2) # implicit cast + elif case == "r": + return (ord("\r"), src + 2) # implicit cast + elif case == "n": + return (ord("\n"), src + 2) # implicit cast + elif case in ("\\", '"', "[", "]"): + return (ord(case), src + 2) # implicit cast else: raise RuntimeError("unknown escape at " + str(src)) - elif src.derefer: + elif src[0]: return decode_utf8(src) else: raise RuntimeError("unexpected end of input") @@ -593,8 +661,8 @@ def parse_char(src: const_char_p) -> Tuple[uint32_t, const_char_p]: # } def parse_name(src: const_char_p) -> const_char_p: pos = const_char_p(src) # type: const_char_p - while is_word_char(pos.derefer or ""): - pos.plus_plus() + while is_word_char(pos[0]): + pos += 1 if pos == src: raise RuntimeError("expecting name at " + str(src)) return pos @@ -615,18 +683,15 @@ def parse_name(src: const_char_p) -> const_char_p: # return pos; # } def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: - # Using a copy of `src` to avoid side effects - pos = const_char_p(src) - - while pos.derefer in (" ", "\t", "#") or ( - newline_ok and pos.derefer in ("\r", "\n") + pos = const_char_p(src) # type: const_char_p + while pos[0] in (" ", "\t", "#") or ( + newline_ok and pos[0] in ("\r", "\n") ): - if pos.derefer == "#": - while pos.derefer is not None and pos.derefer not in ("\r", "\n"): - pos.plus_plus() + if pos[0] == "#": + while pos[0] is not None and pos[0] not in ("\r", "\n"): + pos += 1 else: - pos.plus_plus() - + pos += 1 return pos @@ -640,15 +705,15 @@ def parse_sequence( state: parse_state, src: const_char_p, rule_name: str, - out_elements: std__vector[llama_grammar_element], + out_elements: std.vector[llama_grammar_element], is_nested: bool, ) -> const_char_p: # size_t last_sym_start = out_elements.size(); # const char * pos = src; - last_sym_start = out_elements.size() # type: size_t + last_sym_start = out_elements.size() # type: int pos = const_char_p(src) # type: const_char_p # while (*pos) { - while pos.derefer: + while pos[0]: # if (*pos == '"') { // literal string # pos++; # last_sym_start = out_elements.size(); @@ -658,25 +723,23 @@ def parse_sequence( # out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); # } # pos = parse_space(pos + 1, is_nested); - if pos.derefer == '"': # literal string - pos.plus_plus() + if pos[0] == '"': # literal string + pos += 1 last_sym_start = out_elements.size() - while pos.derefer != '"': - char_pair = parse_char( - pos - ) # type: Tuple[uint32_t, const_char_p] + while pos[0] != '"': + char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0] + llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0] ) ) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '[') { // char range(s) # pos++; # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - elif pos.derefer == "[": # char range(s) - pos.plus_plus() + elif pos[0] == "[": # char range(s) + pos += 1 start_type = ( llama_gretype.LLAMA_GRETYPE_CHAR ) # type: llama_gretype @@ -685,8 +748,8 @@ def parse_sequence( # start_type = LLAMA_GRETYPE_CHAR_NOT; # } # last_sym_start = out_elements.size(); - if pos.derefer == "^": - pos.plus_plus() + if pos[0] == "^": + pos += 1 start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT last_sym_start = out_elements.size() # while (*pos != ']') { @@ -696,10 +759,8 @@ def parse_sequence( # ? LLAMA_GRETYPE_CHAR_ALT # : start_type; # out_elements.push_back({type, char_pair.first}); - while pos.derefer != "]": - char_pair = parse_char( - pos - ) # type: Tuple[uint32_t, const_char_p] + while pos[0] != "]": + char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] type = ( llama_gretype.LLAMA_GRETYPE_CHAR_ALT @@ -707,7 +768,7 @@ def parse_sequence( else start_type ) # type: llama_gretype out_elements.push_back( - llama_grammar_element(type, char_pair[0]) + llama_grammar_element(type.value, char_pair[0]) ) # if (pos[0] == '-' && pos[1] != ']') { # auto endchar_pair = parse_char(pos + 1); @@ -715,14 +776,14 @@ def parse_sequence( # out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); # } # } - if pos.derefer == "-" and (pos + 1).derefer != "]": + if pos[0] == "-" and pos[1] != "]": endchar_pair = parse_char( pos + 1 - ) # type: Tuple[uint32_t, const_char_p] + ) # type: Tuple[int, const_char_p] pos = endchar_pair[1] out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, endchar_pair[0], ) ) @@ -734,16 +795,16 @@ def parse_sequence( # pos = parse_space(name_end, is_nested); # last_sym_start = out_elements.size(); # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - elif is_word_char(pos.derefer): # rule reference + elif is_word_char(pos[0]): # rule reference name_end = parse_name(pos) # type: const_char_p ref_rule_id = get_symbol_id( - state, pos, name_end.sub(pos) - ) # type: uint32_t + state, pos, name_end - pos + ) # type: int pos = parse_space(name_end, is_nested) last_sym_start = out_elements.size() out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id + llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id ) ) # } else if (*pos == '(') { // grouping @@ -758,26 +819,26 @@ def parse_sequence( # throw std::runtime_error(std::string("expecting ')' at ") + pos); # } # pos = parse_space(pos + 1, is_nested); - elif pos.derefer == "(": # grouping + elif pos[0] == "(": # grouping + # parse nested alternates into synthesized rule pos = parse_space(pos + 1, True) - sub_rule_id = generate_symbol_id( - state, rule_name - ) # type: uint32_t + sub_rule_id = generate_symbol_id(state, rule_name) # type: int pos = parse_alternates(state, pos, rule_name, sub_rule_id, True) last_sym_start = out_elements.size() + # output reference to synthesized rule out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id ) ) - if pos.derefer != ")": + if pos[0] != ")": raise RuntimeError("expecting ')' at " + str(pos)) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator # if (last_sym_start == out_elements.size()) { # throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); # } - elif pos.derefer in ("*", "+", "?"): # repetition operator + elif pos[0] in ("*", "+", "?"): # repetition operator if last_sym_start == out_elements.size(): raise RuntimeError( "expecting preceding item to */+/? at " + str(pos) @@ -792,10 +853,10 @@ def parse_sequence( # // add preceding symbol to generated rule # sub_rule.insert( # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - sub_rule_id = generate_symbol_id( - state, rule_name - ) # type: uint32_t - sub_rule = std__vector[llama_grammar_element]() + sub_rule_id = generate_symbol_id(state, rule_name) # type: int + sub_rule = std.vector[ + llama_grammar_element + ]() # type: std.vector[llama_grammar_element] sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, @@ -807,14 +868,14 @@ def parse_sequence( # } # // mark start of alternate def # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if pos.derefer in ("*", "+"): + if pos[0] in ("*", "+"): sub_rule.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id ) ) sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT, 0) + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) ) # if (*pos == '+') { # // add preceding symbol as alternate only for '+' (otherwise empty) @@ -827,20 +888,22 @@ def parse_sequence( # out_elements.resize(last_sym_start); # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); # pos = parse_space(pos + 1, is_nested); - if pos.derefer == "+": + if pos[0] == "+": + # add preceding symbol as alternate only for '+' (otherwise empty) sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end(), ) sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END, 0) + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) ) add_rule(state, sub_rule_id, sub_rule) + # in original rule, replace previous symbol with reference to generated rule out_elements.resize(last_sym_start) out_elements.push_back( llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id + llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id ) ) pos = parse_space(pos + 1, is_nested) @@ -876,20 +939,22 @@ def parse_alternates( state: parse_state, src: const_char_p, rule_name: str, - rule_id: uint32_t, + rule_id: int, is_nested: bool, ) -> const_char_p: - rule = std__vector() # type: std__vector[llama_grammar_element] + rule = std.vector() # type: std.vector[llama_grammar_element] pos = parse_sequence( state, src, rule_name, rule, is_nested ) # type: const_char_p - while pos.derefer == "|": + while pos[0] == "|": rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT, 0) + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) ) pos = parse_space(pos + 1, True) pos = parse_sequence(state, pos, rule_name, rule, is_nested) - rule.push_back(llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END, 0)) + rule.push_back( + llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) + ) add_rule(state, rule_id, rule) return pos @@ -921,15 +986,11 @@ def parse_alternates( def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: name_end = parse_name(src) # type: const_char_p pos = parse_space(name_end, False) # type: const_char_p - name_len = name_end.sub(src) # type: size_t - rule_id = get_symbol_id(state, src, name_len) # type: uint32_t - name = std__string(src, name_len) # type: std__string - - if not ( - pos.derefer == ":" - and (pos + 1).derefer == ":" - and (pos + 2).derefer == "=" - ): + name_len = name_end - src # type: int + rule_id = get_symbol_id(state, src, name_len) # type: int + name = std.string(src, name_len) # type: str + + if not (pos[0] == ":" and pos[1] == ":" and pos[2] == "="): raise RuntimeError("expecting ::= at " + str(pos)) pos = parse_space(pos + 3, True) # type: const_char_p @@ -937,11 +998,11 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: state, pos, name, rule_id, False ) # type: const_char_p - if pos.derefer == "\r": - pos += 2 if (pos + 1).derefer == "\n" else 1 - elif pos.derefer == "\n": - pos.plus_plus() - elif pos.derefer: + if pos[0] == "\r": + pos += 2 if pos[1] == "\n" else 1 + elif pos[0] == "\n": + pos += 1 + elif pos[0]: raise RuntimeError("expecting newline or end at " + str(pos)) return parse_space(pos, True) @@ -963,7 +1024,7 @@ def parse(src: const_char_p) -> parse_state: try: state = parse_state() # type: parse_state pos = parse_space(src, True) # type: const_char_p - while pos.derefer: + while pos[0]: pos = parse_rule(state, pos) return state except Exception as err: @@ -979,7 +1040,7 @@ def parse(src: const_char_p) -> parse_state: # fprintf(file, "", c); # } # } -def print_grammar_char(file: TextIO, c: uint32_t) -> None: +def print_grammar_char(file: TextIO, c: int) -> None: if 0x20 <= c and c <= 0x7F: file.write(chr(c)) else: @@ -998,10 +1059,10 @@ def print_grammar_char(file: TextIO, c: uint32_t) -> None: # } def is_char_element(elem: llama_grammar_element) -> bool: return elem.type in ( - llama_gretype.LLAMA_GRETYPE_CHAR, - llama_gretype.LLAMA_GRETYPE_CHAR_NOT, - llama_gretype.LLAMA_GRETYPE_CHAR_ALT, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + llama_gretype.LLAMA_GRETYPE_CHAR.value, + llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, ) @@ -1012,16 +1073,19 @@ def is_char_element(elem: llama_grammar_element) -> bool: # const std::map & symbol_id_names) { def print_rule( file: TextIO, - rule_id: uint32_t, - rule: std__vector[llama_grammar_element], - symbol_id_names: std__map[uint32_t, str], + rule_id: int, + rule: std.vector[llama_grammar_element], + symbol_id_names: std.map[int, str], ) -> None: # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { # throw std::runtime_error( # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # } # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: + if ( + rule.empty() + or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value + ): raise RuntimeError( "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) @@ -1067,22 +1131,22 @@ def print_rule( # break; # } for i, elem in enumerate(rule[:-1]): - switch = elem.type # type: llama_gretype - if switch == llama_gretype.LLAMA_GRETYPE_END: + case = elem.type # type: int + if case == llama_gretype.LLAMA_GRETYPE_END.value: raise RuntimeError( "unexpected end of rule: " + str(rule_id) + "," + str(i) ) - elif switch == llama_gretype.LLAMA_GRETYPE_ALT: + elif case == llama_gretype.LLAMA_GRETYPE_ALT.value: print("| ", file=file, end="") - elif switch == llama_gretype.LLAMA_GRETYPE_RULE_REF: + elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value: print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") - elif switch == llama_gretype.LLAMA_GRETYPE_CHAR: + elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value: print("[", file=file, end="") print_grammar_char(file, elem.value) - elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_NOT: + elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value: print("[^", file=file, end="") print_grammar_char(file, elem.value) - elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: + elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " @@ -1092,7 +1156,7 @@ def print_rule( ) print("-", file=file, end="") print_grammar_char(file, elem.value) - elif switch == llama_gretype.LLAMA_GRETYPE_CHAR_ALT: + elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_ALT without preceding char: " @@ -1110,8 +1174,8 @@ def print_rule( # fprintf(file, "] "); if is_char_element(elem): if rule[i + 1].type in ( - llama_gretype.LLAMA_GRETYPE_CHAR_ALT, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, ): pass else: @@ -1142,7 +1206,7 @@ def print_rule( # } def print_grammar(file: TextIO, state: parse_state) -> None: try: - symbol_id_names = std__map() # type: std__map[uint32_t, str] + symbol_id_names = std.map() # type: std.map[int, str] for kv in state.symbol_ids.items(): symbol_id_names[kv[1]] = kv[0] @@ -1155,117 +1219,25 @@ def print_grammar(file: TextIO, state: parse_state) -> None: ) -def convert_to_rules( - llama_grammar_elements: std__vector[std__vector[llama_grammar_element]], -) -> Array[llama_cpp.llama_grammar_element_p]: - """Make an Array object that is used for `llama_grammer_init`""" - - # Step 1: Convert to c_llama_grammar_element - llama_grammar_element_p_p = ( - [] - ) # type: List[List[llama_cpp.llama_grammar_element]] - for subvector in llama_grammar_elements: - llama_grammar_element_p_p.append([]) - for elem in subvector: - c_llama_grammar_element = llama_cpp.llama_grammar_element() - c_llama_grammar_element.type = c_int(elem.type.value) - c_llama_grammar_element.value = c_uint32(elem.value) - llama_grammar_element_p_p[-1].append(c_llama_grammar_element) - - # Step 2: Convert each list to llama_grammar_element array and get pointer - element_arrays = [ - (llama_cpp.llama_grammar_element * len(sublist))(*sublist) - for sublist in llama_grammar_element_p_p - ] # type: List[Array[llama_cpp.llama_grammar_element]] - - # Step 3: Get pointer of each array - element_array_pointers = [ - cast(sublist, llama_cpp.llama_grammar_element_p) - for sublist in element_arrays - ] # type: List[llama_cpp.llama_grammar_element_p] - - # Step 4: Make array of these pointers and get its pointer - return (llama_cpp.llama_grammar_element_p * len(element_array_pointers))( - *element_array_pointers - ) - +# def convert_to_rules( +# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]], +# ) -> Array[llama_grammar_element_p]: +# """Make an Array object that is used for `llama_grammer_init`""" -def parse_grammar_init_args( - bnf: str, -) -> Tuple[Array[llama_cpp.llama_grammar_element_p], c_size_t, c_size_t]: - """Parse a GBNF string and return tuple of `grammar rules` and `root symbol id`""" - parsed_grammar = parse(const_char_p(bnf)) # type: parse_state - if parsed_grammar.rules.empty(): - raise Exception( - f"{parse_grammar_init_args.__name__}: error parsing grammar file: parsed_grammar.rules is empty" - ) - print(f"{parse_grammar_init_args.__name__} grammar:", file=sys.stderr) - print_grammar(sys.stdout, parsed_grammar) - print(file=sys.stderr) - grammar_rules = ( - parsed_grammar.c_rules() - ) # type: std__vector[std__vector[llama_grammar_element]] - return ( - convert_to_rules(grammar_rules), - c_size_t(grammar_rules.size()), - c_size_t(parsed_grammar.symbol_ids.at("root")), - ) +# # Step 1: Convert each list to llama_grammar_element array and get pointer +# element_arrays = [ +# (llama_grammar_element * len(subvector))(*subvector) +# for subvector in llama_grammar_elements +# ] # type: List[Array[llama_grammar_element]] +# # Step 2: Get pointer of each array +# element_array_pointers = [ +# cast(subarray, llama_grammar_element_p) for subarray in element_arrays +# ] # type: List[llama_grammar_element_p] -def parse_grammar_init_args_from_file( - bnf_path: Union[str, Path] -) -> Tuple[Array[llama_cpp.llama_grammar_element_p], c_size_t, c_size_t]: - """Parse a GBNF file and return tuple of `grammar rules` and `root symbol id`""" - try: - with open(bnf_path) as f: - params_grammer = f.read() - except Exception as err: - raise Exception( - f"{parse_grammar_init_args_from_file.__name__}: error reading grammar file: {err}" - ) - - if params_grammer: - return parse_grammar_init_args(params_grammer) - - raise Exception( - f"{parse_grammar_init_args_from_file.__name__}: error parsing grammar file: params_grammer is empty" - ) - - -# def get_grammar_p(bnf: str) -> llama_cpp.llama_grammar_p: -# """Parse a GBNF string and return pointer to `llama_grammar`""" - -# grammar_rules, root_symbol_id = parse_rules(bnf) - -# grammar_element_p_p = convert_to_double_ptr( -# grammar_rules -# ) # type: llama_cpp.llama_grammar_element_p_p - -# c_llama_grammar_p = llama_cpp.llama_grammar_init( -# grammar_element_p_p, -# c_size_t(grammar_rules.size()), -# c_size_t(root_symbol_id), -# ) # type: llama_cpp.llama_grammar_p -# return c_llama_grammar_p - - -# def get_grammar_p_from_file( -# bnf_path: Union[str, Path] -# ) -> llama_cpp.llama_grammar_p: -# """Parse a GBNF file and return pointer to `llama_grammar`""" -# try: -# with open(bnf_path) as f: -# params_grammer = f.read() -# except Exception as err: -# raise Exception( -# f"{get_grammar_p_from_file.__name__}: error reading grammar file: {err}" -# ) - -# if params_grammer: -# return get_grammar_p(params_grammer) - -# raise Exception( -# f"{get_grammar_p_from_file.__name__}: error parsing grammar file: params_grammer is empty" +# # Step 3: Make array of these pointers and get its pointer +# return (llama_grammar_element_p * len(element_array_pointers))( +# *element_array_pointers # ) @@ -1282,14 +1254,8 @@ def parse_grammar_init_args_from_file( ) args = parser.parse_args() - rules, n_rules, start_rule_index = parse_grammar_init_args_from_file( - args.grammar - ) - llama_grammar_p = llama_cpp.llama_grammar_init( - rules, - n_rules, - start_rule_index, - ) # type: llama_cpp.llama_grammar_p + llama_grammar = LlamaGrammar.from_file(Path(args.grammar)) + llama_grammar_ptr = llama_grammar.init_grammar() # ----- USAGE: # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p) From b07713cb9f26d97e0d5f740e722e60ad0d7e9ddf Mon Sep 17 00:00:00 2001 From: c0sogi Date: Mon, 7 Aug 2023 15:16:25 +0900 Subject: [PATCH 3/4] reset grammar for every generation --- llama_cpp/llama.py | 9 ++- llama_cpp/llama_grammar.py | 125 +++++++++++-------------------------- 2 files changed, 39 insertions(+), 95 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ab99ee5bd..9328c5b6f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -364,7 +364,7 @@ def __init__( ) if grammar is not None: self.grammar = LlamaGrammar.from_file( - grammar + grammar, verbose=verbose ) # type: Optional[LlamaGrammar] else: self.grammar = None @@ -723,7 +723,6 @@ def generate( The generated tokens. """ assert self.ctx is not None - if reset and len(self._input_ids) > 0: longest_prefix = 0 for a, b in zip(self._input_ids, tokens[:-1]): @@ -741,6 +740,9 @@ def generate( if reset: self.reset() + if self.grammar is not None: + self.grammar.reset() + while True: self.eval(tokens) token = self.sample( @@ -1534,9 +1536,6 @@ def __del__(self): if self.ctx is not None: llama_cpp.llama_free(self.ctx) self.ctx = None - if self.grammar is not None: - llama_cpp.llama_grammar_free(self.grammar.grammar) - self.grammar = None def __getstate__(self): return dict( diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 06b2b7ff2..538867613 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,6 +1,5 @@ """C++ implementation of the llama grammar parser.""" # flake8: noqa -import argparse from pathlib import Path import sys from ctypes import * # type: ignore @@ -19,7 +18,7 @@ overload, ) -import llama_cpp +from . import llama_cpp # Type aliases llama_grammar_element = llama_cpp.llama_grammar_element @@ -41,11 +40,19 @@ class Sentinel: class LlamaGrammar: """Keeps reference counts of all the arguments, so that they are not garbage collected by Python.""" + + def __del__(self) -> None: + """Free the grammar pointer when the object is deleted.""" + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.grammar = None def __init__( self, parsed_grammar: "parse_state", ) -> None: + """Initialize the grammar pointer from the parsed state.""" + self.parsed_grammar = parsed_grammar grammar_rules = ( parsed_grammar.c_rules() ) # type: std.vector[std.vector[llama_grammar_element]] @@ -69,22 +76,25 @@ def __init__( self.n_rules = c_size_t(grammar_rules.size()) self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) - self.grammar = self.init_grammar() + self._grammar = llama_cpp.llama_grammar_init( + self.rules, self.n_rules, self.start_rule_index + ) @classmethod - def from_string(cls, grammar: str) -> "LlamaGrammar": + def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": parsed_grammar = parse(const_char_p(grammar)) # type: parse_state if parsed_grammar.rules.empty(): raise ValueError( f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" ) - print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) - print_grammar(sys.stdout, parsed_grammar) - print(file=sys.stderr) + if verbose: + print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) + print_grammar(sys.stdout, parsed_grammar) + print(file=sys.stderr) return cls(parsed_grammar) @classmethod - def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar": + def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": try: with open(file) as f: grammar = f.read() @@ -94,14 +104,27 @@ def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar": ) if grammar: - return cls.from_string(grammar) + return cls.from_string(grammar, verbose=verbose) raise ValueError( f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" ) - def init_grammar(self) -> llama_grammar_p: - return llama_cpp.llama_grammar_init( + @property + def grammar(self) -> llama_grammar_p: + if self._grammar is None: + raise ValueError( + f"{self.__class__.__name__}.grammar: grammar is freed" + ) + return self._grammar + + @grammar.setter + def grammar(self, value: Optional[llama_grammar_p]) -> None: + self._grammar = value + + def reset(self) -> None: + llama_cpp.llama_grammar_free(self.grammar) + self.grammar = llama_cpp.llama_grammar_init( self.rules, self.n_rules, self.start_rule_index ) @@ -1216,82 +1239,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None: print( f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, - ) - - -# def convert_to_rules( -# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]], -# ) -> Array[llama_grammar_element_p]: -# """Make an Array object that is used for `llama_grammer_init`""" - -# # Step 1: Convert each list to llama_grammar_element array and get pointer -# element_arrays = [ -# (llama_grammar_element * len(subvector))(*subvector) -# for subvector in llama_grammar_elements -# ] # type: List[Array[llama_grammar_element]] - -# # Step 2: Get pointer of each array -# element_array_pointers = [ -# cast(subarray, llama_grammar_element_p) for subarray in element_arrays -# ] # type: List[llama_grammar_element_p] - -# # Step 3: Make array of these pointers and get its pointer -# return (llama_grammar_element_p * len(element_array_pointers))( -# *element_array_pointers -# ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate C++ parser from GBNF grammar" - ) - parser.add_argument( - "-g", - "--grammar", - type=str, - default="./vendor/llama.cpp/grammars/json.gbnf", - help="path to GBNF grammar file", - ) - - args = parser.parse_args() - llama_grammar = LlamaGrammar.from_file(Path(args.grammar)) - llama_grammar_ptr = llama_grammar.init_grammar() - - # ----- USAGE: - # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p) - # llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...) - - # ----- SAMPLE OUTPUT: - # main grammar: - # root ::= object - # object ::= [{] ws object_11 [}] ws - # value ::= object | array | string | number | value_6 ws - # array ::= [[] ws array_15 []] ws - # string ::= ["] string_18 ["] ws - # number ::= number_19 number_25 number_29 ws - # value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l] - # ws ::= ws_31 - # object_8 ::= string [:] ws value object_10 - # object_9 ::= [,] ws string [:] ws value - # object_10 ::= object_9 object_10 | - # object_11 ::= object_8 | - # array_12 ::= value array_14 - # array_13 ::= [,] ws value - # array_14 ::= array_13 array_14 | - # array_15 ::= array_12 | - # string_16 ::= [^"\] | [\] string_17 - # string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] - # string_18 ::= string_16 string_18 | - # number_19 ::= number_20 number_21 - # number_20 ::= [-] | - # number_21 ::= [0-9] | [1-9] number_22 - # number_22 ::= [0-9] number_22 | - # number_23 ::= [.] number_24 - # number_24 ::= [0-9] number_24 | [0-9] - # number_25 ::= number_23 | - # number_26 ::= [eE] number_27 number_28 - # number_27 ::= [-+] | - # number_28 ::= [0-9] number_28 | [0-9] - # number_29 ::= number_26 | - # ws_30 ::= [ ] ws - # ws_31 ::= ws_30 | + ) \ No newline at end of file From 0d7d2031a9401a483293d9e91749ee64a9f64d54 Mon Sep 17 00:00:00 2001 From: c0sogi Date: Mon, 7 Aug 2023 17:02:33 +0900 Subject: [PATCH 4/4] prevent memory access error by llama_grammar_free --- llama_cpp/llama_grammar.py | 251 ++++++++++++++----------------------- 1 file changed, 97 insertions(+), 154 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 538867613..f35f9fa4b 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -40,7 +40,7 @@ class Sentinel: class LlamaGrammar: """Keeps reference counts of all the arguments, so that they are not garbage collected by Python.""" - + def __del__(self) -> None: """Free the grammar pointer when the object is deleted.""" if self.grammar is not None: @@ -52,33 +52,12 @@ def __init__( parsed_grammar: "parse_state", ) -> None: """Initialize the grammar pointer from the parsed state.""" - self.parsed_grammar = parsed_grammar - grammar_rules = ( + self._grammar_rules = ( parsed_grammar.c_rules() - ) # type: std.vector[std.vector[llama_grammar_element]] - - # Step 1: Convert each list to llama_grammar_element array and get pointer - self.element_arrays = [ - (llama_grammar_element * len(sublist))(*sublist) - for sublist in grammar_rules - ] # type: List[Array[llama_grammar_element]] - - # Step 2: Get pointer of each array - self.element_array_pointers = [ - cast(subarray, llama_grammar_element_p) - for subarray in self.element_arrays - ] # type: List[llama_grammar_element_p] - - # Step 3: Make array of these pointers and get its pointer - self.rules = ( - llama_grammar_element_p * len(self.element_array_pointers) - )(*self.element_array_pointers) - - self.n_rules = c_size_t(grammar_rules.size()) - self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) - self._grammar = llama_cpp.llama_grammar_init( - self.rules, self.n_rules, self.start_rule_index - ) + ) # type: std.vector[std.vector[LlamaGrammarElement]] + self._n_rules = self._grammar_rules.size() # type: int + self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int + self.grammar = self.init() @classmethod def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": @@ -110,24 +89,46 @@ def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGramma f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" ) - @property - def grammar(self) -> llama_grammar_p: - if self._grammar is None: - raise ValueError( - f"{self.__class__.__name__}.grammar: grammar is freed" - ) - return self._grammar - - @grammar.setter - def grammar(self, value: Optional[llama_grammar_p]) -> None: - self._grammar = value + def init(self) -> None: + # Step 1: Convert LlamaGrammarElement to llama_grammar_element + self._element_lists = [ + [ + llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value)) + for elem in subvector + ] + for subvector in self._grammar_rules + ] # type: List[List[llama_grammar_element]] + + # Step 2: Convert each list to llama_grammar_element array and get pointer + self._element_arrays = [ + (llama_grammar_element * len(sublist))(*sublist) + for sublist in self._element_lists + ] # type: List[Array[llama_grammar_element]] - def reset(self) -> None: - llama_cpp.llama_grammar_free(self.grammar) + # Step 3: Get pointer of each array + self._element_array_pointers = [ + cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays + ] # type: List[llama_grammar_element_p] + + # Step 4: Make array of these pointers and get its pointer + self._rules = (llama_grammar_element_p * len(self._element_array_pointers))( + *self._element_array_pointers + ) self.grammar = llama_cpp.llama_grammar_init( - self.rules, self.n_rules, self.start_rule_index + self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index) ) + def reset(self) -> None: + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.init() + + +class LlamaGrammarElement: + def __init__(self, type: "llama_gretype", value: int): + self.type = type + self.value = value # Unicode code point or rule ID + class const_char_p: """C++ implementation of const char *.""" @@ -182,21 +183,15 @@ def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: ) def __eq__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos == other.pos def __lt__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos < other.pos def __gt__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos > other.pos @@ -220,9 +215,7 @@ def __init__(self, vector: "std.vector[T]", index: int): def _check_version(self): if self._version != self._vector._version: - raise RuntimeError( - "Iterator used after vector was modified." - ) + raise RuntimeError("Iterator used after vector was modified.") def __iter__(self): return self @@ -280,16 +273,12 @@ def resize( ) -> None: if new_size > self.size(): if fill_value_factory is None: - raise ValueError( - "A fill value factory function must be provided." - ) + raise ValueError("A fill value factory function must be provided.") self.reserve(new_size, fill_value_factory) elif new_size < self.size(): self[:] = self[:new_size] - def reserve( - self, capacity: int, fill_value_factory: Callable[[], T] - ) -> None: + def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None: if capacity > self.size(): fill_value = fill_value_factory() self.extend([fill_value] * (capacity - self.size())) @@ -401,9 +390,7 @@ def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]": for k in keys: if k >= key: return self.iterator(self, k) - raise ValueError( - "No key found that is not less than the input key" - ) + raise ValueError("No key found that is not less than the input key") except TypeError: raise TypeError("Keys of type T cannot be sorted.") @@ -460,9 +447,7 @@ class llama_gretype(Enum): class parse_state: def __init__(self): self.symbol_ids: std.map[str, int] = std.map() - self.rules: std.vector[ - std.vector[llama_grammar_element] - ] = std.vector() + self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector() # std::vector parse_state::c_rules() { # std::vector ret; @@ -471,16 +456,16 @@ def __init__(self): # } # return ret; # } - def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]: - ret = ( - std.vector() - ) # type: std.vector[std.vector[llama_grammar_element]] + def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]: + ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]] for rule in self.rules: ret.push_back(rule.data()) return ret def __repr__(self) -> str: - return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + return ( + f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + ) # struct llama_grammar { @@ -531,12 +516,12 @@ def generate_symbol_id(state: parse_state, base_name: str) -> int: def add_rule( state: parse_state, rule_id: int, - rule: std.vector[llama_grammar_element], + rule: std.vector[LlamaGrammarElement], ) -> None: if state.rules.size() <= rule_id: state.rules.resize( rule_id + 1, - fill_value_factory=std.vector[llama_grammar_element], + fill_value_factory=std.vector[LlamaGrammarElement], ) state.rules[rule_id] = rule @@ -575,9 +560,7 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: # return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); # } def is_word_char(c: str) -> bool: - return ( - ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") - ) + return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") # std::pair parse_hex(const char * src, int size) { @@ -619,9 +602,7 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: break pos += 1 if pos != end: - raise RuntimeError( - "expecting " + str(size) + " hex chars at " + str(src) - ) + raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src)) return (value, pos) @@ -707,9 +688,7 @@ def parse_name(src: const_char_p) -> const_char_p: # } def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: pos = const_char_p(src) # type: const_char_p - while pos[0] in (" ", "\t", "#") or ( - newline_ok and pos[0] in ("\r", "\n") - ): + while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")): if pos[0] == "#": while pos[0] is not None and pos[0] not in ("\r", "\n"): pos += 1 @@ -728,7 +707,7 @@ def parse_sequence( state: parse_state, src: const_char_p, rule_name: str, - out_elements: std.vector[llama_grammar_element], + out_elements: std.vector[LlamaGrammarElement], is_nested: bool, ) -> const_char_p: # size_t last_sym_start = out_elements.size(); @@ -753,9 +732,7 @@ def parse_sequence( char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0] - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0]) ) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '[') { // char range(s) @@ -763,9 +740,7 @@ def parse_sequence( # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; elif pos[0] == "[": # char range(s) pos += 1 - start_type = ( - llama_gretype.LLAMA_GRETYPE_CHAR - ) # type: llama_gretype + start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype # if (*pos == '^') { # pos++; # start_type = LLAMA_GRETYPE_CHAR_NOT; @@ -790,9 +765,7 @@ def parse_sequence( if last_sym_start < out_elements.size() else start_type ) # type: llama_gretype - out_elements.push_back( - llama_grammar_element(type.value, char_pair[0]) - ) + out_elements.push_back(LlamaGrammarElement(type, char_pair[0])) # if (pos[0] == '-' && pos[1] != ']') { # auto endchar_pair = parse_char(pos + 1); # pos = endchar_pair.second; @@ -800,13 +773,11 @@ def parse_sequence( # } # } if pos[0] == "-" and pos[1] != "]": - endchar_pair = parse_char( - pos + 1 - ) # type: Tuple[int, const_char_p] + endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p] pos = endchar_pair[1] out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, + LlamaGrammarElement( + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair[0], ) ) @@ -820,15 +791,11 @@ def parse_sequence( # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); elif is_word_char(pos[0]): # rule reference name_end = parse_name(pos) # type: const_char_p - ref_rule_id = get_symbol_id( - state, pos, name_end - pos - ) # type: int + ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int pos = parse_space(name_end, is_nested) last_sym_start = out_elements.size() out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id) ) # } else if (*pos == '(') { // grouping # // parse nested alternates into synthesized rule @@ -850,9 +817,7 @@ def parse_sequence( last_sym_start = out_elements.size() # output reference to synthesized rule out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) ) if pos[0] != ")": raise RuntimeError("expecting ')' at " + str(pos)) @@ -863,9 +828,7 @@ def parse_sequence( # } elif pos[0] in ("*", "+", "?"): # repetition operator if last_sym_start == out_elements.size(): - raise RuntimeError( - "expecting preceding item to */+/? at " + str(pos) - ) + raise RuntimeError("expecting preceding item to */+/? at " + str(pos)) # // apply transformation to previous symbol (last_sym_start to end) according to # // rewrite rules: # // S* --> S' ::= S S' | @@ -878,8 +841,8 @@ def parse_sequence( # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); sub_rule_id = generate_symbol_id(state, rule_name) # type: int sub_rule = std.vector[ - llama_grammar_element - ]() # type: std.vector[llama_grammar_element] + LlamaGrammarElement + ]() # type: std.vector[LlamaGrammarElement] sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, @@ -893,13 +856,11 @@ def parse_sequence( # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); if pos[0] in ("*", "+"): sub_rule.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id + LlamaGrammarElement( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id ) ) - sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) - ) + sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) # if (*pos == '+') { # // add preceding symbol as alternate only for '+' (otherwise empty) # sub_rule.insert( @@ -918,16 +879,12 @@ def parse_sequence( out_elements.begin() + last_sym_start, out_elements.end(), ) - sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) - ) + sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, sub_rule_id, sub_rule) # in original rule, replace previous symbol with reference to generated rule out_elements.resize(last_sym_start) out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) ) pos = parse_space(pos + 1, is_nested) # } else { @@ -965,19 +922,13 @@ def parse_alternates( rule_id: int, is_nested: bool, ) -> const_char_p: - rule = std.vector() # type: std.vector[llama_grammar_element] - pos = parse_sequence( - state, src, rule_name, rule, is_nested - ) # type: const_char_p + rule = std.vector() # type: std.vector[LlamaGrammarElement] + pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p while pos[0] == "|": - rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) - ) + rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) pos = parse_space(pos + 1, True) pos = parse_sequence(state, pos, rule_name, rule, is_nested) - rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) - ) + rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, rule_id, rule) return pos @@ -1017,9 +968,7 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: raise RuntimeError("expecting ::= at " + str(pos)) pos = parse_space(pos + 3, True) # type: const_char_p - pos = parse_alternates( - state, pos, name, rule_id, False - ) # type: const_char_p + pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p if pos[0] == "\r": pos += 2 if pos[1] == "\n" else 1 @@ -1080,7 +1029,7 @@ def print_grammar_char(file: TextIO, c: int) -> None: # default: return false; # } # } -def is_char_element(elem: llama_grammar_element) -> bool: +def is_char_element(elem: LlamaGrammarElement) -> bool: return elem.type in ( llama_gretype.LLAMA_GRETYPE_CHAR.value, llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, @@ -1097,7 +1046,7 @@ def is_char_element(elem: llama_grammar_element) -> bool: def print_rule( file: TextIO, rule_id: int, - rule: std.vector[llama_grammar_element], + rule: std.vector[LlamaGrammarElement], symbol_id_names: std.map[int, str], ) -> None: # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { @@ -1105,13 +1054,9 @@ def print_rule( # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # } # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - if ( - rule.empty() - or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value - ): + if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " - + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1154,22 +1099,20 @@ def print_rule( # break; # } for i, elem in enumerate(rule[:-1]): - case = elem.type # type: int - if case == llama_gretype.LLAMA_GRETYPE_END.value: - raise RuntimeError( - "unexpected end of rule: " + str(rule_id) + "," + str(i) - ) - elif case == llama_gretype.LLAMA_GRETYPE_ALT.value: + case = elem.type # type: llama_gretype + if case is llama_gretype.LLAMA_GRETYPE_END.value: + raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) + elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") - elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value: + elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") - elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR: print("[", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT: print("[^", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " @@ -1179,7 +1122,7 @@ def print_rule( ) print("-", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_ALT without preceding char: " @@ -1239,4 +1182,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None: print( f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, - ) \ No newline at end of file + )