10
10
# See the License for the specific language governing permissions and
11
11
# limitations under the License.
12
12
13
- import pretend
13
+ import uuid
14
14
15
- from sqlalchemy .sql .expression import func , literal
15
+ import pretend
16
+ import pytest
16
17
18
+ from tests .common .db .oidc import GitHubPublisherFactory
17
19
from warehouse .oidc import utils
18
- from warehouse .oidc .models import GitHubPublisher
19
20
20
21
21
22
def test_find_publisher_by_issuer_bad_issuer_url ():
@@ -27,46 +28,45 @@ def test_find_publisher_by_issuer_bad_issuer_url():
27
28
)
28
29
29
30
30
- def test_find_publisher_by_issuer_github ():
31
- publisher = pretend .stub ()
32
- one_or_none = pretend .call_recorder (lambda : publisher )
33
- filter_ = pretend .call_recorder (lambda * a : pretend .stub (one_or_none = one_or_none ))
34
- filter_by = pretend .call_recorder (lambda ** kw : pretend .stub (filter = filter_ ))
35
- session = pretend .stub (
36
- query = pretend .call_recorder (lambda cls : pretend .stub (filter_by = filter_by ))
31
+ @pytest .mark .parametrize (
32
+ "environment, expected_id" ,
33
+ [
34
+ (None , uuid .UUID ("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" )),
35
+ ("some_other_environment" , uuid .UUID ("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" )),
36
+ ("some_environment" , uuid .UUID ("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" )),
37
+ ],
38
+ )
39
+ def test_find_publisher_by_issuer_github (db_request , environment , expected_id ):
40
+ GitHubPublisherFactory (
41
+ id = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" ,
42
+ repository_owner = "foo" ,
43
+ repository_name = "bar" ,
44
+ repository_owner_id = "1234" ,
45
+ workflow_filename = "ci.yml" ,
46
+ environment = None , # No environment
47
+ )
48
+ GitHubPublisherFactory (
49
+ id = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" ,
50
+ repository_owner = "foo" ,
51
+ repository_name = "bar" ,
52
+ repository_owner_id = "1234" ,
53
+ workflow_filename = "ci.yml" ,
54
+ environment = "some_environment" , # Environment set
37
55
)
56
+
38
57
signed_claims = {
39
58
"repository" : "foo/bar" ,
40
59
"job_workflow_ref" : "foo/bar/.github/workflows/ci.yml@refs/heads/main" ,
41
60
"repository_owner_id" : "1234" ,
42
61
}
62
+ if environment :
63
+ signed_claims ["environment" ] = environment
43
64
44
65
assert (
45
66
utils .find_publisher_by_issuer (
46
- session , "https://token.actions.githubusercontent.com" , signed_claims
47
- )
48
- == publisher
49
- )
50
-
51
- assert session .query .calls == [pretend .call (GitHubPublisher )]
52
- assert filter_by .calls == [
53
- pretend .call (
54
- repository_name = "bar" , repository_owner = "foo" , repository_owner_id = "1234"
55
- )
56
- ]
57
-
58
- # SQLAlchemy BinaryExpression objects don't support comparison with __eq__,
59
- # so we need to dig into the callset and compare the argument manually.
60
- assert len (filter_ .calls ) == 1
61
- assert len (filter_ .calls [0 ].args ) == 1
62
- assert (
63
- filter_ .calls [0 ]
64
- .args [0 ]
65
- .compare (
66
- literal ("ci.yml@refs/heads/main" ).like (
67
- func .concat (GitHubPublisher .workflow_filename , "%" )
68
- )
69
- )
67
+ db_request .db ,
68
+ "https://token.actions.githubusercontent.com" ,
69
+ signed_claims ,
70
+ ).id
71
+ == expected_id
70
72
)
71
-
72
- assert one_or_none .calls == [pretend .call ()]
0 commit comments