-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Several fixes to make Completion.acreate(stream=True)
work
#172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
c2dc889
4f4f3cf
2e2e20e
9817bbd
21ed0ad
7088352
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,9 +89,9 @@ def _make_session() -> requests.Session: | |
return s | ||
|
||
|
||
def parse_stream_helper(line): | ||
def parse_stream_helper(line: bytes): | ||
if line: | ||
if line == b"data: [DONE]": | ||
if line.strip() == b"data: [DONE]": | ||
# return here will cause GeneratorExit exception in urllib3 | ||
# and it will close http connection with TCP Reset | ||
return None | ||
|
@@ -111,7 +111,7 @@ def parse_stream(rbody): | |
|
||
|
||
async def parse_stream_async(rbody: aiohttp.StreamReader): | ||
async for line in rbody: | ||
async for line, _ in rbody.iter_chunks(): | ||
_line = parse_stream_helper(line) | ||
if _line is not None: | ||
yield _line | ||
|
@@ -294,18 +294,29 @@ async def arequest( | |
request_id: Optional[str] = None, | ||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None, | ||
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: | ||
async with aiohttp_session() as session: | ||
result = await self.arequest_raw( | ||
method.lower(), | ||
url, | ||
session, | ||
params=params, | ||
supplied_headers=headers, | ||
files=files, | ||
request_id=request_id, | ||
request_timeout=request_timeout, | ||
) | ||
resp, got_stream = await self._interpret_async_response(result, stream) | ||
ctx = aiohttp_session() | ||
session = await ctx.__aenter__() | ||
result = await self.arequest_raw( | ||
method.lower(), | ||
url, | ||
session, | ||
params=params, | ||
supplied_headers=headers, | ||
files=files, | ||
request_id=request_id, | ||
request_timeout=request_timeout, | ||
) | ||
resp, got_stream = await self._interpret_async_response(result, stream) | ||
if got_stream: | ||
|
||
async def wrap_resp(): | ||
async for r in resp: | ||
yield r | ||
await ctx.__aexit__(None, None, None) | ||
|
||
return wrap_resp(), got_stream, self.api_key | ||
else: | ||
await ctx.__aexit__(None, None, None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I first I thought it'd be easier to just fetch/create a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea I think the context manager is still worth it for that encapsulation |
||
return resp, got_stream, self.api_key | ||
|
||
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|
||
import openai | ||
from openai import error | ||
from aiohttp import ClientSession | ||
|
||
|
||
pytestmark = [pytest.mark.asyncio] | ||
|
@@ -63,3 +64,26 @@ async def test_timeout_does_not_error(): | |
model="ada", | ||
request_timeout=10, | ||
) | ||
|
||
|
||
async def test_completions_stream_finishes_global_session(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding this test. |
||
async with ClientSession() as session: | ||
openai.aiosession.set(session) | ||
|
||
# A query that should be fast | ||
parts = [] | ||
async for part in await openai.Completion.acreate( | ||
prompt="test", model="ada", request_timeout=3, stream=True | ||
): | ||
parts.append(part) | ||
assert len(parts) > 1 | ||
|
||
|
||
async def test_completions_stream_finishes_local_session(): | ||
# A query that should be fast | ||
parts = [] | ||
async for part in await openai.Completion.acreate( | ||
prompt="test", model="ada", request_timeout=3, stream=True | ||
): | ||
parts.append(part) | ||
assert len(parts) > 1 |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's possible for this async generator to never complete (for example if the caller raises an exception before completing the iteration) in which case we'll never close this session, which I think raises an exception on the event loop?
Maybe we should create a session on the
APIRequestor
instance instead 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I think should be good now in the latest commit