Skip to content

Commit 315ce7a

Browse files
committed
Issue #254/#691 introduce _on_auth_update handler
- to make sure all cases are covered - include authenticate_oidc_access_token
1 parent 1923035 commit 315ce7a

File tree

3 files changed

+93
-56
lines changed

3 files changed

+93
-56
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626

2727
### Fixed
2828

29+
- Clear capabilities cache on login ([#254](https://github.com/Open-EO/openeo-python-client/issues/254))
30+
2931

3032
## [0.36.0] - 2024-12-10
3133

openeo/rest/connection.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
slow_response_threshold: Optional[float] = None,
114114
):
115115
self._root_url = root_url
116+
self._auth = None
116117
self.auth = auth or NullAuth()
117118
self.session = session or requests.Session()
118119
self.default_timeout = default_timeout or DEFAULT_TIMEOUT
@@ -129,6 +130,18 @@ def __init__(
129130
def root_url(self):
130131
return self._root_url
131132

133+
@property
134+
def auth(self) -> Union[AuthBase, None]:
135+
return self._auth
136+
137+
@auth.setter
138+
def auth(self, auth: Union[AuthBase, None]):
139+
self._auth = auth
140+
self._on_auth_update()
141+
142+
def _on_auth_update(self):
143+
pass
144+
132145
def build_url(self, path: str):
133146
return url_join(self._root_url, path)
134147

@@ -340,12 +353,12 @@ def __init__(
340353
if "://" not in url:
341354
url = "https://" + url
342355
self._orig_url = url
356+
self._capabilities_cache = LazyLoadCache()
343357
super().__init__(
344358
root_url=self.version_discovery(url, session=session, timeout=default_timeout),
345359
auth=auth, session=session, default_timeout=default_timeout,
346360
slow_response_threshold=slow_response_threshold,
347361
)
348-
self._capabilities_cache = LazyLoadCache()
349362

350363
# Initial API version check.
351364
self._api_version.require_at_least(self._MINIMUM_API_VERSION)
@@ -380,6 +393,10 @@ def version_discovery(
380393
# Be very lenient about failing on the well-known URI strategy.
381394
return url
382395

396+
def _on_auth_update(self):
397+
super()._on_auth_update()
398+
self._capabilities_cache.clear()
399+
383400
def _get_auth_config(self) -> AuthConfig:
384401
if self._auth_config is None:
385402
self._auth_config = AuthConfig()
@@ -411,7 +428,6 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[
411428
).json()
412429
# Switch to bearer based authentication in further requests.
413430
self.auth = BasicBearerAuth(access_token=resp["access_token"])
414-
self._capabilities_cache.clear()
415431
return self
416432

417433
def _get_oidc_provider(
@@ -546,7 +562,6 @@ def _authenticate_oidc(
546562
_log.warning("No OIDC refresh token to store.")
547563
token = tokens.access_token
548564
self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token)
549-
self._capabilities_cache.clear()
550565
self._oidc_auth_renewer = oidc_auth_renewer
551566
return self
552567

tests/rest/test_connection.py

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
API_URL = "https://oeo.test/"
5151

52+
# TODO: eliminate this and replace with `build_capabilities` usage
5253
BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}]
5354

5455

@@ -551,83 +552,102 @@ def test_capabilities_caching(requests_mock):
551552
assert con.capabilities().api_version() == "1.0.0"
552553
assert m.call_count == 1
553554

554-
def test_capabilities_caching_after_authenticate_basic(requests_mock):
555-
user, pwd = "john262", "J0hndo3"
556555

557-
def get_capabilities(request, context):
558-
endpoints = BASIC_ENDPOINTS.copy()
559-
if "Authorization" in request.headers:
560-
endpoints.append({"path": "/account/status", "methods": ["GET"]})
561-
return {"api_version": "1.0.0", "endpoints": endpoints}
556+
def _get_capabilities_auth_dependent(request, context):
557+
capabilities = build_capabilities()
558+
capabilities["endpoints"] = [
559+
{"methods": ["GET"], "path": "/credentials/basic"},
560+
{"methods": ["GET"], "path": "/credentials/oidc"},
561+
]
562+
if "Authorization" in request.headers:
563+
capabilities["endpoints"].append({"methods": ["GET"], "path": "/me"})
564+
return capabilities
565+
562566

563-
get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities)
567+
def test_capabilities_caching_after_authenticate_basic(requests_mock):
568+
user, pwd = "john262", "J0hndo3"
569+
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
564570
requests_mock.get(API_URL + 'credentials/basic', text=_credentials_basic_handler(user, pwd))
565571

566572
con = Connection(API_URL)
567-
assert con.capabilities().capabilities == {
568-
"api_version": "1.0.0",
569-
"endpoints": [
570-
{"methods": ["GET"], "path": "/credentials/basic"},
571-
],
572-
}
573+
assert con.capabilities().capabilities["endpoints"] == [
574+
{"methods": ["GET"], "path": "/credentials/basic"},
575+
{"methods": ["GET"], "path": "/credentials/oidc"},
576+
]
573577
assert get_capabilities_mock.call_count == 1
574578
con.capabilities()
575579
assert get_capabilities_mock.call_count == 1
576580

577-
con.authenticate_basic(user, pwd)
581+
con.authenticate_basic(username=user, password=pwd)
578582
assert get_capabilities_mock.call_count == 1
579-
assert con.capabilities().capabilities == {
580-
"api_version": "1.0.0",
581-
"endpoints": [
582-
{"methods": ["GET"], "path": "/credentials/basic"},
583-
{"methods": ["GET"], "path": "/account/status"},
584-
],
585-
}
586-
assert get_capabilities_mock.call_count == 2
583+
assert con.capabilities().capabilities["endpoints"] == [
584+
{"methods": ["GET"], "path": "/credentials/basic"},
585+
{"methods": ["GET"], "path": "/credentials/oidc"},
586+
{"methods": ["GET"], "path": "/me"},
587+
]
587588

589+
assert get_capabilities_mock.call_count == 2
588590

589591

590-
def test_capabilities_caching_after_authenticate_oidc(requests_mock):
592+
def test_capabilities_caching_after_authenticate_oidc_refresh_token(requests_mock):
591593
client_id = "myclient"
592-
593-
def get_capabilities(request, context):
594-
endpoints = BASIC_ENDPOINTS.copy()
595-
if "Authorization" in request.headers:
596-
endpoints.append({"path": "/account/status", "methods": ["GET"]})
597-
return {"api_version": "1.0.0", "endpoints": endpoints}
598-
599-
get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities)
600-
requests_mock.get(API_URL + 'credentials/oidc', json={
601-
"providers": [{"id": "fauth", "issuer": "https://fauth.test", "title": "Foo Auth", "scopes": ["openid", "im"]}]
602-
})
594+
refresh_token = "fr65h!"
595+
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
596+
requests_mock.get(
597+
API_URL + "credentials/oidc",
598+
json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]},
599+
)
603600
oidc_mock = OidcMock(
604601
requests_mock=requests_mock,
605-
expected_grant_type="authorization_code",
602+
expected_grant_type="refresh_token",
606603
expected_client_id=client_id,
607-
expected_fields={"scope": "im openid"},
608-
oidc_issuer="https://fauth.test",
609-
scopes_supported=["openid", "im"],
604+
expected_fields={"refresh_token": refresh_token},
610605
)
606+
611607
conn = Connection(API_URL)
612-
assert conn.capabilities().capabilities == {
613-
"api_version": "1.0.0",
614-
"endpoints": [
615-
{"methods": ["GET"], "path": "/credentials/basic"},
616-
],
617-
}
608+
assert conn.capabilities().capabilities["endpoints"] == [
609+
{"methods": ["GET"], "path": "/credentials/basic"},
610+
{"methods": ["GET"], "path": "/credentials/oidc"},
611+
]
612+
618613
assert get_capabilities_mock.call_count == 1
619614
conn.capabilities()
620615
assert get_capabilities_mock.call_count == 1
621616

