Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def _make_session() -> requests.Session:
return s


def parse_stream_helper(line):
def parse_stream_helper(line: bytes):
if line:
if line == b"data: [DONE]":
if line.strip() == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
Expand All @@ -111,7 +111,7 @@ def parse_stream(rbody):


async def parse_stream_async(rbody: aiohttp.StreamReader):
async for line in rbody:
async for line, _ in rbody.iter_chunks():
_line = parse_stream_helper(line)
if _line is not None:
yield _line
Expand Down Expand Up @@ -294,18 +294,29 @@ async def arequest(
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
async with aiohttp_session() as session:
result = await self.arequest_raw(
method.lower(),
url,
session,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
ctx = aiohttp_session()
session = await ctx.__aenter__()
result = await self.arequest_raw(
method.lower(),
url,
session,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
if got_stream:

async def wrap_resp():
async for r in resp:
yield r
await ctx.__aexit__(None, None, None)
Copy link
Contributor

@ddeville ddeville Jan 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's possible for this async generator to never complete (for example if the caller raises an exception before completing the iteration) in which case we'll never close this session, which I think raises an exception on the event loop?

Maybe we should create a session on the APIRequestor instance instead 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I think should be good now in the latest commit


return wrap_resp(), got_stream, self.api_key
else:
await ctx.__aexit__(None, None, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first I thought it'd be easier to just fetch/create a ClientSession here rather than getting the async generator and calling __aenter__ and __aexit__ manually but since we have to deal with manually closing one while being careful not to close the other, I think it's probably fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think the context manager is still worth it for that encapsulation

return resp, got_stream, self.api_key

def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
Expand Down
2 changes: 1 addition & 1 deletion openai/api_resources/abstract/engine_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def acreate(
engine=engine,
plain_old_data=cls.plain_old_data,
)
for line in response
async for line in response
)
else:
obj = util.convert_to_openai_object(
Expand Down
24 changes: 24 additions & 0 deletions openai/tests/asyncio/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import openai
from openai import error
from aiohttp import ClientSession


pytestmark = [pytest.mark.asyncio]
Expand Down Expand Up @@ -63,3 +64,26 @@ async def test_timeout_does_not_error():
model="ada",
request_timeout=10,
)


async def test_completions_stream_finishes_global_session():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this test.

async with ClientSession() as session:
openai.aiosession.set(session)

# A query that should be fast
parts = []
async for part in await openai.Completion.acreate(
prompt="test", model="ada", request_timeout=3, stream=True
):
parts.append(part)
assert len(parts) > 1


async def test_completions_stream_finishes_local_session():
# A query that should be fast
parts = []
async for part in await openai.Completion.acreate(
prompt="test", model="ada", request_timeout=3, stream=True
):
parts.append(part)
assert len(parts) > 1