6
6
7
7
import openai
8
8
import requests
9
+ from google import genai
10
+ from google .genai import types as gtypes
9
11
10
12
from khoj .database .adapters import ConversationAdapters
11
13
from khoj .database .models import Agent , KhojUser , TextToImageModelConfig
@@ -86,6 +88,8 @@ async def text_to_image(
86
88
webp_image_bytes = generate_image_with_stability (image_prompt , text_to_image_config , text2image_model )
87
89
elif text_to_image_config .model_type == TextToImageModelConfig .ModelType .REPLICATE :
88
90
webp_image_bytes = generate_image_with_replicate (image_prompt , text_to_image_config , text2image_model )
91
+ elif text_to_image_config .model_type == TextToImageModelConfig .ModelType .GOOGLE :
92
+ webp_image_bytes = generate_image_with_google (image_prompt , text_to_image_config , text2image_model )
89
93
except openai .OpenAIError or openai .BadRequestError or openai .APIConnectionError as e :
90
94
if "content_policy_violation" in e .message :
91
95
logger .error (f"Image Generation blocked by OpenAI: { e } " )
@@ -99,6 +103,12 @@ async def text_to_image(
99
103
status_code = e .status_code # type: ignore
100
104
yield image_url or image , status_code , message
101
105
return
106
+ except ValueError as e :
107
+ logger .error (f"Image Generation failed with { e } " , exc_info = True )
108
+ message = f"Image generation using { text2image_model } via { text_to_image_config .model_type } failed due to an unknown error"
109
+ status_code = 500
110
+ yield image_url or image , status_code , message
111
+ return
102
112
except requests .RequestException as e :
103
113
logger .error (f"Image Generation failed with { e } " , exc_info = True )
104
114
message = f"Image generation using { text2image_model } via { text_to_image_config .model_type } failed due to a network error."
@@ -215,3 +225,28 @@ def generate_image_with_replicate(
215
225
# Get the generated image
216
226
image_url = get_prediction ["output" ][0 ] if isinstance (get_prediction ["output" ], list ) else get_prediction ["output" ]
217
227
return io .BytesIO (requests .get (image_url ).content ).getvalue ()
228
+
229
+
230
+ def generate_image_with_google (
231
+ improved_image_prompt : str , text_to_image_config : TextToImageModelConfig , text2image_model : str
232
+ ):
233
+ """Generate image using Google's AI over API"""
234
+
235
+ # Initialize the Google AI client
236
+ api_key = text_to_image_config .api_key or text_to_image_config .ai_model_api .api_key
237
+ client = genai .Client (api_key = api_key )
238
+
239
+ # Configure image generation settings
240
+ config = gtypes .GenerateImagesConfig (number_of_images = 1 )
241
+
242
+ # Call the Gemini API to generate the image
243
+ response = client .models .generate_images (model = text2image_model , prompt = improved_image_prompt , config = config )
244
+
245
+ if not response .generated_images :
246
+ raise ValueError ("Failed to generate image using Google AI" )
247
+
248
+ # Extract the image bytes from the first generated image
249
+ image_bytes = response .generated_images [0 ].image .image_bytes
250
+
251
+ # Convert to webp for faster loading
252
+ return convert_image_to_webp (image_bytes )
0 commit comments