Skip to content

Add Dalle Support (#147) #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional

from openai.api_resources import (
DALLE,
Answer,
Classification,
Completion,
Expand Down Expand Up @@ -50,6 +51,7 @@
"Completion",
"Customer",
"Edit",
"DALLE",
"Deployment",
"Embedding",
"Engine",
Expand Down
3 changes: 2 additions & 1 deletion openai/api_resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from openai.api_resources.classification import Classification # noqa: F401
from openai.api_resources.completion import Completion # noqa: F401
from openai.api_resources.customer import Customer # noqa: F401
from openai.api_resources.edit import Edit # noqa: F401
from openai.api_resources.dalle import DALLE # noqa: F401
from openai.api_resources.deployment import Deployment # noqa: F401
from openai.api_resources.edit import Edit # noqa: F401
from openai.api_resources.embedding import Embedding # noqa: F401
from openai.api_resources.engine import Engine # noqa: F401
from openai.api_resources.error_object import ErrorObject # noqa: F401
Expand Down
90 changes: 90 additions & 0 deletions openai/api_resources/dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# WARNING: This interface is considered experimental and may changed in the future without warning.
from typing import Any, List

import openai
from openai import api_requestor, util
from openai.api_resources.abstract import APIResource


class DALLE(APIResource):
OBJECT_NAME = "images"

@classmethod
def _get_url(cls, action):
return cls.class_url() + f"/{action}"

@classmethod
def generations(
cls,
**params,
):
instance = cls()
return instance.request("post", cls._get_url("generations"), params)

@classmethod
def variations(
cls,
image,
api_key=None,
api_base=None,
api_type=None,
api_version=None,
organization=None,
**params,
):
requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base or openai.api_base,
api_type=api_type,
api_version=api_version,
organization=organization,
)
_, api_version = cls._get_api_type_and_version(api_type, api_version)

url = cls._get_url("variations")

files: List[Any] = []
for key, value in params.items():
files.append((key, (None, value)))
files.append(("image", ("image", image, "application/octet-stream")))

response, _, api_key = requestor.request("post", url, files=files)

return util.convert_to_openai_object(
response, api_key, api_version, organization
)

@classmethod
def edits(
cls,
image,
mask,
api_key=None,
api_base=None,
api_type=None,
api_version=None,
organization=None,
**params,
):
requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base or openai.api_base,
api_type=api_type,
api_version=api_version,
organization=organization,
)
_, api_version = cls._get_api_type_and_version(api_type, api_version)

url = cls._get_url("edits")

files: List[Any] = []
for key, value in params.items():
files.append((key, (None, value)))
files.append(("image", ("image", image, "application/octet-stream")))
files.append(("mask", ("mask", mask, "application/octet-stream")))

response, _, api_key = requestor.request("post", url, files=files)

return util.convert_to_openai_object(
response, api_key, api_version, organization
)
94 changes: 94 additions & 0 deletions openai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,49 @@ def list(cls, args):
print(file)


class DALLE:
@classmethod
def generations(cls, args):
resp = openai.DALLE.generations(
prompt=args.prompt,
model=args.model,
size=args.size,
num_images=args.num_images,
response_format=args.response_format,
)
print(resp)

@classmethod
def variations(cls, args):
with open(args.image, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
resp = openai.DALLE.variations(
image=buffer_reader,
model=args.model,
size=args.size,
num_images=args.num_images,
response_format=args.response_format,
)
print(resp)

@classmethod
def edits(cls, args):
with open(args.image, "rb") as file_reader:
image_reader = BufferReader(file_reader.read(), desc="Upload progress")
with open(args.mask, "rb") as file_reader:
mask_reader = BufferReader(file_reader.read(), desc="Upload progress")
resp = openai.DALLE.edits(
image=image_reader,
mask=mask_reader,
prompt=args.prompt,
model=args.model,
size=args.size,
num_images=args.num_images,
response_format=args.response_format,
)
print(resp)


class Search:
@classmethod
def prepare_data(cls, args, purpose):
Expand Down Expand Up @@ -983,6 +1026,57 @@ def help(args):
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
sub.set_defaults(func=FineTune.cancel)

# DALLE
sub = subparsers.add_parser("dalle.generations")
sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
sub.add_argument("-p", "--prompt", type=str, required=True)
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument(
"-s", "--size", type=str, default="1024x1024", help="Size of the output image"
)
sub.add_argument("--response-format", type=str, default="url")
sub.set_defaults(func=DALLE.generations)

sub = subparsers.add_parser("dalle.edits")
sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
sub.add_argument("-p", "--prompt", type=str, required=True)
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument(
"-I",
"--image",
type=str,
required=True,
help="Image to modify. Should be a local path and a PNG encoded image.",
)
sub.add_argument(
"-s", "--size", type=str, default="1024x1024", help="Size of the output image"
)
sub.add_argument("--response-format", type=str, default="url")
sub.add_argument(
"-M",
"--mask",
type=str,
required=True,
help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
)
sub.set_defaults(func=DALLE.edits)

sub = subparsers.add_parser("dalle.variations")
sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument(
"-I",
"--image",
type=str,
required=True,
help="Image to modify. Should be a local path and a PNG encoded image.",
)
sub.add_argument(
"-s", "--size", type=str, default="1024x1024", help="Size of the output image"
)
sub.add_argument("--response-format", type=str, default="url")
sub.set_defaults(func=DALLE.variations)


def wandb_register(parser):
subparsers = parser.add_subparsers(
Expand Down
2 changes: 1 addition & 1 deletion openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "0.23.1"
VERSION = "0.24.0"