Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions llama_stack/providers/remote/inference/nvidia/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
OpenAIEmbeddingUsage,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

from . import NVIDIAConfig
Expand All @@ -21,9 +22,7 @@
logger = get_logger(name=__name__, category="inference::nvidia")


class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig

class NVIDIAInferenceAdapter(OpenAIMixin, ModelRegistryHelper):
"""
NVIDIA Inference Adapter for Llama Stack.

Expand All @@ -37,12 +36,27 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""

def __init__(self, config: NVIDIAConfig) -> None:
"""Initialize the NVIDIA inference adapter with configuration."""
# Initialize ModelRegistryHelper with empty model entries since NVIDIA uses dynamic model discovery
ModelRegistryHelper.__init__(self, model_entries=[], allowed_models=config.allowed_models)
self.config = config

# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
embedding_model_metadata: dict[str, dict[str, int]] = {
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
"nvidia/llama-3.2-nv-embedqa-1b-v2": {
"embedding_dimension": 2048,
"context_length": 8192,
},
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {
"embedding_dimension": 512,
"context_length": 4096,
},
"snowflake/arctic-embed-l": {
"embedding_dimension": 512,
"context_length": 1024,
},
}

async def initialize(self) -> None:
Expand Down Expand Up @@ -95,7 +109,7 @@ async def openai_embeddings(
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
encoding_format=(encoding_format if encoding_format is not None else NOT_GIVEN),
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
extra_body=extra_body,
Expand Down