Skip to content

Commit bc84758

Browse files
authored
[None][feat] Add logging for OAI disagg server (#7232)
1 parent d0d8903 commit bc84758

File tree

3 files changed

+71
-11
lines changed

3 files changed

+71
-11
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,17 @@ def serve_encoder(model: str, host: str, port: int, log_level: str,
457457
type=click.Choice(severity_map.keys()),
458458
default='info',
459459
help="The logging level.")
460+
@click.option(
461+
"--metrics-log-interval",
462+
type=int,
463+
default=0,
464+
help=
465+
"The interval of logging metrics in seconds. Set to 0 to disable metrics logging."
466+
)
460467
def disaggregated(config_file: Optional[str],
461468
metadata_server_config_file: Optional[str],
462469
server_start_timeout: int, request_timeout: int,
463-
log_level: str):
470+
log_level: str, metrics_log_interval: int):
464471
"""Running server in disaggregated mode"""
465472

466473
logger.set_level(log_level)
@@ -473,7 +480,8 @@ def disaggregated(config_file: Optional[str],
473480
server = OpenAIDisaggServer(config=disagg_cfg,
474481
req_timeout_secs=request_timeout,
475482
server_start_timeout_secs=server_start_timeout,
476-
metadata_server_cfg=metadata_server_cfg)
483+
metadata_server_cfg=metadata_server_cfg,
484+
metrics_interval_secs=metrics_log_interval)
477485

478486
asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port))
479487

tensorrt_llm/llmapi/disagg_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class DisaggServerConfig():
5151
ctx_router_config: Optional[RouterConfig] = None
5252
gen_router_config: Optional[RouterConfig] = None
5353
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
54-
max_retries: int = 3
54+
max_retries: int = 1
5555
perf_metrics_max_requests: int = 0
5656

5757

@@ -91,7 +91,7 @@ def parse_disagg_config_file(yaml_config_file: str):
9191

9292
def extract_disagg_cfg(hostname: str = 'localhost',
9393
port: int = 8000,
94-
max_retries: int = 3,
94+
max_retries: int = 1,
9595
perf_metrics_max_requests: int = 0,
9696
context_servers: Optional[dict] = None,
9797
generation_servers: Optional[dict] = None,

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from collections import deque
99
from contextlib import asynccontextmanager
1010
from http import HTTPStatus
11-
from typing import Optional, Type, Union
11+
from typing import Callable, Optional, Type, Union
1212

1313
import aiohttp
1414
import uvicorn
1515
from fastapi import FastAPI, HTTPException
1616
from fastapi.exceptions import RequestValidationError
1717
from fastapi.responses import JSONResponse, Response, StreamingResponse
18-
from starlette.status import HTTP_429_TOO_MANY_REQUESTS
18+
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
1919

2020
# yapf: disable
2121
from tensorrt_llm.executor import CppExecutorError
@@ -42,7 +42,8 @@ def __init__(self,
4242
config: DisaggServerConfig,
4343
req_timeout_secs: int = 180,
4444
server_start_timeout_secs: int = 180,
45-
metadata_server_cfg: Optional[MetadataServerConfig] = None):
45+
metadata_server_cfg: Optional[MetadataServerConfig] = None,
46+
metrics_interval_secs: int = 0):
4647

