diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 425d50167529..28538ad12dac 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -132,6 +132,11 @@ def test_includeme(): "/admin/checks/{check_name}/change_state", domain=warehouse, ), + pretend.call( + "admin.checks.run_backfill", + "/admin/checks/{check_name}/run_backfill", + domain=warehouse, + ), pretend.call("admin.verdicts.list", "/admin/verdicts/", domain=warehouse), pretend.call( "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 601e26e79858..273bb8cd3fe3 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -10,8 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import uuid - import pretend import pytest @@ -67,6 +65,12 @@ def test_get_check_not_found(self, db_request): class TestChangeCheckState: + def test_no_check_state(self, db_request): + check = MalwareCheckFactory.create() + db_request.matchdict["check_name"] = check.name + with pytest.raises(HTTPNotFound): + views.change_check_state(db_request) + @pytest.mark.parametrize( ("final_state"), [MalwareCheckState.disabled, MalwareCheckState.wiped_out] ) @@ -75,7 +79,7 @@ def test_change_to_valid_state(self, db_request, final_state): name="MyCheck", state=MalwareCheckState.disabled ) - db_request.POST = {"id": check.id, "check_state": final_state.value} + db_request.POST = {"check_state": final_state.value} db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -107,7 +111,7 @@ def test_change_to_invalid_state(self, db_request): check = MalwareCheckFactory.create(name="MyCheck") initial_state = check.state invalid_check_state = "cancelled" - db_request.POST = {"id": check.id, "check_state": invalid_check_state} + db_request.POST = {"check_state": invalid_check_state} db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -124,13 +128,62 @@ def test_change_to_invalid_state(self, db_request): ] assert check.state == initial_state - def test_check_not_found(self, db_request): - db_request.POST = {"id": uuid.uuid4(), "check_state": "enabled"} - db_request.matchdict["check_name"] = "DoesNotExist" + +class TestRunBackfill: + @pytest.mark.parametrize( + ("check_state", "message"), + [ + ( + MalwareCheckState.disabled, + "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + ), + ( + MalwareCheckState.wiped_out, + "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + ), + ], + ) + def test_invalid_backfill_parameters(self, db_request, check_state, message): + check = MalwareCheckFactory.create(state=check_state) + db_request.matchdict["check_name"] = check.name + + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/DoesNotExist/change_state" + lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name ) - with pytest.raises(HTTPNotFound): - views.change_check_state(db_request) + views.run_backfill(db_request) + + assert db_request.session.flash.calls == [pretend.call(message, queue="error")] + + def test_sucess(self, db_request): + check = MalwareCheckFactory.create(state=MalwareCheckState.enabled) + db_request.matchdict["check_name"] = check.name + + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + + db_request.route_path = pretend.call_recorder( + lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + ) + + backfill_recorder = pretend.stub( + delay=pretend.call_recorder(lambda *a, **kw: None) + ) + + db_request.task = pretend.call_recorder(lambda *a, **kw: backfill_recorder) + + views.run_backfill(db_request) + + assert db_request.session.flash.calls == [ + pretend.call( + "Running %s on 10000 %ss!" % (check.name, check.hooked_object.value), + queue="success", + ) + ] + + assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 234d05dba5bb..91ce94966438 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -19,76 +19,112 @@ import warehouse.malware.checks as checks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict -from warehouse.malware.tasks import remove_verdicts, run_check, sync_checks +from warehouse.malware.tasks import backfill, remove_verdicts, run_check, sync_checks from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory class TestRunCheck: - def test_success(self, monkeypatch, db_request): - project = ProjectFactory.create(name="foo") - release = ReleaseFactory.create(project=project) - file0 = FileFactory.create(release=release, filename="foo.bar") + def test_success(self, db_request): + file0 = FileFactory.create() MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) - task = pretend.stub() run_check(task, db_request, "ExampleCheck", file0.id) + assert db_request.db.query(MalwareVerdict).one() - def test_missing_check_id(self, monkeypatch, db_session): - exc = NoResultFound("No row was found for one()") + def test_disabled_check(self, db_request): + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.disabled + ) + task = pretend.stub() - class FakeMalwareCheck: - def __init__(self, db): - raise exc + with pytest.raises(NoResultFound): + run_check( + task, + db_request, + "ExampleCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) - checks.FakeMalwareCheck = FakeMalwareCheck + def test_missing_check(self, db_request): + task = pretend.stub() + with pytest.raises(AttributeError): + run_check( + task, + db_request, + "DoesNotExistCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) - class Task: - @staticmethod - @pretend.call_recorder - def retry(exc): - raise celery.exceptions.Retry + def test_retry(self, db_session, monkeypatch): + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.evaluation + ) - task = Task() + exc = Exception("Scan failed") + def scan(self, file_id): + raise exc + + monkeypatch.setattr(checks.ExampleCheck, "scan", scan) + + task = pretend.stub( + retry=pretend.call_recorder(pretend.raiser(celery.exceptions.Retry)), + ) request = pretend.stub( db=db_session, - log=pretend.stub( - error=pretend.call_recorder(lambda *args, **kwargs: None), - ), + log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)), ) with pytest.raises(celery.exceptions.Retry): run_check( - task, - request, - "FakeMalwareCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + task, request, "ExampleCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" ) assert request.log.error.calls == [ - pretend.call( - "Error executing check %s: %s", - "FakeMalwareCheck", - "No row was found for one()", - ) + pretend.call("Error executing check ExampleCheck: Scan failed") ] assert task.retry.calls == [pretend.call(exc=exc)] - del checks.FakeMalwareCheck - def test_missing_check(self, db_request): +class TestBackfill: + def test_invalid_check_name(self, db_request): task = pretend.stub() with pytest.raises(AttributeError): - run_check( - task, - db_request, - "DoesNotExistCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", - ) + backfill(task, db_request, "DoesNotExist", 1) + + @pytest.mark.parametrize( + ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)], + ) + def test_run(self, db_session, num_objects, num_runs): + files = [] + for i in range(num_objects): + files.append(FileFactory.create()) + + MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + enqueue_recorder = pretend.stub( + delay=pretend.call_recorder(lambda *a, **kw: None) + ) + task = pretend.call_recorder(lambda *args, **kwargs: enqueue_recorder) + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + task=task, + ) + + backfill(task, request, "ExampleCheck", num_runs) + + assert request.log.info.calls == [ + pretend.call("Running backfill on %d Files." % num_runs), + ] + + assert enqueue_recorder.delay.calls == [ + pretend.call("ExampleCheck", files[i].id) for i in range(num_runs) + ] class TestSyncChecks: diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 284932c29190..2788f51519bf 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -139,6 +139,11 @@ def includeme(config): "/admin/checks/{check_name}/change_state", domain=warehouse, ) + config.add_route( + "admin.checks.run_backfill", + "/admin/checks/{check_name}/run_backfill", + domain=warehouse, + ) config.add_route("admin.verdicts.list", "/admin/verdicts/", domain=warehouse) config.add_route( "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse diff --git a/warehouse/admin/templates/admin/malware/checks/detail.html b/warehouse/admin/templates/admin/malware/checks/detail.html index 2cf80f20cc8a..77914488bb51 100644 --- a/warehouse/admin/templates/admin/malware/checks/detail.html +++ b/warehouse/admin/templates/admin/malware/checks/detail.html @@ -48,11 +48,10 @@

Revision History

Change State

-
- {% for state in states %}
+
+
+

Run Evaluation

+
+
+ +
+

Run this check against 10,000 {{ check.hooked_object.value }}s, selected at random. This is used to evaluate the efficacy of a check.

+
+ +
+
+
+
{% endblock %} diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index f1ad86685d94..7fd3fa9aacf0 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -15,7 +15,7 @@ from sqlalchemy.orm.exc import NoResultFound from warehouse.malware.models import MalwareCheck, MalwareCheckState -from warehouse.malware.tasks import remove_verdicts +from warehouse.malware.tasks import backfill, remove_verdicts @view_config( @@ -26,7 +26,7 @@ uses_session=True, ) def get_checks(request): - all_checks = request.db.query(MalwareCheck).all() + all_checks = request.db.query(MalwareCheck) active_checks = [] for check in all_checks: if not check.is_stale: @@ -43,19 +43,49 @@ def get_checks(request): uses_session=True, ) def get_check(request): - query = ( + check = get_check_by_name(request.db, request.matchdict["check_name"]) + + all_checks = ( request.db.query(MalwareCheck) .filter(MalwareCheck.name == request.matchdict["check_name"]) .order_by(MalwareCheck.version.desc()) + .all() ) - try: - # Throw an exception if and only if no results are returned. - newest = query.limit(1).one() - except NoResultFound: - raise HTTPNotFound + return {"check": check, "checks": all_checks, "states": MalwareCheckState} - return {"check": newest, "checks": query.all(), "states": MalwareCheckState} + +@view_config( + route_name="admin.checks.run_backfill", + permission="admin", + request_method="POST", + uses_session=True, + require_methods=False, + require_csrf=True, +) +def run_backfill(request): + check = get_check_by_name(request.db, request.matchdict["check_name"]) + num_objects = 10000 + + if check.state not in (MalwareCheckState.enabled, MalwareCheckState.evaluation): + request.session.flash( + f"Check must be in 'enabled' or 'evaluation' state to run a backfill.", + queue="error", + ) + return HTTPSeeOther( + request.route_path("admin.checks.detail", check_name=check.name) + ) + + request.session.flash( + f"Running {check.name} on {num_objects} {check.hooked_object.value}s!", + queue="success", + ) + + request.task(backfill).delay(check.name, num_objects) + + return HTTPSeeOther( + request.route_path("admin.checks.detail", check_name=check.name) + ) @view_config( @@ -67,18 +97,16 @@ def get_check(request): require_csrf=True, ) def change_check_state(request): + check = get_check_by_name(request.db, request.matchdict["check_name"]) + try: - check = ( - request.db.query(MalwareCheck) - .filter(MalwareCheck.id == request.POST["id"]) - .one() - ) - except NoResultFound: + check_state = request.POST["check_state"] + except KeyError: raise HTTPNotFound try: - check.state = getattr(MalwareCheckState, request.POST["check_state"]) - except (AttributeError, KeyError): + check.state = getattr(MalwareCheckState, check_state) + except AttributeError: request.session.flash("Invalid check state provided.", queue="error") else: if check.state == MalwareCheckState.wiped_out: @@ -90,3 +118,19 @@ def change_check_state(request): return HTTPSeeOther( request.route_path("admin.checks.detail", check_name=check.name) ) + + +def get_check_by_name(db, check_name): + try: + # Throw an exception if and only if no results are returned. + newest = ( + db.query(MalwareCheck) + .filter(MalwareCheck.name == check_name) + .order_by(MalwareCheck.version.desc()) + .limit(1) + .one() + ) + except NoResultFound: + raise HTTPNotFound + + return newest diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py index 6c94937abbde..72406954c982 100644 --- a/warehouse/malware/checks/base.py +++ b/warehouse/malware/checks/base.py @@ -46,6 +46,10 @@ def _load_check_id(self): self.id = ( self.db.query(MalwareCheck.id) .filter(MalwareCheck.name == self._name) - .filter(MalwareCheck.state == MalwareCheckState.enabled) + .filter( + MalwareCheck.state.in_( + [MalwareCheckState.enabled, MalwareCheckState.evaluation] + ) + ) .one() ) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index a70b7ea344d6..5028ed00e7ef 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -13,22 +13,38 @@ import inspect import warehouse.malware.checks as checks +import warehouse.packaging.models as packaging_models from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict from warehouse.malware.utils import get_check_fields from warehouse.tasks import task -@task(bind=True, ignore_result=True, acks_late=True) +@task(bind=True, ignore_result=True, acks_late=True, retry_backoff=True) def run_check(task, request, check_name, obj_id): + check = getattr(checks, check_name)(request.db) try: - check = getattr(checks, check_name)(request.db) check.run(obj_id) except Exception as exc: - request.log.error("Error executing check %s: %s", check_name, str(exc)) + request.log.error("Error executing check %s: %s" % (check_name, str(exc))) raise task.retry(exc=exc) +@task(bind=True, ignore_result=True, acks_late=True) +def backfill(task, request, check_name, num_objects): + """ + Runs a backfill on a fixed number of objects. + """ + check = getattr(checks, check_name)(request.db) + target_object = getattr(packaging_models, check.hooked_object) + query = request.db.query(target_object.id).limit(num_objects) + + request.log.info("Running backfill on %d %ss." % (num_objects, check.hooked_object)) + + for (elem_id,) in query: + request.task(run_check).delay(check_name, elem_id) + + @task(bind=True, ignore_result=True, acks_late=True) def sync_checks(task, request): code_checks = inspect.getmembers(checks, inspect.isclass)