Skip to content

Commit 4c7cdcc

Browse files
committed
Add interruptible streaming requests for llama-cpp-python server. Closes #183
1 parent 98ae4e5 commit 4c7cdcc

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [Added]
11+
12+
- (server) Streaming requests can are now interrupted pre-maturely when a concurrent request is made. Can be controlled with the `interrupt_requests` setting.
13+
1014
## [0.1.68]
1115

1216
## [Added]

llama_cpp/server/app.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,27 @@ def set_settings(_settings: Settings):
146146
return app
147147

148148

149-
llama_lock = Lock()
149+
llama_outer_lock = Lock()
150+
llama_inner_lock = Lock()
150151

151152

152153
def get_llama():
153-
with llama_lock:
154-
yield llama
154+
# NOTE: This double lock allows the currently streaming llama model to
155+
# check if any other requests are pending in the same thread and cancel
156+
# the stream if so.
157+
llama_outer_lock.acquire()
158+
release_outer_lock = True
159+
try:
160+
llama_inner_lock.acquire()
161+
try:
162+
llama_outer_lock.release()
163+
release_outer_lock = False
164+
yield llama
165+
finally:
166+
llama_inner_lock.release()
167+
finally:
168+
if release_outer_lock:
169+
llama_outer_lock.release()
155170

156171

157172
def get_settings():
@@ -364,14 +379,16 @@ async def event_publisher(inner_send_chan: MemoryObjectSendStream):
364379
await inner_send_chan.send(dict(data=json.dumps(chunk)))
365380
if await request.is_disconnected():
366381
raise anyio.get_cancelled_exc_class()()
382+
if llama_outer_lock.locked():
383+
await inner_send_chan.send(dict(data="[DONE]"))
384+
raise anyio.get_cancelled_exc_class()()
367385
await inner_send_chan.send(dict(data="[DONE]"))
368386
except anyio.get_cancelled_exc_class() as e:
369387
print("disconnected")
370388
with anyio.move_on_after(1, shield=True):
371389
print(
372390
f"Disconnected from client (via refresh/close) {request.client}"
373391
)
374-
await inner_send_chan.send(dict(closing=True))
375392
raise e
376393

377394
return EventSourceResponse(
@@ -494,14 +511,16 @@ async def event_publisher(inner_send_chan: MemoryObjectSendStream):
494511
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
495512
if await request.is_disconnected():
496513
raise anyio.get_cancelled_exc_class()()
514+
if llama_outer_lock.locked():
515+
await inner_send_chan.send(dict(data="[DONE]"))
516+
raise anyio.get_cancelled_exc_class()()
497517
await inner_send_chan.send(dict(data="[DONE]"))
498518
except anyio.get_cancelled_exc_class() as e:
499519
print("disconnected")
500520
with anyio.move_on_after(1, shield=True):
501521
print(
502522
f"Disconnected from client (via refresh/close) {request.client}"
503523
)
504-
await inner_send_chan.send(dict(closing=True))
505524
raise e
506525

507526
return EventSourceResponse(
@@ -533,8 +552,8 @@ class ModelList(TypedDict):
533552
@router.get("/v1/models", response_model=GetModelResponse)
534553
async def get_models(
535554
settings: Settings = Depends(get_settings),
536-
llama: llama_cpp.Llama = Depends(get_llama),
537555
) -> ModelList:
556+
assert llama is not None
538557
return {
539558
"object": "list",
540559
"data": [

0 commit comments

Comments
 (0)