Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions fastdeploy/inter_communicator/zmq_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,15 @@ def recv_control_cmd(self):
Recieve control command from client
"""
self._ensure_socket()
while self.running:
try:
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
task = msgpack.unpackb(task_data)
task_id_str = task["task_id"]
except zmq.Again:
time.sleep(0.001)
continue
with self.mutex:
self.req_dict[task_id_str] = client
return task
try:
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
task = msgpack.unpackb(task_data)
task_id_str = task["task_id"]
except zmq.Again:
return None
with self.mutex:
self.req_dict[task_id_str] = client
return task

def response_for_control_cmd(self, task_id, result):
"""
Expand All @@ -251,7 +249,7 @@ def response_for_control_cmd(self, task_id, result):

with self.mutex:
self.req_dict.pop(task_id, None)
llm_logger.info(f"response control cmd finished, task_id: {task_id}")
llm_logger.debug(f"response control cmd finished, task_id: {task_id}")

def close(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ def _process_batch_output(self):
for token_id in token_ids:
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id)
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True
Expand Down
12 changes: 8 additions & 4 deletions fastdeploy/splitwise/internal_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, cfg, engine, dp_rank):
self.engine = engine
self.dp_rank = dp_rank
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
self.recv_external_instruct_thread = threading.Thread(
target=self._recv_external_module_control_instruct, daemon=True
Expand All @@ -43,7 +44,6 @@ def __init__(self, cfg, engine, dp_rank):
target=self._response_external_module_control_instruct, daemon=True
)
self.response_external_instruct_thread.start()
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently

def _get_current_server_info(self):
"""
Expand Down Expand Up @@ -71,13 +71,17 @@ def _recv_external_module_control_instruct(self):
"""
while True:
try:
task = self.recv_control_cmd_server.recv_control_cmd()
with self.response_lock:
task = self.recv_control_cmd_server.recv_control_cmd()
if task is None:
time.sleep(0.001)
continue
logger.info(f"Recieve control task: {task}")
task_id_str = task["task_id"]
if task["cmd"] == "get_payload":
payload_info = self._get_current_server_info()
result = {"task_id": task_id_str, "result": payload_info}
logger.info(f"Response for task: {task_id_str}")
logger.debug(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)

Expand All @@ -87,7 +91,7 @@ def _recv_external_module_control_instruct(self):
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
)
result = {"task_id": task_id_str, "result": metrics_text}
logger.info(f"Response for task: {task_id_str}")
logger.debug(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "connect_rdma":
Expand Down