Skip to content

Commit 4e91963

Browse files
dimiketheman
andauthored
Fix find_publisher_by_issuer environment filter (#13566)
* Make the tests less mocked * Add some failing tests * Conditionally filter publisher search on environment * Fix additional tests * Apply suggestions from code review Co-authored-by: Mike Fiedler <[email protected]> * Clarify with conditionals * Linting --------- Co-authored-by: Mike Fiedler <[email protected]>
1 parent d841df1 commit 4e91963

File tree

4 files changed

+78
-47
lines changed

4 files changed

+78
-47
lines changed

tests/unit/oidc/test_services.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ def test_find_publisher_full_pending(self, oidc_service):
823823
repository_owner="foo",
824824
repository_owner_id="123",
825825
workflow_filename="example.yml",
826+
environment=None,
826827
)
827828

828829
claims = {
@@ -862,6 +863,7 @@ def test_find_publisher_full(self, oidc_service):
862863
repository_owner="foo",
863864
repository_owner_id="123",
864865
workflow_filename="example.yml",
866+
environment=None,
865867
)
866868

867869
claims = {

tests/unit/oidc/test_utils.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13-
import pretend
13+
import uuid
1414

15-
from sqlalchemy.sql.expression import func, literal
15+
import pretend
16+
import pytest
1617

18+
from tests.common.db.oidc import GitHubPublisherFactory
1719
from warehouse.oidc import utils
18-
from warehouse.oidc.models import GitHubPublisher
1920

2021

2122
def test_find_publisher_by_issuer_bad_issuer_url():
@@ -27,46 +28,45 @@ def test_find_publisher_by_issuer_bad_issuer_url():
2728
)
2829

2930

30-
def test_find_publisher_by_issuer_github():
31-
publisher = pretend.stub()
32-
one_or_none = pretend.call_recorder(lambda: publisher)
33-
filter_ = pretend.call_recorder(lambda *a: pretend.stub(one_or_none=one_or_none))
34-
filter_by = pretend.call_recorder(lambda **kw: pretend.stub(filter=filter_))
35-
session = pretend.stub(
36-
query=pretend.call_recorder(lambda cls: pretend.stub(filter_by=filter_by))
31+
@pytest.mark.parametrize(
32+
"environment, expected_id",
33+
[
34+
(None, uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")),
35+
("some_other_environment", uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")),
36+
("some_environment", uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")),
37+
],
38+
)
39+
def test_find_publisher_by_issuer_github(db_request, environment, expected_id):
40+
GitHubPublisherFactory(
41+
id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
42+
repository_owner="foo",
43+
repository_name="bar",
44+
repository_owner_id="1234",
45+
workflow_filename="ci.yml",
46+
environment=None, # No environment
47+
)
48+
GitHubPublisherFactory(
49+
id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb",
50+
repository_owner="foo",
51+
repository_name="bar",
52+
repository_owner_id="1234",
53+
workflow_filename="ci.yml",
54+
environment="some_environment", # Environment set
3755
)
56+
3857
signed_claims = {
3958
"repository": "foo/bar",
4059
"job_workflow_ref": "foo/bar/.github/workflows/ci.yml@refs/heads/main",
4160
"repository_owner_id": "1234",
4261
}
62+
if environment:
63+
signed_claims["environment"] = environment
4364

4465
assert (
4566
utils.find_publisher_by_issuer(
46-
session, "https://token.actions.githubusercontent.com", signed_claims
47-
)
48-
== publisher
49-
)
50-
51-
assert session.query.calls == [pretend.call(GitHubPublisher)]
52-
assert filter_by.calls == [
53-
pretend.call(
54-
repository_name="bar", repository_owner="foo", repository_owner_id="1234"
55-
)
56-
]
57-
58-
# SQLAlchemy BinaryExpression objects don't support comparison with __eq__,
59-
# so we need to dig into the callset and compare the argument manually.
60-
assert len(filter_.calls) == 1
61-
assert len(filter_.calls[0].args) == 1
62-
assert (
63-
filter_.calls[0]
64-
.args[0]
65-
.compare(
66-
literal("ci.yml@refs/heads/main").like(
67-
func.concat(GitHubPublisher.workflow_filename, "%")
68-
)
69-
)
67+
db_request.db,
68+
"https://token.actions.githubusercontent.com",
69+
signed_claims,
70+
).id
71+
== expected_id
7072
)
71-
72-
assert one_or_none.calls == [pretend.call()]

tests/unit/oidc/test_views.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def test_mint_token_from_oidc_pending_publisher_ok(
250250
repository_owner="foo",
251251
repository_owner_id="123",
252252
workflow_filename="example.yml",
253+
environment=None,
253254
)
254255

255256
db_request.registry.settings = {"warehouse.oidc.enabled": True}
@@ -306,6 +307,7 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others(
306307
repository_owner="foo",
307308
repository_owner_id="123",
308309
workflow_filename="example.yml",
310+
environment=None,
309311
)
310312

311313
# Create some other pending publishers for the same nonexistent project,

warehouse/oidc/utils.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,47 @@ def find_publisher_by_issuer(session, issuer_url, signed_claims, *, pending=Fals
4747

4848
publisher_cls = GitHubPublisher if not pending else PendingGitHubPublisher
4949

50-
return (
51-
session.query(publisher_cls)
52-
.filter_by(
53-
repository_name=repository_name,
54-
repository_owner=repository_owner,
55-
repository_owner_id=signed_claims["repository_owner_id"],
50+
publisher = None
51+
# If an environment exists in the claim set, try finding a publisher
52+
# that matches the provided environment first.
53+
if environment := signed_claims.get("environment"):
54+
publisher = (
55+
session.query(publisher_cls)
56+
.filter_by(
57+
repository_name=repository_name,
58+
repository_owner=repository_owner,
59+
repository_owner_id=signed_claims["repository_owner_id"],
60+
environment=environment,
61+
)
62+
.filter(
63+
literal(workflow_ref).like(
64+
func.concat(publisher_cls.workflow_filename, "%")
65+
)
66+
)
67+
.one_or_none()
5668
)
57-
.filter(
58-
literal(workflow_ref).like(
59-
func.concat(publisher_cls.workflow_filename, "%")
69+
70+
# There are no publishers for that specific environment, try finding a
71+
# publisher without a restriction on the environment
72+
if not publisher:
73+
publisher = (
74+
session.query(publisher_cls)
75+
.filter_by(
76+
repository_name=repository_name,
77+
repository_owner=repository_owner,
78+
repository_owner_id=signed_claims["repository_owner_id"],
79+
environment=None,
6080
)
81+
.filter(
82+
literal(workflow_ref).like(
83+
func.concat(publisher_cls.workflow_filename, "%")
84+
)
85+
)
86+
.one_or_none()
6187
)
62-
.one_or_none()
63-
)
88+
89+
return publisher
90+
6491
else:
6592
# Unreachable; same logic error as above.
6693
return None # pragma: no cover

0 commit comments

Comments
 (0)