622-
conn.authenticate_oidc_authorization_code(client_id=client_id, webbrowser_open=oidc_mock.webbrowser_open)
617+
conn.authenticate_oidc_refresh_token(client_id=client_id, refresh_token=refresh_token)
623618
assert get_capabilities_mock.call_count == 1
624-
assert conn.capabilities().capabilities == {
625-
"api_version": "1.0.0",
626-
"endpoints": [
627-
{"methods": ["GET"], "path": "/credentials/basic"},
628-
{"methods": ["GET"], "path": "/account/status"},
629-
],
630-
}
619+
assert conn.capabilities().capabilities["endpoints"] == [
620+
{"methods": ["GET"], "path": "/credentials/basic"},
621+
{"methods": ["GET"], "path": "/credentials/oidc"},
622+
{"methods": ["GET"], "path": "/me"},
623+
]
624+
assert get_capabilities_mock.call_count == 2
625+
626+
627+
def test_capabilities_caching_after_authenticate_oidc_access_token(requests_mock):
628+
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
629+
requests_mock.get(
630+
API_URL + "credentials/oidc",
631+
json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]},
632+
)
633+
634+
conn = Connection(API_URL)
635+
assert conn.capabilities().capabilities["endpoints"] == [
636+
{"methods": ["GET"], "path": "/credentials/basic"},
637+
{"methods": ["GET"], "path": "/credentials/oidc"},
638+
]
639+
640+
assert get_capabilities_mock.call_count == 1
641+
conn.capabilities()
642+
assert get_capabilities_mock.call_count == 1
643+
644+
conn.authenticate_oidc_access_token(access_token="6cc355!")
645+
assert get_capabilities_mock.call_count == 1
646+
assert conn.capabilities().capabilities["endpoints"] == [
647+
{"methods": ["GET"], "path": "/credentials/basic"},
648+
{"methods": ["GET"], "path": "/credentials/oidc"},
649+
{"methods": ["GET"], "path": "/me"},
650+
]
631651
assert get_capabilities_mock.call_count == 2
632652

633653

0 commit comments

Comments
 (0)