Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/unit/oidc/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ def test_find_publisher_full_pending(self, oidc_service):
repository_owner="foo",
repository_owner_id="123",
workflow_filename="example.yml",
environment=None,
)

claims = {
Expand Down Expand Up @@ -862,6 +863,7 @@ def test_find_publisher_full(self, oidc_service):
repository_owner="foo",
repository_owner_id="123",
workflow_filename="example.yml",
environment=None,
)

claims = {
Expand Down
72 changes: 36 additions & 36 deletions tests/unit/oidc/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pretend
import uuid

from sqlalchemy.sql.expression import func, literal
import pretend
import pytest

from tests.common.db.oidc import GitHubPublisherFactory
from warehouse.oidc import utils
from warehouse.oidc.models import GitHubPublisher


def test_find_publisher_by_issuer_bad_issuer_url():
Expand All @@ -27,46 +28,45 @@ def test_find_publisher_by_issuer_bad_issuer_url():
)


def test_find_publisher_by_issuer_github():
publisher = pretend.stub()
one_or_none = pretend.call_recorder(lambda: publisher)
filter_ = pretend.call_recorder(lambda *a: pretend.stub(one_or_none=one_or_none))
filter_by = pretend.call_recorder(lambda **kw: pretend.stub(filter=filter_))
session = pretend.stub(
query=pretend.call_recorder(lambda cls: pretend.stub(filter_by=filter_by))
@pytest.mark.parametrize(
"environment, expected_id",
[
(None, uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")),
("some_other_environment", uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")),
("some_environment", uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")),
],
)
def test_find_publisher_by_issuer_github(db_request, environment, expected_id):
GitHubPublisherFactory(
id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
repository_owner="foo",
repository_name="bar",
repository_owner_id="1234",
workflow_filename="ci.yml",
environment=None, # No environment
)
GitHubPublisherFactory(
id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb",
repository_owner="foo",
repository_name="bar",
repository_owner_id="1234",
workflow_filename="ci.yml",
environment="some_environment", # Environment set
)

signed_claims = {
"repository": "foo/bar",
"job_workflow_ref": "foo/bar/.github/workflows/ci.yml@refs/heads/main",
"repository_owner_id": "1234",
}
if environment:
signed_claims["environment"] = environment

assert (
utils.find_publisher_by_issuer(
session, "https://token.actions.githubusercontent.com", signed_claims
)
== publisher
)

assert session.query.calls == [pretend.call(GitHubPublisher)]
assert filter_by.calls == [
pretend.call(
repository_name="bar", repository_owner="foo", repository_owner_id="1234"
)
]

# SQLAlchemy BinaryExpression objects don't support comparison with __eq__,
# so we need to dig into the callset and compare the argument manually.
assert len(filter_.calls) == 1
assert len(filter_.calls[0].args) == 1
assert (
filter_.calls[0]
.args[0]
.compare(
literal("ci.yml@refs/heads/main").like(
func.concat(GitHubPublisher.workflow_filename, "%")
)
)
db_request.db,
"https://token.actions.githubusercontent.com",
signed_claims,
).id
== expected_id
)

assert one_or_none.calls == [pretend.call()]
2 changes: 2 additions & 0 deletions tests/unit/oidc/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def test_mint_token_from_oidc_pending_publisher_ok(
repository_owner="foo",
repository_owner_id="123",
workflow_filename="example.yml",
environment=None,
)

db_request.registry.settings = {"warehouse.oidc.enabled": True}
Expand Down Expand Up @@ -306,6 +307,7 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others(
repository_owner="foo",
repository_owner_id="123",
workflow_filename="example.yml",
environment=None,
)

# Create some other pending publishers for the same nonexistent project,
Expand Down
49 changes: 38 additions & 11 deletions warehouse/oidc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,47 @@ def find_publisher_by_issuer(session, issuer_url, signed_claims, *, pending=Fals

publisher_cls = GitHubPublisher if not pending else PendingGitHubPublisher

return (
session.query(publisher_cls)
.filter_by(
repository_name=repository_name,
repository_owner=repository_owner,
repository_owner_id=signed_claims["repository_owner_id"],
publisher = None
# If an environment exists in the claim set, try finding a publisher
# that matches the provided environment first.
if environment := signed_claims.get("environment"):
publisher = (
session.query(publisher_cls)
.filter_by(
repository_name=repository_name,
repository_owner=repository_owner,
repository_owner_id=signed_claims["repository_owner_id"],
environment=environment,
)
.filter(
literal(workflow_ref).like(
func.concat(publisher_cls.workflow_filename, "%")
)
)
.one_or_none()
)
.filter(
literal(workflow_ref).like(
func.concat(publisher_cls.workflow_filename, "%")

# There are no publishers for that specific environment, try finding a
# publisher without a restriction on the environment
if not publisher:
publisher = (
session.query(publisher_cls)
.filter_by(
repository_name=repository_name,
repository_owner=repository_owner,
repository_owner_id=signed_claims["repository_owner_id"],
environment=None,
)
.filter(
literal(workflow_ref).like(
func.concat(publisher_cls.workflow_filename, "%")
)
)
.one_or_none()
)
.one_or_none()
)

return publisher

else:
# Unreachable; same logic error as above.
return None # pragma: no cover