Skip to content

Commit c1370bf

Browse files
xmunozewdurbin
authored andcommitted
Misc cleanup and TODOs on malware checks. (#7355)
* Misc cleanup and TODOs on malware checks. - Change backfill function to invoke `IMalwareCheckService` interface - Add support for `kwargs to `IMalwareCheckService` interface - Rename variable from reserved word `file` to `release_file` - Add `FatalCheckException` for non-retryable exceptions - Replace `MALWARE_CHECK_BACKEND` in dev/environment * Make `IMalwareService` the entrypoint for `run_check` - Add `run_scheduled_check` task that invokes this interface. - Remove useless utility method - Move `FatalCheckException` into warehouse/malware/errors.py.
1 parent 1d3791b commit c1370bf

File tree

15 files changed

+190
-68
lines changed

15 files changed

+190
-68
lines changed

dev/environment

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa
2929

3030
BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService
3131

32-
#TODO: change this to PrinterMalwareCheckService before deploy
33-
MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService
32+
MALWARE_CHECK_BACKEND=warehouse.malware.services.PrinterMalwareCheckService
3433

3534
METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog
3635

tests/common/checks/hooked.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# limitations under the License.
1212

1313
from warehouse.malware.checks.base import MalwareCheckBase
14+
from warehouse.malware.errors import FatalCheckException
1415
from warehouse.malware.models import VerdictClassification, VerdictConfidence
1516

1617

@@ -29,7 +30,7 @@ def __init__(self, db):
2930
def scan(self, **kwargs):
3031
file_id = kwargs.get("obj_id")
3132
if file_id is None:
32-
return
33+
raise FatalCheckException("Missing required kwarg `obj_id`")
3334

3435
self.add_verdict(
3536
file_id=file_id,

tests/unit/admin/views/test_checks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from warehouse.admin.views import checks as views
1919
from warehouse.malware.models import MalwareCheckState, MalwareCheckType
20-
from warehouse.malware.tasks import backfill, run_check
20+
from warehouse.malware.tasks import backfill, run_scheduled_check
2121

2222
from ....common.db.malware import MalwareCheckFactory
2323

@@ -208,7 +208,7 @@ def test_success(self, db_request, check_type):
208208
assert db_request.session.flash.calls == [
209209
pretend.call("Running %s now!" % check.name, queue="success",)
210210
]
211-
assert db_request.task.calls == [pretend.call(run_check)]
211+
assert db_request.task.calls == [pretend.call(run_scheduled_check)]
212212
assert backfill_recorder.delay.calls == [
213213
pretend.call(check.name, manually_triggered=True)
214214
]

tests/unit/malware/checks/setup_patterns/test_check.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def test_scan_missing_kwargs(db_session, obj, file_url):
4343
name="SetupPatternCheck", state=MalwareCheckState.Enabled
4444
)
4545
check = c.SetupPatternCheck(db_session)
46-
check.scan(obj=obj, file_url=file_url)
47-
48-
assert check._verdicts == []
46+
with pytest.raises(c.FatalCheckException):
47+
check.scan(obj=obj, file_url=file_url)
4948

5049

5150
def test_scan_non_sdist(db_session):

tests/unit/malware/test_init.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
from celery.schedules import crontab
1818

1919
from warehouse import malware
20-
from warehouse.malware import utils
2120
from warehouse.malware.interfaces import IMalwareCheckService
22-
from warehouse.malware.tasks import run_check
21+
from warehouse.malware.tasks import run_scheduled_check
2322

2423
from ...common import checks as test_checks
2524
from ...common.db.accounts import UserFactory
@@ -30,7 +29,7 @@ def test_determine_malware_checks_no_checks(monkeypatch, db_request):
3029
def get_enabled_hooked_checks(session):
3130
return defaultdict(list)
3231

33-
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
32+
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)
3433

3534
project = ProjectFactory.create(name="foo")
3635
release = ReleaseFactory.create(project=project)
@@ -49,7 +48,7 @@ def get_enabled_hooked_checks(session):
4948
result["Release"] = ["Check3"]
5049
return result
5150

52-
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
51+
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)
5352

5453
project = ProjectFactory.create(name="foo")
5554
release = ReleaseFactory.create(project=project)
@@ -68,7 +67,7 @@ def get_enabled_hooked_checks(session):
6867
result["Release"] = ["Check3"]
6968
return result
7069

71-
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
70+
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)
7271

7372
user = UserFactory.create()
7473

@@ -85,7 +84,7 @@ def get_enabled_hooked_checks(session):
8584
result["Release"] = ["Check3"]
8685
return result
8786

88-
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
87+
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)
8988

9089
project = ProjectFactory.create(name="foo")
9190
release = ReleaseFactory.create(project=project)
@@ -105,7 +104,7 @@ def get_enabled_hooked_checks(session):
105104
result["Release"] = ["Check3"]
106105
return result
107106

108-
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
107+
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)
109108

110109
project = ProjectFactory.create(name="foo")
111110
release = ReleaseFactory.create(project=project)
@@ -179,6 +178,8 @@ def test_includeme(monkeypatch):
179178

180179
assert config.add_periodic_task.calls == [
181180
pretend.call(
182-
crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",)
181+
crontab(minute="0", hour="*/8"),
182+
run_scheduled_check,
183+
args=("ExampleScheduledCheck",),
183184
)
184185
]

tests/unit/malware/test_services.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# limitations under the License.
1212

1313
import pretend
14+
import pytest
1415

1516
from zope.interface.verify import verifyClass
1617

@@ -31,13 +32,14 @@ def test_create_service(self):
3132
service = PrinterMalwareCheckService.create_service(None, request)
3233
assert service.executor == print
3334

34-
def test_run_checks(self, capfd):
35+
@pytest.mark.parametrize(("kwargs"), [{}, {"manually_triggered": True}])
36+
def test_run_checks(self, capfd, kwargs):
3537
request = pretend.stub()
3638
service = PrinterMalwareCheckService.create_service(None, request)
3739
checks = ["one", "two", "three"]
38-
service.run_checks(checks)
40+
service.run_checks(checks, **kwargs)
3941
out, err = capfd.readouterr()
40-
assert out == "one\ntwo\nthree\n"
42+
assert out == "".join(["%s %s\n" % (check, kwargs) for check in checks])
4143

4244

4345
class TestDatabaseMalwareService:
@@ -50,12 +52,36 @@ def test_create_service(self, db_request):
5052
service = DatabaseMalwareCheckService.create_service(None, db_request)
5153
assert service.executor == db_request.task(run_check).delay
5254

53-
def test_run_checks(self, db_request):
54-
_delay = pretend.call_recorder(lambda *args: None)
55+
def test_run_hooked_check(self, db_request):
56+
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
57+
db_request.task = lambda x: pretend.stub(delay=_delay)
58+
service = DatabaseMalwareCheckService.create_service(None, db_request)
59+
checks = [
60+
"MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187",
61+
"AnotherCheck:44f57b0e-c5b0-47c5-8713-341cf392efe2",
62+
"FinalCheck:e8518a15-8f01-430e-8f5b-87644007c9c0",
63+
]
64+
service.run_checks(checks)
65+
assert _delay.calls == [
66+
pretend.call("MyTestCheck", obj_id="ba70267f-fabf-496f-9ac2-d237a983b187"),
67+
pretend.call("AnotherCheck", obj_id="44f57b0e-c5b0-47c5-8713-341cf392efe2"),
68+
pretend.call("FinalCheck", obj_id="e8518a15-8f01-430e-8f5b-87644007c9c0"),
69+
]
70+
71+
def test_run_scheduled_check(self, db_request):
72+
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
5573
db_request.task = lambda x: pretend.stub(delay=_delay)
5674
service = DatabaseMalwareCheckService.create_service(None, db_request)
57-
checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"]
75+
checks = ["MyTestScheduledCheck"]
5876
service.run_checks(checks)
77+
assert _delay.calls == [pretend.call("MyTestScheduledCheck")]
78+
79+
def test_run_triggered_check(self, db_request):
80+
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
81+
db_request.task = lambda x: pretend.stub(delay=_delay)
82+
service = DatabaseMalwareCheckService.create_service(None, db_request)
83+
checks = ["MyTriggeredCheck"]
84+
service.run_checks(checks, manually_triggered=True)
5985
assert _delay.calls == [
60-
pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187")
86+
pretend.call("MyTriggeredCheck", manually_triggered=True)
6187
]

tests/unit/malware/test_tasks.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from warehouse.malware import tasks
1818
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
19+
from warehouse.malware.services import PrinterMalwareCheckService
1920

2021
from ...common import checks as test_checks
2122
from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory
@@ -101,6 +102,28 @@ def test_missing_check(self, db_request, monkeypatch):
101102
task, db_request, "DoesNotExistCheck",
102103
)
103104

105+
def test_missing_obj_id(self, db_session, monkeypatch):
106+
monkeypatch.setattr(tasks, "checks", test_checks)
107+
task = pretend.stub()
108+
109+
MalwareCheckFactory.create(
110+
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
111+
)
112+
task = pretend.stub()
113+
114+
request = pretend.stub(
115+
db=db_session,
116+
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)),
117+
)
118+
119+
tasks.run_check(task, request, "ExampleHookedCheck")
120+
121+
assert request.log.error.calls == [
122+
pretend.call(
123+
"Fatal exception: ExampleHookedCheck: Missing required kwarg `obj_id`"
124+
)
125+
]
126+
104127
def test_retry(self, db_session, monkeypatch):
105128
monkeypatch.setattr(tasks, "checks", test_checks)
106129
exc = Exception("Scan failed")
@@ -135,6 +158,37 @@ def scan(self, **kwargs):
135158
assert task.retry.calls == [pretend.call(exc=exc)]
136159

137160

161+
class TestRunScheduledCheck:
162+
def test_invalid_check_name(self, db_request, monkeypatch):
163+
monkeypatch.setattr(tasks, "checks", test_checks)
164+
task = pretend.stub()
165+
with pytest.raises(AttributeError):
166+
tasks.run_scheduled_check(task, db_request, "DoesNotExist")
167+
168+
def test_run_check(self, db_session, capfd, monkeypatch):
169+
MalwareCheckFactory.create(
170+
name="ExampleScheduledCheck", state=MalwareCheckState.Enabled
171+
)
172+
173+
request = pretend.stub(
174+
db=db_session,
175+
find_service_factory=pretend.call_recorder(
176+
lambda interface: PrinterMalwareCheckService.create_service
177+
),
178+
)
179+
180+
task = pretend.stub()
181+
182+
tasks.run_scheduled_check(task, request, "ExampleScheduledCheck")
183+
184+
assert request.find_service_factory.calls == [
185+
pretend.call(tasks.IMalwareCheckService)
186+
]
187+
188+
out, err = capfd.readouterr()
189+
assert out == "ExampleScheduledCheck {'manually_triggered': False}\n"
190+
191+
138192
class TestBackfill:
139193
def test_invalid_check_name(self, db_request, monkeypatch):
140194
monkeypatch.setattr(tasks, "checks", test_checks)
@@ -145,33 +199,47 @@ def test_invalid_check_name(self, db_request, monkeypatch):
145199
@pytest.mark.parametrize(
146200
("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)]
147201
)
148-
def test_run(self, db_session, num_objects, num_runs, monkeypatch):
202+
def test_run(self, db_session, capfd, num_objects, num_runs, monkeypatch):
149203
monkeypatch.setattr(tasks, "checks", test_checks)
204+
205+
ids = []
150206
for i in range(num_objects):
151-
FileFactory.create()
207+
ids.append(FileFactory.create().id)
152208

153209
MalwareCheckFactory.create(
154210
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
155211
)
156212

157-
enqueue_recorder = pretend.stub(
158-
delay=pretend.call_recorder(lambda *a, **kw: None)
159-
)
160-
task = pretend.call_recorder(lambda *args, **kwargs: enqueue_recorder)
161-
162213
request = pretend.stub(
163214
db=db_session,
164215
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)),
165-
task=task,
216+
find_service_factory=pretend.call_recorder(
217+
lambda interface: PrinterMalwareCheckService.create_service
218+
),
166219
)
167220

221+
task = pretend.stub()
222+
168223
tasks.backfill(task, request, "ExampleHookedCheck", num_runs)
169224

170225
assert request.log.info.calls == [
171226
pretend.call("Running backfill on %d Files." % num_runs)
172227
]
173228

174-
assert len(enqueue_recorder.delay.calls) == num_runs
229+
assert request.find_service_factory.calls == [
230+
pretend.call(tasks.IMalwareCheckService)
231+
]
232+
233+
out, err = capfd.readouterr()
234+
num_output_lines = 0
235+
for file_id in ids:
236+
logged_output = "ExampleHookedCheck:%s %s\n" % (
237+
file_id,
238+
{"manually_triggered": True},
239+
)
240+
num_output_lines += 1 if logged_output in out else 0
241+
242+
assert num_output_lines == num_runs
175243

176244

177245
class TestSyncChecks:

warehouse/admin/views/checks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sqlalchemy.orm.exc import NoResultFound
1616

1717
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType
18-
from warehouse.malware.tasks import backfill, remove_verdicts, run_check
18+
from warehouse.malware.tasks import backfill, remove_verdicts, run_scheduled_check
1919

2020
EVALUATION_RUN_SIZE = 10000
2121

@@ -94,7 +94,7 @@ def run_evaluation(request):
9494

9595
else:
9696
request.session.flash(f"Running {check.name} now!", queue="success")
97-
request.task(run_check).delay(check.name, manually_triggered=True)
97+
request.task(run_scheduled_check).delay(check.name, manually_triggered=True)
9898

9999
return HTTPSeeOther(
100100
request.route_path("admin.checks.detail", check_name=check.name)

warehouse/malware/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import warehouse.malware.checks as checks
1818

1919
from warehouse import db
20-
from warehouse.malware import utils
2120
from warehouse.malware.interfaces import IMalwareCheckService
22-
from warehouse.malware.tasks import run_check
21+
from warehouse.malware.models import MalwareCheckObjectType
22+
from warehouse.malware.tasks import run_scheduled_check
23+
from warehouse.malware.utils import get_enabled_hooked_checks
2324

2425

2526
@db.listens_for(db.Session, "after_flush")
@@ -31,13 +32,13 @@ def determine_malware_checks(config, session, flush_context):
3132
[
3233
obj.__class__.__name__
3334
for obj in session.new
34-
if obj.__class__.__name__ in utils.valid_check_types()
35+
if obj.__class__.__name__ in MalwareCheckObjectType.__members__
3536
]
3637
):
3738
return
3839

3940
malware_checks = session.info.setdefault("warehouse.malware.checks", set())
40-
enabled_checks = utils.get_enabled_hooked_checks(session)
41+
enabled_checks = get_enabled_hooked_checks(session)
4142
for obj in session.new:
4243
for check_name in enabled_checks.get(obj.__class__.__name__, []):
4344
malware_checks.update([f"{check_name}:{obj.id}"])
@@ -71,5 +72,5 @@ def includeme(config):
7172
check = check_obj[1]
7273
if check.check_type == "scheduled":
7374
config.add_periodic_task(
74-
crontab(**check.schedule), run_check, args=(check_obj[0],)
75+
crontab(**check.schedule), run_scheduled_check, args=(check_obj[0],)
7576
)

0 commit comments

Comments
 (0)