Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 11 additions & 167 deletions llama_stack/providers/remote/inference/fireworks/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from collections.abc import AsyncGenerator

from fireworks.client import Fireworks
from openai import AsyncOpenAI

from llama_stack.apis.common.content_types import (
InterleavedContent,
Expand All @@ -24,12 +22,6 @@
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
Expand All @@ -45,15 +37,14 @@
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
Expand All @@ -68,7 +59,7 @@
logger = get_logger(name=__name__, category="inference::fireworks")


class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
Expand All @@ -79,7 +70,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass

def _get_api_key(self) -> str:
def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
Expand All @@ -91,15 +82,18 @@ def _get_api_key(self) -> str:
)
return provider_data.fireworks_api_key

def _get_base_url(self) -> str:
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"

def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)

def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
"""Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
return prompt

async def completion(
self,
Expand Down Expand Up @@ -285,153 +279,3 @@ async def embeddings(

embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)

async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)

# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]

params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)

return await self._get_openai_client().completions.create(**params)

async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)

# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)

if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
3 changes: 2 additions & 1 deletion tests/integration/inference/test_openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in (
"remote::together", # service returns 400
"remote::fireworks", # service returns 400 malformed input
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")

Expand All @@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in (
"remote::together", # param silently ignored, always returns floats
"remote::fireworks", # param silently ignored, always returns list of floats
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")

Expand Down Expand Up @@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
input=input_texts,
encoding_format="base64",
)

# Validate response structure
assert response.object == "list"
assert response.model == embedding_model_id
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ class Setup(BaseModel):
"embedding_model": "together/togethercomputer/m2-bert-80M-32k-retrieval",
},
),
"fireworks": Setup(
name="fireworks",
description="Fireworks provider with a text model",
defaults={
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
},
),
}


Expand Down
Loading