diff --git a/README.md b/README.md index 8367552..e501528 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,16 @@ export REPLICATE_API_TOKEN= We recommend not adding the token directly to your source code, because you don't want to put your credentials in source control. If anyone used your API key, their usage would be charged to your account. +
+ +Alternative authentication + +As of [replicate 1.0.5](https://github.com/replicate/replicate-python/releases/tag/1.0.5) and [cog 0.14.11](https://github.com/replicate/cog/releases/tag/v0.14.11) it is possible to pass a `REPLICATE_API_TOKEN` via the `context` as part of a prediction request. + +The `Replicate()` constructor will now use this context when available. This grants cog models the ability to use the Replicate client libraries, scoped to a user on a per request basis. + +
+ ## Run a model Create a new Python file and add the following code, replacing the model identifier and input with your own: diff --git a/replicate/client.py b/replicate/client.py index 3e767d6..6a79813 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -348,6 +348,20 @@ def close(self) -> None: self._wrapped_transport.close() # type: ignore +def _get_api_token_from_environment() -> Optional[str]: + """Get API token from cog current scope if available, otherwise from environment.""" + try: + import cog # noqa: I001 # pyright: ignore [reportMissingImports] + + for key, value in cog.current_scope().context.items(): + if key.upper() == "REPLICATE_API_TOKEN": + return value + except: # noqa: S110,E722,BLE001 we don't want this code to cause clients to fail + pass + + return os.environ.get("REPLICATE_API_TOKEN") + + def _build_httpx_client( client_type: Type[Union[httpx.Client, httpx.AsyncClient]], api_token: Optional[str] = None, @@ -359,7 +373,7 @@ def _build_httpx_client( if "User-Agent" not in headers: headers["User-Agent"] = f"replicate-python/{__version__}" if "Authorization" not in headers and ( - api_token := api_token or os.environ.get("REPLICATE_API_TOKEN") + api_token := api_token or _get_api_token_from_environment() ): headers["Authorization"] = f"Bearer {api_token}" diff --git a/replicate/helpers.py b/replicate/helpers.py index c6ac907..62c7d67 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -44,8 +44,7 @@ def encode_json( if isinstance(obj, io.IOBase): if file_encoding_strategy == "base64": return base64_encode_file(obj) - else: - return client.files.create(obj).urls["get"] + return client.files.create(obj).urls["get"] if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) @@ -82,8 +81,7 @@ async def async_encode_json( if file_encoding_strategy == "base64": # TODO: This should ideally use an async based file reader path. return base64_encode_file(obj) - else: - return (await client.files.async_create(obj)).urls["get"] + return (await client.files.async_create(obj)).urls["get"] if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) @@ -183,9 +181,9 @@ def transform_output(value: Any, client: "Client") -> Any: def transform(obj: Any) -> Any: if isinstance(obj, Mapping): return {k: transform(v) for k, v in obj.items()} - elif isinstance(obj, Sequence) and not isinstance(obj, str): + if isinstance(obj, Sequence) and not isinstance(obj, str): return [transform(item) for item in obj] - elif isinstance(obj, str) and ( + if isinstance(obj, str) and ( obj.startswith("https:") or obj.startswith("data:") ): return FileOutput(obj, client) diff --git a/replicate/run.py b/replicate/run.py index 19db492..e82ffb4 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -38,7 +38,7 @@ def run( if "wait" not in params: params["wait"] = True - is_blocking = params["wait"] != False # noqa: E712 + is_blocking = params["wait"] is not False version, owner, name, version_id = identifier._resolve(ref) @@ -108,7 +108,7 @@ async def async_run( if "wait" not in params: params["wait"] = True - is_blocking = params["wait"] != False # noqa: E712 + is_blocking = params["wait"] is not False version, owner, name, version_id = identifier._resolve(ref) diff --git a/tests/test_client.py b/tests/test_client.py index 2c585b1..6ba6aea 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,10 +1,13 @@ import os +import sys from unittest import mock import httpx import pytest import respx +from replicate.client import _get_api_token_from_environment + @pytest.mark.asyncio async def test_authorization_when_setting_environ_after_import(): @@ -114,3 +117,159 @@ def mock_send(request): pass mock_send_wrapper.assert_called_once() + + +class TestGetApiToken: + """Test cases for _get_api_token_from_environment function covering all import paths.""" + + def test_cog_not_available_falls_back_to_env(self): + """Test fallback to environment when cog package is not available.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token == "env-token" # noqa: S105 + + def test_cog_import_error_falls_back_to_env(self): + """Test fallback to environment when cog import raises exception.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch( + "builtins.__import__", + side_effect=ModuleNotFoundError("No module named 'cog'"), + ): + token = _get_api_token_from_environment() + assert token == "env-token" # noqa: S105 + + def test_cog_no_current_scope_method_falls_back_to_env(self): + """Test fallback when cog exists but has no current_scope method.""" + mock_cog = mock.MagicMock() + del mock_cog.current_scope # Remove the method + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" # noqa: S105 + + def test_cog_current_scope_returns_none_falls_back_to_env(self): + """Test fallback when current_scope() returns None.""" + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = None + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" # noqa: S105 + + def test_cog_scope_no_context_attr_falls_back_to_env(self): + """Test fallback when scope has no context attribute.""" + mock_scope = mock.MagicMock() + del mock_scope.context # Remove the context attribute + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" # noqa: S105 + + def test_cog_scope_context_not_dict_falls_back_to_env(self): + """Test fallback when scope.context is not a dictionary.""" + mock_scope = mock.MagicMock() + mock_scope.context = "not a dict" + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" # noqa: S105 + + def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self): + """Test fallback when replicate_api_token key is missing from context.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"other_key": "other_value"} # Missing replicate_api_token + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" # noqa: S105 + + def test_cog_scope_replicate_api_token_valid_string(self): + """Test successful retrieval of non-empty token from cog.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "cog-token" # noqa: S105 + + def test_cog_scope_replicate_api_token_case_insensitive(self): + """Test successful retrieval of non-empty token from cog ignoring case.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"replicate_api_token": "cog-token"} + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "cog-token" # noqa: S105 + + def test_cog_scope_replicate_api_token_empty_string(self): + """Test that empty string from cog is returned (not falling back to env).""" + mock_scope = mock.MagicMock() + mock_scope.context = {"replicate_api_token": ""} # Empty string + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "" # Should return empty string, not env token + + def test_cog_scope_replicate_api_token_none(self): + """Test that None from cog is returned (not falling back to env).""" + mock_scope = mock.MagicMock() + mock_scope.context = {"replicate_api_token": None} + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token is None # Should return None, not env token + + def test_cog_current_scope_raises_exception_falls_back_to_env(self): + """Test fallback when current_scope() raises an exception.""" + mock_cog = mock.MagicMock() + mock_cog.current_scope.side_effect = RuntimeError("Scope error") + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" # noqa: S105 + + def test_no_env_token_returns_none(self): + """Test that None is returned when no environment token is set and cog unavailable.""" + with mock.patch.dict(os.environ, {}, clear=True): # Clear all env vars + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token is None + + def test_env_token_empty_string(self): + """Test that empty string from environment is returned.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": ""}): + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token == ""