Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 35 additions & 41 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

from collections.abc import AsyncIterator
from enum import Enum
from typing import (
Annotated,
Any,
Literal,
Protocol,
runtime_checkable,
)
from typing import Annotated, Any, Literal, Protocol, runtime_checkable

from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
Expand Down Expand Up @@ -357,34 +351,6 @@ class CompletionRequest(BaseModel):
logprobs: LogProbConfig | None = None


@json_schema_type
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.

:param content: The generated completion text
:param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens
"""

content: str
stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None


@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.

:param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens
"""

delta: str
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None


class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt.

Expand Down Expand Up @@ -1010,7 +976,7 @@ class InferenceProvider(Protocol):
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
query: (str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam),
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
Expand All @@ -1025,7 +991,12 @@ async def rerank(
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete

@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/completions",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
Expand Down Expand Up @@ -1079,7 +1050,12 @@ async def openai_completion(
"""
...

@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/chat/completions",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
Expand Down Expand Up @@ -1138,7 +1114,12 @@ async def openai_chat_completion(
"""
...

@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/embeddings",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
Expand Down Expand Up @@ -1172,7 +1153,12 @@ class Inference(InferenceProvider):
- Embedding models: these models generate embeddings to be used for semantic search.
"""

@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/chat/completions",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
async def list_chat_completions(
self,
Expand All @@ -1192,7 +1178,15 @@ async def list_chat_completions(
raise NotImplementedError("List chat completions is not implemented")

@webmethod(
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
route="/openai/v1/chat/completions/{completion_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/chat/completions/{completion_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
Expand Down
30 changes: 18 additions & 12 deletions llama_stack/core/routers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@
from datetime import UTC, datetime
from typing import Annotated, Any

from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from openai.types.chat import (
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
)
from openai.types.chat import (
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
)
from pydantic import Field, TypeAdapter

from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
ListOpenAIChatCompletionResponse,
Message,
Expand All @@ -51,7 +51,10 @@
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.providers.utils.telemetry.tracing import (
enqueue_event,
get_current_span,
)

logger = get_logger(name=__name__, category="core::routers")

Expand Down Expand Up @@ -434,7 +437,7 @@ async def stream_tokens_and_compute_metrics(
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
completion_text = ""
async for chunk in response:
complete = False
Expand Down Expand Up @@ -500,7 +503,7 @@ async def stream_tokens_and_compute_metrics(

async def count_tokens_and_compute_metrics(
self,
response: ChatCompletionResponse | CompletionResponse,
response: ChatCompletionResponse,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
Expand All @@ -522,7 +525,10 @@ async def count_tokens_and_compute_metrics(
model=model,
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
if metric.metric in [
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
enqueue_event(metric)

# Return metrics in response
Expand Down Expand Up @@ -646,7 +652,7 @@ async def stream_tokens_and_compute_metrics_openai_chat(
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
tool_calls=(assembled_tool_calls if assembled_tool_calls else None),
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,13 @@
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)

from .config import SentenceTransformersInferenceConfig

log = get_logger(name=__name__, category="inference")


class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
Expand Down
10 changes: 5 additions & 5 deletions llama_stack/providers/remote/inference/runpod/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse

# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
Expand All @@ -36,13 +37,12 @@
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}

SAFETY_MODELS_ENTRIES = []

# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
]


class RunpodInferenceAdapter(
Expand Down
3 changes: 1 addition & 2 deletions llama_stack/providers/remote/inference/watsonx/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
GreedySamplingStrategy,
Inference,
OpenAIChatCompletion,
Expand Down Expand Up @@ -81,7 +80,7 @@ def _get_openai_client(self) -> AsyncOpenAI:
)
return self._openai_client

async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
Expand Down
Loading
Loading