Skip to content

Commit 258cd87

Browse files
mitya52MarcMcIntosh
authored andcommitted
n_ctx from model assigner (#677)
* n_ctx from model assigner * models_dict_patch * fix missed fields and patch pass
1 parent 060b1bf commit 258cd87

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import json
33
import copy
44
import asyncio
5-
import aiohttp
65
import aiofiles
76
import termcolor
87
import os
@@ -13,6 +12,8 @@
1312
from fastapi import APIRouter, HTTPException, Query, Header
1413
from fastapi.responses import Response, StreamingResponse
1514

15+
from itertools import chain
16+
1617
from refact_utils.scripts import env
1718
from refact_utils.finetune.utils import running_models_and_loras
1819
from refact_utils.third_party.utils.models import available_third_party_models
@@ -250,8 +251,9 @@ def _select_default_model(models: List[str]) -> str:
250251
# completion models
251252
completion_models = {}
252253
for model_name in running_models.get("completion", []):
253-
if model_info := self._model_assigner.models_db.get(_get_base_model_info(model_name)):
254-
completion_models[model_name] = self._model_assigner.to_completion_model_record(model_info)
254+
base_model_name = _get_base_model_info(model_name)
255+
if model_info := self._model_assigner.models_db.get(base_model_name):
256+
completion_models[model_name] = self._model_assigner.to_completion_model_record(base_model_name, model_info)
255257
elif model := available_third_party_models().get(model_name):
256258
completion_models[model_name] = model.to_completion_model_record()
257259
else:
@@ -261,8 +263,9 @@ def _select_default_model(models: List[str]) -> str:
261263
# chat models
262264
chat_models = {}
263265
for model_name in running_models.get("chat", []):
264-
if model_info := self._model_assigner.models_db.get(_get_base_model_info(model_name)):
265-
chat_models[model_name] = self._model_assigner.to_chat_model_record(model_info)
266+
base_model_name = _get_base_model_info(model_name)
267+
if model_info := self._model_assigner.models_db.get(base_model_name):
268+
chat_models[model_name] = self._model_assigner.to_chat_model_record(base_model_name, model_info)
266269
elif model := available_third_party_models().get(model_name):
267270
chat_models[model_name] = model.to_chat_model_record()
268271
else:
@@ -337,6 +340,18 @@ def _parse_client_version(user_agent: str = Header(None)) -> Optional[Tuple[int,
337340

338341
@staticmethod
339342
def _to_deprecated_caps_format(data: Dict[str, Any]):
343+
models_dict_patch = {}
344+
for model_name, model_record in chain(
345+
data["completion"]["models"].items(),
346+
data["completion"]["models"].items(),
347+
):
348+
dict_patch = {}
349+
if n_ctx := model_record.get("n_ctx"):
350+
dict_patch["n_ctx"] = n_ctx
351+
if supports_tools := model_record.get("supports_tools"):
352+
dict_patch["supports_tools"] = supports_tools
353+
if dict_patch:
354+
models_dict_patch[model_name] = dict_patch
340355
return {
341356
"cloud_name": data["cloud_name"],
342357
"endpoint_template": data["completion"]["endpoint"],
@@ -349,7 +364,7 @@ def _to_deprecated_caps_format(data: Dict[str, Any]):
349364
"code_completion_default_model": data["completion"]["default_model"],
350365
"multiline_code_completion_default_model": data["completion"]["default_multiline_model"],
351366
"code_chat_default_model": data["chat"]["default_model"],
352-
"models_dict_patch": {}, # NOTE: this actually should have n_ctx, but we're skiping it
367+
"models_dict_patch": models_dict_patch,
353368
"default_embeddings_model": data["embedding"]["default_model"],
354369
"endpoint_embeddings_template": "v1/embeddings",
355370
"endpoint_embeddings_style": "openai",

refact-server/refact_webgui/webgui/selfhost_model_assigner.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,15 @@ def share_gpu_backends(self) -> Set[str]:
107107
def models_db(self) -> Dict[str, Any]:
108108
return models_mini_db
109109

110-
@staticmethod
111-
def to_completion_model_record(model_info: Dict[str, Any]) -> Dict[str, Any]:
110+
def to_completion_model_record(self, model_name: str, model_info: Dict[str, Any]) -> Dict[str, Any]:
112111
return {
113-
"n_ctx": model_info["T"],
112+
"n_ctx": min(self.model_assignment["model_assign"].get(model_name, {}).get("n_ctx", model_info["T"]), model_info["T"]),
114113
"supports_scratchpads": model_info["supports_scratchpads"]["completion"],
115114
}
116115

117-
@staticmethod
118-
def to_chat_model_record(model_info: Dict[str, Any]) -> Dict[str, Any]:
116+
def to_chat_model_record(self, model_name: str, model_info: Dict[str, Any]) -> Dict[str, Any]:
119117
return {
120-
"n_ctx": model_info["T"],
118+
"n_ctx": min(self.model_assignment["model_assign"].get(model_name, {}).get("n_ctx", model_info["T"]), model_info["T"]),
121119
"supports_scratchpads": model_info["supports_scratchpads"]["chat"],
122120
}
123121

0 commit comments

Comments
 (0)