Skip to content

Read REPLICATE_API_TOKEN from cog context #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ export REPLICATE_API_TOKEN=<your token>

We recommend not adding the token directly to your source code, because you don't want to put your credentials in source control. If anyone used your API key, their usage would be charged to your account.

<details>

<summary>Alternative authentication</summary>

As of [replicate 1.0.5](https://github.com/replicate/replicate-python/releases/tag/1.0.5) and [cog 0.14.11](https://github.com/replicate/cog/releases/tag/v0.14.11) it is possible to pass a `REPLICATE_API_TOKEN` via the `context` as part of a prediction request.

The `Replicate()` constructor will now use this context when available. This grants cog models the ability to use the Replicate client libraries, scoped to a user on a per request basis.

</details>

## Run a model

Create a new Python file and add the following code, replacing the model identifier and input with your own:
Expand Down
16 changes: 15 additions & 1 deletion replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,20 @@ def close(self) -> None:
self._wrapped_transport.close() # type: ignore


def _get_api_token_from_environment() -> Optional[str]:
"""Get API token from cog current scope if available, otherwise from environment."""
try:
import cog # noqa: I001 # pyright: ignore [reportMissingImports]

for key, value in cog.current_scope().context.items():
if key.upper() == "REPLICATE_API_TOKEN":
return value
except: # noqa: S110,E722,BLE001 we don't want this code to cause clients to fail
pass

return os.environ.get("REPLICATE_API_TOKEN")


def _build_httpx_client(
client_type: Type[Union[httpx.Client, httpx.AsyncClient]],
api_token: Optional[str] = None,
Expand All @@ -359,7 +373,7 @@ def _build_httpx_client(
if "User-Agent" not in headers:
headers["User-Agent"] = f"replicate-python/{__version__}"
if "Authorization" not in headers and (
api_token := api_token or os.environ.get("REPLICATE_API_TOKEN")
api_token := api_token or _get_api_token_from_environment()
):
headers["Authorization"] = f"Bearer {api_token}"

Expand Down
10 changes: 4 additions & 6 deletions replicate/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def encode_json(
if isinstance(obj, io.IOBase):
if file_encoding_strategy == "base64":
return base64_encode_file(obj)
else:
return client.files.create(obj).urls["get"]
return client.files.create(obj).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand Down Expand Up @@ -82,8 +81,7 @@ async def async_encode_json(
if file_encoding_strategy == "base64":
# TODO: This should ideally use an async based file reader path.
return base64_encode_file(obj)
else:
return (await client.files.async_create(obj)).urls["get"]
return (await client.files.async_create(obj)).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand Down Expand Up @@ -183,9 +181,9 @@ def transform_output(value: Any, client: "Client") -> Any:
def transform(obj: Any) -> Any:
if isinstance(obj, Mapping):
return {k: transform(v) for k, v in obj.items()}
elif isinstance(obj, Sequence) and not isinstance(obj, str):
if isinstance(obj, Sequence) and not isinstance(obj, str):
return [transform(item) for item in obj]
elif isinstance(obj, str) and (
if isinstance(obj, str) and (
obj.startswith("https:") or obj.startswith("data:")
):
return FileOutput(obj, client)
Expand Down
4 changes: 2 additions & 2 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run(

if "wait" not in params:
params["wait"] = True
is_blocking = params["wait"] != False # noqa: E712
is_blocking = params["wait"] is not False

version, owner, name, version_id = identifier._resolve(ref)

Expand Down Expand Up @@ -108,7 +108,7 @@ async def async_run(

if "wait" not in params:
params["wait"] = True
is_blocking = params["wait"] != False # noqa: E712
is_blocking = params["wait"] is not False

version, owner, name, version_id = identifier._resolve(ref)

Expand Down
159 changes: 159 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import sys
from unittest import mock

import httpx
import pytest
import respx

from replicate.client import _get_api_token_from_environment


@pytest.mark.asyncio
async def test_authorization_when_setting_environ_after_import():
Expand Down Expand Up @@ -114,3 +117,159 @@ def mock_send(request):
pass

mock_send_wrapper.assert_called_once()


class TestGetApiToken:
"""Test cases for _get_api_token_from_environment function covering all import paths."""

def test_cog_not_available_falls_back_to_env(self):
"""Test fallback to environment when cog package is not available."""
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": None}):
token = _get_api_token_from_environment()
assert token == "env-token" # noqa: S105

def test_cog_import_error_falls_back_to_env(self):
"""Test fallback to environment when cog import raises exception."""
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch(
"builtins.__import__",
side_effect=ModuleNotFoundError("No module named 'cog'"),
):
token = _get_api_token_from_environment()
assert token == "env-token" # noqa: S105

def test_cog_no_current_scope_method_falls_back_to_env(self):
"""Test fallback when cog exists but has no current_scope method."""
mock_cog = mock.MagicMock()
del mock_cog.current_scope # Remove the method

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "env-token" # noqa: S105

def test_cog_current_scope_returns_none_falls_back_to_env(self):
"""Test fallback when current_scope() returns None."""
mock_cog = mock.MagicMock()
mock_cog.current_scope.return_value = None

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "env-token" # noqa: S105

def test_cog_scope_no_context_attr_falls_back_to_env(self):
"""Test fallback when scope has no context attribute."""
mock_scope = mock.MagicMock()
del mock_scope.context # Remove the context attribute

mock_cog = mock.MagicMock()
mock_cog.current_scope.return_value = mock_scope

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "env-token" # noqa: S105

def test_cog_scope_context_not_dict_falls_back_to_env(self):
"""Test fallback when scope.context is not a dictionary."""
mock_scope = mock.MagicMock()
mock_scope.context = "not a dict"

mock_cog = mock.MagicMock()
mock_cog.current_scope.return_value = mock_scope

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "env-token" # noqa: S105

def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self):
"""Test fallback when replicate_api_token key is missing from context."""
mock_scope = mock.MagicMock()
mock_scope.context = {"other_key": "other_value"} # Missing replicate_api_token

mock_cog = mock.MagicMock()
mock_cog.current_scope.return_value = mock_scope

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "env-token" # noqa: S105

def test_cog_scope_replicate_api_token_valid_string(self):
"""Test successful retrieval of non-empty token from cog."""
mock_scope = mock.MagicMock()
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}

mock_cog = mock.MagicMock()
mock_cog.current_scope.return_value = mock_scope

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "cog-token" # noqa: S105

def test_cog_scope_replicate_api_token_case_insensitive(self):
"""Test successful retrieval of non-empty token from cog ignoring case."""
mock_scope = mock.MagicMock()
mock_scope.context = {"replicate_api_token": "cog-token"}

mock_cog = mock.MagicMock()
mock_cog.current_scope.return_value = mock_scope

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "cog-token" # noqa: S105

def test_cog_scope_replicate_api_token_empty_string(self):
"""Test that empty string from cog is returned (not falling back to env)."""
mock_scope = mock.MagicMock()
mock_scope.context = {"replicate_api_token": ""} # Empty string

mock_cog = mock.MagicMock()
mock_cog.current_scope.return_value = mock_scope

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "" # Should return empty string, not env token

def test_cog_scope_replicate_api_token_none(self):
"""Test that None from cog is returned (not falling back to env)."""
mock_scope = mock.MagicMock()
mock_scope.context = {"replicate_api_token": None}

mock_cog = mock.MagicMock()
mock_cog.current_scope.return_value = mock_scope

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token is None # Should return None, not env token

def test_cog_current_scope_raises_exception_falls_back_to_env(self):
"""Test fallback when current_scope() raises an exception."""
mock_cog = mock.MagicMock()
mock_cog.current_scope.side_effect = RuntimeError("Scope error")

with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
token = _get_api_token_from_environment()
assert token == "env-token" # noqa: S105

def test_no_env_token_returns_none(self):
"""Test that None is returned when no environment token is set and cog unavailable."""
with mock.patch.dict(os.environ, {}, clear=True): # Clear all env vars
with mock.patch.dict(sys.modules, {"cog": None}):
token = _get_api_token_from_environment()
assert token is None

def test_env_token_empty_string(self):
"""Test that empty string from environment is returned."""
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": ""}):
with mock.patch.dict(sys.modules, {"cog": None}):
token = _get_api_token_from_environment()
assert token == ""
Loading