Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
organization = os.environ.get("OPENAI_ORGANIZATION")
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
api_type = os.environ.get("OPENAI_API_TYPE", "open_ai")
api_version = (
api_version = os.environ.get("OPENAI_API_VERSION", (
"2022-12-01" if api_type in ("azure", "azure_ad", "azuread") else None
)
))
verify_ssl_certs = True # No effect. Certificates are always verified.
proxy = None
app_info = None
Expand Down
39 changes: 39 additions & 0 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import time
import platform
import sys
import threading
Expand Down Expand Up @@ -144,6 +145,44 @@ def format_app_info(cls, info):
if info["url"]:
str += " (%s)" % (info["url"],)
return str

def poll(
self,
method,
url,
until,
params = None,
headers = None,
interval = None,
delay = None
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
if delay:
time.sleep(delay)

response, b, api_key = self.request(method, url, params, headers)
while not until(response):
time.sleep(interval or response.retry_after or 1)
response, b, api_key = self.request(method, url)
return response, b, api_key

async def apoll(
self,
method,
url,
until,
params = None,
headers = None,
interval = None,
delay = None
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
if delay:
await asyncio.sleep(delay)

response, b, api_key = await self.arequest(method, url, params, headers)
while not until(response):
await asyncio.sleep(interval or response.retry_after or 1)
response, b, api_key = await self.arequest(method, url)
return response, b, api_key

@overload
def request(
Expand Down
1 change: 1 addition & 0 deletions openai/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class APIResource(OpenAIObject):
api_prefix = ""
azure_api_prefix = "openai"
azure_dalle_prefix = "dalle"
azure_deployments_prefix = "deployments"

@classmethod
Expand Down
43 changes: 33 additions & 10 deletions openai/api_resources/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ class Image(APIResource):
OBJECT_NAME = "images"

@classmethod
def _get_url(cls, action):
return cls.class_url() + f"/{action}"
def _get_url(cls, openai_action, azure_action, api_type, api_version):
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
return f"/{cls.azure_dalle_prefix}{cls.class_url()}/{azure_action}?api-version={api_version}"
else:
return f"{cls.class_url()}/{openai_action}"

@classmethod
def _get_azure_operations_url(cls, operation_id, api_version):
return f"/{cls.azure_dalle_prefix}/operations/{operation_id}?api-version={api_version}"

@classmethod
def create(
Expand All @@ -31,12 +38,20 @@ def create(
organization=organization,
)

_, api_version = cls._get_api_type_and_version(api_type, api_version)
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)

response, _, api_key = requestor.request(
"post", cls._get_url("generations"), params
"post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params
)

if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
url = cls._get_azure_operations_url(response.data['id'], api_version)
response, _, api_key = requestor.poll(
"get", url,
until=lambda response: response.data["status"] not in ["NotStarted", "Running"],
delay=response.retry_after
)

return util.convert_to_openai_object(
response, api_key, api_version, organization
)
Expand All @@ -60,12 +75,20 @@ async def acreate(
organization=organization,
)

_, api_version = cls._get_api_type_and_version(api_type, api_version)
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)

response, _, api_key = await requestor.arequest(
"post", cls._get_url("generations"), params
"post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params
)

if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
url = cls._get_azure_operations_url(response.data['id'], api_version)
response, _, api_key = await requestor.apoll(
"get", url,
until=lambda response: response.data["status"] not in ["NotStarted", "Running"],
delay=response.retry_after
)

return util.convert_to_openai_object(
response, api_key, api_version, organization
)
Expand All @@ -88,9 +111,9 @@ def _prepare_create_variation(
api_version=api_version,
organization=organization,
)
_, api_version = cls._get_api_type_and_version(api_type, api_version)
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)

url = cls._get_url("variations")
url = cls._get_url("variations", None, api_type=api_type, api_version=api_version)

files: List[Any] = []
for key, value in params.items():
Expand Down Expand Up @@ -171,9 +194,9 @@ def _prepare_create_edit(
api_version=api_version,
organization=organization,
)
_, api_version = cls._get_api_type_and_version(api_type, api_version)
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)

url = cls._get_url("edits")
url = cls._get_url("edits", None, api_type=api_type, api_version=api_version)

files: List[Any] = []
for key, value in params.items():
Expand Down
7 changes: 7 additions & 0 deletions openai/openai_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ def __init__(self, data, headers):
@property
def request_id(self) -> Optional[str]:
return self._headers.get("request-id")

@property
def retry_after(self) -> Optional[int]:
try:
return int(self._headers.get("retry-after"))
except ValueError:
return None

@property
def organization(self) -> Optional[str]:
Expand Down