Skip to content

Commit e5db5ab

Browse files
authored
warehouse, tests: handle malformed JWTs gracefully (#13541)
* warehouse, tests: handle malformed JWTs gracefully * oidc/services: scope sentry on exception Signed-off-by: William Woodruff <[email protected]> * oidc/services: even more scopeage Signed-off-by: William Woodruff <[email protected]> * test: fix OIDC service tests Use a slightly less invasive monkeypatch/stub. Signed-off-by: William Woodruff <[email protected]> --------- Signed-off-by: William Woodruff <[email protected]>
1 parent 93dc80b commit e5db5ab

File tree

2 files changed

+70
-12
lines changed

2 files changed

+70
-12
lines changed

tests/unit/oidc/test_services.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from cryptography.hazmat.primitives.asymmetric import rsa
18-
from jwt import PyJWK, PyJWTError, algorithms
18+
from jwt import DecodeError, PyJWK, PyJWTError, algorithms
1919
from zope.interface.verify import verifyClass
2020

2121
from tests.common.db.oidc import GitHubPublisherFactory, PendingGitHubPublisherFactory
@@ -104,6 +104,44 @@ def test_verify_jwt_signature(self, monkeypatch):
104104
)
105105
]
106106

107+
@pytest.mark.parametrize("exc", [DecodeError, TypeError("foo")])
108+
def test_verify_jwt_signature_get_key_for_token_fails(self, monkeypatch, exc):
109+
service = services.OIDCPublisherService(
110+
session=pretend.stub(),
111+
publisher="fakepublisher",
112+
issuer_url=pretend.stub(),
113+
audience="fakeaudience",
114+
cache_url=pretend.stub(),
115+
metrics=pretend.stub(
116+
increment=pretend.call_recorder(lambda *a, **kw: None)
117+
),
118+
)
119+
120+
token = pretend.stub()
121+
jwt = pretend.stub(decode=pretend.raiser(exc), PyJWTError=PyJWTError)
122+
monkeypatch.setattr(service, "_get_key_for_token", pretend.raiser(exc))
123+
monkeypatch.setattr(services, "jwt", jwt)
124+
monkeypatch.setattr(
125+
services.sentry_sdk,
126+
"capture_message",
127+
pretend.call_recorder(lambda s: None),
128+
)
129+
130+
assert service.verify_jwt_signature(token) is None
131+
assert service.metrics.increment.calls == [
132+
pretend.call(
133+
"warehouse.oidc.verify_jwt_signature.malformed_jwt",
134+
tags=["publisher:fakepublisher"],
135+
)
136+
]
137+
138+
if exc != DecodeError:
139+
assert services.sentry_sdk.capture_message.calls == [
140+
pretend.call(f"JWT backend raised generic error: {exc}")
141+
]
142+
else:
143+
assert services.sentry_sdk.capture_message.calls == []
144+
107145
@pytest.mark.parametrize("exc", [PyJWTError, TypeError("foo")])
108146
def test_verify_jwt_signature_fails(self, monkeypatch, exc):
109147
service = services.OIDCPublisherService(
@@ -124,9 +162,11 @@ def test_verify_jwt_signature_fails(self, monkeypatch, exc):
124162
service, "_get_key_for_token", pretend.call_recorder(lambda t: key)
125163
)
126164
monkeypatch.setattr(services, "jwt", jwt)
127-
128-
sentry_sdk = pretend.stub(capture_message=pretend.call_recorder(lambda s: None))
129-
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)
165+
monkeypatch.setattr(
166+
services.sentry_sdk,
167+
"capture_message",
168+
pretend.call_recorder(lambda s: None),
169+
)
130170

131171
assert service.verify_jwt_signature(token) is None
132172
assert service.metrics.increment.calls == [
@@ -137,11 +177,11 @@ def test_verify_jwt_signature_fails(self, monkeypatch, exc):
137177
]
138178

139179
if exc != PyJWTError:
140-
assert sentry_sdk.capture_message.calls == [
141-
pretend.call(f"JWT verify raised generic error: {exc}")
180+
assert services.sentry_sdk.capture_message.calls == [
181+
pretend.call(f"JWT backend raised generic error: {exc}")
142182
]
143183
else:
144-
assert sentry_sdk.capture_message.calls == []
184+
assert services.sentry_sdk.capture_message.calls == []
145185

146186
def test_find_publisher(self, monkeypatch):
147187
service = services.OIDCPublisherService(

warehouse/oidc/services.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,23 @@ def _get_key_for_token(self, token):
220220
return self._get_key(unverified_header["kid"])
221221

222222
def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None:
223-
key = self._get_key_for_token(unverified_token)
223+
try:
224+
key = self._get_key_for_token(unverified_token)
225+
except Exception as e:
226+
# The user might feed us an entirely nonsense JWT, e.g. one
227+
# with missing components.
228+
self.metrics.increment(
229+
"warehouse.oidc.verify_jwt_signature.malformed_jwt",
230+
tags=[f"publisher:{self.publisher}"],
231+
)
232+
233+
if not isinstance(e, jwt.PyJWTError):
234+
with sentry_sdk.push_scope() as scope:
235+
scope.fingerprint = e
236+
# Similar to below: Other exceptions indicate an abstraction
237+
# leak, so we log them for upstream reporting.
238+
sentry_sdk.capture_message(f"JWT backend raised generic error: {e}")
239+
return None
224240

225241
try:
226242
# NOTE: Many of the keyword arguments here are defaults, but we
@@ -252,10 +268,12 @@ def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None:
252268
tags=[f"publisher:{self.publisher}"],
253269
)
254270
if not isinstance(e, jwt.PyJWTError):
255-
# We expect pyjwt to only raise subclasses of PyJWTError, but
256-
# we can't enforce this. Other exceptions indicate an abstraction
257-
# leak, so we log them for upstream reporting.
258-
sentry_sdk.capture_message(f"JWT verify raised generic error: {e}")
271+
with sentry_sdk.push_scope() as scope:
272+
scope.fingerprint = e
273+
# We expect pyjwt to only raise subclasses of PyJWTError, but
274+
# we can't enforce this. Other exceptions indicate an abstraction
275+
# leak, so we log them for upstream reporting.
276+
sentry_sdk.capture_message(f"JWT backend raised generic error: {e}")
259277
return None
260278

261279
def find_publisher(

0 commit comments

Comments
 (0)