Skip to content

Commit 9da8bab

Browse files
server: add test for token probs
1 parent 9afdffe commit 9da8bab

File tree

3 files changed

+74
-8
lines changed

3 files changed

+74
-8
lines changed

examples/server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
4949
- `--api-key`: Set an api key for request authorization. By default, the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
5050
- `--api-key-file`: Path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`s.
5151
- `--embeddings`: Enable embedding vector output and the OAI compatible endpoint /v1/embeddings. Physical batch size (`--ubatch-size`) must be carefully defined. Default: disabled
52-
- `-np N`, `--parallel N`: Set the number of slots for process requests. Default: `1`
52+
- `-np N`, `--parallel N`: Set the number of slots for process requests. Default: `1`. Values > 1 will allow for higher throughput with multiple parallel requests but the results will **not** be deterministic due to differences in rounding error.
5353
- `-cb`, `--cont-batching`: Enable continuous batching (a.k.a dynamic batching). Default: disabled
5454
- `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load a system prompt (initial prompt of all slots). This is useful for chat applications. [See more](#change-system-prompt-on-runtime)
5555
- `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA.

examples/server/tests/features/results.feature

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,43 @@ Feature: Results
7070
Then all predictions are equal
7171
Examples:
7272
| n_parallel | temp |
73-
| 1 | 0.0 |
74-
| 2 | 0.0 |
75-
| 4 | 0.0 |
76-
| 1 | 1.0 |
73+
| 1 | 0.0 |
74+
| 2 | 0.0 |
75+
| 4 | 0.0 |
76+
| 1 | 1.0 |
7777
# FIXME: These tests fail on master. The problem seems to be the unified KV cache.
7878
# See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227
7979
# and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 .
80-
# | 2 | 1.0 |
81-
# | 4 | 1.0 |
80+
# | 2 | 1.0 |
81+
# | 4 | 1.0 |
82+
83+
Scenario Outline: consistent token probs with same seed and prompt
84+
Given <n_slots> slots
85+
And 1 threads
86+
And 1.0 temperature
87+
And <n_predict> max tokens to predict
88+
Then the server is starting
89+
Then the server is healthy
90+
91+
Given 1 prompts "The meaning of life is" with seed 42
92+
And concurrent completion requests
93+
# Then the server is busy # Not all slots will be utilized.
94+
Then the server is idle
95+
And all slots are idle
96+
97+
Given <n_parallel> prompts "The meaning of life is" with seed 42
98+
And concurrent completion requests
99+
# Then the server is busy # Not all slots will be utilized.
100+
Then the server is idle
101+
And all slots are idle
102+
103+
Then all token probabilities are equal
104+
Examples:
105+
| n_slots | n_parallel | n_predict |
106+
| 4 | 1 | 1 |
107+
| 4 | 1 | 10 |
108+
| 4 | 4 | 1 |
109+
# FIXME: These tests fail on master. The problem seems to be the unified KV cache.
110+
# See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227
111+
# and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 .
112+
# | 4 | 4 | 10 |

examples/server/tests/features/steps/steps.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
def step_server_config(context, server_fqdn, server_port):
2424
context.server_fqdn = server_fqdn
2525
context.server_port = int(server_port)
26+
context.n_threads = None
2627
context.n_gpu_layer = None
2728
if 'PORT' in os.environ:
2829
context.server_port = int(os.environ['PORT'])
@@ -109,6 +110,11 @@ def step_n_gpu_layer(context, ngl):
109110
context.n_gpu_layer = ngl
110111

111112

113+
@step('{n_threads:d} threads')
114+
def step_n_threads(context, n_threads):
115+
context.n_thread = n_threads
116+
117+
112118
@step('{draft:d} as draft')
113119
def step_draft(context, draft):
114120
context.draft = draft
@@ -274,13 +280,22 @@ async def step_predictions_equal(context):
274280

275281
@step('all predictions are different')
276282
@async_run_until_complete
277-
async def step_predictions_equal(context):
283+
async def step_predictions_different(context):
278284
n_completions = await gather_tasks_results(context)
279285
assert n_completions >= 2, "need at least 2 completions"
280286
assert_all_predictions_different(context.tasks_result)
281287
context.tasks_result = []
282288

283289

290+
@step('all token probabilities are equal')
291+
@async_run_until_complete
292+
async def step_token_probabilities_equal(context):
293+
n_completions = await gather_tasks_results(context)
294+
assert n_completions >= 2, "need at least 2 completions"
295+
assert_all_token_probabilities_equal(context.tasks_result)
296+
context.tasks_result = []
297+
298+
284299
@step('the completion is truncated')
285300
def step_assert_completion_truncated(context):
286301
step_assert_completion_truncated(context, '')
@@ -869,6 +884,7 @@ async def request_completion(prompt,
869884
"id_slot": id_slot,
870885
"seed": seed if seed is not None else 42,
871886
"temperature": temperature if temperature is not None else "0.8f",
887+
"n_probs": 2,
872888
},
873889
headers=headers,
874890
timeout=3600) as response:
@@ -1123,6 +1139,23 @@ def assert_all_predictions_different(completion_responses):
11231139
assert content_i != content_j, "contents not different"
11241140

11251141

1142+
def assert_all_token_probabilities_equal(completion_responses):
1143+
n_predict = len(completion_responses[0]['completion_probabilities'])
1144+
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
1145+
for pos in range(n_predict):
1146+
for i, response_i in enumerate(completion_responses):
1147+
probs_i = response_i['completion_probabilities'][pos]['probs']
1148+
print(f"pos {pos}, probs {i}: {probs_i}")
1149+
for pos in range(n_predict):
1150+
for i, response_i in enumerate(completion_responses):
1151+
probs_i = response_i['completion_probabilities'][pos]['probs']
1152+
for j, response_j in enumerate(completion_responses):
1153+
if i == j:
1154+
continue
1155+
probs_j = response_j['completion_probabilities'][pos]['probs']
1156+
assert probs_i == probs_j, "contents not equal"
1157+
1158+
11261159
async def gather_tasks_results(context):
11271160
n_tasks = len(context.concurrent_tasks)
11281161
if context.debug:
@@ -1261,6 +1294,8 @@ def start_server_background(context):
12611294
server_args.extend(['--batch-size', context.n_batch])
12621295
if context.n_ubatch:
12631296
server_args.extend(['--ubatch-size', context.n_ubatch])
1297+
if context.n_threads:
1298+
server_args.extend(['--threads', context.threads])
12641299
if context.n_gpu_layer:
12651300
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
12661301
if context.draft is not None:

0 commit comments

Comments
 (0)