From 9d3aa13cb18d15a1af13e6797d718ad4aa27d97e Mon Sep 17 00:00:00 2001 From: Chen Junyi Date: Tue, 7 Jan 2025 19:13:45 +0800 Subject: [PATCH 01/14] add xgrammar --- lightllm/server/api_cli.py | 7 +- .../server/router/model_infer/infer_batch.py | 16 +- .../model_infer/mode_backend/__init__.py | 1 + .../continues_batch/impl_for_xgrammar_mode.py | 140 ++++++++++++++++++ .../server/router/model_infer/model_rpc.py | 8 +- lightllm/server/router/req_queue/__init__.py | 3 +- lightllm/server/sampling_params.py | 3 + test/test_xgrammar_constraint.py | 78 ++++++++++ 8 files changed, 250 insertions(+), 6 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py create mode 100644 test/test_xgrammar_constraint.py diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 544f2e8cc..3d75e255b 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -149,7 +149,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--beam_mode", action="store_true", help="use beamsearch mode") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") - parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") + # parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") + parser.add_argument("--output_constraint_mode", type=str, + choices=["outlines", "xgrammar", "none"], + default="none", + help="set the output constraint backend, none means no output constraint",) + parser.add_argument("--guided_grammar", type=str, default=None, help="output constraint mode config") parser.add_argument( "--first_token_constraint_mode", action="store_true", diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index c94678a7a..280142bb4 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -40,6 +40,7 @@ def __init__( stop_sequences: List[List[int]] = [], input_penalty: bool = False, regular_constraint: Optional[str] = None, + guided_grammar: Optional[str] = None, allowed_token_ids: Optional[List[int]] = None, move_kv_to_decode_node: Optional[bool] = None, ) -> None: @@ -59,11 +60,20 @@ def __init__( if self.top_k == -1: self.top_k = vocab_size self.input_penalty = input_penalty - # output constraint states + + # constraint states self.regular_constraint = regular_constraint - self.regex_guide = None + self.guided_grammar = guided_grammar self.fsm_current_state: int = 0 self.allowed_token_ids = allowed_token_ids + + # Outlines constraint states + self.regex_guide = None + + # Xgrammar constraint states + self.xgrammar_compiled_grammar = None + self.xgrammar_matcher = None + # p d mode use params self.move_kv_to_decode_node = move_kv_to_decode_node # this check is not very good to placed here. to do... @@ -74,7 +84,7 @@ def __init__( return def has_constraint_setting(self) -> bool: - return self.regular_constraint is not None or self.allowed_token_ids is not None + return self.regular_constraint is not None or self.allowed_token_ids is not None or self.guided_grammar is not None class InferReq: diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 9cee2b95b..3b4a91f06 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -10,3 +10,4 @@ from .dp_backend.impl import DPBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode +from .continues_batch.impl_for_xgrammar_mode import XgrammarBackend diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py new file mode 100644 index 000000000..c9b99550e --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py @@ -0,0 +1,140 @@ +import os +import shutil +import torch +import xgrammar as xgr + +from .impl import ContinuesBatchBackend +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.io_struct import FinishStatus +from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams +from .pre_process import prepare_prefill_inputs, prepare_decode_inputs +from .post_process import sample +from lightllm.server.tokenizer import get_tokenizer +from typing import List +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class XgrammarBackend(ContinuesBatchBackend): + def __init__(self) -> None: + super().__init__() + + def init_custom(self): + self.tokenizer = get_tokenizer(self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code) + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(self.tokenizer) + self.xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + self.xgrammar_token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) + + eos_token_ids = [] + eos_token_ids.append(self.tokenizer.eos_token_id) + eos_token_ids.extend(self.args.eos_id) + self.tokenizer.eos_token_ids = eos_token_ids + logger.info(f"eos_ids {self.tokenizer.eos_token_ids}") + return + + @calculate_time(show=False, min_cost_ms=300) + def prefill_batch(self, batch_id): + output_dict = {} + batch: InferBatch = self.cache.pop(batch_id) + kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.model.mem_manager) + run_reqs: List[InferReq] = run_reqs + + logics = self.model.forward(**kwargs) + + mask = torch.ones_like(logics, dtype=torch.bool) + for i, run_obj in enumerate(run_reqs): + run_obj: InferReq = run_obj + sample_params = run_obj.sampling_param + if sample_params.guided_grammar is not None: + xgrammar_compiled_grammar = self.xgrammar_compiler.compile_grammar(sample_params.guided_grammar) + sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) + self._mask_req_out_token(i, run_obj, mask, logics[i]) + + # fix the logics with -inf to a large negative value + logics[logics == float("-inf")] = -1000000.0 + logics[mask] = -1000000.0 + + next_token_ids, next_token_probs = sample(logics, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + req_obj.update_finish_status(self.eos_id) + + self._handle_req_ans(req_obj, next_token_id, next_token_logprob, output_dict) + + self.cache[batch.batch_id] = batch + return output_dict + + @calculate_time(show=True, min_cost_ms=200) + def decode_batch(self, batch_id): + output_dict = {} + batch: InferBatch = self.cache.pop(batch_id) + kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache) + run_reqs: List[InferReq] = run_reqs + + logits = self.model.forward(**kwargs) + + all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs]) + if not all_has_no_constraint: + mask = torch.ones_like(logits, dtype=torch.bool) + for i, run_obj in enumerate(run_reqs): + self._mask_req_out_token(i, run_obj, mask, logits[i]) + logits[mask] = -1000000.0 + + logits[logits == float("-inf")] = -1000000.0 + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + req_obj: InferReq = req_obj + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + req_obj.update_finish_status(self.eos_id) + + self._handle_req_ans(req_obj, next_token_id, next_token_logprob, output_dict) + + self.cache[batch.batch_id] = batch + return output_dict + + def _handle_req_ans(self, req_obj: InferReq, next_token_id, next_token_logprob, output_dict): + next_token_id = int(next_token_id) + if req_obj.sampling_param.guided_grammar is not None: + sample_params = req_obj.sampling_param + if sample_params.xgrammar_matcher.is_terminated(): + req_obj.finish_status = FinishStatus.FINISHED_STOP + else: + assert(sample_params.xgrammar_matcher.accept_token(next_token_id)) + + metadata = { + "id": next_token_id, + "logprob": float(next_token_logprob), + } + output_dict[req_obj.r_id] = ( + req_obj.req_status, + req_obj.cur_kv_len, + req_obj.get_output_len(), + [(next_token_id, metadata)], + req_obj.finish_status.value, + None, + ) + return + + def _mask_req_out_token(self, i, run_obj: InferReq, mask, logits): + sample_params = run_obj.sampling_param + if sample_params.guided_grammar is not None: + sample_params.xgrammar_matcher.fill_next_token_bitmask(self.xgrammar_token_bitmask) + xgr.apply_token_bitmask_inplace(logits, self.xgrammar_token_bitmask.to(logits.device)) + mask[i, :] = False + elif sample_params.allowed_token_ids is not None: + mask[i, sample_params.allowed_token_ids] = False + else: + mask[i, :] = False + return diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index d5f1a6106..0c1734c4d 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -14,6 +14,7 @@ RewardModelBackend, TokenHealingBackend, SimpleConstraintBackend, + XgrammarBackend, FirstTokenConstraintBackend, ContinuesBatchBackendForPrefillNode, ContinuesBatchBackendForDecodeNode, @@ -46,11 +47,14 @@ def exposed_init_model(self, kvargs): is_token_healing = kvargs.get("is_token_healing", False) is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False) if kvargs.get("args", None) is not None: - is_simple_constraint_mode = kvargs.get("args", None).simple_constraint_mode + is_simple_constraint_mode = kvargs.get("args", None).output_constraint_mode == "outlines" + is_xgrammar_constraint_mode = kvargs.get("args", None).output_constraint_mode == "xgrammar" + assert not (is_simple_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" is_prefill_node = kvargs.get("args", None).run_mode == "prefill" is_decode_node = kvargs.get("args", None).run_mode == "decode" else: is_simple_constraint_mode = False + is_xgrammar_constraint_mode = False is_prefill_node = False is_decode_node = False # use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) @@ -72,6 +76,8 @@ def exposed_init_model(self, kvargs): self.backend = TokenHealingBackend() elif is_simple_constraint_mode: self.backend = SimpleConstraintBackend() + elif is_xgrammar_constraint_mode: + self.backend = XgrammarBackend() elif is_first_token_constraint_mode: self.backend = FirstTokenConstraintBackend() elif kvargs.get("dp_size", 1) > 1: diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index afb690c67..2f20dd54f 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -17,7 +17,8 @@ def build_req_queue(args, router, dp_size: int): queue_class = BeamContinuesBatchQueue if args.token_healing_mode: queue_class = ContinuesBatchQueue - if args.simple_constraint_mode: + # if args.simple_constraint_mode: + if args.output_constraint_mode != "none": queue_class = ContinuesBatchQueue if args.first_token_constraint_mode: queue_class = ContinuesBatchQueue diff --git a/lightllm/server/sampling_params.py b/lightllm/server/sampling_params.py index af8e6215f..7c84f4d5a 100644 --- a/lightllm/server/sampling_params.py +++ b/lightllm/server/sampling_params.py @@ -42,6 +42,7 @@ def __init__( # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty input_penalty: bool = DEFAULT_INPUT_PENALTY, regular_constraint: Optional[str] = None, # Regular expressions constrain the output. + guided_grammar: Optional[str] = None, # EBNF constrain the output. # If provided, the engine will construct a logits, # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--simple_constraint_mode" started server. @@ -76,6 +77,7 @@ def __init__( self.add_spaces_between_special_tokens = add_spaces_between_special_tokens self.print_eos_token = print_eos_token self.regular_constraint = regular_constraint + self.guided_grammar = guided_grammar self.allowed_token_ids = allowed_token_ids self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node @@ -251,6 +253,7 @@ def to_dict(self): ret["best_of"] = self.best_of ret["input_penalty"] = self.input_penalty ret["regular_constraint"] = self.regular_constraint + ret["guided_grammar"] = self.guided_grammar ret["allowed_token_ids"] = self.allowed_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node return ret diff --git a/test/test_xgrammar_constraint.py b/test/test_xgrammar_constraint.py new file mode 100644 index 000000000..2543c07c0 --- /dev/null +++ b/test/test_xgrammar_constraint.py @@ -0,0 +1,78 @@ +import time +import requests +import json +import threading + +""" +python -m lightllm.server.api_server --model_dir /Meta-Llama-3-8B-Instruct \ + --host 0.0.0.0 \ + --port 8017 \ + --tp 1 \ + --max_total_token_num 100000 \ + --simple_constraint_mode \ + --use_dynamic_prompt_cache +""" + + +class RequestThread(threading.Thread): + def __init__(self, url, headers, data): + threading.Thread.__init__(self) + self.url = url + self.headers = headers + self.data = data + + def run(self): + response = requests.post(self.url, headers=self.headers, data=json.dumps(self.data)) + if response.status_code == 200: + print(response.json()) + else: + print("Error:", response.status_code, response.text) + + +url = "http://localhost:9999/generate" +headers = {"Content-Type": "application/json"} +json_grammar_ebnf_str = r""" +root ::= basic_array | basic_object +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= (([\"] basic_string_1 [\"])) +basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 +escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" +basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" +ws ::= [ \n\t]* +""" + +for i in range(1): + data = { + "inputs": "Introduce yourself in JSON briefly.", + # 'temperature': 0.1, + "parameters": { + "do_sample": False, + "guided_grammar": json_grammar_ebnf_str, + "max_new_tokens": 200, + }, + } + thread = RequestThread(url, headers, data) + thread.start() + +time.sleep(2) + +for i in range(20): + data = { + "inputs": "12-(25+16)*7=", + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": 200, + "guided_grammar": r"""root ::= (expr "=" term)+ +expr ::= term ([-+*/] term)* +term ::= num | "(" expr ")" +num ::= [0-9]+""", + }, + } + thread = RequestThread(url, headers, data) + thread.start() \ No newline at end of file From e9739631e22b13cbe3ba7e8e80dd687126a2f179 Mon Sep 17 00:00:00 2001 From: Chen Junyi Date: Tue, 7 Jan 2025 19:15:33 +0800 Subject: [PATCH 02/14] refine --- lightllm/server/api_cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3d75e255b..18c7171f9 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -154,7 +154,6 @@ def make_argument_parser() -> argparse.ArgumentParser: choices=["outlines", "xgrammar", "none"], default="none", help="set the output constraint backend, none means no output constraint",) - parser.add_argument("--guided_grammar", type=str, default=None, help="output constraint mode config") parser.add_argument( "--first_token_constraint_mode", action="store_true", From 60e3d79f6bbd8eeb7287204a3312ce1a6ed53828 Mon Sep 17 00:00:00 2001 From: Chen Junyi Date: Tue, 7 Jan 2025 19:19:56 +0800 Subject: [PATCH 03/14] fix code style --- .../mode_backend/continues_batch/impl_for_xgrammar_mode.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py index c9b99550e..ce4b84269 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py @@ -21,7 +21,9 @@ def __init__(self) -> None: super().__init__() def init_custom(self): - self.tokenizer = get_tokenizer(self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code) + self.tokenizer = get_tokenizer( + self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code + ) tokenizer_info = xgr.TokenizerInfo.from_huggingface(self.tokenizer) self.xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) @@ -111,7 +113,7 @@ def _handle_req_ans(self, req_obj: InferReq, next_token_id, next_token_logprob, if sample_params.xgrammar_matcher.is_terminated(): req_obj.finish_status = FinishStatus.FINISHED_STOP else: - assert(sample_params.xgrammar_matcher.accept_token(next_token_id)) + assert sample_params.xgrammar_matcher.accept_token(next_token_id) metadata = { "id": next_token_id, From 283111b5c2c7d9c1e8df735b3153e0b6f7198d78 Mon Sep 17 00:00:00 2001 From: Chen Junyi Date: Tue, 7 Jan 2025 19:25:37 +0800 Subject: [PATCH 04/14] fix code style --- lightllm/server/api_cli.py | 9 ++++++--- lightllm/server/router/model_infer/infer_batch.py | 6 ++++-- lightllm/server/router/model_infer/model_rpc.py | 5 ++++- test/test_xgrammar_constraint.py | 14 ++------------ 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 18c7171f9..761af79cb 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -149,11 +149,14 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--beam_mode", action="store_true", help="use beamsearch mode") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") - # parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") - parser.add_argument("--output_constraint_mode", type=str, + + parser.add_argument( + "--output_constraint_mode", + type=str, choices=["outlines", "xgrammar", "none"], default="none", - help="set the output constraint backend, none means no output constraint",) + help="set the output constraint backend, none means no output constraint", + ) parser.add_argument( "--first_token_constraint_mode", action="store_true", diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 280142bb4..98205f104 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -64,11 +64,11 @@ def __init__( # constraint states self.regular_constraint = regular_constraint self.guided_grammar = guided_grammar - self.fsm_current_state: int = 0 self.allowed_token_ids = allowed_token_ids # Outlines constraint states self.regex_guide = None + self.fsm_current_state: int = 0 # Xgrammar constraint states self.xgrammar_compiled_grammar = None @@ -84,7 +84,9 @@ def __init__( return def has_constraint_setting(self) -> bool: - return self.regular_constraint is not None or self.allowed_token_ids is not None or self.guided_grammar is not None + return ( + self.regular_constraint is not None or self.allowed_token_ids is not None or self.guided_grammar is not None + ) class InferReq: diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 0c1734c4d..a7a65cc7c 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -49,7 +49,9 @@ def exposed_init_model(self, kvargs): if kvargs.get("args", None) is not None: is_simple_constraint_mode = kvargs.get("args", None).output_constraint_mode == "outlines" is_xgrammar_constraint_mode = kvargs.get("args", None).output_constraint_mode == "xgrammar" - assert not (is_simple_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" + assert not ( + is_simple_constraint_mode and is_xgrammar_constraint_mode + ), "only one constraint mode can be true" is_prefill_node = kvargs.get("args", None).run_mode == "prefill" is_decode_node = kvargs.get("args", None).run_mode == "decode" else: @@ -77,6 +79,7 @@ def exposed_init_model(self, kvargs): elif is_simple_constraint_mode: self.backend = SimpleConstraintBackend() elif is_xgrammar_constraint_mode: + # now we prioritize simple_constraint_mode(Outlines) self.backend = XgrammarBackend() elif is_first_token_constraint_mode: self.backend = FirstTokenConstraintBackend() diff --git a/test/test_xgrammar_constraint.py b/test/test_xgrammar_constraint.py index 2543c07c0..16e5203d1 100644 --- a/test/test_xgrammar_constraint.py +++ b/test/test_xgrammar_constraint.py @@ -3,16 +3,6 @@ import json import threading -""" -python -m lightllm.server.api_server --model_dir /Meta-Llama-3-8B-Instruct \ - --host 0.0.0.0 \ - --port 8017 \ - --tp 1 \ - --max_total_token_num 100000 \ - --simple_constraint_mode \ - --use_dynamic_prompt_cache -""" - class RequestThread(threading.Thread): def __init__(self, url, headers, data): @@ -51,7 +41,7 @@ def run(self): "inputs": "Introduce yourself in JSON briefly.", # 'temperature': 0.1, "parameters": { - "do_sample": False, + "do_sample": False, "guided_grammar": json_grammar_ebnf_str, "max_new_tokens": 200, }, @@ -75,4 +65,4 @@ def run(self): }, } thread = RequestThread(url, headers, data) - thread.start() \ No newline at end of file + thread.start() From 64ff827d3ab1a3c812e98f2eb17e5c11196f23dd Mon Sep 17 00:00:00 2001 From: Chen Junyi Date: Tue, 21 Jan 2025 11:09:17 +0800 Subject: [PATCH 05/14] add json schema support --- .pre-commit-config.yaml | 2 +- .../server/router/model_infer/infer_batch.py | 10 +- .../continues_batch/impl_for_xgrammar_mode.py | 12 +- lightllm/server/sampling_params.py | 4 + test/format_out/test_xgrammar_constraint.py | 139 ++++++++++++++++++ test/test_xgrammar_constraint.py | 13 ++ 6 files changed, 173 insertions(+), 7 deletions(-) create mode 100644 test/format_out/test_xgrammar_constraint.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 678ac8b8f..446e69da7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,4 +11,4 @@ repos: hooks: - id: flake8 additional_dependencies: [flake8-typing-imports==1.9.0] - args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606'] \ No newline at end of file + args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, W191, E101'] \ No newline at end of file diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 98205f104..9d081d2e1 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -1,13 +1,14 @@ import os import copy import time +from pydantic import BaseModel import torch import torch.distributed as dist import numpy as np import collections from dataclasses import dataclass, field -from typing import List, Dict, Tuple, Optional +from typing import List, Dict, Tuple, Optional, Union from lightllm.common.req_manager import ReqManager from lightllm.common.mem_manager import MemoryManager from lightllm.utils.infer_utils import mark_start, mark_end @@ -41,6 +42,7 @@ def __init__( input_penalty: bool = False, regular_constraint: Optional[str] = None, guided_grammar: Optional[str] = None, + guided_json: Optional[Union[str, dict, BaseModel]] = None, allowed_token_ids: Optional[List[int]] = None, move_kv_to_decode_node: Optional[bool] = None, ) -> None: @@ -64,6 +66,7 @@ def __init__( # constraint states self.regular_constraint = regular_constraint self.guided_grammar = guided_grammar + self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids # Outlines constraint states @@ -85,7 +88,10 @@ def __init__( def has_constraint_setting(self) -> bool: return ( - self.regular_constraint is not None or self.allowed_token_ids is not None or self.guided_grammar is not None + self.regular_constraint is not None + or self.allowed_token_ids is not None + or self.guided_grammar is not None + or self.guided_json is not None ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py index ce4b84269..16f499ae3 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py @@ -1,3 +1,4 @@ +import json import os import shutil import torch @@ -32,8 +33,8 @@ def init_custom(self): eos_token_ids = [] eos_token_ids.append(self.tokenizer.eos_token_id) eos_token_ids.extend(self.args.eos_id) - self.tokenizer.eos_token_ids = eos_token_ids - logger.info(f"eos_ids {self.tokenizer.eos_token_ids}") + # self.tokenizer.eos_token_ids = eos_token_ids + # logger.info(f"eos_ids {self.tokenizer.eos_token_ids}") return @calculate_time(show=False, min_cost_ms=300) @@ -52,6 +53,9 @@ def prefill_batch(self, batch_id): if sample_params.guided_grammar is not None: xgrammar_compiled_grammar = self.xgrammar_compiler.compile_grammar(sample_params.guided_grammar) sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) + elif sample_params.guided_json is not None: + xgrammar_compiled_grammar = self.xgrammar_compiler.compile_json_schema(sample_params.guided_json) + sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) self._mask_req_out_token(i, run_obj, mask, logics[i]) # fix the logics with -inf to a large negative value @@ -108,7 +112,7 @@ def decode_batch(self, batch_id): def _handle_req_ans(self, req_obj: InferReq, next_token_id, next_token_logprob, output_dict): next_token_id = int(next_token_id) - if req_obj.sampling_param.guided_grammar is not None: + if req_obj.sampling_param.guided_grammar is not None or req_obj.sampling_param.guided_json is not None: sample_params = req_obj.sampling_param if sample_params.xgrammar_matcher.is_terminated(): req_obj.finish_status = FinishStatus.FINISHED_STOP @@ -131,7 +135,7 @@ def _handle_req_ans(self, req_obj: InferReq, next_token_id, next_token_logprob, def _mask_req_out_token(self, i, run_obj: InferReq, mask, logits): sample_params = run_obj.sampling_param - if sample_params.guided_grammar is not None: + if sample_params.guided_grammar is not None or sample_params.guided_json is not None: sample_params.xgrammar_matcher.fill_next_token_bitmask(self.xgrammar_token_bitmask) xgr.apply_token_bitmask_inplace(logits, self.xgrammar_token_bitmask.to(logits.device)) mask[i, :] = False diff --git a/lightllm/server/sampling_params.py b/lightllm/server/sampling_params.py index 7c84f4d5a..b102ea21a 100644 --- a/lightllm/server/sampling_params.py +++ b/lightllm/server/sampling_params.py @@ -1,6 +1,7 @@ """Sampling parameters for text generation.""" import os from typing import List, Optional, Union, Tuple +from pydantic import BaseModel from transformers import GenerationConfig from .req_id_generator import MAX_BEST_OF @@ -43,6 +44,7 @@ def __init__( input_penalty: bool = DEFAULT_INPUT_PENALTY, regular_constraint: Optional[str] = None, # Regular expressions constrain the output. guided_grammar: Optional[str] = None, # EBNF constrain the output. + guided_json: Optional[Union[str, dict, BaseModel]] = None, # JSON schema constrain the output. # If provided, the engine will construct a logits, # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--simple_constraint_mode" started server. @@ -78,6 +80,7 @@ def __init__( self.print_eos_token = print_eos_token self.regular_constraint = regular_constraint self.guided_grammar = guided_grammar + self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node @@ -254,6 +257,7 @@ def to_dict(self): ret["input_penalty"] = self.input_penalty ret["regular_constraint"] = self.regular_constraint ret["guided_grammar"] = self.guided_grammar + ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node return ret diff --git a/test/format_out/test_xgrammar_constraint.py b/test/format_out/test_xgrammar_constraint.py new file mode 100644 index 000000000..9d2ad0ef8 --- /dev/null +++ b/test/format_out/test_xgrammar_constraint.py @@ -0,0 +1,139 @@ +import time +import requests +import json +import threading +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("/mnt/nvme0/chenjunyi/models/nb10_w8/") + + +class RequestThread(threading.Thread): + def __init__(self, url, headers, data): + threading.Thread.__init__(self) + self.url = url + self.headers = headers + self.data = data + + def run(self): + response = requests.post(self.url, headers=self.headers, data=json.dumps(self.data)) + if response.status_code == 200: + print(response.json()) + else: + print("Error:", response.status_code, response.text) + + +url = "http://localhost:9999/generate" +headers = {"Content-Type": "application/json"} +json_grammar_ebnf_str = r""" +root ::= basic_array | basic_object +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= (([\"] basic_string_1 [\"])) +basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 +escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" +basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" +ws ::= [ \n\t]* +""" + +json_schema_str = r""" +{ + "type": "array", + "items": { + "type": "object", + "properties": { + "金额": { + "type": "number" + }, + "标题": { + "type": "string" + }, + "类型": { + "type": "string" + }, + "大类": { + "type": "string" + }, + "小类": { + "type": "string" + }, + "日期": { + "type": "string" + }, + "时间": { + "type": "string" + } + }, + "required": [ + "金额", + "标题", + "类型", + "大类", + "小类", + "时间" + ] + } +} +""" + +person_schema = r"""{ + "title": "Person", + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer", + } + }, + "required": ["name", "age"] +} +""" + +system_prompt = open("system.md", "r").read() +user_input = open("user.md", "r").read() + +messages = [ + { + "role": "system", + "content": system_prompt, + }, + {"role": "user", "content": user_input}, +] + +inputs = tokenizer.apply_chat_template(messages, tokenize=False) + +for i in range(1): + data = { + "inputs": inputs, + # 'temperature': 0.1, + "parameters": { + "do_sample": False, + "guided_json": json_schema_str, + "max_new_tokens": 200, + }, + } + thread = RequestThread(url, headers, data) + thread.start() + +# time.sleep(2) + +# for i in range(20): +# data = { +# "inputs": "12-(25+16)*7=", +# "parameters": { +# "do_sample": False, +# "ignore_eos": True, +# "max_new_tokens": 200, +# "guided_grammar": r"""root ::= (expr "=" term)+ +# expr ::= term ([-+*/] term)* +# term ::= num | "(" expr ")" +# num ::= [0-9]+""", +# }, +# } +# thread = RequestThread(url, headers, data) +# thread.start() diff --git a/test/test_xgrammar_constraint.py b/test/test_xgrammar_constraint.py index 16e5203d1..dd0e5d8b8 100644 --- a/test/test_xgrammar_constraint.py +++ b/test/test_xgrammar_constraint.py @@ -3,6 +3,19 @@ import json import threading +""" +python -m lightllm.server.api_server --model_dir /mnt/nvme0/chenjunyi/models/nb10_w8/ \ + --host 0.0.0.0 \ + --port 9999 \ + --tp 1 \ + --nccl_port 65535 \ + --max_req_total_len 200000 \ + --max_total_token_num 400000 \ + --data_type bf16 \ + --trust_remote_code \ + --output_constraint_mode xgrammar +""" + class RequestThread(threading.Thread): def __init__(self, url, headers, data): From ed05d5505de3f69939d989535b29eaecf143fb5e Mon Sep 17 00:00:00 2001 From: Chen Junyi Date: Tue, 21 Jan 2025 11:10:22 +0800 Subject: [PATCH 06/14] add json schema support --- test/format_out/test_constraint_server.py | 67 +++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 test/format_out/test_constraint_server.py diff --git a/test/format_out/test_constraint_server.py b/test/format_out/test_constraint_server.py new file mode 100644 index 000000000..62b622031 --- /dev/null +++ b/test/format_out/test_constraint_server.py @@ -0,0 +1,67 @@ +import time +import requests +import json +import threading + +""" +python -m lightllm.server.api_server --model_dir /Meta-Llama-3-8B-Instruct \ + --host 0.0.0.0 \ + --port 8017 \ + --tp 1 \ + --max_total_token_num 100000 \ + --simple_constraint_mode \ + --use_dynamic_prompt_cache +""" + + +class RequestThread(threading.Thread): + def __init__(self, url, headers, data): + threading.Thread.__init__(self) + self.url = url + self.headers = headers + self.data = data + + def run(self): + response = requests.post(self.url, headers=self.headers, data=json.dumps(self.data)) + if response.status_code == 200: + print(response.json()) + else: + print("Error:", response.status_code, response.text) + + +url = "http://localhost:8017/generate" +headers = {"Content-Type": "application/json"} + +for i in range(1): + data = { + "inputs": "(100+1+3)*2=", + # 'temperature': 0.1, + "parameters": {"do_sample": False, "regular_constraint": r"-?\d+"}, + } + thread = RequestThread(url, headers, data) + thread.start() + +time.sleep(2) + +for i in range(20): + data = { + "inputs": "Are dog a man? ", + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": 200, + "regular_constraint": r"(Yes|No) Reason is [a-zA-Z\s]+", + }, + } + thread = RequestThread(url, headers, data) + thread.start() + +time.sleep(10) + +for i in range(20): + data = { + "inputs": "Are dog a man? ", + "parameters": {"do_sample": False, "ignore_eos": True, "max_new_tokens": 200, "allowed_token_ids": [2, 3]}, + } + thread = RequestThread(url, headers, data) + thread.start() From 09a67c73ec3929bb4feee00eb49396239d4d48ba Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Tue, 25 Feb 2025 15:41:27 +0800 Subject: [PATCH 07/14] save --- lightllm/server/core/objs/sampling_params.py | 57 ++- .../server/router/model_infer/infer_batch.py | 7 +- .../continues_batch/impl_for_xgrammar_mode.py | 93 ++-- test/format_out/system.md | 458 ++++++++++++++++++ test/format_out/test_xgrammar_constraint.py | 7 +- test/format_out/user.md | 5 + 6 files changed, 569 insertions(+), 58 deletions(-) create mode 100644 test/format_out/system.md create mode 100644 test/format_out/user.md diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 2c041c570..86b35c365 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -12,7 +12,8 @@ ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256)) MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10)) REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) - +GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) +JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) class StopSequence(ctypes.Structure): _pack_ = 4 @@ -98,6 +99,46 @@ def to_str(self): return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00") +class GuidedGrammar(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("constraint", ctypes.c_byte * GRAMMAR_CONSTRAINT_MAX_LENGTH), + ("length", ctypes.c_int), + ] + + def initialize(self, constraint: str): + constraint_bytes = constraint.encode("utf-8") + assert len(constraint_bytes) < GRAMMAR_CONSTRAINT_MAX_LENGTH, "Guided grammar is too long." + + ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) + self.length = len(constraint_bytes) + # TODO: Later we can add a grammar parser to check the grammar + return + + def to_str(self): + return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00") + + +class GuidedJsonSchema(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("constraint", ctypes.c_byte * JSON_SCHEMA_MAX_LENGTH), + ("length", ctypes.c_int), + ] + + def initialize(self, constraint: str): + constraint_bytes = constraint.encode("utf-8") + assert len(constraint_bytes) < JSON_SCHEMA_MAX_LENGTH, "Guided json schema is too long." + + ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) + self.length = len(constraint_bytes) + # TODO: Later we can add a json schema parser to check the schema + return + + def to_str(self): + return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00") + + class AllowedTokenIds(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -191,6 +232,8 @@ class SamplingParams(ctypes.Structure): # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty ("input_penalty", ctypes.c_bool), ("regular_constraint", RegularConstraint), + ("guided_grammar", GuidedGrammar), + ("guided_json", GuidedJsonSchema), # If provided, the engine will construct a logits, # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--simple_constraint_mode" started server. @@ -251,6 +294,16 @@ def init(self, tokenizer, **kwargs): self.regular_constraint = RegularConstraint() self.regular_constraint.initialize(regular_constraint) + # Initialize guided_grammar + guided_grammar = kwargs.get("guided_grammar", "") + self.guided_grammar = GuidedGrammar() + self.guided_grammar.initialize(guided_grammar) + + # Initialize guided_json + guided_json = kwargs.get("guided_json", "") + self.guided_json = GuidedJsonSchema() + self.guided_json.initialize(guided_json) + # Initialize stop_sequence_groups stop_sequences = kwargs.get("stop_sequences", []) self.stop_sequences = StopSequenceGroups() @@ -342,6 +395,8 @@ def to_dict(self): "best_of": self.best_of, "input_penalty": self.input_penalty, "regular_constraint": self.regular_constraint.to_str(), + "guided_grammar": self.guided_grammar.to_str(), + "guided_json": self.guided_json.to_str(), "allowed_token_ids": self.allowed_token_ids.to_list(), "group_request_id": self.group_request_id, "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 2b119b180..1bafe217d 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -195,10 +195,15 @@ def __init__( # output constraint states self.regular_constraint = self.shm_param.regular_constraint.to_str() + self.guided_grammar = self.shm_param.guided_grammar.to_str() + self.guided_json = self.shm_param.guided_json.to_str() if len(self.regular_constraint) == 0: self.regular_constraint = None + if len(self.guided_grammar) == 0: + self.guided_grammar = None + if len(self.guided_json) == 0: + self.guided_json = None - self.regex_guide = None self.fsm_current_state: int = 0 self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list() if len(self.allowed_token_ids) == 0: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py index 16f499ae3..2f4b87636 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py @@ -1,17 +1,17 @@ -import json import os import shutil import torch +from typing import List, Tuple import xgrammar as xgr from .impl import ContinuesBatchBackend -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.server.io_struct import FinishStatus -from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams from .pre_process import prepare_prefill_inputs, prepare_decode_inputs from .post_process import sample + +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.core.objs import FinishStatus +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams from lightllm.server.tokenizer import get_tokenizer -from typing import List from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -38,11 +38,10 @@ def init_custom(self): return @calculate_time(show=False, min_cost_ms=300) - def prefill_batch(self, batch_id): - output_dict = {} - batch: InferBatch = self.cache.pop(batch_id) - kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.model.mem_manager) - run_reqs: List[InferReq] = run_reqs + def prefill(self, reqs: List[Tuple]): + req_ids = self._init_reqs(reqs) + kwargs, run_reqs = prepare_prefill_inputs(req_ids, is_multimodal=self.is_multimodal) + run_reqs: List[InferReq] = reqs logics = self.model.forward(**kwargs) @@ -66,22 +65,13 @@ def prefill_batch(self, batch_id): next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): - req_obj.cur_kv_len = len(req_obj.input_token_ids) - req_obj.input_token_ids.append(next_token_id) - req_obj.out_token_id_count[next_token_id] += 1 - req_obj.update_finish_status(self.eos_id) + self.post_handel(run_reqs, next_token_ids, next_token_logprobs) - self._handle_req_ans(req_obj, next_token_id, next_token_logprob, output_dict) - - self.cache[batch.batch_id] = batch - return output_dict + return @calculate_time(show=True, min_cost_ms=200) - def decode_batch(self, batch_id): - output_dict = {} - batch: InferBatch = self.cache.pop(batch_id) - kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache) + def decode(self): + kwargs, run_reqs = prepare_decode_inputs(g_infer_context.infer_req_ids) run_reqs: List[InferReq] = run_reqs logits = self.model.forward(**kwargs) @@ -98,39 +88,40 @@ def decode_batch(self, batch_id): next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self.post_handel(run_reqs, next_token_ids, next_token_logprobs) + return + + def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logprobs): + finished_req_ids = [] + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + # prefill and decode is same req_obj: InferReq = req_obj - req_obj.cur_kv_len = len(req_obj.input_token_ids) - req_obj.input_token_ids.append(next_token_id) + req_obj.cur_kv_len = req_obj.get_cur_total_len() + + req_obj.set_next_gen_token_id(next_token_id, next_token_logprob) + req_obj.cur_output_len += 1 + req_obj.out_token_id_count[next_token_id] += 1 req_obj.update_finish_status(self.eos_id) - self._handle_req_ans(req_obj, next_token_id, next_token_logprob, output_dict) - - self.cache[batch.batch_id] = batch - return output_dict - - def _handle_req_ans(self, req_obj: InferReq, next_token_id, next_token_logprob, output_dict): - next_token_id = int(next_token_id) - if req_obj.sampling_param.guided_grammar is not None or req_obj.sampling_param.guided_json is not None: - sample_params = req_obj.sampling_param - if sample_params.xgrammar_matcher.is_terminated(): - req_obj.finish_status = FinishStatus.FINISHED_STOP - else: - assert sample_params.xgrammar_matcher.accept_token(next_token_id) - - metadata = { - "id": next_token_id, - "logprob": float(next_token_logprob), - } - output_dict[req_obj.r_id] = ( - req_obj.req_status, - req_obj.cur_kv_len, - req_obj.get_output_len(), - [(next_token_id, metadata)], - req_obj.finish_status.value, - None, - ) + if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted: + finished_req_ids.append(req_obj.shm_req.request_id) + + if self.tp_rank < self.dp_size: + # shm_cur_kv_len shm_cur_output_len 是 router 调度进程需要读的信息 + # finish_token_index finish_status candetoken_out_len 是 + # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 + req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len + req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len + + if req_obj.finish_status.is_finished(): + req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 + req_obj.shm_req.finish_status = req_obj.finish_status + + req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len + + g_infer_context.filter(finished_req_ids) return def _mask_req_out_token(self, i, run_obj: InferReq, mask, logits): diff --git a/test/format_out/system.md b/test/format_out/system.md new file mode 100644 index 000000000..dfae5b96d --- /dev/null +++ b/test/format_out/system.md @@ -0,0 +1,458 @@ +你是一款记账助手的抽取器,擅长从用户的输入中提取关键信息并输出为 JSON 格式。根据用户输入内容、结合提交时间和地点信息,判断其是否为记账任务或笔记任务,并提取相关信息。 + +用户的请求中的可能包含以下信息: + +- 时间:即用户提交请求的当前时间 +- 地点:即用户当前的 GPS 定位位置 +- 内容:即用户输入的文字内容,可能包含误识别的错别字,请尽量纠正 + + +## 需要提取的参数 + +- **type**:任务类型。如果是记账则为 `accounting`,笔记为 `note`。用户输入的内容中有明确金额发生的时候才能识别为记账 `accounting`,否则为笔记`note`。 + +- **direction**:收支方向,根据用户内容中描述的行为判断。 + - `outcome`:表示支出,例如,"花了"、"花费"、"付了"、"买了" 等。 + - `income`:表示收入,例如,"赚了"、"收到了"、"得到了"、"中了" 等。 + +- **category1**:一级分类。从给定的支出分类中为支出的账单选择对应的一级分类,如:"餐饮"、"交通"、"娱乐"等;从给定的收入分类中为收入的账单选择对应的一级分类,如:"工资"、"兼职"、"生活费"等,"收入"本身不是一级分类,必须严格匹配,收入的账单选择对应的一级分类时,不能选择属于支出的一级分类。当无法找到明显的匹配项时,选择对应的 "其他"分类,禁止自行输出不存在于给定分类中的分类。 + +- **category2**:二级分类。从对应的一级分类下的二级分类中选择,必须严格匹配。当无法找到明显的匹配项时,选择对应的 "其他"分类,禁止自行输出不存在于给定分类中的分类。 + +- **amount**:金额数值。提取用户输入中的金额,支持中文数字和阿拉伯数字,若出现两者混合的情况,请准确提取、组合,结果为保留两位小数的阿拉伯数字,不需要加单位。 + +- **date**:日期,格式为 `YYYY-MM-DD`。根据用户输入的日期信息提取出date值; +用户可能会输入 "昨天"、"上周五" 等,需要结合提交日期进行计算。如果请求问题的中提到的是上周几,你需要根据当天是周几(星期几)进行推理计算,推算出过去14天里符合条件的日期;如果请求问题的中提到的是上上周几,你需要根据当天是周几(星期几)进行推理计算,推算出过去21天里符合条件的日期,依此类推; +计算结果只能往前推算出过去已结束的日期,不能推算出未来未开始的日期; +如果用户输入中未提及date信息,默认为提交请求的当前日期。 + +- **time**:时间,格式为 `HH:MM`。根据用户输入的具体或模糊时间提取,模糊时间按照下文的映射规则确定具体时间。如果用户输入中提取不到时间,则需要结合用户输入或计算出的日期判断该填 `null`还是当前时间,判断规则是:若用户输入或解析到的日期非当天,则填`null`,若用户未提及日期,则该条的日期和时间为当前提交的日期和时间。 + +- **title**:对该笔记账的简短描述,尽量不超过 10 个字。可以省略金额、时间、分类等重复信息,突出主要内容。 + + +## 处理规则 + +1. 判断任务类型:如果输入中包含金额,则为记账任务,否则为笔记任务。记账任务可能包含多条,需要根据内容进行拆分,最多拆为20条。笔记任务只有一条,不再进行拆分。 +2. 按任务类型提取对应的相关信息: + - 记账任务: + - direction, category1, category2, amount, date, time, title + - 笔记任务: + - date, time, title +3. 对于笔记任务,按用户提交时间判断日期,并提取时间信息,无需进行额外解析推断。 +4. 对于记账任务,先根据语义判断收支方向,然后据此判断分类,分类必须从以下的分类信息中选择,不能随意填写。当无法找到明显的匹配项时,选择对应的 "其他"分类,禁止自行输出不存在于给定分类中的分类。 +5. 对于记账任务的日期的解析,根据内容中的时间表达(如"昨天"、"上周五")结合当前日期计算实际日期。对于星期的表达,如"周一"、"周末"等,默认按过去七天内的日期计算。周末视为周六。每周开始按周一算,即周一到周日。 +6. 对于记账任务中具体提及的时间,提取为 `HH:MM` 格式。模糊时间可以按照以下规则映射(如果时间无法推断则 time 为 null): + - 凌晨 / 黎明 / 拂晓 / 早起:00:00-06:00(默认01:00) + - 清晨 / 早晨 / 早上 / 大清早 / 一大早 / 日出 / 早间:06:00-10:00(默认07:00) + - 上午 / 午前 / 上半天:09:00-12:00(默认10:00) + - 中午 / 正午 / 午时 / 晌午:12:00-14:00(默认13:00) + - 下午 / 午后 / 下半天:13:00-18:00(默认14:00) + - 傍晚 / 黄昏 / 日落 / 傍黑 / 薄暮:17:00-20:00(默认18:00) + - 晚上 / 夜晚:19:00-23:59(默认20:00) + - 深夜 / 半夜 / 大半夜 / 午夜 / 夜半:22:00-02:00(默认23:00) + - 夜间 / 夜里 / 黑夜:19:00-06:00(默认23:00) + - 白天:06:00-18:00(默认12:00) +7. title 标题为用户输入的内容的缩写,应简洁明了,突出主要信息,不重复金额、时间等已提取的内容。 +8. 必须按照标准正确的JSON格式输出,即一个 JSON Array,不要有 ```json 这种 Markdown 格式的 code fence,不要包含多余的标记、注释或解释性文字。 + +## 具体的分类信息 + +### 支出的一级和二级分类 + +- 餐饮 + - 中餐 + - 西餐 + - 日料 + - 韩料 + - 茶饮 + - 咖啡 + - 碳酸饮料 + - 果汁 + - 植物饮品 + - 其他饮品 + - 零食 + - 水果 + - 其他餐饮 +- 交通 + - 公交 + - 地铁 + - 共享单车 + - 共享电动车 + - 共享汽车 + - 高铁 + - 火车 + - 飞机 + - 轮渡 + - 客车 + - 油费 + - 停车费 + - 洗车费 + - 过路费 + - 维修保养费 + - 打车 + - 租车 + - 其他交通 +- 服饰 + - 上衣 + - 下装 + - 鞋类 + - 内衣 + - 配饰 + - 首饰 + - 其他服饰 +- 医疗 + - 门诊挂号费 + - 住院费 + - 药品 + - 保健 + - 护理 + - 检查 + - 手术 + - 医疗器械 + - 其他医疗 +- 美容服务 + - 护肤 + - 医美 + - 美发 + - 美甲 + - 其他美容服务 +- 赠予 + - 礼金 + - 红包 + - 捐款 + - 礼物 + - 捐物 +- 学习 + - 学费 + - 书籍 + - 文具 + - 其他学习 +- 娱乐 + - K歌 + - 演唱会 + - 话剧 + - 音乐会 + - 脱口秀 + - 桌游 + - 密室 + - 电子游戏 + - 彩票 + - 电影 + - 展览 + - 动漫展 + - 其他娱乐 +- 运动 + - 飞盘 + - 徒步 + - 羽毛球 + - 网球 + - 篮球 + - 足球 + - 其他运动 +- 住房 + - 购房费用 + - 租房费用 + - 物业费 + - 水电费 + - 燃气费 + - 家电购买或租赁费用 + - 家具购买或租赁费用 + - 家电家具维修费 + - 房屋装修费用 + - 其他住房费用 +- 日用品 + - 洗漱用品 + - 护肤品 + - 化妆品 + - 房屋清洁用品 + - 家用五金 + - 厨具用品 + - 母婴用品 + - 数码产品 + - 其他日用品 +- 旅行 + - 旅行住宿 + - 旅游活动 + - 旅行购物 + - 其他旅行 +- 宠物 + - 宠物食品 + - 宠物玩具 + - 宠物医疗 + - 购买宠物 + - 其他宠物 +- 通讯 + - 话费 + - 网费 +- 保险 + - 保险 +- 其他支出 + - 其他支出 + +### 收入的一级和二级分类 + +- 工资 + - 工资 +- 生活费 + - 生活费 +- 奖学金 + - 奖学金 +- 年终奖 + - 年终奖 +- 报销 + - 报销 +- 绩效奖金 + - 绩效奖金 +- 兼职 + - 兼职 +- 零花钱 + - 零花钱 +- 投资收入 + - 投资收入 +- 礼金红包 + - 礼金 + - 红包 +- 加班补贴 + - 加班补贴 +- 餐饮补贴 + - 餐饮补贴 +- 经营收入 + - 经营收入 +- 二手闲置 + - 二手闲置 +- 彩票 + - 彩票 +- 其他收入 + - 其他收入 + + +## 示例 + +### 输入上下文示例 + +1. 记账任务 +时间:2024-06-20 14:58 星期四 +地点:北京市 海淀区 海淀区中关村南大街 27 号 +内容:昨天吃了麦当劳,花了二十八块八 + +2. 记账任务 +时间:2024-11-05 11:09 星期二 +地点:北京市 海淀区 北四环西路 理想国际大厦 星巴克 +内容:刚喝了杯三十块的咖啡 + +3. 笔记任务 +时间:2024-05-17 21:09 星期五 +地点:北京市 丰台区 丰台区南四环西路 188 号 +内容:去园博园春游了,园博园的花开了很漂亮 + +4. 记账任务 +时间:2024-11-11 14:58 星期一 +地点:北京市 门头沟区 门头沟区滨河路 1 号 +内容:周六晚上看了一场音乐剧,票价168块,位置还不错,演出很精彩。 + +5. 记账任务 +时间:2024-10-24 17:58 星期四 +地点:北京市 朝阳区 北四环东路 10 号 +内容:公司团建,去了郊区农家乐,AA制人均200块,包含餐费和车费,活动很多,玩得挺开心的。 + +6. 记账任务 +时间:2024-11-11 12:10 星期一 +地点:北京市 海淀区 北四环西路 理想国际大厦 星巴克 +内容:婴儿推车1200元 + +7. 记账任务 +时间:2024-08-10 22:10 星期六 +地点:北京市 海淀区 北四环西路 理想国际大厦 星巴克 +内容:七夕酒吧驻唱赚了200 + +8. 多条记账任务 +时间:2024-11-12 18:12 星期二 +地点:北京市 海淀区 北四环西路 理想国际大厦 星巴克 +内容:买了一根口红花了二百三,一块眼影花了490 + +9. 多条记账任务 +时间:2024-09-02 22:58 星期一 +地点:北京市 门头沟区 门头沟区滨河路 1 号 +内容:今天早饭包子花了10块,下午奶茶花了25,夜宵点外卖花了30 + +### 输出 JSON 格式 + +1. 记账任务 + +```json +[ + { + "type": "accounting", + "category1": "餐饮", + "category2": "其他餐饮", + "direction": "outcome", + "amount": 28.8, + "date": "2024-06-19", + "time": null, + "title": "麦当劳用餐" + } +] +``` + +2. 记账任务 + +```json +[ + { + "type": "accounting", + "category1": "餐饮", + "category2": "咖啡", + "direction": "outcome", + "amount": 30, + "date": "2024-11-05", + "time": "11:09", + "title": "星巴克咖啡" + } +] +``` + +3. 笔记任务 + +```json +[ + { + "type": "note", + "date": "2024-05-17", + "time": "21:09", + "title": "园博园春游" + } +] +``` + +4. 记账任务 + +```json +[ + { + "type": "accounting", + "category1": "娱乐", + "category2": "音乐会", + "direction": "outcome", + "amount": 168, + "date": "2024-11-09", + "time": "20:00", + "title": "音乐剧观赏" + } +] +``` + +5. 记账任务 + +```json +[ + { + "type": "accounting", + "category1": "娱乐", + "category2": "其它娱乐", + "direction": "outcome", + "amount": 200, + "date": "2024-10-24", + "time": null, + "title": "公司团建农家乐" + } +] +``` + +6. 记账任务 + +```json +[ + { + "type": "accounting", + "category1": "日用品", + "category2": "母婴用品", + "direction": "outcome", + "amount": 1200, + "date": "2024-11-11", + "time": "12:10", + "title": "购买婴儿推车" + } +] +``` + +7. 记账任务 + +```json +[ + { + "type": "accounting", + "category1": "兼职", + "category2": "兼职", + "direction": "income", + "amount": 200, + "date": "2024-08-10", + "time": "22:10", + "title": "兼职酒吧驻唱" + } +] +``` + +8. 多条记账任务 + +```json +[ + { + "type": "accounting", + "category1": "日用品", + "category2": "化妆品", + "direction": "outcome", + "amount": 230, + "date": "2024-11-12", + "time": "18:12", + "title": "购买口红" + }, + { + "type": "accounting", + "category1": "日用品", + "category2": "化妆品", + "direction": "outcome", + "amount": 490, + "date": "2024-11-12", + "time": "18:12", + "title": "购买眼影" + } +] +``` + +9. 多条记账任务 + +```json +[ + { + "type": "accounting", + "category1": "餐饮", + "category2": "中餐", + "direction": "outcome", + "amount": 10, + "date": "2024-09-02", + "time": "07:00", + "title": "早饭10元" + }, + { + "type": "accounting", + "category1": "餐饮", + "category2": "茶饮", + "direction": "outcome", + "amount": 25, + "date": "2024-09-02", + "time": "14:00", + "title": "下午茶25元" + }, + { + "type": "accounting", + "category1": "餐饮", + "category2": "其他餐饮", + "direction": "outcome", + "amount": 30, + "date": "2024-09-02", + "time": "23:00", + "title": "夜宵30元" + } +] +``` diff --git a/test/format_out/test_xgrammar_constraint.py b/test/format_out/test_xgrammar_constraint.py index 9d2ad0ef8..e04be684d 100644 --- a/test/format_out/test_xgrammar_constraint.py +++ b/test/format_out/test_xgrammar_constraint.py @@ -22,7 +22,7 @@ def run(self): print("Error:", response.status_code, response.text) -url = "http://localhost:9999/generate" +url = "http://0.0.0.0:8888/generate" headers = {"Content-Type": "application/json"} json_grammar_ebnf_str = r""" root ::= basic_array | basic_object @@ -98,10 +98,7 @@ def run(self): user_input = open("user.md", "r").read() messages = [ - { - "role": "system", - "content": system_prompt, - }, + {"role": "system", "content": system_prompt,}, {"role": "user", "content": user_input}, ] diff --git a/test/format_out/user.md b/test/format_out/user.md new file mode 100644 index 000000000..fac770a43 --- /dev/null +++ b/test/format_out/user.md @@ -0,0 +1,5 @@ +以下是用户当前的时间、位置信息和语音录入的内容。请根据这些信息,输出符合要求的 JSON。只有当用户输入的内容中有明确金额发生的时候才能识别为记账,否则为笔记。必须按照标准正确的JSON格式输出,即一个 JSON Array,不要有 ```json 这种 Markdown 格式的 code fence,不要包含多余的标记、注释或解释性文字—— + +时间:2024-11-16 21:46 星期六 +地点:北京市 海淀区 中关村大街 18 号 1 号楼 10 层 +内容:10月9号晚上看了一场音乐剧,票价168块,位置还不错,演出很精彩。Í From 2b2e17f4c6bc6353791bc5d5e5e205acf7324885 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Tue, 25 Feb 2025 16:30:58 +0800 Subject: [PATCH 08/14] save --- lightllm/server/core/objs/sampling_params.py | 7 +- .../continues_batch/impl_for_xgrammar_mode.py | 14 +- test/format_out/system.md | 458 ------------------ test/format_out/test_xgrammar_constraint.py | 8 +- test/format_out/user.md | 5 - 5 files changed, 17 insertions(+), 475 deletions(-) delete mode 100644 test/format_out/system.md delete mode 100644 test/format_out/user.md diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 86b35c365..7a21ec9ff 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -15,6 +15,7 @@ GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) + class StopSequence(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -77,7 +78,7 @@ def to_list(self): class RegularConstraint(ctypes.Structure): _pack_ = 4 _fields_ = [ - ("constraint", ctypes.c_byte * REGULAR_CONSTRAINT_MAX_LENGTH), + ("constraint", ctypes.c_ubyte * REGULAR_CONSTRAINT_MAX_LENGTH), ("length", ctypes.c_int), ] @@ -102,7 +103,7 @@ def to_str(self): class GuidedGrammar(ctypes.Structure): _pack_ = 4 _fields_ = [ - ("constraint", ctypes.c_byte * GRAMMAR_CONSTRAINT_MAX_LENGTH), + ("constraint", ctypes.c_ubyte * GRAMMAR_CONSTRAINT_MAX_LENGTH), ("length", ctypes.c_int), ] @@ -122,7 +123,7 @@ def to_str(self): class GuidedJsonSchema(ctypes.Structure): _pack_ = 4 _fields_ = [ - ("constraint", ctypes.c_byte * JSON_SCHEMA_MAX_LENGTH), + ("constraint", ctypes.c_ubyte * JSON_SCHEMA_MAX_LENGTH), ("length", ctypes.c_int), ] diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py index 2f4b87636..85fb0229e 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py @@ -33,7 +33,7 @@ def init_custom(self): eos_token_ids = [] eos_token_ids.append(self.tokenizer.eos_token_id) eos_token_ids.extend(self.args.eos_id) - # self.tokenizer.eos_token_ids = eos_token_ids + self.tokenizer.eos_token_ids = eos_token_ids # logger.info(f"eos_ids {self.tokenizer.eos_token_ids}") return @@ -41,7 +41,6 @@ def init_custom(self): def prefill(self, reqs: List[Tuple]): req_ids = self._init_reqs(reqs) kwargs, run_reqs = prepare_prefill_inputs(req_ids, is_multimodal=self.is_multimodal) - run_reqs: List[InferReq] = reqs logics = self.model.forward(**kwargs) @@ -59,7 +58,7 @@ def prefill(self, reqs: List[Tuple]): # fix the logics with -inf to a large negative value logics[logics == float("-inf")] = -1000000.0 - logics[mask] = -1000000.0 + # logics[mask] = -1000000.0 next_token_ids, next_token_probs = sample(logics, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() @@ -81,7 +80,7 @@ def decode(self): mask = torch.ones_like(logits, dtype=torch.bool) for i, run_obj in enumerate(run_reqs): self._mask_req_out_token(i, run_obj, mask, logits[i]) - logits[mask] = -1000000.0 + # logits[mask] = -1000000.0 logits[logits == float("-inf")] = -1000000.0 next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) @@ -90,7 +89,7 @@ def decode(self): self.post_handel(run_reqs, next_token_ids, next_token_logprobs) return - + def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logprobs): finished_req_ids = [] @@ -105,7 +104,10 @@ def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logpr req_obj.out_token_id_count[next_token_id] += 1 req_obj.update_finish_status(self.eos_id) - if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted: + matcher = req_obj.sampling_param.xgrammar_matcher + assert matcher.accept_token(next_token_id) + + if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted or matcher.is_terminated(): finished_req_ids.append(req_obj.shm_req.request_id) if self.tp_rank < self.dp_size: diff --git a/test/format_out/system.md b/test/format_out/system.md deleted file mode 100644 index dfae5b96d..000000000 --- a/test/format_out/system.md +++ /dev/null @@ -1,458 +0,0 @@ -你是一款记账助手的抽取器,擅长从用户的输入中提取关键信息并输出为 JSON 格式。根据用户输入内容、结合提交时间和地点信息,判断其是否为记账任务或笔记任务,并提取相关信息。 - -用户的请求中的可能包含以下信息: - -- 时间:即用户提交请求的当前时间 -- 地点:即用户当前的 GPS 定位位置 -- 内容:即用户输入的文字内容,可能包含误识别的错别字,请尽量纠正 - - -## 需要提取的参数 - -- **type**:任务类型。如果是记账则为 `accounting`,笔记为 `note`。用户输入的内容中有明确金额发生的时候才能识别为记账 `accounting`,否则为笔记`note`。 - -- **direction**:收支方向,根据用户内容中描述的行为判断。 - - `outcome`:表示支出,例如,"花了"、"花费"、"付了"、"买了" 等。 - - `income`:表示收入,例如,"赚了"、"收到了"、"得到了"、"中了" 等。 - -- **category1**:一级分类。从给定的支出分类中为支出的账单选择对应的一级分类,如:"餐饮"、"交通"、"娱乐"等;从给定的收入分类中为收入的账单选择对应的一级分类,如:"工资"、"兼职"、"生活费"等,"收入"本身不是一级分类,必须严格匹配,收入的账单选择对应的一级分类时,不能选择属于支出的一级分类。当无法找到明显的匹配项时,选择对应的 "其他"分类,禁止自行输出不存在于给定分类中的分类。 - -- **category2**:二级分类。从对应的一级分类下的二级分类中选择,必须严格匹配。当无法找到明显的匹配项时,选择对应的 "其他"分类,禁止自行输出不存在于给定分类中的分类。 - -- **amount**:金额数值。提取用户输入中的金额,支持中文数字和阿拉伯数字,若出现两者混合的情况,请准确提取、组合,结果为保留两位小数的阿拉伯数字,不需要加单位。 - -- **date**:日期,格式为 `YYYY-MM-DD`。根据用户输入的日期信息提取出date值; -用户可能会输入 "昨天"、"上周五" 等,需要结合提交日期进行计算。如果请求问题的中提到的是上周几,你需要根据当天是周几(星期几)进行推理计算,推算出过去14天里符合条件的日期;如果请求问题的中提到的是上上周几,你需要根据当天是周几(星期几)进行推理计算,推算出过去21天里符合条件的日期,依此类推; -计算结果只能往前推算出过去已结束的日期,不能推算出未来未开始的日期; -如果用户输入中未提及date信息,默认为提交请求的当前日期。 - -- **time**:时间,格式为 `HH:MM`。根据用户输入的具体或模糊时间提取,模糊时间按照下文的映射规则确定具体时间。如果用户输入中提取不到时间,则需要结合用户输入或计算出的日期判断该填 `null`还是当前时间,判断规则是:若用户输入或解析到的日期非当天,则填`null`,若用户未提及日期,则该条的日期和时间为当前提交的日期和时间。 - -- **title**:对该笔记账的简短描述,尽量不超过 10 个字。可以省略金额、时间、分类等重复信息,突出主要内容。 - - -## 处理规则 - -1. 判断任务类型:如果输入中包含金额,则为记账任务,否则为笔记任务。记账任务可能包含多条,需要根据内容进行拆分,最多拆为20条。笔记任务只有一条,不再进行拆分。 -2. 按任务类型提取对应的相关信息: - - 记账任务: - - direction, category1, category2, amount, date, time, title - - 笔记任务: - - date, time, title -3. 对于笔记任务,按用户提交时间判断日期,并提取时间信息,无需进行额外解析推断。 -4. 对于记账任务,先根据语义判断收支方向,然后据此判断分类,分类必须从以下的分类信息中选择,不能随意填写。当无法找到明显的匹配项时,选择对应的 "其他"分类,禁止自行输出不存在于给定分类中的分类。 -5. 对于记账任务的日期的解析,根据内容中的时间表达(如"昨天"、"上周五")结合当前日期计算实际日期。对于星期的表达,如"周一"、"周末"等,默认按过去七天内的日期计算。周末视为周六。每周开始按周一算,即周一到周日。 -6. 对于记账任务中具体提及的时间,提取为 `HH:MM` 格式。模糊时间可以按照以下规则映射(如果时间无法推断则 time 为 null): - - 凌晨 / 黎明 / 拂晓 / 早起:00:00-06:00(默认01:00) - - 清晨 / 早晨 / 早上 / 大清早 / 一大早 / 日出 / 早间:06:00-10:00(默认07:00) - - 上午 / 午前 / 上半天:09:00-12:00(默认10:00) - - 中午 / 正午 / 午时 / 晌午:12:00-14:00(默认13:00) - - 下午 / 午后 / 下半天:13:00-18:00(默认14:00) - - 傍晚 / 黄昏 / 日落 / 傍黑 / 薄暮:17:00-20:00(默认18:00) - - 晚上 / 夜晚:19:00-23:59(默认20:00) - - 深夜 / 半夜 / 大半夜 / 午夜 / 夜半:22:00-02:00(默认23:00) - - 夜间 / 夜里 / 黑夜:19:00-06:00(默认23:00) - - 白天:06:00-18:00(默认12:00) -7. title 标题为用户输入的内容的缩写,应简洁明了,突出主要信息,不重复金额、时间等已提取的内容。 -8. 必须按照标准正确的JSON格式输出,即一个 JSON Array,不要有 ```json 这种 Markdown 格式的 code fence,不要包含多余的标记、注释或解释性文字。 - -## 具体的分类信息 - -### 支出的一级和二级分类 - -- 餐饮 - - 中餐 - - 西餐 - - 日料 - - 韩料 - - 茶饮 - - 咖啡 - - 碳酸饮料 - - 果汁 - - 植物饮品 - - 其他饮品 - - 零食 - - 水果 - - 其他餐饮 -- 交通 - - 公交 - - 地铁 - - 共享单车 - - 共享电动车 - - 共享汽车 - - 高铁 - - 火车 - - 飞机 - - 轮渡 - - 客车 - - 油费 - - 停车费 - - 洗车费 - - 过路费 - - 维修保养费 - - 打车 - - 租车 - - 其他交通 -- 服饰 - - 上衣 - - 下装 - - 鞋类 - - 内衣 - - 配饰 - - 首饰 - - 其他服饰 -- 医疗 - - 门诊挂号费 - - 住院费 - - 药品 - - 保健 - - 护理 - - 检查 - - 手术 - - 医疗器械 - - 其他医疗 -- 美容服务 - - 护肤 - - 医美 - - 美发 - - 美甲 - - 其他美容服务 -- 赠予 - - 礼金 - - 红包 - - 捐款 - - 礼物 - - 捐物 -- 学习 - - 学费 - - 书籍 - - 文具 - - 其他学习 -- 娱乐 - - K歌 - - 演唱会 - - 话剧 - - 音乐会 - - 脱口秀 - - 桌游 - - 密室 - - 电子游戏 - - 彩票 - - 电影 - - 展览 - - 动漫展 - - 其他娱乐 -- 运动 - - 飞盘 - - 徒步 - - 羽毛球 - - 网球 - - 篮球 - - 足球 - - 其他运动 -- 住房 - - 购房费用 - - 租房费用 - - 物业费 - - 水电费 - - 燃气费 - - 家电购买或租赁费用 - - 家具购买或租赁费用 - - 家电家具维修费 - - 房屋装修费用 - - 其他住房费用 -- 日用品 - - 洗漱用品 - - 护肤品 - - 化妆品 - - 房屋清洁用品 - - 家用五金 - - 厨具用品 - - 母婴用品 - - 数码产品 - - 其他日用品 -- 旅行 - - 旅行住宿 - - 旅游活动 - - 旅行购物 - - 其他旅行 -- 宠物 - - 宠物食品 - - 宠物玩具 - - 宠物医疗 - - 购买宠物 - - 其他宠物 -- 通讯 - - 话费 - - 网费 -- 保险 - - 保险 -- 其他支出 - - 其他支出 - -### 收入的一级和二级分类 - -- 工资 - - 工资 -- 生活费 - - 生活费 -- 奖学金 - - 奖学金 -- 年终奖 - - 年终奖 -- 报销 - - 报销 -- 绩效奖金 - - 绩效奖金 -- 兼职 - - 兼职 -- 零花钱 - - 零花钱 -- 投资收入 - - 投资收入 -- 礼金红包 - - 礼金 - - 红包 -- 加班补贴 - - 加班补贴 -- 餐饮补贴 - - 餐饮补贴 -- 经营收入 - - 经营收入 -- 二手闲置 - - 二手闲置 -- 彩票 - - 彩票 -- 其他收入 - - 其他收入 - - -## 示例 - -### 输入上下文示例 - -1. 记账任务 -时间:2024-06-20 14:58 星期四 -地点:北京市 海淀区 海淀区中关村南大街 27 号 -内容:昨天吃了麦当劳,花了二十八块八 - -2. 记账任务 -时间:2024-11-05 11:09 星期二 -地点:北京市 海淀区 北四环西路 理想国际大厦 星巴克 -内容:刚喝了杯三十块的咖啡 - -3. 笔记任务 -时间:2024-05-17 21:09 星期五 -地点:北京市 丰台区 丰台区南四环西路 188 号 -内容:去园博园春游了,园博园的花开了很漂亮 - -4. 记账任务 -时间:2024-11-11 14:58 星期一 -地点:北京市 门头沟区 门头沟区滨河路 1 号 -内容:周六晚上看了一场音乐剧,票价168块,位置还不错,演出很精彩。 - -5. 记账任务 -时间:2024-10-24 17:58 星期四 -地点:北京市 朝阳区 北四环东路 10 号 -内容:公司团建,去了郊区农家乐,AA制人均200块,包含餐费和车费,活动很多,玩得挺开心的。 - -6. 记账任务 -时间:2024-11-11 12:10 星期一 -地点:北京市 海淀区 北四环西路 理想国际大厦 星巴克 -内容:婴儿推车1200元 - -7. 记账任务 -时间:2024-08-10 22:10 星期六 -地点:北京市 海淀区 北四环西路 理想国际大厦 星巴克 -内容:七夕酒吧驻唱赚了200 - -8. 多条记账任务 -时间:2024-11-12 18:12 星期二 -地点:北京市 海淀区 北四环西路 理想国际大厦 星巴克 -内容:买了一根口红花了二百三,一块眼影花了490 - -9. 多条记账任务 -时间:2024-09-02 22:58 星期一 -地点:北京市 门头沟区 门头沟区滨河路 1 号 -内容:今天早饭包子花了10块,下午奶茶花了25,夜宵点外卖花了30 - -### 输出 JSON 格式 - -1. 记账任务 - -```json -[ - { - "type": "accounting", - "category1": "餐饮", - "category2": "其他餐饮", - "direction": "outcome", - "amount": 28.8, - "date": "2024-06-19", - "time": null, - "title": "麦当劳用餐" - } -] -``` - -2. 记账任务 - -```json -[ - { - "type": "accounting", - "category1": "餐饮", - "category2": "咖啡", - "direction": "outcome", - "amount": 30, - "date": "2024-11-05", - "time": "11:09", - "title": "星巴克咖啡" - } -] -``` - -3. 笔记任务 - -```json -[ - { - "type": "note", - "date": "2024-05-17", - "time": "21:09", - "title": "园博园春游" - } -] -``` - -4. 记账任务 - -```json -[ - { - "type": "accounting", - "category1": "娱乐", - "category2": "音乐会", - "direction": "outcome", - "amount": 168, - "date": "2024-11-09", - "time": "20:00", - "title": "音乐剧观赏" - } -] -``` - -5. 记账任务 - -```json -[ - { - "type": "accounting", - "category1": "娱乐", - "category2": "其它娱乐", - "direction": "outcome", - "amount": 200, - "date": "2024-10-24", - "time": null, - "title": "公司团建农家乐" - } -] -``` - -6. 记账任务 - -```json -[ - { - "type": "accounting", - "category1": "日用品", - "category2": "母婴用品", - "direction": "outcome", - "amount": 1200, - "date": "2024-11-11", - "time": "12:10", - "title": "购买婴儿推车" - } -] -``` - -7. 记账任务 - -```json -[ - { - "type": "accounting", - "category1": "兼职", - "category2": "兼职", - "direction": "income", - "amount": 200, - "date": "2024-08-10", - "time": "22:10", - "title": "兼职酒吧驻唱" - } -] -``` - -8. 多条记账任务 - -```json -[ - { - "type": "accounting", - "category1": "日用品", - "category2": "化妆品", - "direction": "outcome", - "amount": 230, - "date": "2024-11-12", - "time": "18:12", - "title": "购买口红" - }, - { - "type": "accounting", - "category1": "日用品", - "category2": "化妆品", - "direction": "outcome", - "amount": 490, - "date": "2024-11-12", - "time": "18:12", - "title": "购买眼影" - } -] -``` - -9. 多条记账任务 - -```json -[ - { - "type": "accounting", - "category1": "餐饮", - "category2": "中餐", - "direction": "outcome", - "amount": 10, - "date": "2024-09-02", - "time": "07:00", - "title": "早饭10元" - }, - { - "type": "accounting", - "category1": "餐饮", - "category2": "茶饮", - "direction": "outcome", - "amount": 25, - "date": "2024-09-02", - "time": "14:00", - "title": "下午茶25元" - }, - { - "type": "accounting", - "category1": "餐饮", - "category2": "其他餐饮", - "direction": "outcome", - "amount": 30, - "date": "2024-09-02", - "time": "23:00", - "title": "夜宵30元" - } -] -``` diff --git a/test/format_out/test_xgrammar_constraint.py b/test/format_out/test_xgrammar_constraint.py index e04be684d..fb4e7b3f8 100644 --- a/test/format_out/test_xgrammar_constraint.py +++ b/test/format_out/test_xgrammar_constraint.py @@ -4,7 +4,7 @@ import threading from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("/mnt/nvme0/chenjunyi/models/nb10_w8/") +tokenizer = AutoTokenizer.from_pretrained("/mnt/nvme0/models/Meta-Llama-3.1-8B-Instruct") class RequestThread(threading.Thread): @@ -97,8 +97,10 @@ def run(self): system_prompt = open("system.md", "r").read() user_input = open("user.md", "r").read() +# user_input = """generate a person information for me, for example, {'name': 'John', 'age': 25}.""" + messages = [ - {"role": "system", "content": system_prompt,}, + {"role": "system", "content": system_prompt}, {"role": "user", "content": user_input}, ] @@ -110,7 +112,7 @@ def run(self): # 'temperature': 0.1, "parameters": { "do_sample": False, - "guided_json": json_schema_str, + "guided_grammar": json_grammar_ebnf_str, "max_new_tokens": 200, }, } diff --git a/test/format_out/user.md b/test/format_out/user.md deleted file mode 100644 index fac770a43..000000000 --- a/test/format_out/user.md +++ /dev/null @@ -1,5 +0,0 @@ -以下是用户当前的时间、位置信息和语音录入的内容。请根据这些信息,输出符合要求的 JSON。只有当用户输入的内容中有明确金额发生的时候才能识别为记账,否则为笔记。必须按照标准正确的JSON格式输出,即一个 JSON Array,不要有 ```json 这种 Markdown 格式的 code fence,不要包含多余的标记、注释或解释性文字—— - -时间:2024-11-16 21:46 星期六 -地点:北京市 海淀区 中关村大街 18 号 1 号楼 10 层 -内容:10月9号晚上看了一场音乐剧,票价168块,位置还不错,演出很精彩。Í From f67156457946a769db4459e34c596cfc5a565b7f Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Wed, 26 Feb 2025 15:02:37 +0800 Subject: [PATCH 09/14] fix --- .pre-commit-config.yaml | 2 +- .../server/core/objs/py_sampling_params.py | 2 +- lightllm/server/core/objs/sampling_params.py | 48 +++++++++-- lightllm/server/core/objs/start_args_type.py | 2 +- .../model_infer/mode_backend/__init__.py | 2 +- ...y => impl_for_outlines_constraint_mode.py} | 2 +- .../continues_batch/impl_for_xgrammar_mode.py | 28 +++---- .../server/router/model_infer/model_rpc.py | 13 ++- lightllm/server/router/req_queue/__init__.py | 1 - test/format_out/test_xgrammar_constraint.py | 2 +- test/test_xgrammar_constraint.py | 81 ------------------- 11 files changed, 66 insertions(+), 117 deletions(-) rename lightllm/server/router/model_infer/mode_backend/continues_batch/{impl_for_simple_constraint_mode.py => impl_for_outlines_constraint_mode.py} (99%) delete mode 100644 test/test_xgrammar_constraint.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 446e69da7..678ac8b8f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,4 +11,4 @@ repos: hooks: - id: flake8 additional_dependencies: [flake8-typing-imports==1.9.0] - args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, W191, E101'] \ No newline at end of file + args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606'] \ No newline at end of file diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 6dd998170..0d0212cf9 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -52,7 +52,7 @@ def __init__( guided_json: Optional[Union[str, dict, BaseModel]] = None, # JSON schema constrain the output. # If provided, the engine will construct a logits, # processor which only retains scores for the given token ids. Defaults to None. - # allowed_token_ids only can be used in "--simple_constraint_mode" started server. + # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. allowed_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 7a21ec9ff..772615d11 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -107,16 +107,27 @@ class GuidedGrammar(ctypes.Structure): ("length", ctypes.c_int), ] - def initialize(self, constraint: str): + def initialize(self, constraint: str, tokenizer): constraint_bytes = constraint.encode("utf-8") assert len(constraint_bytes) < GRAMMAR_CONSTRAINT_MAX_LENGTH, "Guided grammar is too long." ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) self.length = len(constraint_bytes) - # TODO: Later we can add a grammar parser to check the grammar + try: + if self.length > 0: + import xgrammar as xgr + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) + xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + print(constraint) + xgrammar_compiler.compile_grammar(constraint) + except Exception as e: + raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}") return def to_str(self): + if self.length == 0: + return "" return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00") @@ -127,16 +138,26 @@ class GuidedJsonSchema(ctypes.Structure): ("length", ctypes.c_int), ] - def initialize(self, constraint: str): + def initialize(self, constraint: str, tokenizer): constraint_bytes = constraint.encode("utf-8") assert len(constraint_bytes) < JSON_SCHEMA_MAX_LENGTH, "Guided json schema is too long." ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) self.length = len(constraint_bytes) - # TODO: Later we can add a json schema parser to check the schema + try: + if self.length > 0: + import xgrammar as xgr + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) + xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + xgrammar_compiler.compile_json_schema(constraint) + except Exception as e: + raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}") return def to_str(self): + if self.length == 0: + return "" return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00") @@ -237,7 +258,7 @@ class SamplingParams(ctypes.Structure): ("guided_json", GuidedJsonSchema), # If provided, the engine will construct a logits, # processor which only retains scores for the given token ids. Defaults to None. - # allowed_token_ids only can be used in "--simple_constraint_mode" started server. + # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. ("allowed_token_ids", AllowedTokenIds), ("stop_sequences", StopSequenceGroups), ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), @@ -298,12 +319,12 @@ def init(self, tokenizer, **kwargs): # Initialize guided_grammar guided_grammar = kwargs.get("guided_grammar", "") self.guided_grammar = GuidedGrammar() - self.guided_grammar.initialize(guided_grammar) + self.guided_grammar.initialize(guided_grammar, tokenizer) # Initialize guided_json guided_json = kwargs.get("guided_json", "") self.guided_json = GuidedJsonSchema() - self.guided_json.initialize(guided_json) + self.guided_json.initialize(guided_json, tokenizer) # Initialize stop_sequence_groups stop_sequences = kwargs.get("stop_sequences", []) @@ -370,13 +391,26 @@ def verify(self): ) self._verify_allowed_token_ids() + self._verify_grammar_constraint() return + def _verify_grammar_constraint(self): + if self.guided_grammar.length != 0: + if self.regular_constraint.length != 0: + raise ValueError("guided_grammar and regular_constraint can not be used in same time") + if self.guided_json.length != 0: + raise ValueError("guided_grammar and guided_json can not be used in same time") + return + def _verify_allowed_token_ids(self): if self.allowed_token_ids.size != 0: if self.regular_constraint.length != 0: raise ValueError("allowed_token_ids and regular_constraint can not be used in same time") + if self.guided_grammar.length != 0: + raise ValueError("allowed_token_ids and guided_grammar can not be used in same time") + if self.guided_json.length != 0: + raise ValueError("allowed_token_ids and guided_json can not be used in same time") return def to_dict(self): diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 32898a668..7c2e3c6c2 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -41,7 +41,7 @@ class StartArgs: enable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - simple_constraint_mode: bool = field(default=False) + output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) cache_capacity: int = field(default=200) diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 9bbae9fec..509f1dd70 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -4,7 +4,7 @@ from .chunked_prefill.impl import ChunkedPrefillBackend from .diverse_backend.impl import DiversehBackend from .continues_batch.impl_for_token_healing import TokenHealingBackend -from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend +from .continues_batch.impl_for_outlines_constraint_mode import OutlinesConstraintBackend from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend from .dp_backend.impl import DPBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_outlines_constraint_mode.py similarity index 99% rename from lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py rename to lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_outlines_constraint_mode.py index 963cf4ff8..00af16be4 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_outlines_constraint_mode.py @@ -14,7 +14,7 @@ logger = init_logger(__name__) -class SimpleConstraintBackend(ContinuesBatchBackend): +class OutlinesConstraintBackend(ContinuesBatchBackend): def __init__(self) -> None: super().__init__() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py index 85fb0229e..3d880614c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_xgrammar_mode.py @@ -2,7 +2,6 @@ import shutil import torch from typing import List, Tuple -import xgrammar as xgr from .impl import ContinuesBatchBackend from .pre_process import prepare_prefill_inputs, prepare_decode_inputs @@ -22,6 +21,8 @@ def __init__(self) -> None: super().__init__() def init_custom(self): + import xgrammar as xgr + self.tokenizer = get_tokenizer( self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code ) @@ -33,18 +34,17 @@ def init_custom(self): eos_token_ids = [] eos_token_ids.append(self.tokenizer.eos_token_id) eos_token_ids.extend(self.args.eos_id) - self.tokenizer.eos_token_ids = eos_token_ids - # logger.info(f"eos_ids {self.tokenizer.eos_token_ids}") return @calculate_time(show=False, min_cost_ms=300) def prefill(self, reqs: List[Tuple]): + import xgrammar as xgr + req_ids = self._init_reqs(reqs) kwargs, run_reqs = prepare_prefill_inputs(req_ids, is_multimodal=self.is_multimodal) logics = self.model.forward(**kwargs) - mask = torch.ones_like(logics, dtype=torch.bool) for i, run_obj in enumerate(run_reqs): run_obj: InferReq = run_obj sample_params = run_obj.sampling_param @@ -54,11 +54,10 @@ def prefill(self, reqs: List[Tuple]): elif sample_params.guided_json is not None: xgrammar_compiled_grammar = self.xgrammar_compiler.compile_json_schema(sample_params.guided_json) sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) - self._mask_req_out_token(i, run_obj, mask, logics[i]) + self._mask_req_out_token(i, run_obj, logics[i]) # fix the logics with -inf to a large negative value logics[logics == float("-inf")] = -1000000.0 - # logics[mask] = -1000000.0 next_token_ids, next_token_probs = sample(logics, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() @@ -70,6 +69,8 @@ def prefill(self, reqs: List[Tuple]): @calculate_time(show=True, min_cost_ms=200) def decode(self): + import xgrammar as xgr + kwargs, run_reqs = prepare_decode_inputs(g_infer_context.infer_req_ids) run_reqs: List[InferReq] = run_reqs @@ -77,10 +78,8 @@ def decode(self): all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs]) if not all_has_no_constraint: - mask = torch.ones_like(logits, dtype=torch.bool) for i, run_obj in enumerate(run_reqs): - self._mask_req_out_token(i, run_obj, mask, logits[i]) - # logits[mask] = -1000000.0 + self._mask_req_out_token(i, run_obj, logits[i]) logits[logits == float("-inf")] = -1000000.0 next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) @@ -91,6 +90,8 @@ def decode(self): return def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logprobs): + import xgrammar as xgr + finished_req_ids = [] for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): @@ -126,14 +127,11 @@ def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logpr g_infer_context.filter(finished_req_ids) return - def _mask_req_out_token(self, i, run_obj: InferReq, mask, logits): + def _mask_req_out_token(self, i, run_obj: InferReq, logits): + import xgrammar as xgr + sample_params = run_obj.sampling_param if sample_params.guided_grammar is not None or sample_params.guided_json is not None: sample_params.xgrammar_matcher.fill_next_token_bitmask(self.xgrammar_token_bitmask) xgr.apply_token_bitmask_inplace(logits, self.xgrammar_token_bitmask.to(logits.device)) - mask[i, :] = False - elif sample_params.allowed_token_ids is not None: - mask[i, sample_params.allowed_token_ids] = False - else: - mask[i, :] = False return diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index fd99c2016..2a63e6b21 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -13,7 +13,7 @@ DiversehBackend, RewardModelBackend, TokenHealingBackend, - SimpleConstraintBackend, + OutlinesConstraintBackend, XgrammarBackend, FirstTokenConstraintBackend, ContinuesBatchBackendForPrefillNode, @@ -107,15 +107,15 @@ def init_model(self, kvargs): is_token_healing = kvargs.get("is_token_healing", False) is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False) if kvargs.get("args", None) is not None: - is_simple_constraint_mode = kvargs.get("args", None).output_constraint_mode == "outlines" + is_outlines_constraint_mode = kvargs.get("args", None).output_constraint_mode == "outlines" is_xgrammar_constraint_mode = kvargs.get("args", None).output_constraint_mode == "xgrammar" assert not ( - is_simple_constraint_mode and is_xgrammar_constraint_mode + is_outlines_constraint_mode and is_xgrammar_constraint_mode ), "only one constraint mode can be true" is_prefill_node = kvargs.get("args", None).run_mode == "prefill" is_decode_node = kvargs.get("args", None).run_mode == "decode" else: - is_simple_constraint_mode = False + is_outlines_constraint_mode = False is_xgrammar_constraint_mode = False is_prefill_node = False is_decode_node = False @@ -134,10 +134,9 @@ def init_model(self, kvargs): self.backend = DiversehBackend() elif is_token_healing: self.backend = TokenHealingBackend() - elif is_simple_constraint_mode: - self.backend = SimpleConstraintBackend() + elif is_outlines_constraint_mode: + self.backend = OutlinesConstraintBackend() elif is_xgrammar_constraint_mode: - # now we prioritize simple_constraint_mode(Outlines) self.backend = XgrammarBackend() elif is_first_token_constraint_mode: self.backend = FirstTokenConstraintBackend() diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 03d03e249..793b81662 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -15,7 +15,6 @@ def build_req_queue(args, router, dp_size: int): queue_class = ChunkedPrefillQueue if args.token_healing_mode: queue_class = ContinuesBatchQueue - # if args.simple_constraint_mode: if args.output_constraint_mode != "none": queue_class = ContinuesBatchQueue if args.first_token_constraint_mode: diff --git a/test/format_out/test_xgrammar_constraint.py b/test/format_out/test_xgrammar_constraint.py index fb4e7b3f8..67490dd31 100644 --- a/test/format_out/test_xgrammar_constraint.py +++ b/test/format_out/test_xgrammar_constraint.py @@ -112,7 +112,7 @@ def run(self): # 'temperature': 0.1, "parameters": { "do_sample": False, - "guided_grammar": json_grammar_ebnf_str, + # "guided_json": json_schema_str, "max_new_tokens": 200, }, } diff --git a/test/test_xgrammar_constraint.py b/test/test_xgrammar_constraint.py deleted file mode 100644 index dd0e5d8b8..000000000 --- a/test/test_xgrammar_constraint.py +++ /dev/null @@ -1,81 +0,0 @@ -import time -import requests -import json -import threading - -""" -python -m lightllm.server.api_server --model_dir /mnt/nvme0/chenjunyi/models/nb10_w8/ \ - --host 0.0.0.0 \ - --port 9999 \ - --tp 1 \ - --nccl_port 65535 \ - --max_req_total_len 200000 \ - --max_total_token_num 400000 \ - --data_type bf16 \ - --trust_remote_code \ - --output_constraint_mode xgrammar -""" - - -class RequestThread(threading.Thread): - def __init__(self, url, headers, data): - threading.Thread.__init__(self) - self.url = url - self.headers = headers - self.data = data - - def run(self): - response = requests.post(self.url, headers=self.headers, data=json.dumps(self.data)) - if response.status_code == 200: - print(response.json()) - else: - print("Error:", response.status_code, response.text) - - -url = "http://localhost:9999/generate" -headers = {"Content-Type": "application/json"} -json_grammar_ebnf_str = r""" -root ::= basic_array | basic_object -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= (([\"] basic_string_1 [\"])) -basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 -escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" -basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" -ws ::= [ \n\t]* -""" - -for i in range(1): - data = { - "inputs": "Introduce yourself in JSON briefly.", - # 'temperature': 0.1, - "parameters": { - "do_sample": False, - "guided_grammar": json_grammar_ebnf_str, - "max_new_tokens": 200, - }, - } - thread = RequestThread(url, headers, data) - thread.start() - -time.sleep(2) - -for i in range(20): - data = { - "inputs": "12-(25+16)*7=", - "parameters": { - "do_sample": False, - "ignore_eos": True, - "max_new_tokens": 200, - "guided_grammar": r"""root ::= (expr "=" term)+ -expr ::= term ([-+*/] term)* -term ::= num | "(" expr ")" -num ::= [0-9]+""", - }, - } - thread = RequestThread(url, headers, data) - thread.start() From b13f372c22da661815d14041f0f29d7bff4eacac Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Wed, 26 Feb 2025 15:11:02 +0800 Subject: [PATCH 10/14] add-test --- .../server/core/objs/test_sampling_params.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/unit_tests/server/core/objs/test_sampling_params.py b/unit_tests/server/core/objs/test_sampling_params.py index 8cb89f681..489f8ae34 100644 --- a/unit_tests/server/core/objs/test_sampling_params.py +++ b/unit_tests/server/core/objs/test_sampling_params.py @@ -7,11 +7,31 @@ ExponentialDecayLengthPenalty, DecodeNode, SamplingParams, + GuidedGrammar, + GuidedJsonSchema, STOP_SEQUENCE_MAX_LENGTH, REGULAR_CONSTRAINT_MAX_LENGTH, ALLOWED_TOKEN_IDS_MAX_LENGTH, ) +grammar_str = r"""root ::= (expr "=" term)+ +expr ::= term ([-+*/] term)* +term ::= num | "(" expr ")" +num ::= [0-9]+""" + +schema_str = r"""{ + "type": "array", + "items": { + "type": "object", + "properties": { + "Title": {"type": "string"}, + "Date": {"type": "string"}, + "Time": {"type": "string"} + }, + "required": ["Title", "Time", "Date"] + } +}""" + @pytest.mark.parametrize( "sequence, expected", @@ -58,6 +78,24 @@ def test_regular_constraint_initialization(): constraint.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) +def test_guided_grammar_initialization(): + grammar = GuidedGrammar() + grammar.initialize(grammar_str) + assert grammar.to_str() == grammar_str + + with pytest.raises(AssertionError): + grammar.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) + + +def test_guided_json_schema_initialization(): + schema = GuidedJsonSchema() + schema.initialize(schema_str) + assert schema.to_str() == schema_str + + with pytest.raises(AssertionError): + schema.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) + + def test_allowed_token_ids_initialization(): allowed_ids = AllowedTokenIds() allowed_ids.initialize([1, 2, 3]) From 70e03a477b2f152d5cfd32476c63ad8ca13006f1 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Wed, 26 Feb 2025 15:13:28 +0800 Subject: [PATCH 11/14] fix --- lightllm/server/core/objs/py_sampling_params.py | 3 +-- lightllm/server/router/model_infer/infer_batch.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 0d0212cf9..342f6f7b9 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -4,7 +4,6 @@ """ import os from typing import List, Optional, Union, Tuple -from pydantic import BaseModel from transformers import GenerationConfig from lightllm.server.req_id_generator import MAX_BEST_OF @@ -49,7 +48,7 @@ def __init__( input_penalty: bool = DEFAULT_INPUT_PENALTY, regular_constraint: Optional[str] = None, # Regular expressions constrain the output. guided_grammar: Optional[str] = None, # EBNF constrain the output. - guided_json: Optional[Union[str, dict, BaseModel]] = None, # JSON schema constrain the output. + guided_json: Optional[Union[str, dict]] = None, # JSON schema constrain the output. # If provided, the engine will construct a logits, # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1bafe217d..d2a1e84f3 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -1,7 +1,6 @@ import os import copy import time -from pydantic import BaseModel import torch import torch.distributed as dist import numpy as np From c62cbf2a8411d50f5d3a7900e33dcf30eaff74e0 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Wed, 26 Feb 2025 17:33:34 +0800 Subject: [PATCH 12/14] fix --- lightllm/server/core/objs/sampling_params.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 772615d11..959b2b629 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -119,7 +119,6 @@ def initialize(self, constraint: str, tokenizer): tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - print(constraint) xgrammar_compiler.compile_grammar(constraint) except Exception as e: raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}") From 95403f983ab6173d596a03acba2d2b0852707658 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Thu, 27 Feb 2025 10:43:05 +0800 Subject: [PATCH 13/14] fix unit test --- lightllm/server/core/objs/sampling_params.py | 4 ++-- unit_tests/server/core/objs/test_sampling_params.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 772615d11..bd9669803 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -114,7 +114,7 @@ def initialize(self, constraint: str, tokenizer): ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) self.length = len(constraint_bytes) try: - if self.length > 0: + if self.length > 0 and tokenizer is not None: import xgrammar as xgr tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) @@ -145,7 +145,7 @@ def initialize(self, constraint: str, tokenizer): ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) self.length = len(constraint_bytes) try: - if self.length > 0: + if self.length > 0 and tokenizer is not None: import xgrammar as xgr tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) diff --git a/unit_tests/server/core/objs/test_sampling_params.py b/unit_tests/server/core/objs/test_sampling_params.py index 489f8ae34..ee6aea12d 100644 --- a/unit_tests/server/core/objs/test_sampling_params.py +++ b/unit_tests/server/core/objs/test_sampling_params.py @@ -12,6 +12,8 @@ STOP_SEQUENCE_MAX_LENGTH, REGULAR_CONSTRAINT_MAX_LENGTH, ALLOWED_TOKEN_IDS_MAX_LENGTH, + GRAMMAR_CONSTRAINT_MAX_LENGTH, + JSON_SCHEMA_MAX_LENGTH, ) grammar_str = r"""root ::= (expr "=" term)+ @@ -80,20 +82,20 @@ def test_regular_constraint_initialization(): def test_guided_grammar_initialization(): grammar = GuidedGrammar() - grammar.initialize(grammar_str) + grammar.initialize(grammar_str, None) assert grammar.to_str() == grammar_str with pytest.raises(AssertionError): - grammar.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) + grammar.initialize("a" * (GRAMMAR_CONSTRAINT_MAX_LENGTH + 1), None) def test_guided_json_schema_initialization(): schema = GuidedJsonSchema() - schema.initialize(schema_str) + schema.initialize(schema_str, None) assert schema.to_str() == schema_str with pytest.raises(AssertionError): - schema.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1)) + schema.initialize("a" * (JSON_SCHEMA_MAX_LENGTH + 1), None) def test_allowed_token_ids_initialization(): From fede819084ddad3d651a1399eff6abc62c21b5d1 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Thu, 27 Feb 2025 10:52:59 +0800 Subject: [PATCH 14/14] fix unit test --- lightllm/server/core/objs/sampling_params.py | 4 ++-- unit_tests/server/core/objs/test_sampling_params.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 959b2b629..872e5d78f 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -114,7 +114,7 @@ def initialize(self, constraint: str, tokenizer): ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) self.length = len(constraint_bytes) try: - if self.length > 0: + if self.length > 0 and tokenizer is not None: import xgrammar as xgr tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) @@ -144,7 +144,7 @@ def initialize(self, constraint: str, tokenizer): ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes)) self.length = len(constraint_bytes) try: - if self.length > 0: + if self.length > 0 and tokenizer is not None: import xgrammar as xgr tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) diff --git a/unit_tests/server/core/objs/test_sampling_params.py b/unit_tests/server/core/objs/test_sampling_params.py index bdec8512b..ca684274e 100644 --- a/unit_tests/server/core/objs/test_sampling_params.py +++ b/unit_tests/server/core/objs/test_sampling_params.py @@ -12,6 +12,8 @@ STOP_SEQUENCE_MAX_LENGTH, REGULAR_CONSTRAINT_MAX_LENGTH, ALLOWED_TOKEN_IDS_MAX_LENGTH, + JSON_SCHEMA_MAX_LENGTH, + GRAMMAR_CONSTRAINT_MAX_LENGTH, ) grammar_str = r"""root ::= (expr "=" term)+ @@ -84,7 +86,7 @@ def test_guided_grammar_initialization(): assert grammar.to_str() == grammar_str with pytest.raises(AssertionError): - grammar.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1), None) + grammar.initialize("a" * (GRAMMAR_CONSTRAINT_MAX_LENGTH + 1), None) def test_guided_json_schema_initialization(): @@ -93,7 +95,7 @@ def test_guided_json_schema_initialization(): assert schema.to_str() == schema_str with pytest.raises(AssertionError): - schema.initialize("a" * (REGULAR_CONSTRAINT_MAX_LENGTH + 1), None) + schema.initialize("a" * (JSON_SCHEMA_MAX_LENGTH + 1), None) def test_allowed_token_ids_initialization():