diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index 598f97b198..884ca8922e 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -20,6 +20,7 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | +| `enable_model_discovery` | `` | No | True | Whether to enable model discovery from the vLLM server | ## Sample Configuration @@ -28,4 +29,5 @@ url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} +enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} ``` diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index b14477a9ad..81c947f569 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/postgres-demo/run.yaml b/llama_stack/distributions/postgres-demo/run.yaml index 0cf0e82e6a..98e784e764 100644 --- a/llama_stack/distributions/postgres-demo/run.yaml +++ b/llama_stack/distributions/postgres-demo/run.yaml @@ -16,6 +16,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: sentence-transformers provider_type: inline::sentence-transformers vector_io: diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index de5fe56811..187e3ccde4 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index c440e4e4b6..d02bd439d9 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index fd65fa10d3..f272040c00 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -71,6 +71,9 @@ async def openai_completion(self, *args, **kwargs): async def should_refresh_models(self) -> bool: return False + async def enable_model_discovery(self) -> bool: + return True + async def list_models(self) -> list[Model] | None: return None diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index b984d97bf1..3dd5d2b899 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -52,6 +52,9 @@ async def shutdown(self) -> None: async def should_refresh_models(self) -> bool: return False + async def enable_model_discovery(self) -> bool: + return True + async def list_models(self) -> list[Model] | None: return [ Model( diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 86ef3fe268..3887107ddb 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -34,6 +34,10 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): default=False, description="Whether to refresh models periodically", ) + enable_model_discovery: bool = Field( + default=True, + description="Whether to enable model discovery from the vLLM server", + ) @field_validator("tls_verify") @classmethod @@ -59,4 +63,5 @@ def sample_run_config( "max_tokens": "${env.VLLM_MAX_TOKENS:=4096}", "api_token": "${env.VLLM_API_TOKEN:=fake}", "tls_verify": "${env.VLLM_TLS_VERIFY:=true}", + "enable_model_discovery": "${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}", } diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 54ac8e1dc0..3b3edb5933 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -283,6 +283,11 @@ async def should_refresh_models(self) -> bool: return self.config.refresh_models async def list_models(self) -> list[Model] | None: + log.debug(f"VLLM list_models called, enable_model_discovery={self.config.enable_model_discovery}") + if not self.config.enable_model_discovery: + log.debug("VLLM list_models returning None due to enable_model_discovery=False") + return None + models = [] async for m in self.client.models.list(): model_type = ModelType.llm # unclear how to determine embedding vs. llm models @@ -337,13 +342,43 @@ async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) except ValueError: pass # Ignore statically unknown model, will check live listing + + # Check if provider enables model discovery before making HTTP request + if not self.config.enable_model_discovery: + log.debug("Model discovery disabled for vLLM: Trusting model exists") + # Warn if API key is set but model discovery is disabled + if self.config.api_token: + log.warning( + "Model discovery is disabled but VLLM_API_TOKEN is set. " + "If you're not using model discovery, you may not need to set the API token. " + "Consider removing VLLM_API_TOKEN from your configuration or setting enable_model_discovery=true." + ) + return model + try: res = self.client.models.list() except APIConnectionError as e: raise ValueError( f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL." ) from e - available_models = [m.id async for m in res] + + try: + available_models = [m.id async for m in res] + except Exception as e: + # Provide helpful error message for model discovery failures + log.error(f"Model discovery failed with the following output from vLLM server: {e}.\n") + log.error( + f"Model discovery failed: This typically occurs when a provider (like vLLM) is configured " + f"with model discovery enabled but the provider server doesn't support the /models endpoint. " + f"To resolve this, either:\n" + f"1. Check that {self.config.url} correctly points to the vLLM server, or\n" + f"2. Ensure your provider server supports the /v1/models endpoint and if authenticated that VLLM_API_TOKEN is set, or\n" + f"3. Set enable_model_discovery=false for the problematic provider in your configuration\n" + ) + raise ValueError( + f"Model discovery failed for vLLM at {self.config.url}. Please check the server configuration and logs." + ) from e + if model.provider_resource_id not in available_models: raise ValueError( f"Model {model.provider_resource_id} is not being served by vLLM. " diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 4913c2e1fb..17c43fe388 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -100,6 +100,9 @@ async def list_models(self) -> list[Model] | None: async def should_refresh_models(self) -> bool: return False + async def enable_model_discovery(self) -> bool: + return True + def get_provider_model_id(self, identifier: str) -> str | None: return self.alias_to_provider_id_map.get(identifier, None) diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4354b067e6..a00a45963e 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -425,3 +425,6 @@ async def check_model_availability(self, model: str) -> bool: async def should_refresh_models(self) -> bool: return False + + async def enable_model_discovery(self) -> bool: + return True diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 54a9dd72e2..ec7aca27e4 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -52,6 +52,9 @@ async def unregister_model(self, model_id: str): async def should_refresh_models(self): return False + async def enable_model_discovery(self): + return True + async def list_models(self): return [ Model( diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index cd31e4943d..701282179f 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -636,27 +636,75 @@ async def test_should_refresh_models(): Test the should_refresh_models method with different refresh_models configurations. This test verifies that: - 1. When refresh_models is True, should_refresh_models returns True regardless of api_token - 2. When refresh_models is False, should_refresh_models returns False regardless of api_token + 1. When refresh_models is True, should_refresh_models returns True + 2. When refresh_models is False, should_refresh_models returns False """ - # Test case 1: refresh_models is True, api_token is None - config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True) + # Test case 1: refresh_models is True + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=True) adapter1 = VLLMInferenceAdapter(config1) result1 = await adapter1.should_refresh_models() assert result1 is True, "should_refresh_models should return True when refresh_models is True" - # Test case 2: refresh_models is True, api_token is empty string - config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True) + # Test case 2: refresh_models is False + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=False) adapter2 = VLLMInferenceAdapter(config2) result2 = await adapter2.should_refresh_models() - assert result2 is True, "should_refresh_models should return True when refresh_models is True" + assert result2 is False, "should_refresh_models should return False when refresh_models is False" - # Test case 3: refresh_models is True, api_token is "fake" (default) - config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True) + +async def test_enable_model_discovery_flag(): + """ + Test the enable_model_discovery flag functionality. + + This test verifies that: + 1. When enable_model_discovery is True (default), list_models returns models from the server + 2. When enable_model_discovery is False, list_models returns None without calling the server + """ + + # Test case 1: enable_model_discovery is True (default) + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=True) + adapter1 = VLLMInferenceAdapter(config1) + adapter1.__provider_id__ = "test-vllm" + + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + + async def mock_models_list(): + yield OpenAIModel(id="test-model-1", created=1, object="model", owned_by="test") + yield OpenAIModel(id="test-model-2", created=2, object="model", owned_by="test") + + mock_client.models.list.return_value = mock_models_list() + mock_client_property.return_value = mock_client + + models = await adapter1.list_models() + assert models is not None, "list_models should return models when enable_model_discovery is True" + assert len(models) == 2, "Should return 2 models" + assert models[0].identifier == "test-model-1" + assert models[1].identifier == "test-model-2" + mock_client.models.list.assert_called_once() + + # Test case 2: enable_model_discovery is False + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=False) + adapter2 = VLLMInferenceAdapter(config2) + adapter2.__provider_id__ = "test-vllm" + + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client_property.return_value = mock_client + + models = await adapter2.list_models() + assert models is None, "list_models should return None when enable_model_discovery is False" + mock_client.models.list.assert_not_called() + + # Test case 3: enable_model_discovery defaults to True + config3 = VLLMInferenceAdapterConfig(url="http://test.localhost") adapter3 = VLLMInferenceAdapter(config3) - result3 = await adapter3.should_refresh_models() - assert result3 is True, "should_refresh_models should return True when refresh_models is True" + adapter3.__provider_id__ = "test-vllm" + result3 = await adapter3.enable_model_discovery() + assert result3 is True, "enable_model_discovery should return True by default" # Test case 4: refresh_models is True, api_token is real token config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) @@ -670,6 +718,22 @@ async def test_should_refresh_models(): result5 = await adapter5.should_refresh_models() assert result5 is False, "should_refresh_models should return False when refresh_models is False" + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + + async def mock_models_list(): + yield OpenAIModel(id="default-model", created=1, object="model", owned_by="test") + + mock_client.models.list.return_value = mock_models_list() + mock_client_property.return_value = mock_client + + models = await adapter3.list_models() + assert models is not None, "list_models should return models when enable_model_discovery defaults to True" + assert len(models) == 1, "Should return 1 model" + assert models[0].identifier == "default-model" + mock_client.models.list.assert_called_once() + async def test_provider_data_var_context_propagation(vllm_inference_adapter): """ diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 55449804a0..3cb393e0e9 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -32,6 +32,7 @@ async def test_setup(cached_disk_dist_registry): mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) + mock_inference.enable_model_discovery = AsyncMock(return_value=True) routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry,