Skip to content

Commit e8dc892

Browse files
committed
feat(model): introduce new Gemini model
Added integration tests and unit tests for the Gemini model.
1 parent a03b74c commit e8dc892

File tree

5 files changed

+679
-1
lines changed

5 files changed

+679
-1
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ from strands import Agent
107107
from strands.models import BedrockModel
108108
from strands.models.ollama import OllamaModel
109109
from strands.models.llamaapi import LlamaAPIModel
110+
from strands.models.gemini import GeminiModel
110111

111112
# Bedrock
112113
bedrock_model = BedrockModel(
@@ -130,11 +131,21 @@ llama_model = LlamaAPIModel(
130131
)
131132
agent = Agent(model=llama_model)
132133
response = agent("Tell me about Agentic AI")
134+
135+
# Gemini
136+
gemini_model = GeminiModel(
137+
model_id="gemini-pro",
138+
max_tokens=1024,
139+
params={"temperature": 0.7}
140+
)
141+
agent = Agent(model=gemini_model)
142+
response = agent("Tell me about Agentic AI")
133143
```
134144

135145
Built-in providers:
136146
- [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/)
137147
- [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/)
148+
- [Gemini](https://strandsagents.com/latest/user-guide/concepts/model-providers/gemini/)
138149
- [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/)
139150
- [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/)
140151
- [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/)

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,12 @@ ollama = [
7575
llamaapi = [
7676
"llama-api-client>=0.1.0,<1.0.0",
7777
]
78+
gemini = [
79+
"google-generativeai>=0.8.5",
80+
]
7881

7982
[tool.hatch.envs.hatch-static-analysis]
80-
features = ["anthropic", "litellm", "llamaapi", "ollama"]
83+
features = ["anthropic", "litellm", "llamaapi", "ollama", "gemini"]
8184
dependencies = [
8285
"mypy>=1.15.0,<2.0.0",
8386
"ruff>=0.11.6,<0.12.0",

src/strands/models/gemini.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
"""Google Gemini model provider.
2+
3+
- Docs: https://ai.google.dev/docs/gemini_api_overview
4+
"""
5+
6+
import base64
7+
import json
8+
import logging
9+
import mimetypes
10+
from typing import Any, Iterable, Optional, TypedDict
11+
12+
import google.generativeai.generative_models as genai # mypy: disable-error-code=import
13+
from typing_extensions import Required, Unpack, override
14+
15+
from ..types.content import ContentBlock, Messages
16+
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
17+
from ..types.models import Model
18+
from ..types.streaming import StreamEvent
19+
from ..types.tools import ToolSpec
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class GeminiModel(Model):
25+
"""Google Gemini model provider implementation."""
26+
27+
EVENT_TYPES = {
28+
"message_start",
29+
"content_block_start",
30+
"content_block_delta",
31+
"content_block_stop",
32+
"message_stop",
33+
}
34+
35+
OVERFLOW_MESSAGES = {
36+
"input is too long",
37+
"input length exceeds context window",
38+
"input and output tokens exceed your context limit",
39+
}
40+
41+
class GeminiConfig(TypedDict, total=False):
42+
"""Configuration options for Gemini models.
43+
44+
Attributes:
45+
max_tokens: Maximum number of tokens to generate.
46+
model_id: Gemini model ID (e.g., "gemini-pro").
47+
For a complete list of supported models, see
48+
https://ai.google.dev/models/gemini.
49+
params: Additional model parameters (e.g., temperature).
50+
For a complete list of supported parameters, see
51+
https://ai.google.dev/docs/gemini_api_overview#generation_config.
52+
"""
53+
54+
max_tokens: Required[int]
55+
model_id: Required[str]
56+
params: Optional[dict[str, Any]]
57+
58+
def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[GeminiConfig]):
59+
"""Initialize provider instance.
60+
61+
Args:
62+
client_args: Arguments for the underlying Gemini client (e.g., api_key).
63+
For a complete list of supported arguments, see
64+
https://ai.google.dev/docs/gemini_api_overview#client_libraries.
65+
**model_config: Configuration options for the Gemini model.
66+
"""
67+
self.config = GeminiModel.GeminiConfig(**model_config)
68+
69+
logger.debug("config=<%s> | initializing", self.config)
70+
71+
client_args = client_args or {}
72+
genai.client.configure(**client_args)
73+
self.model = genai.GenerativeModel(self.config["model_id"])
74+
75+
@override
76+
def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override]
77+
"""Update the Gemini model configuration with the provided arguments.
78+
79+
Args:
80+
**model_config: Configuration overrides.
81+
"""
82+
self.config.update(model_config)
83+
self.model = genai.GenerativeModel(self.config["model_id"])
84+
85+
@override
86+
def get_config(self) -> GeminiConfig:
87+
"""Get the Gemini model configuration.
88+
89+
Returns:
90+
The Gemini model configuration.
91+
"""
92+
return self.config
93+
94+
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
95+
"""Format a Gemini content block.
96+
97+
Args:
98+
content: Message content.
99+
100+
Returns:
101+
Gemini formatted content block.
102+
"""
103+
if "image" in content:
104+
return {
105+
"inline_data": {
106+
"data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"),
107+
"mime_type": mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream"),
108+
}
109+
}
110+
111+
if "text" in content:
112+
return {"text": content["text"]}
113+
114+
return {"text": json.dumps(content)}
115+
116+
def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
117+
"""Format a Gemini messages array.
118+
119+
Args:
120+
messages: List of message objects to be processed by the model.
121+
122+
Returns:
123+
A Gemini messages array.
124+
"""
125+
formatted_messages = []
126+
127+
for message in messages:
128+
formatted_contents = []
129+
130+
for content in message["content"]:
131+
if "cachePoint" in content:
132+
continue
133+
134+
formatted_contents.append(self._format_request_message_content(content))
135+
136+
if formatted_contents:
137+
formatted_messages.append({"role": message["role"], "parts": formatted_contents})
138+
139+
return formatted_messages
140+
141+
@override
142+
def format_request(
143+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
144+
) -> dict[str, Any]:
145+
"""Format a Gemini streaming request.
146+
147+
Args:
148+
messages: List of message objects to be processed by the model.
149+
tool_specs: List of tool specifications to make available to the model.
150+
system_prompt: System prompt to provide context to the model.
151+
152+
Returns:
153+
A Gemini streaming request.
154+
"""
155+
generation_config = {"max_output_tokens": self.config["max_tokens"], **(self.config.get("params") or {})}
156+
157+
return {
158+
"contents": self._format_request_messages(messages),
159+
"generation_config": generation_config,
160+
"tools": [
161+
{
162+
"function_declarations": [
163+
{
164+
"name": tool_spec["name"],
165+
"description": tool_spec["description"],
166+
"parameters": tool_spec["inputSchema"]["json"],
167+
}
168+
for tool_spec in tool_specs or []
169+
]
170+
}
171+
]
172+
if tool_specs
173+
else None,
174+
"system_instruction": system_prompt,
175+
}
176+
177+
@override
178+
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
179+
"""Format the Gemini response events into standardized message chunks.
180+
181+
Args:
182+
event: A response event from the Gemini model.
183+
184+
Returns:
185+
The formatted chunk.
186+
187+
Raises:
188+
RuntimeError: If chunk_type is not recognized.
189+
This error should never be encountered as we control chunk_type in the stream method.
190+
"""
191+
match event["type"]:
192+
case "message_start":
193+
return {"messageStart": {"role": "assistant"}}
194+
195+
case "content_block_start":
196+
return {"contentBlockStart": {"start": {}}}
197+
198+
case "content_block_delta":
199+
return {"contentBlockDelta": {"delta": {"text": event["text"]}}}
200+
201+
case "content_block_stop":
202+
return {"contentBlockStop": {}}
203+
204+
case "message_stop":
205+
return {"messageStop": {"stopReason": event["stop_reason"]}}
206+
207+
case "metadata":
208+
return {
209+
"metadata": {
210+
"usage": {
211+
"inputTokens": event["usage"]["prompt_token_count"],
212+
"outputTokens": event["usage"]["candidates_token_count"],
213+
"totalTokens": event["usage"]["total_token_count"],
214+
},
215+
"metrics": {
216+
"latencyMs": 0,
217+
},
218+
}
219+
}
220+
221+
case _:
222+
raise RuntimeError(f"event_type=<{event['type']} | unknown type")
223+
224+
@override
225+
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
226+
"""Send the request to the Gemini model and get the streaming response.
227+
228+
Args:
229+
request: The formatted request to send to the Gemini model.
230+
231+
Returns:
232+
An iterable of response events from the Gemini model.
233+
234+
Raises:
235+
ContextWindowOverflowException: If the input exceeds the model's context window.
236+
ModelThrottledException: If the request is throttled by Gemini.
237+
"""
238+
try:
239+
response = self.model.generate_content(**request, stream=True)
240+
241+
yield {"type": "message_start"}
242+
yield {"type": "content_block_start"}
243+
244+
for chunk in response:
245+
if chunk.text:
246+
yield {"type": "content_block_delta", "text": chunk.text}
247+
248+
yield {"type": "content_block_stop"}
249+
yield {"type": "message_stop", "stop_reason": "end_turn"}
250+
251+
# Get usage information
252+
usage = response.usage_metadata
253+
yield {
254+
"type": "metadata",
255+
"usage": {
256+
"prompt_token_count": usage.prompt_token_count,
257+
"candidates_token_count": usage.candidates_token_count,
258+
"total_token_count": usage.total_token_count,
259+
},
260+
}
261+
262+
except Exception as error:
263+
if "quota" in str(error).lower():
264+
raise ModelThrottledException(str(error)) from error
265+
266+
if any(overflow_message in str(error).lower() for overflow_message in GeminiModel.OVERFLOW_MESSAGES):
267+
raise ContextWindowOverflowException(str(error)) from error
268+
269+
raise error

tests-integ/test_model_gemini.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Integration tests for the Gemini model provider."""
2+
3+
import os
4+
5+
import pytest
6+
7+
import strands
8+
from strands import Agent
9+
from strands.models.gemini import GeminiModel
10+
11+
12+
@pytest.fixture
13+
def model():
14+
return GeminiModel(
15+
client_args={
16+
"api_key": os.getenv("GOOGLE_API_KEY"),
17+
},
18+
model_id="gemini-pro",
19+
max_tokens=512,
20+
)
21+
22+
23+
@pytest.fixture
24+
def tools():
25+
@strands.tool
26+
def tool_time() -> str:
27+
return "12:00"
28+
29+
@strands.tool
30+
def tool_weather() -> str:
31+
return "sunny"
32+
33+
return [tool_time, tool_weather]
34+
35+
36+
@pytest.fixture
37+
def system_prompt():
38+
return "You are an AI assistant that uses & instead of ."
39+
40+
41+
@pytest.fixture
42+
def agent(model, tools, system_prompt):
43+
return Agent(model=model, tools=tools, system_prompt=system_prompt)
44+
45+
46+
@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY environment variable missing")
47+
def test_agent(agent):
48+
result = agent("What is the time and weather in New York?")
49+
text = result.message["content"][0]["text"].lower()
50+
51+
assert all(string in text for string in ["12:00", "sunny", "&"])

0 commit comments

Comments
 (0)