2
2
import json
3
3
import copy
4
4
import asyncio
5
- import aiohttp
6
5
import aiofiles
7
6
import termcolor
8
7
import os
13
12
from fastapi import APIRouter , HTTPException , Query , Header
14
13
from fastapi .responses import Response , StreamingResponse
15
14
15
+ from itertools import chain
16
+
16
17
from refact_utils .scripts import env
17
18
from refact_utils .finetune .utils import running_models_and_loras
18
19
from refact_utils .third_party .utils .models import available_third_party_models
@@ -250,8 +251,9 @@ def _select_default_model(models: List[str]) -> str:
250
251
# completion models
251
252
completion_models = {}
252
253
for model_name in running_models .get ("completion" , []):
253
- if model_info := self ._model_assigner .models_db .get (_get_base_model_info (model_name )):
254
- completion_models [model_name ] = self ._model_assigner .to_completion_model_record (model_info )
254
+ base_model_name = _get_base_model_info (model_name )
255
+ if model_info := self ._model_assigner .models_db .get (base_model_name ):
256
+ completion_models [model_name ] = self ._model_assigner .to_completion_model_record (base_model_name , model_info )
255
257
elif model := available_third_party_models ().get (model_name ):
256
258
completion_models [model_name ] = model .to_completion_model_record ()
257
259
else :
@@ -261,8 +263,9 @@ def _select_default_model(models: List[str]) -> str:
261
263
# chat models
262
264
chat_models = {}
263
265
for model_name in running_models .get ("chat" , []):
264
- if model_info := self ._model_assigner .models_db .get (_get_base_model_info (model_name )):
265
- chat_models [model_name ] = self ._model_assigner .to_chat_model_record (model_info )
266
+ base_model_name = _get_base_model_info (model_name )
267
+ if model_info := self ._model_assigner .models_db .get (base_model_name ):
268
+ chat_models [model_name ] = self ._model_assigner .to_chat_model_record (base_model_name , model_info )
266
269
elif model := available_third_party_models ().get (model_name ):
267
270
chat_models [model_name ] = model .to_chat_model_record ()
268
271
else :
@@ -337,6 +340,18 @@ def _parse_client_version(user_agent: str = Header(None)) -> Optional[Tuple[int,
337
340
338
341
@staticmethod
339
342
def _to_deprecated_caps_format (data : Dict [str , Any ]):
343
+ models_dict_patch = {}
344
+ for model_name , model_record in chain (
345
+ data ["completion" ]["models" ].items (),
346
+ data ["completion" ]["models" ].items (),
347
+ ):
348
+ dict_patch = {}
349
+ if n_ctx := model_record .get ("n_ctx" ):
350
+ dict_patch ["n_ctx" ] = n_ctx
351
+ if supports_tools := model_record .get ("supports_tools" ):
352
+ dict_patch ["supports_tools" ] = supports_tools
353
+ if dict_patch :
354
+ models_dict_patch [model_name ] = dict_patch
340
355
return {
341
356
"cloud_name" : data ["cloud_name" ],
342
357
"endpoint_template" : data ["completion" ]["endpoint" ],
@@ -349,7 +364,7 @@ def _to_deprecated_caps_format(data: Dict[str, Any]):
349
364
"code_completion_default_model" : data ["completion" ]["default_model" ],
350
365
"multiline_code_completion_default_model" : data ["completion" ]["default_multiline_model" ],
351
366
"code_chat_default_model" : data ["chat" ]["default_model" ],
352
- "models_dict_patch" : {}, # NOTE: this actually should have n_ctx, but we're skiping it
367
+ "models_dict_patch" : models_dict_patch ,
353
368
"default_embeddings_model" : data ["embedding" ]["default_model" ],
354
369
"endpoint_embeddings_template" : "v1/embeddings" ,
355
370
"endpoint_embeddings_style" : "openai" ,
0 commit comments