diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 62a988ea6f..91c448bc6b 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -6,13 +6,7 @@ from collections.abc import AsyncIterator from enum import Enum -from typing import ( - Annotated, - Any, - Literal, - Protocol, - runtime_checkable, -) +from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict @@ -357,34 +351,6 @@ class CompletionRequest(BaseModel): logprobs: LogProbConfig | None = None -@json_schema_type -class CompletionResponse(MetricResponseMixin): - """Response from a completion request. - - :param content: The generated completion text - :param stop_reason: Reason why generation stopped - :param logprobs: Optional log probabilities for generated tokens - """ - - content: str - stop_reason: StopReason - logprobs: list[TokenLogProbs] | None = None - - -@json_schema_type -class CompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed completion response. - - :param delta: New content generated since last chunk. This can be one or more tokens. - :param stop_reason: Optional reason why generation stopped, if complete - :param logprobs: Optional log probabilities for generated tokens - """ - - delta: str - stop_reason: StopReason | None = None - logprobs: list[TokenLogProbs] | None = None - - class SystemMessageBehavior(Enum): """Config for how to override the default system prompt. @@ -1010,7 +976,7 @@ class InferenceProvider(Protocol): async def rerank( self, model: str, - query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + query: (str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam), items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], max_num_results: int | None = None, ) -> RerankResponse: @@ -1025,7 +991,12 @@ async def rerank( raise NotImplementedError("Reranking is not implemented") return # this is so mypy's safe-super rule will consider the method concrete - @webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/completions", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1) async def openai_completion( self, @@ -1079,7 +1050,12 @@ async def openai_completion( """ ... - @webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/chat/completions", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1) async def openai_chat_completion( self, @@ -1138,7 +1114,12 @@ async def openai_chat_completion( """ ... - @webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/embeddings", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1) async def openai_embeddings( self, @@ -1172,7 +1153,12 @@ class Inference(InferenceProvider): - Embedding models: these models generate embeddings to be used for semantic search. """ - @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/chat/completions", + method="GET", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1) async def list_chat_completions( self, @@ -1192,7 +1178,15 @@ async def list_chat_completions( raise NotImplementedError("List chat completions is not implemented") @webmethod( - route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True + route="/openai/v1/chat/completions/{completion_id}", + method="GET", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) + @webmethod( + route="/chat/completions/{completion_id}", + method="GET", + level=LLAMA_STACK_API_V1, ) @webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1) async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index c4338e6140..ff86a89ed4 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -10,21 +10,21 @@ from datetime import UTC, datetime from typing import Annotated, Any -from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam -from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam +from openai.types.chat import ( + ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam, +) +from openai.types.chat import ( + ChatCompletionToolParam as OpenAIChatCompletionToolParam, +) from pydantic import Field, TypeAdapter -from llama_stack.apis.common.content_types import ( - InterleavedContent, -) +from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, Inference, ListOpenAIChatCompletionResponse, Message, @@ -51,7 +51,10 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span +from llama_stack.providers.utils.telemetry.tracing import ( + enqueue_event, + get_current_span, +) logger = get_logger(name=__name__, category="core::routers") @@ -434,7 +437,7 @@ async def stream_tokens_and_compute_metrics( prompt_tokens, model, tool_prompt_format: ToolPromptFormat | None = None, - ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: completion_text = "" async for chunk in response: complete = False @@ -500,7 +503,7 @@ async def stream_tokens_and_compute_metrics( async def count_tokens_and_compute_metrics( self, - response: ChatCompletionResponse | CompletionResponse, + response: ChatCompletionResponse, prompt_tokens, model, tool_prompt_format: ToolPromptFormat | None = None, @@ -522,7 +525,10 @@ async def count_tokens_and_compute_metrics( model=model, ) for metric in completion_metrics: - if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens + if metric.metric in [ + "completion_tokens", + "total_tokens", + ]: # Only log completion and total tokens enqueue_event(metric) # Return metrics in response @@ -646,7 +652,7 @@ async def stream_tokens_and_compute_metrics_openai_chat( message = OpenAIAssistantMessageParam( role="assistant", content=content_str if content_str else None, - tool_calls=assembled_tool_calls if assembled_tool_calls else None, + tool_calls=(assembled_tool_calls if assembled_tool_calls else None), ) logprobs_content = choice_data["logprobs_content_parts"] final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index b984d97bf1..c495565fda 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -23,9 +23,6 @@ from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, -) from .config import SentenceTransformersInferenceConfig @@ -33,7 +30,6 @@ class SentenceTransformersInferenceImpl( - OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, InferenceProvider, ModelsProtocolPrivate, diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 08652f8c0b..a7802b2828 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -7,9 +7,10 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import OpenAIEmbeddingsResponse - -# from llama_stack.providers.datatypes import ModelsProtocolPrivate -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry +from llama_stack.providers.utils.inference.model_registry import ( + ModelRegistryHelper, + build_hf_repo_model_entry, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, ) @@ -36,13 +37,12 @@ "Llama3.2-3B": "meta-llama/Llama-3.2-3B", } -SAFETY_MODELS_ENTRIES = [] # Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template MODEL_ENTRIES = [ build_hf_repo_model_entry(provider_model_id, model_descriptor) for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items() -] + SAFETY_MODELS_ENTRIES +] class RunpodInferenceAdapter( diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index fc58691e2e..63b877b5e8 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -13,7 +13,6 @@ from llama_stack.apis.inference import ( ChatCompletionRequest, - CompletionRequest, GreedySamplingStrategy, Inference, OpenAIChatCompletion, @@ -81,7 +80,7 @@ def _get_openai_client(self) -> AsyncOpenAI: ) return self._openai_client - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: + async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {"params": {}} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index d863eb53a5..e3f1d0913e 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -10,9 +10,7 @@ import uuid import warnings from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable -from typing import ( - Any, -) +from typing import Any from openai import AsyncStream from openai.types.chat import ( @@ -97,8 +95,6 @@ ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, GreedySamplingStrategy, JsonSchemaResponseFormat, Message, @@ -244,31 +240,6 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA return None -def process_completion_response( - response: OpenAICompatCompletionResponse, -) -> CompletionResponse: - choice = response.choices[0] - # drop suffix if present and return stop reason as end of turn - if choice.text.endswith("<|eot_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_turn, - content=choice.text[: -len("<|eot_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - # drop suffix if present and return stop reason as end of message - if choice.text.endswith("<|eom_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_message, - content=choice.text[: -len("<|eom_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - return CompletionResponse( - stop_reason=get_stop_reason(choice.finish_reason), - content=choice.text, - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - - def process_chat_completion_response( response: OpenAICompatCompletionResponse, request: ChatCompletionRequest, @@ -335,40 +306,40 @@ def process_chat_completion_response( ) -async def process_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], -) -> AsyncGenerator[CompletionResponseStreamChunk, None]: - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - text = text_from_choice(choice) - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - yield CompletionResponseStreamChunk( - delta=text, - stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), - ) - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - yield CompletionResponseStreamChunk( - delta="", - stop_reason=stop_reason, - ) +# async def process_completion_stream_response( +# stream: AsyncGenerator[OpenAICompatCompletionResponse, None], +# ) -> AsyncGenerator[CompletionResponseStreamChunk, None]: +# stop_reason = None + +# async for chunk in stream: +# choice = chunk.choices[0] +# finish_reason = choice.finish_reason + +# text = text_from_choice(choice) +# if text == "<|eot_id|>": +# stop_reason = StopReason.end_of_turn +# text = "" +# continue +# elif text == "<|eom_id|>": +# stop_reason = StopReason.end_of_message +# text = "" +# continue +# yield CompletionResponseStreamChunk( +# delta=text, +# stop_reason=stop_reason, +# logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), +# ) +# if finish_reason: +# if finish_reason in ["stop", "eos", "eos_token"]: +# stop_reason = StopReason.end_of_turn +# elif finish_reason == "length": +# stop_reason = StopReason.out_of_tokens +# break + +# yield CompletionResponseStreamChunk( +# delta="", +# stop_reason=stop_reason, +# ) async def process_chat_completion_stream_response( @@ -984,7 +955,9 @@ def openai_messages_to_messages( return converted_messages -def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None): +def openai_content_to_content( + content: str | Iterable[OpenAIChatCompletionContentPartParam] | None, +): if content is None: return "" if isinstance(content, str): diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index 1ee1b2f470..df22747544 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -12,11 +12,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.inference import Inference -from llama_stack.core.datatypes import ( - Api, - Provider, - StackRunConfig, -) +from llama_stack.core.datatypes import Api, Provider, StackRunConfig from llama_stack.core.resolver import resolve_impls from llama_stack.core.routers.inference import InferenceRouter from llama_stack.core.routing_tables.models import ModelsRoutingTable @@ -54,7 +50,12 @@ def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: class SampleImpl: - def __init__(self, config: SampleConfig, deps: dict[Api, Any], provider_spec: ProviderSpec = None): + def __init__( + self, + config: SampleConfig, + deps: dict[Api, Any], + provider_spec: ProviderSpec = None, + ): self.__provider_id__ = "test_provider" self.__provider_spec__ = provider_spec self.__provider_config__ = config