diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 425d50167529..28538ad12dac 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -132,6 +132,11 @@ def test_includeme(): "/admin/checks/{check_name}/change_state", domain=warehouse, ), + pretend.call( + "admin.checks.run_backfill", + "/admin/checks/{check_name}/run_backfill", + domain=warehouse, + ), pretend.call("admin.verdicts.list", "/admin/verdicts/", domain=warehouse), pretend.call( "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 601e26e79858..273bb8cd3fe3 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -10,8 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import uuid - import pretend import pytest @@ -67,6 +65,12 @@ def test_get_check_not_found(self, db_request): class TestChangeCheckState: + def test_no_check_state(self, db_request): + check = MalwareCheckFactory.create() + db_request.matchdict["check_name"] = check.name + with pytest.raises(HTTPNotFound): + views.change_check_state(db_request) + @pytest.mark.parametrize( ("final_state"), [MalwareCheckState.disabled, MalwareCheckState.wiped_out] ) @@ -75,7 +79,7 @@ def test_change_to_valid_state(self, db_request, final_state): name="MyCheck", state=MalwareCheckState.disabled ) - db_request.POST = {"id": check.id, "check_state": final_state.value} + db_request.POST = {"check_state": final_state.value} db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -107,7 +111,7 @@ def test_change_to_invalid_state(self, db_request): check = MalwareCheckFactory.create(name="MyCheck") initial_state = check.state invalid_check_state = "cancelled" - db_request.POST = {"id": check.id, "check_state": invalid_check_state} + db_request.POST = {"check_state": invalid_check_state} db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -124,13 +128,62 @@ def test_change_to_invalid_state(self, db_request): ] assert check.state == initial_state - def test_check_not_found(self, db_request): - db_request.POST = {"id": uuid.uuid4(), "check_state": "enabled"} - db_request.matchdict["check_name"] = "DoesNotExist" + +class TestRunBackfill: + @pytest.mark.parametrize( + ("check_state", "message"), + [ + ( + MalwareCheckState.disabled, + "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + ), + ( + MalwareCheckState.wiped_out, + "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + ), + ], + ) + def test_invalid_backfill_parameters(self, db_request, check_state, message): + check = MalwareCheckFactory.create(state=check_state) + db_request.matchdict["check_name"] = check.name + + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/DoesNotExist/change_state" + lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name ) - with pytest.raises(HTTPNotFound): - views.change_check_state(db_request) + views.run_backfill(db_request) + + assert db_request.session.flash.calls == [pretend.call(message, queue="error")] + + def test_sucess(self, db_request): + check = MalwareCheckFactory.create(state=MalwareCheckState.enabled) + db_request.matchdict["check_name"] = check.name + + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + + db_request.route_path = pretend.call_recorder( + lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + ) + + backfill_recorder = pretend.stub( + delay=pretend.call_recorder(lambda *a, **kw: None) + ) + + db_request.task = pretend.call_recorder(lambda *a, **kw: backfill_recorder) + + views.run_backfill(db_request) + + assert db_request.session.flash.calls == [ + pretend.call( + "Running %s on 10000 %ss!" % (check.name, check.hooked_object.value), + queue="success", + ) + ] + + assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 234d05dba5bb..91ce94966438 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -19,76 +19,112 @@ import warehouse.malware.checks as checks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict -from warehouse.malware.tasks import remove_verdicts, run_check, sync_checks +from warehouse.malware.tasks import backfill, remove_verdicts, run_check, sync_checks from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory class TestRunCheck: - def test_success(self, monkeypatch, db_request): - project = ProjectFactory.create(name="foo") - release = ReleaseFactory.create(project=project) - file0 = FileFactory.create(release=release, filename="foo.bar") + def test_success(self, db_request): + file0 = FileFactory.create() MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) - task = pretend.stub() run_check(task, db_request, "ExampleCheck", file0.id) + assert db_request.db.query(MalwareVerdict).one() - def test_missing_check_id(self, monkeypatch, db_session): - exc = NoResultFound("No row was found for one()") + def test_disabled_check(self, db_request): + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.disabled + ) + task = pretend.stub() - class FakeMalwareCheck: - def __init__(self, db): - raise exc + with pytest.raises(NoResultFound): + run_check( + task, + db_request, + "ExampleCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) - checks.FakeMalwareCheck = FakeMalwareCheck + def test_missing_check(self, db_request): + task = pretend.stub() + with pytest.raises(AttributeError): + run_check( + task, + db_request, + "DoesNotExistCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) - class Task: - @staticmethod - @pretend.call_recorder - def retry(exc): - raise celery.exceptions.Retry + def test_retry(self, db_session, monkeypatch): + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.evaluation + ) - task = Task() + exc = Exception("Scan failed") + def scan(self, file_id): + raise exc + + monkeypatch.setattr(checks.ExampleCheck, "scan", scan) + + task = pretend.stub( + retry=pretend.call_recorder(pretend.raiser(celery.exceptions.Retry)), + ) request = pretend.stub( db=db_session, - log=pretend.stub( - error=pretend.call_recorder(lambda *args, **kwargs: None), - ), + log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)), ) with pytest.raises(celery.exceptions.Retry): run_check( - task, - request, - "FakeMalwareCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + task, request, "ExampleCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" ) assert request.log.error.calls == [ - pretend.call( - "Error executing check %s: %s", - "FakeMalwareCheck", - "No row was found for one()", - ) + pretend.call("Error executing check ExampleCheck: Scan failed") ] assert task.retry.calls == [pretend.call(exc=exc)] - del checks.FakeMalwareCheck - def test_missing_check(self, db_request): +class TestBackfill: + def test_invalid_check_name(self, db_request): task = pretend.stub() with pytest.raises(AttributeError): - run_check( - task, - db_request, - "DoesNotExistCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", - ) + backfill(task, db_request, "DoesNotExist", 1) + + @pytest.mark.parametrize( + ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)], + ) + def test_run(self, db_session, num_objects, num_runs): + files = [] + for i in range(num_objects): + files.append(FileFactory.create()) + + MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + enqueue_recorder = pretend.stub( + delay=pretend.call_recorder(lambda *a, **kw: None) + ) + task = pretend.call_recorder(lambda *args, **kwargs: enqueue_recorder) + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + task=task, + ) + + backfill(task, request, "ExampleCheck", num_runs) + + assert request.log.info.calls == [ + pretend.call("Running backfill on %d Files." % num_runs), + ] + + assert enqueue_recorder.delay.calls == [ + pretend.call("ExampleCheck", files[i].id) for i in range(num_runs) + ] class TestSyncChecks: diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 284932c29190..2788f51519bf 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -139,6 +139,11 @@ def includeme(config): "/admin/checks/{check_name}/change_state", domain=warehouse, ) + config.add_route( + "admin.checks.run_backfill", + "/admin/checks/{check_name}/run_backfill", + domain=warehouse, + ) config.add_route("admin.verdicts.list", "/admin/verdicts/", domain=warehouse) config.add_route( "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse diff --git a/warehouse/admin/templates/admin/malware/checks/detail.html b/warehouse/admin/templates/admin/malware/checks/detail.html index 2cf80f20cc8a..77914488bb51 100644 --- a/warehouse/admin/templates/admin/malware/checks/detail.html +++ b/warehouse/admin/templates/admin/malware/checks/detail.html @@ -48,11 +48,10 @@