Skip to content

Commit f7e277e

Browse files
committed
Give the IntegrityService access to the session
1 parent c9a774f commit f7e277e

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def activestate_oidc_service(db_session):
555555

556556
@pytest.fixture
557557
def integrity_service(db_session):
558-
return attestations_services.NullIntegrityService()
558+
return attestations_services.NullIntegrityService(db_session)
559559

560560

561561
@pytest.fixture

tests/unit/attestations/test_services.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_get_provenance_digest(self, db_request):
6666
)
6767

6868
file = FileFactory.create()
69-
service = services.NullIntegrityService()
69+
service = services.NullIntegrityService(session=db_request.db)
7070

7171
provenance = service.generate_provenance(db_request, file, [VALID_ATTESTATION])
7272
assert isinstance(provenance, Provenance)
@@ -87,6 +87,7 @@ def test_persist_attestations_succeeds(self, db_request, storage_service):
8787
integrity_service = IntegrityService(
8888
storage=storage_service,
8989
metrics=pretend.stub(),
90+
session=db_request.db,
9091
)
9192

9293
file = FileFactory.create()
@@ -113,6 +114,7 @@ def test_parse_attestations_fails_no_publisher(self, db_request):
113114
integrity_service = IntegrityService(
114115
storage=pretend.stub(),
115116
metrics=pretend.stub(),
117+
session=db_request.db,
116118
)
117119

118120
db_request.oidc_publisher = None
@@ -126,6 +128,7 @@ def test_parse_attestations_fails_unsupported_publisher(self, db_request):
126128
integrity_service = IntegrityService(
127129
storage=pretend.stub(),
128130
metrics=pretend.stub(),
131+
session=db_request.db,
129132
)
130133
db_request.oidc_publisher = pretend.stub(publisher_name="not-existing")
131134
with pytest.raises(
@@ -138,6 +141,7 @@ def test_parse_attestations_fails_malformed_attestation(self, metrics, db_reques
138141
integrity_service = IntegrityService(
139142
storage=pretend.stub(),
140143
metrics=metrics,
144+
session=db_request.db,
141145
)
142146

143147
db_request.oidc_publisher = pretend.stub(publisher_name="GitHub")
@@ -157,6 +161,7 @@ def test_parse_attestations_fails_multiple_attestations(self, metrics, db_reques
157161
integrity_service = IntegrityService(
158162
storage=pretend.stub(),
159163
metrics=metrics,
164+
session=db_request.db,
160165
)
161166

162167
db_request.oidc_publisher = pretend.stub(publisher_name="GitHub")
@@ -195,6 +200,7 @@ def test_parse_attestations_fails_verification(
195200
integrity_service = IntegrityService(
196201
storage=pretend.stub(),
197202
metrics=metrics,
203+
session=db_request.db,
198204
)
199205

200206
db_request.oidc_publisher = pretend.stub(
@@ -224,6 +230,7 @@ def test_parse_attestations_fails_wrong_predicate(
224230
integrity_service = IntegrityService(
225231
storage=pretend.stub(),
226232
metrics=metrics,
233+
session=db_request.db,
227234
)
228235
db_request.oidc_publisher = pretend.stub(
229236
publisher_name="GitHub",
@@ -258,6 +265,7 @@ def test_parse_attestations_succeeds(self, metrics, monkeypatch, db_request):
258265
integrity_service = IntegrityService(
259266
storage=pretend.stub(),
260267
metrics=metrics,
268+
session=db_request.db,
261269
)
262270
db_request.oidc_publisher = pretend.stub(
263271
publisher_name="GitHub",
@@ -283,6 +291,7 @@ def test_generate_provenance_fails_unsupported_publisher(self, db_request, metri
283291
integrity_service = IntegrityService(
284292
storage=pretend.stub(),
285293
metrics=pretend.stub(),
294+
session=db_request.db,
286295
)
287296

288297
db_request.oidc_publisher = pretend.stub(publisher_name="not-existing")
@@ -309,6 +318,7 @@ def test_generate_provenance_succeeds(
309318
integrity_service = IntegrityService(
310319
storage=storage_service,
311320
metrics=metrics,
321+
session=db_request.db,
312322
)
313323

314324
file = FileFactory.create()
@@ -361,6 +371,7 @@ def test_persist_provenance_succeeds(self, db_request, storage_service, metrics)
361371
integrity_service = IntegrityService(
362372
storage=storage_service,
363373
metrics=metrics,
374+
session=db_request.db,
364375
)
365376
file = FileFactory.create()
366377
assert integrity_service._persist_provenance(provenance, file) is None
@@ -376,6 +387,7 @@ def test_get_provenance_digest_succeeds(self, db_request, metrics, storage_servi
376387
integrity_service = IntegrityService(
377388
storage=storage_service,
378389
metrics=metrics,
390+
session=db_request.db,
379391
)
380392

381393
db_request.oidc_publisher = GitHubPublisherFactory.create()
@@ -396,6 +408,7 @@ def test_get_provenance_digest_fails_no_attestations(self, db_request):
396408
integrity_service = IntegrityService(
397409
storage=pretend.stub(),
398410
metrics=pretend.stub(),
411+
session=db_request.db,
399412
)
400413

401414
assert integrity_service.get_provenance_digest(file) is None

warehouse/attestations/services.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,18 @@ def _extract_attestations_from_request(request: Request) -> list[Attestation]:
111111

112112
@implementer(IIntegrityService)
113113
class NullIntegrityService:
114-
def __init__(self):
114+
def __init__(self, session):
115115
warnings.warn(
116116
"NullIntegrityService is intended only for use in development, "
117117
"you should not use it in production due to the lack of actual "
118118
"attestation verification.",
119119
InsecureIntegrityServiceWarning,
120120
)
121+
self.db = session
121122

122123
@classmethod
123-
def create_service(cls, _context, _request):
124-
return cls()
124+
def create_service(cls, _context, request):
125+
return cls(session=request.db)
125126

126127
def parse_attestations(
127128
self, request: Request, _distribution: Distribution
@@ -165,19 +166,17 @@ def get_provenance_digest(self, file: File) -> str | None:
165166

166167
@implementer(IIntegrityService)
167168
class IntegrityService:
168-
def __init__(
169-
self,
170-
storage: IFileStorage,
171-
metrics: IMetricsService,
172-
):
169+
def __init__(self, storage: IFileStorage, metrics: IMetricsService, session):
173170
self.storage: IFileStorage = storage
174171
self.metrics: IMetricsService = metrics
172+
self.db = session
175173

176174
@classmethod
177-
def create_service(cls, _context, request: Request):
175+
def create_service(cls, _context, request):
178176
return cls(
179177
storage=request.find_service(IFileStorage, name="archive"),
180178
metrics=request.find_service(IMetricsService),
179+
session=request.db,
181180
)
182181

183182
def parse_attestations(

0 commit comments

Comments
 (0)