Skip to content

Commit 7bb6fac

Browse files
committedMar 11, 2025·
Add support for Google Imagen AI models for image generation
Use the new Google GenAI client to generate images with Imagen
1 parent bd06fcd commit 7bb6fac

File tree

4 files changed

+70
-4
lines changed

4 files changed

+70
-4
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Generated by Django 5.0.10 on 2025-03-11 16:58
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("database", "0085_alter_agent_output_modes"),
9+
]
10+
11+
operations = [
12+
migrations.AlterField(
13+
model_name="texttoimagemodelconfig",
14+
name="model_type",
15+
field=models.CharField(
16+
choices=[
17+
("openai", "Openai"),
18+
("stability-ai", "Stabilityai"),
19+
("replicate", "Replicate"),
20+
("google", "Google"),
21+
],
22+
default="openai",
23+
max_length=200,
24+
),
25+
),
26+
]

‎src/khoj/database/models/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ class ModelType(models.TextChoices):
530530
OPENAI = "openai"
531531
STABILITYAI = "stability-ai"
532532
REPLICATE = "replicate"
533+
GOOGLE = "google"
533534

534535
model_name = models.CharField(max_length=200, default="dall-e-3")
535536
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
@@ -547,11 +548,11 @@ def clean(self):
547548
error[
548549
"ai_model_api"
549550
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
550-
if self.model_type != self.ModelType.OPENAI:
551+
if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE:
551552
if not self.api_key:
552-
error["api_key"] = "The API key field must be set for non OpenAI models."
553+
error["api_key"] = "The API key field must be set for non OpenAI, non Google models."
553554
if self.ai_model_api:
554-
error["ai_model_api"] = "AI Model API cannot be set for non OpenAI models."
555+
error["ai_model_api"] = "AI Model API cannot be set for non OpenAI, non Google models."
555556
if error:
556557
raise ValidationError(error)
557558

‎src/khoj/processor/image/generate.py

+35
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import openai
88
import requests
9+
from google import genai
10+
from google.genai import types as gtypes
911

1012
from khoj.database.adapters import ConversationAdapters
1113
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
@@ -86,6 +88,8 @@ async def text_to_image(
8688
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
8789
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
8890
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
91+
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
92+
webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model)
8993
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
9094
if "content_policy_violation" in e.message:
9195
logger.error(f"Image Generation blocked by OpenAI: {e}")
@@ -99,6 +103,12 @@ async def text_to_image(
99103
status_code = e.status_code # type: ignore
100104
yield image_url or image, status_code, message
101105
return
106+
except ValueError as e:
107+
logger.error(f"Image Generation failed with {e}", exc_info=True)
108+
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to an unknown error"
109+
status_code = 500
110+
yield image_url or image, status_code, message
111+
return
102112
except requests.RequestException as e:
103113
logger.error(f"Image Generation failed with {e}", exc_info=True)
104114
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
@@ -215,3 +225,28 @@ def generate_image_with_replicate(
215225
# Get the generated image
216226
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
217227
return io.BytesIO(requests.get(image_url).content).getvalue()
228+
229+
230+
def generate_image_with_google(
231+
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
232+
):
233+
"""Generate image using Google's AI over API"""
234+
235+
# Initialize the Google AI client
236+
api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
237+
client = genai.Client(api_key=api_key)
238+
239+
# Configure image generation settings
240+
config = gtypes.GenerateImagesConfig(number_of_images=1)
241+
242+
# Call the Gemini API to generate the image
243+
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
244+
245+
if not response.generated_images:
246+
raise ValueError("Failed to generate image using Google AI")
247+
248+
# Extract the image bytes from the first generated image
249+
image_bytes = response.generated_images[0].image.image_bytes
250+
251+
# Convert to webp for faster loading
252+
return convert_image_to_webp(image_bytes)

‎src/khoj/routers/helpers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,11 @@ async def generate_better_image_prompt(
10921092
online_results=simplified_online_results,
10931093
personality_context=personality_context,
10941094
)
1095-
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
1095+
elif model_type in [
1096+
TextToImageModelConfig.ModelType.STABILITYAI,
1097+
TextToImageModelConfig.ModelType.REPLICATE,
1098+
TextToImageModelConfig.ModelType.GOOGLE,
1099+
]:
10961100
image_prompt = prompts.image_generation_improve_prompt_sd.format(
10971101
query=q,
10981102
chat_history=conversation_history,

0 commit comments

Comments
 (0)
Please sign in to comment.