From 5476623e205c3624be129db63b4d8e3b375544ac Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 23 May 2025 15:41:46 +0100 Subject: [PATCH 1/6] Support getting REPLICATE_API_TOKEN from cog context This commit introduces support for the cog context into the Replicate SDK. The `current_scope` helper now makes per-prediction context available via the `current_scope().context` dict. A cog model can then provide a REPLICATE_API_TOKEN on a per-prediction basis to be used by the model. def predict(prompt: str) -> str: replicate = Replicate() output = replicate.run("anthropic/claude-3.5-haiku", {input: {"prompt": "prompt"}}) return output --- replicate/client.py | 18 +++++- tests/test_client.py | 146 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) diff --git a/replicate/client.py b/replicate/client.py index 3e767d6..164a5d9 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -348,6 +348,22 @@ def close(self) -> None: self._wrapped_transport.close() # type: ignore +def _get_api_token_from_environment() -> Optional[str]: + """Get API token from cog current scope if available, otherwise from environment.""" + try: + import cog + + if hasattr(cog, "current_scope"): + scope = cog.current_scope() + if scope and hasattr(scope, "content") and isinstance(scope.content, dict): + if "replicate_api_token" in scope.content: + return scope.content["replicate_api_token"] + except (ImportError, AttributeError, Exception): + pass + + return os.environ.get("REPLICATE_API_TOKEN") + + def _build_httpx_client( client_type: Type[Union[httpx.Client, httpx.AsyncClient]], api_token: Optional[str] = None, @@ -359,7 +375,7 @@ def _build_httpx_client( if "User-Agent" not in headers: headers["User-Agent"] = f"replicate-python/{__version__}" if "Authorization" not in headers and ( - api_token := api_token or os.environ.get("REPLICATE_API_TOKEN") + api_token := api_token or _get_api_token_from_environment() ): headers["Authorization"] = f"Bearer {api_token}" diff --git a/tests/test_client.py b/tests/test_client.py index 2c585b1..be0469e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,10 +1,13 @@ import os +import sys from unittest import mock import httpx import pytest import respx +from replicate.client import _get_api_token_from_environment + @pytest.mark.asyncio async def test_authorization_when_setting_environ_after_import(): @@ -114,3 +117,146 @@ def mock_send(request): pass mock_send_wrapper.assert_called_once() + + +class TestGetApiToken: + """Test cases for _get_api_token_from_environment function covering all import paths.""" + + def test_cog_not_available_falls_back_to_env(self): + """Test fallback to environment when cog package is not available.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_import_error_falls_back_to_env(self): + """Test fallback to environment when cog import raises exception.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch( + "builtins.__import__", + side_effect=ModuleNotFoundError("No module named 'cog'"), + ): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_no_current_scope_method_falls_back_to_env(self): + """Test fallback when cog exists but has no current_scope method.""" + mock_cog = mock.MagicMock() + del mock_cog.current_scope # Remove the method + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_current_scope_returns_none_falls_back_to_env(self): + """Test fallback when current_scope() returns None.""" + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = None + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_scope_no_content_attr_falls_back_to_env(self): + """Test fallback when scope has no content attribute.""" + mock_scope = mock.MagicMock() + del mock_scope.content # Remove the content attribute + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_scope_content_not_dict_falls_back_to_env(self): + """Test fallback when scope.content is not a dictionary.""" + mock_scope = mock.MagicMock() + mock_scope.content = "not a dict" + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self): + """Test fallback when replicate_api_token key is missing from content.""" + mock_scope = mock.MagicMock() + mock_scope.content = {"other_key": "other_value"} # Missing replicate_api_token + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_scope_replicate_api_token_valid_string(self): + """Test successful retrieval of non-empty token from cog.""" + mock_scope = mock.MagicMock() + mock_scope.content = {"replicate_api_token": "cog-token"} + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "cog-token" + + def test_cog_scope_replicate_api_token_empty_string(self): + """Test that empty string from cog is returned (not falling back to env).""" + mock_scope = mock.MagicMock() + mock_scope.content = {"replicate_api_token": ""} # Empty string + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "" # Should return empty string, not env token + + def test_cog_scope_replicate_api_token_none(self): + """Test that None from cog is returned (not falling back to env).""" + mock_scope = mock.MagicMock() + mock_scope.content = {"replicate_api_token": None} + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token is None # Should return None, not env token + + def test_cog_current_scope_raises_exception_falls_back_to_env(self): + """Test fallback when current_scope() raises an exception.""" + mock_cog = mock.MagicMock() + mock_cog.current_scope.side_effect = RuntimeError("Scope error") + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_no_env_token_returns_none(self): + """Test that None is returned when no environment token is set and cog unavailable.""" + with mock.patch.dict(os.environ, {}, clear=True): # Clear all env vars + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token is None + + def test_env_token_empty_string(self): + """Test that empty string from environment is returned.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": ""}): + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token == "" From 80f89c973eaa8dad5557eb9f2765f06c5da17e53 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 23 May 2025 16:00:58 +0100 Subject: [PATCH 2/6] Document alternative authentication --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index 8367552..e501528 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,16 @@ export REPLICATE_API_TOKEN= We recommend not adding the token directly to your source code, because you don't want to put your credentials in source control. If anyone used your API key, their usage would be charged to your account. +
+ +Alternative authentication + +As of [replicate 1.0.5](https://github.com/replicate/replicate-python/releases/tag/1.0.5) and [cog 0.14.11](https://github.com/replicate/cog/releases/tag/v0.14.11) it is possible to pass a `REPLICATE_API_TOKEN` via the `context` as part of a prediction request. + +The `Replicate()` constructor will now use this context when available. This grants cog models the ability to use the Replicate client libraries, scoped to a user on a per request basis. + +
+ ## Run a model Create a new Python file and add the following code, replacing the model identifier and input with your own: From 903e6afd59a253539608e1d531ec1a1f3715d298 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 23 May 2025 16:25:09 +0100 Subject: [PATCH 3/6] Make scope lookup case insensitive --- replicate/client.py | 9 ++++---- tests/test_client.py | 53 +++++++++++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 164a5d9..fcff0cc 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -355,10 +355,11 @@ def _get_api_token_from_environment() -> Optional[str]: if hasattr(cog, "current_scope"): scope = cog.current_scope() - if scope and hasattr(scope, "content") and isinstance(scope.content, dict): - if "replicate_api_token" in scope.content: - return scope.content["replicate_api_token"] - except (ImportError, AttributeError, Exception): + if scope and hasattr(scope, "context") and isinstance(scope.context, dict): + for key, value in scope.context.items(): + if key.upper() == "REPLICATE_API_TOKEN": + return scope.context[key] + except: # noqa: S110,E722,BLE001 we don't want this code to cause clients to fail pass return os.environ.get("REPLICATE_API_TOKEN") diff --git a/tests/test_client.py b/tests/test_client.py index be0469e..6ba6aea 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -127,7 +127,7 @@ def test_cog_not_available_falls_back_to_env(self): with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): with mock.patch.dict(sys.modules, {"cog": None}): token = _get_api_token_from_environment() - assert token == "env-token" + assert token == "env-token" # noqa: S105 def test_cog_import_error_falls_back_to_env(self): """Test fallback to environment when cog import raises exception.""" @@ -137,7 +137,7 @@ def test_cog_import_error_falls_back_to_env(self): side_effect=ModuleNotFoundError("No module named 'cog'"), ): token = _get_api_token_from_environment() - assert token == "env-token" + assert token == "env-token" # noqa: S105 def test_cog_no_current_scope_method_falls_back_to_env(self): """Test fallback when cog exists but has no current_scope method.""" @@ -147,7 +147,7 @@ def test_cog_no_current_scope_method_falls_back_to_env(self): with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): with mock.patch.dict(sys.modules, {"cog": mock_cog}): token = _get_api_token_from_environment() - assert token == "env-token" + assert token == "env-token" # noqa: S105 def test_cog_current_scope_returns_none_falls_back_to_env(self): """Test fallback when current_scope() returns None.""" @@ -157,12 +157,12 @@ def test_cog_current_scope_returns_none_falls_back_to_env(self): with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): with mock.patch.dict(sys.modules, {"cog": mock_cog}): token = _get_api_token_from_environment() - assert token == "env-token" + assert token == "env-token" # noqa: S105 - def test_cog_scope_no_content_attr_falls_back_to_env(self): - """Test fallback when scope has no content attribute.""" + def test_cog_scope_no_context_attr_falls_back_to_env(self): + """Test fallback when scope has no context attribute.""" mock_scope = mock.MagicMock() - del mock_scope.content # Remove the content attribute + del mock_scope.context # Remove the context attribute mock_cog = mock.MagicMock() mock_cog.current_scope.return_value = mock_scope @@ -170,12 +170,12 @@ def test_cog_scope_no_content_attr_falls_back_to_env(self): with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): with mock.patch.dict(sys.modules, {"cog": mock_cog}): token = _get_api_token_from_environment() - assert token == "env-token" + assert token == "env-token" # noqa: S105 - def test_cog_scope_content_not_dict_falls_back_to_env(self): - """Test fallback when scope.content is not a dictionary.""" + def test_cog_scope_context_not_dict_falls_back_to_env(self): + """Test fallback when scope.context is not a dictionary.""" mock_scope = mock.MagicMock() - mock_scope.content = "not a dict" + mock_scope.context = "not a dict" mock_cog = mock.MagicMock() mock_cog.current_scope.return_value = mock_scope @@ -183,12 +183,12 @@ def test_cog_scope_content_not_dict_falls_back_to_env(self): with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): with mock.patch.dict(sys.modules, {"cog": mock_cog}): token = _get_api_token_from_environment() - assert token == "env-token" + assert token == "env-token" # noqa: S105 def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self): - """Test fallback when replicate_api_token key is missing from content.""" + """Test fallback when replicate_api_token key is missing from context.""" mock_scope = mock.MagicMock() - mock_scope.content = {"other_key": "other_value"} # Missing replicate_api_token + mock_scope.context = {"other_key": "other_value"} # Missing replicate_api_token mock_cog = mock.MagicMock() mock_cog.current_scope.return_value = mock_scope @@ -196,12 +196,12 @@ def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self): with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): with mock.patch.dict(sys.modules, {"cog": mock_cog}): token = _get_api_token_from_environment() - assert token == "env-token" + assert token == "env-token" # noqa: S105 def test_cog_scope_replicate_api_token_valid_string(self): """Test successful retrieval of non-empty token from cog.""" mock_scope = mock.MagicMock() - mock_scope.content = {"replicate_api_token": "cog-token"} + mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} mock_cog = mock.MagicMock() mock_cog.current_scope.return_value = mock_scope @@ -209,12 +209,25 @@ def test_cog_scope_replicate_api_token_valid_string(self): with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): with mock.patch.dict(sys.modules, {"cog": mock_cog}): token = _get_api_token_from_environment() - assert token == "cog-token" + assert token == "cog-token" # noqa: S105 + + def test_cog_scope_replicate_api_token_case_insensitive(self): + """Test successful retrieval of non-empty token from cog ignoring case.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"replicate_api_token": "cog-token"} + + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "cog-token" # noqa: S105 def test_cog_scope_replicate_api_token_empty_string(self): """Test that empty string from cog is returned (not falling back to env).""" mock_scope = mock.MagicMock() - mock_scope.content = {"replicate_api_token": ""} # Empty string + mock_scope.context = {"replicate_api_token": ""} # Empty string mock_cog = mock.MagicMock() mock_cog.current_scope.return_value = mock_scope @@ -227,7 +240,7 @@ def test_cog_scope_replicate_api_token_empty_string(self): def test_cog_scope_replicate_api_token_none(self): """Test that None from cog is returned (not falling back to env).""" mock_scope = mock.MagicMock() - mock_scope.content = {"replicate_api_token": None} + mock_scope.context = {"replicate_api_token": None} mock_cog = mock.MagicMock() mock_cog.current_scope.return_value = mock_scope @@ -245,7 +258,7 @@ def test_cog_current_scope_raises_exception_falls_back_to_env(self): with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): with mock.patch.dict(sys.modules, {"cog": mock_cog}): token = _get_api_token_from_environment() - assert token == "env-token" + assert token == "env-token" # noqa: S105 def test_no_env_token_returns_none(self): """Test that None is returned when no environment token is set and cog unavailable.""" From e358745fde5139b0fb148bebb14a3541af0d4e76 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 23 May 2025 16:25:13 +0100 Subject: [PATCH 4/6] Linting --- replicate/helpers.py | 10 ++++------ replicate/run.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index c6ac907..62c7d67 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -44,8 +44,7 @@ def encode_json( if isinstance(obj, io.IOBase): if file_encoding_strategy == "base64": return base64_encode_file(obj) - else: - return client.files.create(obj).urls["get"] + return client.files.create(obj).urls["get"] if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) @@ -82,8 +81,7 @@ async def async_encode_json( if file_encoding_strategy == "base64": # TODO: This should ideally use an async based file reader path. return base64_encode_file(obj) - else: - return (await client.files.async_create(obj)).urls["get"] + return (await client.files.async_create(obj)).urls["get"] if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) @@ -183,9 +181,9 @@ def transform_output(value: Any, client: "Client") -> Any: def transform(obj: Any) -> Any: if isinstance(obj, Mapping): return {k: transform(v) for k, v in obj.items()} - elif isinstance(obj, Sequence) and not isinstance(obj, str): + if isinstance(obj, Sequence) and not isinstance(obj, str): return [transform(item) for item in obj] - elif isinstance(obj, str) and ( + if isinstance(obj, str) and ( obj.startswith("https:") or obj.startswith("data:") ): return FileOutput(obj, client) diff --git a/replicate/run.py b/replicate/run.py index 19db492..e82ffb4 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -38,7 +38,7 @@ def run( if "wait" not in params: params["wait"] = True - is_blocking = params["wait"] != False # noqa: E712 + is_blocking = params["wait"] is not False version, owner, name, version_id = identifier._resolve(ref) @@ -108,7 +108,7 @@ async def async_run( if "wait" not in params: params["wait"] = True - is_blocking = params["wait"] != False # noqa: E712 + is_blocking = params["wait"] is not False version, owner, name, version_id = identifier._resolve(ref) From 073b368596eed2f7586f3a21abcb96f74cf0fa1b Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 23 May 2025 17:00:36 +0100 Subject: [PATCH 5/6] Update replicate/client.py Co-authored-by: Philip Potter --- replicate/client.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index fcff0cc..122c65a 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -353,12 +353,9 @@ def _get_api_token_from_environment() -> Optional[str]: try: import cog - if hasattr(cog, "current_scope"): - scope = cog.current_scope() - if scope and hasattr(scope, "context") and isinstance(scope.context, dict): - for key, value in scope.context.items(): - if key.upper() == "REPLICATE_API_TOKEN": - return scope.context[key] + for key, value in cog.current_scope().context.items(): + if key.upper() == "REPLICATE_API_TOKEN": + return value except: # noqa: S110,E722,BLE001 we don't want this code to cause clients to fail pass From 855a5969d1eea166e891cfeb52e3de5e5b7b242c Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Tue, 27 May 2025 11:42:41 +0100 Subject: [PATCH 6/6] format and lint --- replicate/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/client.py b/replicate/client.py index 122c65a..6a79813 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -351,7 +351,7 @@ def close(self) -> None: def _get_api_token_from_environment() -> Optional[str]: """Get API token from cog current scope if available, otherwise from environment.""" try: - import cog + import cog # noqa: I001 # pyright: ignore [reportMissingImports] for key, value in cog.current_scope().context.items(): if key.upper() == "REPLICATE_API_TOKEN":