Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions 01-getting-started/15-a2a/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info

# Virtual environments
.venv
sessions/
1 change: 1 addition & 0 deletions 01-getting-started/15-a2a/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.13
10 changes: 10 additions & 0 deletions 01-getting-started/15-a2a/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Agent to Agent (A2A) protocol using Strands Agent SDK

An example Strands agent that helps with searching AWS documentation.

## Getting started

1. Install [uv](https://docs.astral.sh/uv/getting-started/installation/).
2. Configure AWS credentials, follow instructions [here](https://cuddly-sniffle-lrmk2y7.pages.github.io/0.1.x-strands/user-guide/quickstart/#configuring-credentials).
3. Start the A2A server using `uv run __main__.py`.
4. Run the test client `uv run test_client.py`.
62 changes: 62 additions & 0 deletions 01-getting-started/15-a2a/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import click
from agent import StrandAgent

from a2a.server.apps import A2AStarletteApplication
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from a2a.types import (
AgentAuthentication,
AgentCapabilities,
AgentCard,
AgentSkill,
)

from agent_executor import StrandsAgentExecutor


@click.command()
@click.option("--host", "host", default="localhost")
@click.option("--port", "port", default=10000)
def main(host: str, port: int):
request_handler = DefaultRequestHandler(
agent_executor=StrandsAgentExecutor(),
task_store=InMemoryTaskStore(),
)

server = A2AStarletteApplication(
agent_card=get_agent_card(host, port), http_handler=request_handler
)
import uvicorn

uvicorn.run(server.build(), host=host, port=port)


def get_agent_card(host: str, port: int):
"""Returns the Agent Card for the Currency Agent."""
capabilities = AgentCapabilities(streaming=True, pushNotifications=True)
skill = AgentSkill(
id="search_aws_docs",
name="AWS Documentation search",
description="Search AWS documentation for topics related to AWS services.",
tags=["AWS Documentation researcher"],
examples=[
"What is Amazon Bedrock?",
"What is Amazon Bedrock pricing model?",
"How to enable AWS lambda trigger from Amazon S3?",
],
)
return AgentCard(
name="AWS Documentation researcher",
description="Helps with queries related to AWS services.",
url=f"http://{host}:{port}/",
version="1.0.0",
defaultInputModes=StrandAgent.SUPPORTED_CONTENT_TYPES,
defaultOutputModes=StrandAgent.SUPPORTED_CONTENT_TYPES,
capabilities=capabilities,
skills=[skill],
authentication=AgentAuthentication(schemes=["public"]),
)


if __name__ == "__main__":
main()
125 changes: 125 additions & 0 deletions 01-getting-started/15-a2a/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from mcp import StdioServerParameters, stdio_client
from strands import Agent
from strands.tools.mcp import MCPClient
from strands_tools import file_write
import os
import json
import asyncio


class StrandAgent:
SUPPORTED_CONTENT_TYPES = ["text", "text/plain"]

def __init__(self):
self.agent = None

try:
os.makedirs("sessions", exist_ok=True)
self.documentation_mcp_server = MCPClient(
lambda: stdio_client(
StdioServerParameters(
command="uvx",
args=["awslabs.aws-documentation-mcp-server@latest"],
)
)
)
self.documentation_mcp_server.start()
self.tools = self.documentation_mcp_server.list_tools_sync() + [file_write]

except Exception as e:
return f"Error initializing agent: {str(e)}"

def _load_agent_from_memory(self, session_id: str) -> str:
session_path = os.path.join("sessions", f"{session_id}.json")
agent = None

try:
if os.path.isfile(session_path):
with open(session_path, "r") as f:
state = json.load(f)

agent = Agent(
messages=state["messages"],
system_prompt=state["system_prompt"],
tools=self.tools,
callback_handler=None,
)
else:
agent = Agent(
system_prompt="""You are a thorough AWS researcher specialized in finding accurate
information online. For each question:

1. Determine what information you need
2. Search the AWS Documentation for reliable information
3. Extract key information and cite your sources
4. Store important findings in memory for future reference
5. Synthesize what you've found into a clear, comprehensive answer

When researching, focus only on AWS documentation. Always provide citations
for the information you find.

Finally output your response to a file in current directory.
""",
tools=self.tools,
callback_handler=None,
)
return agent
except Exception as e:
raise f"Error Loading agent from memory: {e}"

def _store_agent_into_memory(self, agent: Agent, session_id: str) -> bool:
session_path = os.path.join("sessions", f"{session_id}.json")
state = {"messages": agent.messages, "system_prompt": agent.system_prompt}
with open(session_path, "w") as f:
json.dump(state, f)
return True

async def stream(self, query: str, session_id: str):
agent = self._load_agent_from_memory(session_id=session_id)
response = str()
try:
async for event in agent.stream_async(query):
if "data" in event:
# Only stream text chunks to the client
response += event["data"]
yield {
"is_task_complete": "complete" in event,
"require_user_input": False,
"content": event["data"],
}

except Exception as e:
yield {
"is_task_complete": False,
"require_user_input": True,
"content": f"We are unable to process your request at the moment. Error: {e}",
}
finally:
self._store_agent_into_memory(agent, session_id)
yield {
"is_task_complete": True,
"require_user_input": False,
"content": response,
}

def invoke(self, query: str, session_id: str):
agent = self._load_agent_from_memory(session_id=session_id)
try:
response = str(agent(query))

self._store_agent_into_memory(agent, session_id)

except Exception as e:
raise f"Error invoking agent: {e}"
return response


async def main():
agent = StrandAgent()

async for chunk in agent.stream("hello", "123"):
print(chunk, "")


if __name__ == "__main__":
asyncio.run(main())
81 changes: 81 additions & 0 deletions 01-getting-started/15-a2a/agent_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from agent import StrandAgent
from typing_extensions import override

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events.event_queue import EventQueue
from a2a.types import (
TaskArtifactUpdateEvent,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
UnsupportedOperationError
)
from a2a.utils import new_agent_text_message, new_task, new_text_artifact
from a2a.utils.errors import ServerError


class StrandsAgentExecutor(AgentExecutor):
"""Currency AgentExecutor Example."""

def __init__(self):
self.agent = StrandAgent()

@override
async def execute(
self,
context: RequestContext,
event_queue: EventQueue,
) -> None:
query = context.get_user_input()
task = context.current_task

if not context.message:
raise Exception("No message provided")

if not task:
task = new_task(context.message)
event_queue.enqueue_event(task)

async for event in self.agent.stream(query, task.contextId):
if event["is_task_complete"]:
event_queue.enqueue_event(
TaskArtifactUpdateEvent(
append=False,
contextId=task.contextId,
taskId=task.id,
lastChunk=True,
artifact=new_text_artifact(
name="current_result",
description="Result of request to agent.",
text=event["content"],
),
)
)
event_queue.enqueue_event(
TaskStatusUpdateEvent(
status=TaskStatus(state=TaskState.completed),
final=True,
contextId=task.contextId,
taskId=task.id,
)
)
else:
event_queue.enqueue_event(
TaskStatusUpdateEvent(
status=TaskStatus(
state=TaskState.working,
message=new_agent_text_message(
event["content"],
task.contextId,
task.id,
),
),
final=False,
contextId=task.contextId,
taskId=task.id,
)
)

@override
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
raise ServerError(error=UnsupportedOperationError())
12 changes: 12 additions & 0 deletions 01-getting-started/15-a2a/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[project]
name = "aws-assistant-strands"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"a2a-sdk>=0.2.1a1",
"mcp[cli]>=1.9.0",
"strands-agents>=0.1.0",
"strands-agents-tools>=0.1.0",
]
Loading