diff --git a/refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py b/refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py index 09ad460b1..b76c64abd 100644 --- a/refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py +++ b/refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py @@ -2,7 +2,6 @@ import json import copy import asyncio -import aiohttp import aiofiles import termcolor import os @@ -13,6 +12,8 @@ from fastapi import APIRouter, HTTPException, Query, Header from fastapi.responses import Response, StreamingResponse +from itertools import chain + from refact_utils.scripts import env from refact_utils.finetune.utils import running_models_and_loras from refact_utils.third_party.utils.models import available_third_party_models @@ -250,8 +251,9 @@ def _select_default_model(models: List[str]) -> str: # completion models completion_models = {} for model_name in running_models.get("completion", []): - if model_info := self._model_assigner.models_db.get(_get_base_model_info(model_name)): - completion_models[model_name] = self._model_assigner.to_completion_model_record(model_info) + base_model_name = _get_base_model_info(model_name) + if model_info := self._model_assigner.models_db.get(base_model_name): + completion_models[model_name] = self._model_assigner.to_completion_model_record(base_model_name, model_info) elif model := available_third_party_models().get(model_name): completion_models[model_name] = model.to_completion_model_record() else: @@ -261,8 +263,9 @@ def _select_default_model(models: List[str]) -> str: # chat models chat_models = {} for model_name in running_models.get("chat", []): - if model_info := self._model_assigner.models_db.get(_get_base_model_info(model_name)): - chat_models[model_name] = self._model_assigner.to_chat_model_record(model_info) + base_model_name = _get_base_model_info(model_name) + if model_info := self._model_assigner.models_db.get(base_model_name): + chat_models[model_name] = self._model_assigner.to_chat_model_record(base_model_name, model_info) elif model := available_third_party_models().get(model_name): chat_models[model_name] = model.to_chat_model_record() else: @@ -337,6 +340,18 @@ def _parse_client_version(user_agent: str = Header(None)) -> Optional[Tuple[int, @staticmethod def _to_deprecated_caps_format(data: Dict[str, Any]): + models_dict_patch = {} + for model_name, model_record in chain( + data["completion"]["models"].items(), + data["completion"]["models"].items(), + ): + dict_patch = {} + if n_ctx := model_record.get("n_ctx"): + dict_patch["n_ctx"] = n_ctx + if supports_tools := model_record.get("supports_tools"): + dict_patch["supports_tools"] = supports_tools + if dict_patch: + models_dict_patch[model_name] = dict_patch return { "cloud_name": data["cloud_name"], "endpoint_template": data["completion"]["endpoint"], @@ -349,7 +364,7 @@ def _to_deprecated_caps_format(data: Dict[str, Any]): "code_completion_default_model": data["completion"]["default_model"], "multiline_code_completion_default_model": data["completion"]["default_multiline_model"], "code_chat_default_model": data["chat"]["default_model"], - "models_dict_patch": {}, # NOTE: this actually should have n_ctx, but we're skiping it + "models_dict_patch": models_dict_patch, "default_embeddings_model": data["embedding"]["default_model"], "endpoint_embeddings_template": "v1/embeddings", "endpoint_embeddings_style": "openai", diff --git a/refact-server/refact_webgui/webgui/selfhost_model_assigner.py b/refact-server/refact_webgui/webgui/selfhost_model_assigner.py index 69024743a..192de301f 100644 --- a/refact-server/refact_webgui/webgui/selfhost_model_assigner.py +++ b/refact-server/refact_webgui/webgui/selfhost_model_assigner.py @@ -107,17 +107,15 @@ def share_gpu_backends(self) -> Set[str]: def models_db(self) -> Dict[str, Any]: return models_mini_db - @staticmethod - def to_completion_model_record(model_info: Dict[str, Any]) -> Dict[str, Any]: + def to_completion_model_record(self, model_name: str, model_info: Dict[str, Any]) -> Dict[str, Any]: return { - "n_ctx": model_info["T"], + "n_ctx": min(self.model_assignment["model_assign"].get(model_name, {}).get("n_ctx", model_info["T"]), model_info["T"]), "supports_scratchpads": model_info["supports_scratchpads"]["completion"], } - @staticmethod - def to_chat_model_record(model_info: Dict[str, Any]) -> Dict[str, Any]: + def to_chat_model_record(self, model_name: str, model_info: Dict[str, Any]) -> Dict[str, Any]: return { - "n_ctx": model_info["T"], + "n_ctx": min(self.model_assignment["model_assign"].get(model_name, {}).get("n_ctx", model_info["T"]), model_info["T"]), "supports_scratchpads": model_info["supports_scratchpads"]["chat"], }