Skip to content

server: add model alias presets #14083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2776,6 +2776,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
key_file.close();
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--alias-presets-file"}, "FNAME",
"path to file containing alias preset configurations (default: none)",
[](common_params & params, const std::string & value) {
params.alias_presets_file = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--ssl-key-file"}, "FNAME",
"path to file a PEM-encoded SSL private key",
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ struct common_params {

std::string ssl_file_key = ""; // NOLINT
std::string ssl_file_cert = ""; // NOLINT
std::string alias_presets_file = ""; // NOLINT

// "advanced" endpoints are disabled by default for better security
bool webui = true;
Expand Down
2 changes: 2 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ These words will not be included in the completion, so make sure to add them to

`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation.

`alias-presets-file`: A JSON file of model-alias and it's parameter presets. E.g. `{ "llama-low": {"temperature": 0.1}, "llama-high": {"temperature": 1.0}" }`. If a `model` is specified in the request and has a preset, it will be applied before handling a completion. In case there is a conflict in the request's parameters vs presets, the request's parameters take precedence.

**Response format**

- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
Expand Down
47 changes: 47 additions & 0 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <condition_variable>
#include <cstddef>
#include <cinttypes>
#include <fstream>
#include <deque>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -1886,6 +1887,8 @@ struct server_context {
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;

std::unordered_map<std::string, json> model_alias_presets;

~server_context() {
mtmd_free(mctx);

Expand All @@ -1906,6 +1909,33 @@ struct server_context {
llama_batch_free(batch);
}

void load_model_alias_presets(const std::string & alias_presets_file) {
try {
std::ifstream file(alias_presets_file);
if (!file) {
SRV_ERR("failed to open alias presets file '%s'\n", alias_presets_file.c_str());
return;
}

json presets_json;
file >> presets_json;
file.close();

for (const auto & [model_alias_name, preset] : presets_json.items()) {
if (preset.is_object()) {
model_alias_presets[model_alias_name] = preset;
SRV_INF("loaded preset for model alias '%s'\n", model_alias_name.c_str());
} else {
SRV_WRN("skipping invalid preset for model alias '%s' (not an object)\n", model_alias_name.c_str());
}
}

SRV_INF("loaded %zu model alias presets from '%s'\n", model_alias_presets.size(), alias_presets_file.c_str());
} catch (const std::exception & e) {
SRV_ERR("failed to parse alias presets file '%s': %s\n", alias_presets_file.c_str(), e.what());
}
}

bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.path.c_str());

Expand Down Expand Up @@ -2023,6 +2053,10 @@ struct server_context {
}
}

if (!params_base.alias_presets_file.empty()) {
load_model_alias_presets(params_base.alias_presets_file);
}

return true;
}

Expand Down Expand Up @@ -4181,6 +4215,17 @@ int main(int argc, char ** argv) {
return;
}

// apply presets if available
const std::string model_alias = json_value(data, "model", std::string());
if (!model_alias.empty() && ctx_server.model_alias_presets.find(model_alias) != ctx_server.model_alias_presets.end()) {
const auto & preset = ctx_server.model_alias_presets.at(model_alias);
for (const auto & [key, value] : preset.items()) {
if (!data.contains(key)) {
data[key] = value;
}
}
}

auto completion_id = gen_chatcmplid();
std::unordered_set<int> task_ids;
try {
Expand Down Expand Up @@ -4245,6 +4290,8 @@ int main(int argc, char ** argv) {
}
}



tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
Expand Down
139 changes: 139 additions & 0 deletions tools/server/tests/unit/test_alias_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import json
import os
import tempfile
from pathlib import Path
import sys

import pytest

# ensure grandparent path is in sys.path
path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))

from utils import *

server = ServerPreset.stories15m_moe()

LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"

@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.stories15m_moe()
server.lora_files = [download_file(LORA_FILE_URL)]


def test_alias_presets_per_request():
global server
server.n_slots = 4

preset_data = {
"bedtime-stories": {
"lora": [{"id": 0, "scale": 0.0}]
},
"shakespeare-light": {
"lora": [{"id": 0, "scale": 0.3}]
},
"shakespeare-medium": {
"lora": [{"id": 0, "scale": 0.7}]
},
"shakespeare-full": {
"lora": [{"id": 0, "scale": 1.0}]
}
}

with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(preset_data, f)
preset_file_path = f.name

try:
server.alias_presets_file = preset_file_path
server.start()

# running the same prompt with different model aliases, all in parallel
# each prompt will be processed by a different slot
prompt = "Look in thy glass"
alias_config = [
("bedtime-stories", "(bright|day|many|happy)+"),
("bedtime-stories", "(bright|day|many|happy)+"),
("shakespeare-light", "(special|thing|gifted)+"),
("shakespeare-medium", "(far|from|home|away)+"),
("shakespeare-full", "(eye|love|glass|sun)+"),
("shakespeare-full", "(eye|love|glass|sun)+"),
]

tasks = [(
server.make_request,
("POST", "/completions", {
"model": model_alias,
"prompt": prompt,
"seed": 42,
"temperature": 0.0,
"cache_prompt": False,
})
) for model_alias, _ in alias_config]
results = parallel_function_calls(tasks)

assert all([res.status_code == 200 for res in results])
for res, (_, re_test) in zip(results, alias_config):
assert match_regex(re_test, res.body["content"])

finally:
server.stop()
os.unlink(preset_file_path)

def test_alias_override():
# test whether we honor the user's override even in case a preset is set
global server
server.n_slots = 2

# Use the same preset data as test_alias_presets_per_request
preset_data = {
"bedtime-stories": {
"lora": [{"id": 0, "scale": 0.0}]
},
"shakespeare-light": {
"lora": [{"id": 0, "scale": 0.3}]
},
"shakespeare-medium": {
"lora": [{"id": 0, "scale": 0.7}]
},
"shakespeare-full": {
"lora": [{"id": 0, "scale": 1.0}]
}
}

with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(preset_data, f)
preset_file_path = f.name

try:
server.alias_presets_file = preset_file_path
server.start()

prompt = "Look in thy glass"

res1 = server.make_request("POST", "/completions", {
"model": "bedtime-stories",
"prompt": prompt,
"cache_prompt": False,
})

# override to shakespeare
res2 = server.make_request("POST", "/completions", {
"model": "bedtime-stories",
"prompt": prompt,
"cache_prompt": False,
"lora": [{"id": 0, "scale": 1.0}],
})

assert res1.status_code == 200
assert res2.status_code == 200

assert match_regex("(bright|day|many|happy)+", res1.body["content"])
assert match_regex("(eye|love|glass|sun)+", res2.body["content"])
assert res1.body["content"] != res2.body["content"]

finally:
server.stop()
os.unlink(preset_file_path)
3 changes: 3 additions & 0 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ServerProcess:
reasoning_budget: int | None = None
chat_template: str | None = None
chat_template_file: str | None = None
alias_presets_file: str | None = None
server_path: str | None = None
mmproj_url: str | None = None

Expand Down Expand Up @@ -198,6 +199,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.extend(["--chat-template", self.chat_template])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.alias_presets_file:
server_args.extend(["--alias-presets-file", self.alias_presets_file])
if self.mmproj_url:
server_args.extend(["--mmproj-url", self.mmproj_url])

Expand Down