Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/docs/providers/inference/remote_vllm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `enable_model_discovery` | `<class 'bool'>` | No | True | Whether to enable model discovery from the vLLM server |

## Sample Configuration

Expand All @@ -28,4 +29,5 @@ url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
```
1 change: 1 addition & 0 deletions llama_stack/distributions/ci-tests/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/postgres-demo/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/starter-gpu/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/starter/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ async def openai_completion(self, *args, **kwargs):
async def should_refresh_models(self) -> bool:
return False

async def enable_model_discovery(self) -> bool:
return True

async def list_models(self) -> list[Model] | None:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ async def shutdown(self) -> None:
async def should_refresh_models(self) -> bool:
return False

async def enable_model_discovery(self) -> bool:
return True

async def list_models(self) -> list[Model] | None:
return [
Model(
Expand Down
5 changes: 5 additions & 0 deletions llama_stack/providers/remote/inference/vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
default=False,
description="Whether to refresh models periodically",
)
enable_model_discovery: bool = Field(
default=True,
description="Whether to enable model discovery from the vLLM server",
)

@field_validator("tls_verify")
@classmethod
Expand All @@ -59,4 +63,5 @@ def sample_run_config(
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
"enable_model_discovery": "${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}",
}
37 changes: 36 additions & 1 deletion llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ async def should_refresh_models(self) -> bool:
return self.config.refresh_models

async def list_models(self) -> list[Model] | None:
log.debug(f"VLLM list_models called, enable_model_discovery={self.config.enable_model_discovery}")
if not self.config.enable_model_discovery:
log.debug("VLLM list_models returning None due to enable_model_discovery=False")
return None

models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
Expand Down Expand Up @@ -337,13 +342,43 @@ async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing

# Check if provider enables model discovery before making HTTP request
if not self.config.enable_model_discovery:
log.debug("Model discovery disabled for vLLM: Trusting model exists")
# Warn if API key is set but model discovery is disabled
if self.config.api_token:
log.warning(
"Model discovery is disabled but VLLM_API_TOKEN is set. "
"If you're not using model discovery, you may not need to set the API token. "
"Consider removing VLLM_API_TOKEN from your configuration or setting enable_model_discovery=true."
)
return model

try:
res = self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
available_models = [m.id async for m in res]

try:
available_models = [m.id async for m in res]
except Exception as e:
# Provide helpful error message for model discovery failures
log.error(f"Model discovery failed with the following output from vLLM server: {e}.\n")
log.error(
f"Model discovery failed: This typically occurs when a provider (like vLLM) is configured "
f"with model discovery enabled but the provider server doesn't support the /models endpoint. "
f"To resolve this, either:\n"
f"1. Check that {self.config.url} correctly points to the vLLM server, or\n"
f"2. Ensure your provider server supports the /v1/models endpoint and if authenticated that VLLM_API_TOKEN is set, or\n"
f"3. Set enable_model_discovery=false for the problematic provider in your configuration\n"
)
raise ValueError(
f"Model discovery failed for vLLM at {self.config.url}. Please check the server configuration and logs."
) from e

if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
Expand Down
3 changes: 3 additions & 0 deletions llama_stack/providers/utils/inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ async def list_models(self) -> list[Model] | None:
async def should_refresh_models(self) -> bool:
return False

async def enable_model_discovery(self) -> bool:
return True

def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)

Expand Down
3 changes: 3 additions & 0 deletions llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,6 @@ async def check_model_availability(self, model: str) -> bool:

async def should_refresh_models(self) -> bool:
return False

async def enable_model_discovery(self) -> bool:
return True
3 changes: 3 additions & 0 deletions tests/unit/distribution/routers/test_routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ async def unregister_model(self, model_id: str):
async def should_refresh_models(self):
return False

async def enable_model_discovery(self):
return True

async def list_models(self):
return [
Model(
Expand Down
86 changes: 75 additions & 11 deletions tests/unit/providers/inference/test_remote_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,27 +636,75 @@ async def test_should_refresh_models():
Test the should_refresh_models method with different refresh_models configurations.

This test verifies that:
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
1. When refresh_models is True, should_refresh_models returns True
2. When refresh_models is False, should_refresh_models returns False
"""

# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
# Test case 1: refresh_models is True
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"

# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
# Test case 2: refresh_models is False
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=False)
adapter2 = VLLMInferenceAdapter(config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
assert result2 is False, "should_refresh_models should return False when refresh_models is False"

# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)

async def test_enable_model_discovery_flag():
"""
Test the enable_model_discovery flag functionality.

This test verifies that:
1. When enable_model_discovery is True (default), list_models returns models from the server
2. When enable_model_discovery is False, list_models returns None without calling the server
"""

# Test case 1: enable_model_discovery is True (default)
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=True)
adapter1 = VLLMInferenceAdapter(config1)
adapter1.__provider_id__ = "test-vllm"

# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()

async def mock_models_list():
yield OpenAIModel(id="test-model-1", created=1, object="model", owned_by="test")
yield OpenAIModel(id="test-model-2", created=2, object="model", owned_by="test")

mock_client.models.list.return_value = mock_models_list()
mock_client_property.return_value = mock_client

models = await adapter1.list_models()
assert models is not None, "list_models should return models when enable_model_discovery is True"
assert len(models) == 2, "Should return 2 models"
assert models[0].identifier == "test-model-1"
assert models[1].identifier == "test-model-2"
mock_client.models.list.assert_called_once()

# Test case 2: enable_model_discovery is False
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=False)
adapter2 = VLLMInferenceAdapter(config2)
adapter2.__provider_id__ = "test-vllm"

# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client_property.return_value = mock_client

models = await adapter2.list_models()
assert models is None, "list_models should return None when enable_model_discovery is False"
mock_client.models.list.assert_not_called()

# Test case 3: enable_model_discovery defaults to True
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost")
adapter3 = VLLMInferenceAdapter(config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
adapter3.__provider_id__ = "test-vllm"
result3 = await adapter3.enable_model_discovery()
assert result3 is True, "enable_model_discovery should return True by default"

# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
Expand All @@ -670,6 +718,22 @@ async def test_should_refresh_models():
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"

# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()

async def mock_models_list():
yield OpenAIModel(id="default-model", created=1, object="model", owned_by="test")

mock_client.models.list.return_value = mock_models_list()
mock_client_property.return_value = mock_client

models = await adapter3.list_models()
assert models is not None, "list_models should return models when enable_model_discovery defaults to True"
assert len(models) == 1, "Should return 1 model"
assert models[0].identifier == "default-model"
mock_client.models.list.assert_called_once()


async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/unit/server/test_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async def test_setup(cached_disk_dist_registry):
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
mock_inference.enable_model_discovery = AsyncMock(return_value=True)
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
Expand Down
Loading