diff --git a/.changeset/nostalgic-tireless-kestrel.md b/.changeset/nostalgic-tireless-kestrel.md new file mode 100644 index 00000000..e7bd4fb9 --- /dev/null +++ b/.changeset/nostalgic-tireless-kestrel.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +Add LLM customization support (eg. api_base) diff --git a/stagehand/config.py b/stagehand/config.py index a577230d..8805e40b 100644 --- a/stagehand/config.py +++ b/stagehand/config.py @@ -20,6 +20,7 @@ class StagehandConfig(BaseModel): browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions. model_name (Optional[str]): Name of the model to use. model_api_key (Optional[str]): Model API key. + model_client_options (Optional[dict[str, Any]]): Options for the model client. logger (Optional[Callable[[Any], None]]): Custom logging function. verbose (Optional[int]): Verbosity level for logs (1=minimal, 2=medium, 3=detailed). use_rich_logging (bool): Whether to use Rich for colorized logging. @@ -50,6 +51,11 @@ class StagehandConfig(BaseModel): model_api_key: Optional[str] = Field( None, alias="modelApiKey", description="Model API key" ) + model_client_options: Optional[dict[str, Any]] = Field( + None, + alias="modelClientOptions", + description="Configuration options for the language model client (i.e. api_base)", + ) verbose: Optional[int] = Field( 1, description="Verbosity level for logs: 0=minimal (ERROR), 1=medium (INFO), 2=detailed (DEBUG)", diff --git a/stagehand/llm/client.py b/stagehand/llm/client.py index ac0ead4a..e9fbefe5 100644 --- a/stagehand/llm/client.py +++ b/stagehand/llm/client.py @@ -54,7 +54,7 @@ def __init__( setattr(litellm, key, value) self.logger.debug(f"Set global litellm.{key}", category="llm") # Handle common aliases or expected config names if necessary - elif key == "api_base": # Example: map api_base if needed + elif key == "api_base" or key == "baseURL": litellm.api_base = value self.logger.debug( f"Set global litellm.api_base to {value}", category="llm" diff --git a/stagehand/main.py b/stagehand/main.py index 4a201adb..03d6d45f 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -161,7 +161,11 @@ def __init__( # Handle non-config parameters self.api_url = self.config.api_url + + # Handle model-related settings + self.model_client_options = self.config.model_client_options or {} self.model_api_key = self.config.model_api_key or os.getenv("MODEL_API_KEY") + self.model_name = self.config.model_name # Extract frequently used values from config for convenience @@ -181,11 +185,11 @@ def __init__( self.local_browser_launch_options = ( self.config.local_browser_launch_options or {} ) - - # Handle model-related settings - self.model_client_options = {} - if self.model_api_key and "apiKey" not in self.model_client_options: + if self.model_api_key: self.model_client_options["apiKey"] = self.model_api_key + else: + if "apiKey" in self.model_client_options: + self.model_api_key = self.model_client_options["apiKey"] # Handle browserbase session create params self.browserbase_session_create_params = make_serializable( diff --git a/tests/unit/llm/test_llm_integration.py b/tests/unit/llm/test_llm_integration.py index a01e7a73..00d09c40 100644 --- a/tests/unit/llm/test_llm_integration.py +++ b/tests/unit/llm/test_llm_integration.py @@ -40,6 +40,7 @@ def test_llm_client_with_custom_options(self): api_key="test-key", default_model="gpt-4o-mini", stagehand_logger=StagehandLogger(), + api_base="https://test-api-base.com", ) assert client.default_model == "gpt-4o-mini" diff --git a/tests/unit/test_client_initialization.py b/tests/unit/test_client_initialization.py index 237ada9b..afec5c6b 100644 --- a/tests/unit/test_client_initialization.py +++ b/tests/unit/test_client_initialization.py @@ -228,3 +228,32 @@ async def mock_create_session(): # Call _create_session and expect error with pytest.raises(RuntimeError, match="Invalid response format"): await client._create_session() + + @mock.patch.dict(os.environ, {"MODEL_API_KEY": "test-model-api-key"}, clear=True) + def test_init_with_model_api_key_in_env(self): + config = StagehandConfig(env="LOCAL") + client = Stagehand(config=config) + assert client.model_api_key == "test-model-api-key" + + def test_init_with_custom_llm(self): + config = StagehandConfig( + env="LOCAL", + model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"} + ) + client = Stagehand(config=config) + assert client.model_api_key == "custom-llm-key" + assert client.model_client_options["apiKey"] == "custom-llm-key" + assert client.model_client_options["baseURL"] == "https://custom-llm.com" + + def test_init_with_custom_llm_override(self): + config = StagehandConfig( + env="LOCAL", + model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"} + ) + client = Stagehand( + config=config, + model_client_options={"apiKey": "override-llm-key", "baseURL": "https://override-llm.com"} + ) + assert client.model_api_key == "override-llm-key" + assert client.model_client_options["apiKey"] == "override-llm-key" + assert client.model_client_options["baseURL"] == "https://override-llm.com" \ No newline at end of file