Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
self.citation_files: dict[str, str] = {}

async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
Expand Down Expand Up @@ -126,6 +128,7 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
Expand Down Expand Up @@ -160,7 +163,7 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(await convert_chat_choice_to_response_message(choice))
output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))

# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
Expand Down Expand Up @@ -211,6 +214,8 @@ def _separate_tool_calls(self, current_response, messages) -> tuple[list, list,

for choice in current_response.choices:
next_turn_messages.append(choice.message)
logger.debug(f"Choice message content: {choice.message.content}")
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")

if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
Expand Down Expand Up @@ -470,6 +475,8 @@ async def _coordinate_tool_execution(
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
if result.citation_files:
self.citation_files.update(result.citation_files)

if tool_call_log:
output_messages.append(tool_call_log)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ async def execute_tool_call(

# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
sequence_number=sequence_number,
final_output_message=output_message,
final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
)

async def _execute_knowledge_search_via_vector_store(
Expand Down Expand Up @@ -129,36 +132,65 @@ async def search_single_store(vector_store_id):
for results in all_results:
search_results.extend(results)

# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)

unique_files = set()
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
# Get file_id from attributes if result_item.file_id is empty
file_id = result_item.file_id or (
result_item.attributes.get("document_id") if result_item.attributes else None
)
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"

text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
unique_files.add(file_id)

content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))

citation_instruction = ""
if unique_files:
citation_instruction = (
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
)

content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
)
)

# handling missing attributes for old versions
citation_files = {}
for result in search_results:
file_id = result.file_id
if not file_id and result.attributes:
file_id = result.attributes.get("document_id")

filename = result.filename
if not filename and result.attributes:
filename = result.attributes.get("filename")
if not filename:
filename = "unknown"

citation_files[file_id] = filename

return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
"citation_files": citation_files,
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel):
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
citation_files: dict[str, str] | None = None


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import re
import uuid

from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
Expand Down Expand Up @@ -45,7 +47,9 @@
)


async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
async def convert_chat_choice_to_response_message(
choice: OpenAIChoice, citation_files: dict[str, str] | None = None
) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
Expand All @@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)

annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})

return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
status="completed",
role="assistant",
)
Expand Down Expand Up @@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str):
return role_to_type.get(role)


def _extract_citations_from_text(
text: str, citation_files: dict[str, str]
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
"""Extract citation markers from text and create annotations

Args:
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
citation_files: Dictionary mapping file_id to filename

Returns:
Tuple of (annotations_list, clean_text_without_markers)
"""
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")

annotations = []
parts = []
total_len = 0
last_end = 0

for m in file_id_regex.finditer(text):
# segment before the marker
prefix = text[last_end : m.start()]

# drop one space if it exists (since marker is at sentence end)
if prefix.endswith(" "):
prefix = prefix[:-1]

parts.append(prefix)
total_len += len(prefix)

fid = m.group(1)
if fid in citation_files:
annotations.append(
OpenAIResponseAnnotationFileCitation(
file_id=fid,
filename=citation_files[fid],
index=total_len, # index points to punctuation
)
)

last_end = m.end()

parts.append(text[last_end:])
cleaned_text = "".join(parts)
return annotations, cleaned_text


def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
Expand Down
5 changes: 4 additions & 1 deletion llama_stack/providers/inline/tool_runtime/rag/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,5 +331,8 @@ async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvoc

return ToolInvocationResult(
content=result.content or [],
metadata=result.metadata,
metadata={
**(result.metadata or {}),
"citation_files": getattr(result, "citation_files", None),
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ async def openai_search_vector_store(
content = self._chunk_to_vector_store_content(chunk)

response_data_item = VectorStoreSearchResponse(
file_id=chunk.metadata.get("file_id", ""),
file_id=chunk.metadata.get("document_id", ""),
filename=chunk.metadata.get("filename", ""),
score=score,
attributes=chunk.metadata,
Expand Down Expand Up @@ -608,12 +608,15 @@ async def openai_attach_file_to_vector_store(

content = content_from_data_and_mime_type(content_response.body, mime_type)

chunk_attributes = attributes.copy()
chunk_attributes["filename"] = file_response.filename

chunks = make_overlapped_chunks(
file_id,
content,
max_chunk_size_tokens,
chunk_overlap_tokens,
attributes,
chunk_attributes,
)

if not chunks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
Expand Down Expand Up @@ -35,6 +36,7 @@
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
_extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
Expand Down Expand Up @@ -340,3 +342,26 @@ def test_is_function_tool_call_false_wrong_type(self):

result = is_function_tool_call(tool_call, tools)
assert result is False


class TestExtractCitationsFromText:
def test_extract_citations_and_annotations(self):
text = "Start [not-a-file]. New source <|file-abc123|>. "
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}

annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)

expected_annotations = [
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
]
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."

assert cleaned_text == expected_clean_text
assert annotations == expected_annotations
# OpenAI cites at the end of the sentence
assert cleaned_text[expected_annotations[0].index] == "."
assert cleaned_text[expected_annotations[1].index] == "?"
assert cleaned_text[expected_annotations[2].index] == "!"
Loading