4748
self.ctx_servers, self.gen_servers = get_ctx_gen_server_urls(config.server_configs)
4849
self.metadata_server = create_metadata_server(metadata_server_cfg)
@@ -68,6 +69,17 @@ def __init__(self,
6869
if config.max_retries < 0:
6970
raise ValueError(f"Max retries {config.max_retries} must be greater than or equal to 0")
7071
self.max_retries = config.max_retries
72+
# Metrics counters and synchronization
73+
self._metrics = {
74+
"ctx_total_requests": 0,
75+
"ctx_completed_requests": 0,
76+
"gen_total_requests": 0,
77+
"gen_completed_requests": 0,
78+
}
79+
self._metrics_lock = asyncio.Lock()
80+
self._metrics_task = None
81+
self.metrics_interval_secs = metrics_interval_secs
82+
7183
logger.info(f"Server max retries: {self.max_retries}")
7284

7385
if (len(self.gen_servers) == 0):
@@ -98,13 +110,25 @@ async def lifespan(app: FastAPI):
98110
await self.ctx_router.start_server_monitoring(metadata_server_cfg.refresh_interval)
99111
await self.gen_router.start_server_monitoring(metadata_server_cfg.refresh_interval)
100112

113+
# Start periodic metrics logging
114+
if self.metrics_interval_secs > 0:
115+
self._metrics_task = asyncio.create_task(self._log_metrics_periodically(self.metrics_interval_secs))
116+
101117
yield
102118

103119
if self.metadata_server:
104120
logger.info("Stopping server monitoring via metadata service")
105121
await self.ctx_router.stop_server_monitoring()
106122
await self.gen_router.stop_server_monitoring()
107123

124+
# Stop periodic metrics logging
125+
if self._metrics_task is not None:
126+
self._metrics_task.cancel()
127+
try:
128+
await self._metrics_task
129+
except asyncio.CancelledError:
130+
pass
131+
108132
await self.session.close() # Ensure session cleanup
109133

110134
self.app = FastAPI(lifespan=lifespan)
@@ -115,6 +139,29 @@ async def validation_exception_handler(_, exc):
115139

116140
self.register_routes()
117141

142+
async def _increment_metric(self, key: str, amount: int = 1):
143+
if self.metrics_interval_secs > 0:
144+
async with self._metrics_lock:
145+
self._metrics[key] += amount
146+
147+
async def _get_metrics_snapshot(self):
148+
async with self._metrics_lock:
149+
return dict(self._metrics)
150+
151+
async def _log_metrics_periodically(self, interval_seconds: int):
152+
try:
153+
while True:
154+
await asyncio.sleep(interval_seconds)
155+
snapshot = await self._get_metrics_snapshot()
156+
logger.info(
157+
(
158+
f"[Statistics] total_context_requests={snapshot['ctx_total_requests']}, completed_context_requests={snapshot['ctx_completed_requests']}, "
159+
f"total_generation_requests={snapshot['gen_total_requests']}, completed_generation_requests={snapshot['gen_completed_requests']}"
160+
)
161+
)
162+
except asyncio.CancelledError:
163+
pass
164+
118165
@staticmethod
119166
def create_error_response(
120167
message: str,
@@ -198,15 +245,15 @@ async def merge_streaming_responses(self, ctx_response,
198245
gen_server: str,
199246
gen_req: Union[CompletionRequest, ChatCompletionRequest]):
200247
try:
201-
202248
if ctx_response is not None and len(ctx_response.choices) != 1:
203249
raise ValueError("Context server did not return a single choice. This is not expected")
204250

205251
#If request finished after first token not due to length, return right away and skip gen
206252
if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length", "not_finished"]:
207-
yield f"data: [DONE]\n\n".encode('utf-8')
253+
yield "data: [DONE]\n\n".encode('utf-8')
208254
else:
209255
# Then yield the generation responses
256+
await self._increment_metric("gen_total_requests")
210257
if isinstance(gen_req, CompletionRequest):
211258
gen_response = await self.send_completion_request(gen_server, gen_req)
212259
elif isinstance(gen_req, ChatCompletionRequest):
@@ -216,6 +263,7 @@ async def merge_streaming_responses(self, ctx_response,
216263

217264
async for chunk in gen_response.body_iterator:
218265
yield chunk
266+
await self._increment_metric("gen_completed_requests")
219267

220268
finally:
221269
await self.gen_router.finish_request(gen_req)
@@ -258,6 +306,7 @@ async def _send_context_request(self, ctx_server: str, ctx_req: Union[Completion
258306
ctx_req.stream_options = None
259307

260308
logger.debug("Sending request to ctx server: %s", ctx_server)
309+
await self._increment_metric("ctx_total_requests")
261310
try:
262311
if isinstance(ctx_req, ChatCompletionRequest):
263312
ctx_response = await self.send_chat_request(ctx_server, ctx_req)
@@ -266,6 +315,7 @@ async def _send_context_request(self, ctx_server: str, ctx_req: Union[Completion
266315
ctx_response = await self.send_completion_request(ctx_server, ctx_req)
267316
finally:
268317
await self.ctx_router.finish_request(ctx_req)
318+
await self._increment_metric("ctx_completed_requests")
269319

270320
choices = ctx_response.choices
271321
if len(choices) > 1:
@@ -342,11 +392,13 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio
342392
del ctx_response.choices[0].disaggregated_params
343393
return ctx_response
344394
else:
395+
await self._increment_metric("gen_total_requests")
345396
if isinstance(req, CompletionRequest):
346397
gen_response = await self.send_completion_request(gen_server, req)
347398
else:
348399
assert isinstance(req, ChatCompletionRequest)
349400
gen_response = await self.send_chat_request(gen_server, req)
401+
await self._increment_metric("gen_completed_requests")
350402
return gen_response
351403
finally:
352404
if gen_server is not None:
@@ -400,7 +452,7 @@ async def send_request(self, url: str,
400452
request: Union[CompletionRequest, ChatCompletionRequest],
401453
endpoint: str,
402454
response_type: Type[Union[CompletionResponse, ChatCompletionResponse]],
403-
create_generator: callable) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]:
455+
create_generator: Callable) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]:
404456
for attempt in range(self.max_retries + 1):
405457
try:
406458
if request.stream:
@@ -419,7 +471,7 @@ async def send_request(self, url: str,
419471
return response_type(**response_dict)
420472
except (aiohttp.ClientError, OSError) as e:
421473
if attempt == self.max_retries:
422-
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail=f"Too many requests") from e
474+
raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error") from e
423475
logger.error(f"Client error: {e} - retry {attempt} of {self.max_retries}")
424476
# TODO : add a configurable retry interval
425477
await asyncio.sleep(1)

0 commit comments

Comments
 (0)