8
8
from collections import deque
9
9
from contextlib import asynccontextmanager
10
10
from http import HTTPStatus
11
- from typing import Optional , Type , Union
11
+ from typing import Callable , Optional , Type , Union
12
12
13
13
import aiohttp
14
14
import uvicorn
15
15
from fastapi import FastAPI , HTTPException
16
16
from fastapi .exceptions import RequestValidationError
17
17
from fastapi .responses import JSONResponse , Response , StreamingResponse
18
- from starlette .status import HTTP_429_TOO_MANY_REQUESTS
18
+ from starlette .status import HTTP_500_INTERNAL_SERVER_ERROR
19
19
20
20
# yapf: disable
21
21
from tensorrt_llm .executor import CppExecutorError
@@ -42,7 +42,8 @@ def __init__(self,
42
42
config : DisaggServerConfig ,
43
43
req_timeout_secs : int = 180 ,
44
44
server_start_timeout_secs : int = 180 ,
45
- metadata_server_cfg : Optional [MetadataServerConfig ] = None ):
45
+ metadata_server_cfg : Optional [MetadataServerConfig ] = None ,
46
+ metrics_interval_secs : int = 0 ):
46
47
47
48
self .ctx_servers , self .gen_servers = get_ctx_gen_server_urls (config .server_configs )
48
49
self .metadata_server = create_metadata_server (metadata_server_cfg )
@@ -68,6 +69,17 @@ def __init__(self,
68
69
if config .max_retries < 0 :
69
70
raise ValueError (f"Max retries { config .max_retries } must be greater than or equal to 0" )
70
71
self .max_retries = config .max_retries
72
+ # Metrics counters and synchronization
73
+ self ._metrics = {
74
+ "ctx_total_requests" : 0 ,
75
+ "ctx_completed_requests" : 0 ,
76
+ "gen_total_requests" : 0 ,
77
+ "gen_completed_requests" : 0 ,
78
+ }
79
+ self ._metrics_lock = asyncio .Lock ()
80
+ self ._metrics_task = None
81
+ self .metrics_interval_secs = metrics_interval_secs
82
+
71
83
logger .info (f"Server max retries: { self .max_retries } " )
72
84
73
85
if (len (self .gen_servers ) == 0 ):
@@ -98,13 +110,25 @@ async def lifespan(app: FastAPI):
98
110
await self .ctx_router .start_server_monitoring (metadata_server_cfg .refresh_interval )
99
111
await self .gen_router .start_server_monitoring (metadata_server_cfg .refresh_interval )
100
112
113
+ # Start periodic metrics logging
114
+ if self .metrics_interval_secs > 0 :
115
+ self ._metrics_task = asyncio .create_task (self ._log_metrics_periodically (self .metrics_interval_secs ))
116
+
101
117
yield
102
118
103
119
if self .metadata_server :
104
120
logger .info ("Stopping server monitoring via metadata service" )
105
121
await self .ctx_router .stop_server_monitoring ()
106
122
await self .gen_router .stop_server_monitoring ()
107
123
124
+ # Stop periodic metrics logging
125
+ if self ._metrics_task is not None :
126
+ self ._metrics_task .cancel ()
127
+ try :
128
+ await self ._metrics_task
129
+ except asyncio .CancelledError :
130
+ pass
131
+
108
132
await self .session .close () # Ensure session cleanup
109
133
110
134
self .app = FastAPI (lifespan = lifespan )
@@ -115,6 +139,29 @@ async def validation_exception_handler(_, exc):
115
139
116
140
self .register_routes ()
117
141
142
+ async def _increment_metric (self , key : str , amount : int = 1 ):
143
+ if self .metrics_interval_secs > 0 :
144
+ async with self ._metrics_lock :
145
+ self ._metrics [key ] += amount
146
+
147
+ async def _get_metrics_snapshot (self ):
148
+ async with self ._metrics_lock :
149
+ return dict (self ._metrics )
150
+
151
+ async def _log_metrics_periodically (self , interval_seconds : int ):
152
+ try :
153
+ while True :
154
+ await asyncio .sleep (interval_seconds )
155
+ snapshot = await self ._get_metrics_snapshot ()
156
+ logger .info (
157
+ (
158
+ f"[Statistics] total_context_requests={ snapshot ['ctx_total_requests' ]} , completed_context_requests={ snapshot ['ctx_completed_requests' ]} , "
159
+ f"total_generation_requests={ snapshot ['gen_total_requests' ]} , completed_generation_requests={ snapshot ['gen_completed_requests' ]} "
160
+ )
161
+ )
162
+ except asyncio .CancelledError :
163
+ pass
164
+
118
165
@staticmethod
119
166
def create_error_response (
120
167
message : str ,
@@ -198,15 +245,15 @@ async def merge_streaming_responses(self, ctx_response,
198
245
gen_server : str ,
199
246
gen_req : Union [CompletionRequest , ChatCompletionRequest ]):
200
247
try :
201
-
202
248
if ctx_response is not None and len (ctx_response .choices ) != 1 :
203
249
raise ValueError ("Context server did not return a single choice. This is not expected" )
204
250
205
251
#If request finished after first token not due to length, return right away and skip gen
206
252
if ctx_response is not None and ctx_response .choices [0 ].finish_reason not in ["length" , "not_finished" ]:
207
- yield f "data: [DONE]\n \n " .encode ('utf-8' )
253
+ yield "data: [DONE]\n \n " .encode ('utf-8' )
208
254
else :
209
255
# Then yield the generation responses
256
+ await self ._increment_metric ("gen_total_requests" )
210
257
if isinstance (gen_req , CompletionRequest ):
211
258
gen_response = await self .send_completion_request (gen_server , gen_req )
212
259
elif isinstance (gen_req , ChatCompletionRequest ):
@@ -216,6 +263,7 @@ async def merge_streaming_responses(self, ctx_response,
216
263
217
264
async for chunk in gen_response .body_iterator :
218
265
yield chunk
266
+ await self ._increment_metric ("gen_completed_requests" )
219
267
220
268
finally :
221
269
await self .gen_router .finish_request (gen_req )
@@ -258,6 +306,7 @@ async def _send_context_request(self, ctx_server: str, ctx_req: Union[Completion
258
306
ctx_req .stream_options = None
259
307
260
308
logger .debug ("Sending request to ctx server: %s" , ctx_server )
309
+ await self ._increment_metric ("ctx_total_requests" )
261
310
try :
262
311
if isinstance (ctx_req , ChatCompletionRequest ):
263
312
ctx_response = await self .send_chat_request (ctx_server , ctx_req )
@@ -266,6 +315,7 @@ async def _send_context_request(self, ctx_server: str, ctx_req: Union[Completion
266
315
ctx_response = await self .send_completion_request (ctx_server , ctx_req )
267
316
finally :
268
317
await self .ctx_router .finish_request (ctx_req )
318
+ await self ._increment_metric ("ctx_completed_requests" )
269
319
270
320
choices = ctx_response .choices
271
321
if len (choices ) > 1 :
@@ -342,11 +392,13 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio
342
392
del ctx_response .choices [0 ].disaggregated_params
343
393
return ctx_response
344
394
else :
395
+ await self ._increment_metric ("gen_total_requests" )
345
396
if isinstance (req , CompletionRequest ):
346
397
gen_response = await self .send_completion_request (gen_server , req )
347
398
else :
348
399
assert isinstance (req , ChatCompletionRequest )
349
400
gen_response = await self .send_chat_request (gen_server , req )
401
+ await self ._increment_metric ("gen_completed_requests" )
350
402
return gen_response
351
403
finally :
352
404
if gen_server is not None :
@@ -400,7 +452,7 @@ async def send_request(self, url: str,
400
452
request : Union [CompletionRequest , ChatCompletionRequest ],
401
453
endpoint : str ,
402
454
response_type : Type [Union [CompletionResponse , ChatCompletionResponse ]],
403
- create_generator : callable ) -> Union [CompletionResponse , ChatCompletionResponse , StreamingResponse ]:
455
+ create_generator : Callable ) -> Union [CompletionResponse , ChatCompletionResponse , StreamingResponse ]:
404
456
for attempt in range (self .max_retries + 1 ):
405
457
try :
406
458
if request .stream :
@@ -419,7 +471,7 @@ async def send_request(self, url: str,
419
471
return response_type (** response_dict )
420
472
except (aiohttp .ClientError , OSError ) as e :
421
473
if attempt == self .max_retries :
422
- raise HTTPException (status_code = HTTP_429_TOO_MANY_REQUESTS , detail = f"Too many requests " ) from e
474
+ raise HTTPException (status_code = HTTP_500_INTERNAL_SERVER_ERROR , detail = f"Internal server error " ) from e
423
475
logger .error (f"Client error: { e } - retry { attempt } of { self .max_retries } " )
424
476
# TODO : add a configurable retry interval
425
477
await asyncio .sleep (1 )
0 commit comments