From b051793420f5a96fd28a30a4f66390522508e6ae Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 19 Dec 2019 11:26:55 -0800 Subject: [PATCH 01/16] Add new models for malware detection. (#7118) * Add new models for malware detection. Fixes #7090 and #7092. * Code review changes. - FK on release_file.id field instead of md5 - Change message type from String to Text - Change Enum class in model to singular form --- warehouse/malware/__init__.py | 11 ++ warehouse/malware/models.py | 116 ++++++++++++++++++ ...1ff3d24c22_add_malware_detection_tables.py | 110 +++++++++++++++++ 3 files changed, 237 insertions(+) create mode 100644 warehouse/malware/__init__.py create mode 100644 warehouse/malware/models.py create mode 100644 warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py new file mode 100644 index 000000000000..164f68b09175 --- /dev/null +++ b/warehouse/malware/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py new file mode 100644 index 000000000000..fc6576c7dfcd --- /dev/null +++ b/warehouse/malware/models.py @@ -0,0 +1,116 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + +from citext import CIText +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Enum, + ForeignKey, + Integer, + String, + Text, + UniqueConstraint, + orm, + sql, +) +from sqlalchemy.dialects.postgresql import JSONB + +from warehouse import db +from warehouse.utils.attrs import make_repr + + +class MalwareCheckType(enum.Enum): + + EventHook = "event_hook" + Scheduled = "scheduled" + + +class MalwareCheckState(enum.Enum): + + Enabled = "enabled" + Evaluation = "evaluation" + Disabled = "disabled" + WipedOut = "wiped_out" + + +class VerdictClassification(enum.Enum): + + Threat = "threat" + Indeterminate = "indeterminate" + Benign = "benign" + + +class VerdictConfidence(enum.Enum): + + Low = "low" + Medium = "medium" + High = "high" + + +class MalwareCheck(db.Model): + + __tablename__ = "malware_checks" + __table_args__ = (UniqueConstraint("name", "version"),) + __repr__ = make_repr("name", "version") + + name = Column(CIText, nullable=False) + version = Column(Integer, default=0, nullable=False) + short_description = Column(String(length=128), nullable=False) + long_description = Column(Text, nullable=False) + check_type = Column( + Enum(MalwareCheckType, values_callable=lambda x: [e.value for e in x]), + nullable=False, + ) + # This field contains the same content as the ProjectEvent and UserEvent "tag" + # fields. + hook_name = Column(String, nullable=True) + state = Column( + Enum(MalwareCheckState, values_callable=lambda x: [e.value for e in x]), + nullable=False, + server_default=("disabled"), + ) + created = Column(DateTime, nullable=False, server_default=sql.func.now()) + + +class MalwareVerdict(db.Model): + __tablename__ = "malware_verdicts" + + run_date = Column(DateTime, nullable=False, server_default=sql.func.now()) + check_id = Column( + ForeignKey("malware_checks.id", onupdate="CASCADE", ondelete="CASCADE"), + nullable=False, + index=True, + ) + file_id = Column(ForeignKey("release_files.id"), nullable=False) + classification = Column( + Enum(VerdictClassification, values_callable=lambda x: [e.value for e in x]), + nullable=False, + ) + confidence = Column( + Enum(VerdictConfidence, values_callable=lambda x: [e.value for e in x]), + nullable=False, + ) + message = Column(Text, nullable=True) + details = Column(JSONB, nullable=True) + manually_reviewed = Column(Boolean, nullable=False, server_default=sql.false()) + administrator_verdict = Column( + Enum(VerdictClassification, values_callable=lambda x: [e.value for e in x]), + nullable=True, + ) + full_report_link = Column(String, nullable=True) + + check = orm.relationship("MalwareCheck", foreign_keys=[check_id], lazy=True) + release_file = orm.relationship("File", foreign_keys=[file_id], lazy=True) diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py new file mode 100644 index 000000000000..899cc51e0f57 --- /dev/null +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -0,0 +1,110 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Add malware detection tables + +Revision ID: 061ff3d24c22 +Revises: b5bb5d08543d +Create Date: 2019-12-18 17:27:00.183542 +""" +import citext +import sqlalchemy as sa + +from alembic import op +from sqlalchemy.dialects import postgresql + +revision = "061ff3d24c22" +down_revision = "b5bb5d08543d" + +MalwareCheckTypes = sa.Enum("event_hook", "scheduled", name="malwarechecktypes") + +MalwareCheckStates = sa.Enum( + "enabled", "evaluation", "disabled", "wiped_out", name="malwarecheckstate" +) + +VerdictClassifications = sa.Enum( + "threat", "indeterminate", "benign", name="verdictclassification" +) +VerdictConfidences = sa.Enum("low", "medium", "high", name="verdictconfidence") + + +def upgrade(): + op.create_table( + "malware_checks", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + server_default=sa.text("gen_random_uuid()"), + nullable=False, + ), + sa.Column("name", citext.CIText(), nullable=False), + sa.Column("version", sa.Integer(), nullable=False), + sa.Column("short_description", sa.String(length=128), nullable=False), + sa.Column("long_description", sa.Text(), nullable=False), + sa.Column("check_type", MalwareCheckTypes, nullable=False), + sa.Column("hook_name", sa.String(), nullable=True), + sa.Column( + "state", MalwareCheckStates, server_default="disabled", nullable=False, + ), + sa.Column( + "created", sa.DateTime(), server_default=sa.text("now()"), nullable=False + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name", "version"), + ) + op.create_table( + "malware_verdicts", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + server_default=sa.text("gen_random_uuid()"), + nullable=False, + ), + sa.Column( + "run_date", sa.DateTime(), server_default=sa.text("now()"), nullable=False + ), + sa.Column("check_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("file_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("classification", VerdictClassifications, nullable=False,), + sa.Column("confidence", VerdictConfidences, nullable=False,), + sa.Column("message", sa.Text(), nullable=True), + sa.Column("details", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column( + "manually_reviewed", + sa.Boolean(), + server_default=sa.text("false"), + nullable=False, + ), + sa.Column("administrator_verdict", VerdictClassifications, nullable=True,), + sa.Column("full_report_link", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["check_id"], ["malware_checks.id"], onupdate="CASCADE", ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["file_id"], ["release_files.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_malware_verdicts_check_id"), + "malware_verdicts", + ["check_id"], + unique=False, + ) + + +def downgrade(): + op.drop_index(op.f("ix_malware_verdicts_check_id"), table_name="malware_verdicts") + op.drop_table("malware_verdicts") + op.drop_table("malware_checks") + MalwareCheckTypes.drop(op.get_bind()) + MalwareCheckStates.drop(op.get_bind()) + VerdictClassifications.drop(op.get_bind()) + VerdictConfidences.drop(op.get_bind()) From ce21ebfd81ae15c136e1ab8cb80692b67a2326c4 Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 26 Dec 2019 06:03:27 -0800 Subject: [PATCH 02/16] Add admin interface to view and enable checks (#7134) * Add admin interface to view and enable checks - Implement list, detail and change_state views (#7133) - Add unit tests for check admin view * Add comprehensive test coverage for check admin --- tests/common/db/malware.py | 40 ++++++ tests/unit/admin/test_routes.py | 9 ++ tests/unit/admin/views/test_checks.py | 122 ++++++++++++++++++ warehouse/admin/routes.py | 11 ++ warehouse/admin/templates/admin/base.html | 5 + .../admin/malware/checks/detail.html | 70 ++++++++++ .../templates/admin/malware/checks/index.html | 57 ++++++++ warehouse/admin/views/checks.py | 89 +++++++++++++ warehouse/malware/models.py | 31 +++-- ...1ff3d24c22_add_malware_detection_tables.py | 2 +- 10 files changed, 425 insertions(+), 11 deletions(-) create mode 100644 tests/common/db/malware.py create mode 100644 tests/unit/admin/views/test_checks.py create mode 100644 warehouse/admin/templates/admin/malware/checks/detail.html create mode 100644 warehouse/admin/templates/admin/malware/checks/index.html create mode 100644 warehouse/admin/views/checks.py diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py new file mode 100644 index 000000000000..263fa82fa684 --- /dev/null +++ b/tests/common/db/malware.py @@ -0,0 +1,40 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import factory +import factory.fuzzy + +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType + +from .base import WarehouseFactory + + +class MalwareCheckFactory(WarehouseFactory): + class Meta: + model = MalwareCheck + + name = factory.fuzzy.FuzzyText(length=12) + version = 1 + short_description = factory.fuzzy.FuzzyText(length=80) + long_description = factory.fuzzy.FuzzyText(length=300) + check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType]) + hook_name = ( + "project:release:file:upload" + if check_type == MalwareCheckType.event_hook + else None + ) + state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState]) + created = factory.fuzzy.FuzzyNaiveDateTime( + datetime.datetime.utcnow() - datetime.timedelta(days=7) + ) diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 2600365a7ed6..e10962ac54dc 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -123,4 +123,13 @@ def test_includeme(): pretend.call("admin.flags.edit", "/admin/flags/edit/", domain=warehouse), pretend.call("admin.squats", "/admin/squats/", domain=warehouse), pretend.call("admin.squats.review", "/admin/squats/review/", domain=warehouse), + pretend.call("admin.checks.list", "/admin/checks/", domain=warehouse), + pretend.call( + "admin.checks.detail", "/admin/checks/{check_name}", domain=warehouse + ), + pretend.call( + "admin.checks.change_state", + "/admin/checks/{check_name}/change_state", + domain=warehouse, + ), ] diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py new file mode 100644 index 000000000000..55a45b2d5ac0 --- /dev/null +++ b/tests/unit/admin/views/test_checks.py @@ -0,0 +1,122 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +import pretend +import pytest + +from pyramid.httpexceptions import HTTPNotFound + +from warehouse.admin.views import checks as views +from warehouse.malware.models import MalwareCheckState + +from ....common.db.malware import MalwareCheckFactory + + +class TestListChecks: + def test_get_checks_none(self, db_request): + assert views.get_checks(db_request) == {"checks": []} + + def test_get_checks(self, db_request): + checks = [MalwareCheckFactory.create() for _ in range(10)] + assert views.get_checks(db_request) == {"checks": checks} + + def test_get_checks_different_versions(self, db_request): + checks = [MalwareCheckFactory.create() for _ in range(5)] + checks_same = [ + MalwareCheckFactory.create(name="MyCheck", version=i) for i in range(1, 6) + ] + checks.append(checks_same[-1]) + assert views.get_checks(db_request) == {"checks": checks} + + +class TestGetCheck: + def test_get_check(self, db_request): + check = MalwareCheckFactory.create() + db_request.matchdict["check_name"] = check.name + assert views.get_check(db_request) == { + "check": check, + "checks": [check], + "states": MalwareCheckState, + } + + def test_get_check_many_versions(self, db_request): + check1 = MalwareCheckFactory.create(name="MyCheck", version="1") + check2 = MalwareCheckFactory.create(name="MyCheck", version="2") + db_request.matchdict["check_name"] = check1.name + assert views.get_check(db_request) == { + "check": check2, + "checks": [check2, check1], + "states": MalwareCheckState, + } + + def test_get_check_not_found(self, db_request): + db_request.matchdict["check_name"] = "DoesNotExist" + with pytest.raises(HTTPNotFound): + views.get_check(db_request) + + +class TestChangeCheckState: + def test_change_to_enabled(self, db_request): + check = MalwareCheckFactory.create( + name="MyCheck", state=MalwareCheckState.disabled + ) + + db_request.POST = {"id": check.id, "check_state": "enabled"} + db_request.matchdict["check_name"] = check.name + + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.route_path = pretend.call_recorder( + lambda *a, **kw: "/admin/checks/MyCheck/change_state" + ) + + views.change_check_state(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Changed 'MyCheck' check to 'enabled'!", queue="success") + ] + assert check.state == MalwareCheckState.enabled + + def test_change_to_invalid_state(self, db_request): + check = MalwareCheckFactory.create(name="MyCheck") + initial_state = check.state + invalid_check_state = "cancelled" + db_request.POST = {"id": check.id, "check_state": invalid_check_state} + db_request.matchdict["check_name"] = check.name + + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.route_path = pretend.call_recorder( + lambda *a, **kw: "/admin/checks/MyCheck/change_state" + ) + + views.change_check_state(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Invalid check state provided.", queue="error") + ] + assert check.state == initial_state + + def test_check_not_found(self, db_request): + db_request.POST = {"id": uuid.uuid4(), "check_state": "enabled"} + db_request.matchdict["check_name"] = "DoesNotExist" + + db_request.route_path = pretend.call_recorder( + lambda *a, **kw: "/admin/checks/DoesNotExist/change_state" + ) + + with pytest.raises(HTTPNotFound): + views.change_check_state(db_request) diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index b0f1afde424d..4180df8188a6 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -128,3 +128,14 @@ def includeme(config): # Squats config.add_route("admin.squats", "/admin/squats/", domain=warehouse) config.add_route("admin.squats.review", "/admin/squats/review/", domain=warehouse) + + # Malware checks + config.add_route("admin.checks.list", "/admin/checks/", domain=warehouse) + config.add_route( + "admin.checks.detail", "/admin/checks/{check_name}", domain=warehouse + ) + config.add_route( + "admin.checks.change_state", + "/admin/checks/{check_name}/change_state", + domain=warehouse, + ) diff --git a/warehouse/admin/templates/admin/base.html b/warehouse/admin/templates/admin/base.html index 410d70fe1e60..cc74b4675b81 100644 --- a/warehouse/admin/templates/admin/base.html +++ b/warehouse/admin/templates/admin/base.html @@ -125,6 +125,11 @@ Squats +
  • + + Checks + +
  • diff --git a/warehouse/admin/templates/admin/malware/checks/detail.html b/warehouse/admin/templates/admin/malware/checks/detail.html new file mode 100644 index 000000000000..2cf80f20cc8a --- /dev/null +++ b/warehouse/admin/templates/admin/malware/checks/detail.html @@ -0,0 +1,70 @@ +{# + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. +-#} +{% extends "admin/base.html" %} + +{% block title %}{{ check.name }}{% endblock %} + +{% block breadcrumb %} +
  • Checks
  • +
  • {{ check.name }}
  • +{% endblock %} + +{% block content %} +
    +
    +

    {{ check.long_description }}

    +

    Revision History

    +
    + + + + + + + {% for c in checks %} + + + + + + {% endfor %} +
    VersionStateCreated
    {{ c.version }}{{ c.state.value }}{{ c.created }}
    +
    +
    +
    +
    +
    +

    Change State

    +
    +
    + + +
    +
    + +
    +
    + +
    +
    +
    +
    + +{% endblock %} diff --git a/warehouse/admin/templates/admin/malware/checks/index.html b/warehouse/admin/templates/admin/malware/checks/index.html new file mode 100644 index 000000000000..5717849e2579 --- /dev/null +++ b/warehouse/admin/templates/admin/malware/checks/index.html @@ -0,0 +1,57 @@ +{# + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. +-#} +{% extends "admin/base.html" %} + +{% block title %}Malware Checks{% endblock %} + +{% block breadcrumb %} +
  • Checks
  • +{% endblock %} + +{% block content %} +
    +
    + + + + + + + + + {% for check in checks %} + + + + + + + + {% else %} + + + + {% endfor %} +
    Check NameStateRevisionsLast ModifiedDescription
    + + {{ check.name }} + + {{ check.state.value }}{{ check.version }}{{ check.created }}{{ check.short_description }}
    +
    + No checks! +
    +
    +
    +
    +{% endblock content %} diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py new file mode 100644 index 000000000000..e3d38a88a0c0 --- /dev/null +++ b/warehouse/admin/views/checks.py @@ -0,0 +1,89 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyramid.httpexceptions import HTTPNotFound, HTTPSeeOther +from pyramid.view import view_config +from sqlalchemy.orm.exc import NoResultFound + +from warehouse.malware.models import MalwareCheck, MalwareCheckState + + +@view_config( + route_name="admin.checks.list", + renderer="admin/malware/checks/index.html", + permission="moderator", + request_method="GET", + uses_session=True, +) +def get_checks(request): + all_checks = request.db.query(MalwareCheck).all() + active_checks = [] + for check in all_checks: + if not check.is_stale: + active_checks.append(check) + + return {"checks": active_checks} + + +@view_config( + route_name="admin.checks.detail", + renderer="admin/malware/checks/detail.html", + permission="moderator", + request_method="GET", + uses_session=True, +) +def get_check(request): + query = ( + request.db.query(MalwareCheck) + .filter(MalwareCheck.name == request.matchdict["check_name"]) + .order_by(MalwareCheck.version.desc()) + ) + + try: + # Throw an exception if and only if no results are returned. + newest = query.limit(1).one() + except NoResultFound: + raise HTTPNotFound + + return {"check": newest, "checks": query.all(), "states": MalwareCheckState} + + +@view_config( + route_name="admin.checks.change_state", + permission="admin", + request_method="POST", + uses_session=True, + require_methods=False, + require_csrf=True, +) +def change_check_state(request): + try: + check = ( + request.db.query(MalwareCheck) + .filter(MalwareCheck.id == request.POST["id"]) + .one() + ) + except NoResultFound: + raise HTTPNotFound + + try: + check.state = getattr(MalwareCheckState, request.POST["check_state"]) + except (AttributeError, KeyError): + request.session.flash("Invalid check state provided.", queue="error") + else: + request.session.flash( + f"Changed {check.name!r} check to {check.state.value!r}!", queue="success" + ) + finally: + return HTTPSeeOther( + request.route_path("admin.checks.detail", check_name=check.name) + ) diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index fc6576c7dfcd..0c51e3006991 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -34,23 +34,23 @@ class MalwareCheckType(enum.Enum): - EventHook = "event_hook" - Scheduled = "scheduled" + event_hook = "event_hook" + scheduled = "scheduled" class MalwareCheckState(enum.Enum): - Enabled = "enabled" - Evaluation = "evaluation" - Disabled = "disabled" - WipedOut = "wiped_out" + enabled = "enabled" + evaluation = "evaluation" + disabled = "disabled" + wiped_out = "wiped_out" class VerdictClassification(enum.Enum): - Threat = "threat" - Indeterminate = "indeterminate" - Benign = "benign" + threat = "threat" + indeterminate = "indeterminate" + benign = "benign" class VerdictConfidence(enum.Enum): @@ -67,7 +67,7 @@ class MalwareCheck(db.Model): __repr__ = make_repr("name", "version") name = Column(CIText, nullable=False) - version = Column(Integer, default=0, nullable=False) + version = Column(Integer, default=1, nullable=False) short_description = Column(String(length=128), nullable=False) long_description = Column(Text, nullable=False) check_type = Column( @@ -84,6 +84,17 @@ class MalwareCheck(db.Model): ) created = Column(DateTime, nullable=False, server_default=sql.func.now()) + @property + def is_stale(self): + session = orm.object_session(self) + newest = ( + session.query(MalwareCheck) + .filter(MalwareCheck.name == self.name) + .order_by(MalwareCheck.version.desc()) + .first() + ) + return self.version != newest.version + class MalwareVerdict(db.Model): __tablename__ = "malware_verdicts" diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py index 899cc51e0f57..569cc0f100b1 100644 --- a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -47,7 +47,7 @@ def upgrade(): nullable=False, ), sa.Column("name", citext.CIText(), nullable=False), - sa.Column("version", sa.Integer(), nullable=False), + sa.Column("version", sa.Integer(), default=1, nullable=False), sa.Column("short_description", sa.String(length=128), nullable=False), sa.Column("long_description", sa.Text(), nullable=False), sa.Column("check_type", MalwareCheckTypes, nullable=False), From 52589e114b44f187628e9fb3fc9bbc130fffa46c Mon Sep 17 00:00:00 2001 From: Cristina Date: Mon, 6 Jan 2020 08:52:47 -0800 Subject: [PATCH 03/16] Add initial hook-based check execution mechanism (#7160) * Add initial hook-based check execution mechanism * scratch/poc * Add initial hook-based check execution mechanism * Use sqlalchemy event hooks for malware checks * Fix unit tests * Add enum for MalwareCheckObjectType * Add unit tests for init. * Add tests for tasks, services, and utils. Also, some small bugfixes in MalwareCheckFactory and the get_enabled_checks method. * Fix spurious task test. * Add missing drop enum to downgrade function. * Added TODO to dev/environment * Be more explicit in check lookup Co-authored-by: Ernest W. Durbin III --- dev/environment | 3 + tests/common/db/malware.py | 13 +- tests/conftest.py | 3 + tests/unit/malware/__init__.py | 11 ++ tests/unit/malware/test_init.py | 171 ++++++++++++++++++ tests/unit/malware/test_services.py | 61 +++++++ tests/unit/malware/test_tasks.py | 85 +++++++++ tests/unit/malware/test_utils.py | 47 +++++ tests/unit/test_config.py | 1 + warehouse/config.py | 4 + warehouse/malware/__init__.py | 48 +++++ warehouse/malware/checks/__init__.py | 13 ++ warehouse/malware/checks/base.py | 45 +++++ warehouse/malware/checks/example.py | 43 +++++ warehouse/malware/interfaces.py | 26 +++ warehouse/malware/models.py | 16 +- warehouse/malware/services.py | 45 +++++ warehouse/malware/tasks.py | 26 +++ warehouse/malware/utils.py | 42 +++++ ...1ff3d24c22_add_malware_detection_tables.py | 7 +- 20 files changed, 700 insertions(+), 10 deletions(-) create mode 100644 tests/unit/malware/__init__.py create mode 100644 tests/unit/malware/test_init.py create mode 100644 tests/unit/malware/test_services.py create mode 100644 tests/unit/malware/test_tasks.py create mode 100644 tests/unit/malware/test_utils.py create mode 100644 warehouse/malware/checks/__init__.py create mode 100644 warehouse/malware/checks/base.py create mode 100644 warehouse/malware/checks/example.py create mode 100644 warehouse/malware/interfaces.py create mode 100644 warehouse/malware/services.py create mode 100644 warehouse/malware/tasks.py create mode 100644 warehouse/malware/utils.py diff --git a/dev/environment b/dev/environment index e7ae3673787b..ec7eeae6d2f3 100644 --- a/dev/environment +++ b/dev/environment @@ -29,6 +29,9 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService +#TODO: change this to PrinterMalwareCheckService before deploy +MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService + METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog STATUSPAGE_URL=https://2p66nmmycsj3.statuspage.io diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py index 263fa82fa684..7b01dc4723d6 100644 --- a/tests/common/db/malware.py +++ b/tests/common/db/malware.py @@ -15,7 +15,12 @@ import factory import factory.fuzzy -from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType +from warehouse.malware.models import ( + MalwareCheck, + MalwareCheckObjectType, + MalwareCheckState, + MalwareCheckType, +) from .base import WarehouseFactory @@ -29,11 +34,7 @@ class Meta: short_description = factory.fuzzy.FuzzyText(length=80) long_description = factory.fuzzy.FuzzyText(length=300) check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType]) - hook_name = ( - "project:release:file:upload" - if check_type == MalwareCheckType.event_hook - else None - ) + hooked_object = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckObjectType]) state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState]) created = factory.fuzzy.FuzzyNaiveDateTime( datetime.datetime.utcnow() - datetime.timedelta(days=7) diff --git a/tests/conftest.py b/tests/conftest.py index 3623ec94998c..1b3a9b8b93e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -174,6 +174,9 @@ def app_config(database): "files.backend": "warehouse.packaging.services.LocalFileStorage", "docs.backend": "warehouse.packaging.services.LocalFileStorage", "mail.backend": "warehouse.email.services.SMTPEmailSender", + "malware_check.backend": ( + "warehouse.malware.services.PrinterMalwareCheckService" + ), "files.url": "http://localhost:7000/", "sessions.secret": "123456", "sessions.url": "redis://localhost:0/", diff --git a/tests/unit/malware/__init__.py b/tests/unit/malware/__init__.py new file mode 100644 index 000000000000..164f68b09175 --- /dev/null +++ b/tests/unit/malware/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py new file mode 100644 index 000000000000..5db87f2c396e --- /dev/null +++ b/tests/unit/malware/test_init.py @@ -0,0 +1,171 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import pretend + +from warehouse import malware +from warehouse.malware import utils +from warehouse.malware.interfaces import IMalwareCheckService + +from ...common.db.accounts import UserFactory +from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory + + +def test_determine_malware_checks_no_checks(monkeypatch, db_request): + def get_enabled_checks(session): + return defaultdict(list) + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + + session = pretend.stub(info={}, new={file0, release, project}, dirty={}, deleted={}) + + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + assert session.info["warehouse.malware.checks"] == set() + + +def test_determine_malware_checks_nothing_new(monkeypatch, db_request): + def get_enabled_checks(session): + result = defaultdict(list) + result["File"] = ["Check1", "Check2"] + result["Release"] = ["Check3"] + return result + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + + session = pretend.stub(info={}, new={}, dirty={file0, release}, deleted={}) + + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + assert session.info.get("warehouse.malware.checks") is None + + +def test_determine_malware_checks_unsupported_object(monkeypatch, db_request): + def get_enabled_checks(session): + result = defaultdict(list) + result["File"] = ["Check1", "Check2"] + result["Release"] = ["Check3"] + return result + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + user = UserFactory.create() + + session = pretend.stub(info={}, new={user}, dirty={}, deleted={}) + + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + assert session.info.get("warehouse.malware.checks") is None + + +def test_determine_malware_checks_file_only(monkeypatch, db_request): + def get_enabled_checks(session): + result = defaultdict(list) + result["File"] = ["Check1", "Check2"] + result["Release"] = ["Check3"] + return result + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + + session = pretend.stub(info={}, new={file0}, dirty={}, deleted={}) + + checks = set(["Check%d:%s" % (x, file0.id) for x in range(1, 3)]) + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + assert session.info["warehouse.malware.checks"] == checks + + +def test_determine_malware_checks_file_and_release(monkeypatch, db_request): + def get_enabled_checks(session): + result = defaultdict(list) + result["File"] = ["Check1", "Check2"] + result["Release"] = ["Check3"] + return result + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + file1 = FileFactory.create(release=release, filename="foo.baz") + + session = pretend.stub( + info={}, new={project, release, file0, file1}, dirty={}, deleted={} + ) + + checks = set(["Check%d:%s" % (x, file0.id) for x in range(1, 3)]) + checks.update(["Check%d:%s" % (x, file1.id) for x in range(1, 3)]) + checks.add("Check3:%s" % release.id) + + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + + assert session.info["warehouse.malware.checks"] == checks + + +def test_enqueue_malware_checks(app_config): + malware_check = pretend.stub( + run_checks=pretend.call_recorder(lambda malware_checks: None) + ) + factory = pretend.call_recorder(lambda ctx, config: malware_check) + app_config.register_service_factory(factory, IMalwareCheckService) + app_config.commit() + session = pretend.stub( + info={ + "warehouse.malware.checks": {"Check1:ba70267f-fabf-496f-9ac2-d237a983b187"} + } + ) + + malware.queue_malware_checks(app_config, session) + + assert factory.calls == [pretend.call(None, app_config)] + assert malware_check.run_checks.calls == [ + pretend.call({"Check1:ba70267f-fabf-496f-9ac2-d237a983b187"}) + ] + assert "warehouse.malware.checks" not in session.info + + +def test_enqueue_malware_checks_no_checks(app_config): + session = pretend.stub(info={}) + malware.queue_malware_checks(app_config, session) + assert "warehouse.malware.checks" not in session.info + + +def test_includeme(): + malware_check_class = pretend.stub( + create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub()) + ) + + config = pretend.stub( + maybe_dotted=lambda dotted: malware_check_class, + register_service_factory=pretend.call_recorder( + lambda factory, iface, name=None: None + ), + registry=pretend.stub( + settings={"malware_check.backend": "TestMalwareCheckService"} + ), + ) + + malware.includeme(config) + + assert config.register_service_factory.calls == [ + pretend.call(malware_check_class.create_service, IMalwareCheckService), + ] diff --git a/tests/unit/malware/test_services.py b/tests/unit/malware/test_services.py new file mode 100644 index 000000000000..7a9cb636f720 --- /dev/null +++ b/tests/unit/malware/test_services.py @@ -0,0 +1,61 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pretend + +from zope.interface.verify import verifyClass + +from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.services import ( + DatabaseMalwareCheckService, + PrinterMalwareCheckService, +) +from warehouse.malware.tasks import run_check + + +class TestPrinterMalwareCheckService: + def test_verify_service(self): + assert verifyClass(IMalwareCheckService, PrinterMalwareCheckService) + + def test_create_service(self): + request = pretend.stub() + service = PrinterMalwareCheckService.create_service(None, request) + assert service.executor == print + + def test_run_checks(self, capfd): + request = pretend.stub() + service = PrinterMalwareCheckService.create_service(None, request) + checks = ["one", "two", "three"] + service.run_checks(checks) + out, err = capfd.readouterr() + assert out == "one\ntwo\nthree\n" + + +class TestDatabaseMalwareService: + def test_verify_service(self): + assert verifyClass(IMalwareCheckService, DatabaseMalwareCheckService) + + def test_create_service(self, db_request): + _delay = pretend.call_recorder(lambda *args: None) + db_request.task = lambda x: pretend.stub(delay=_delay) + service = DatabaseMalwareCheckService.create_service(None, db_request) + assert service.executor == db_request.task(run_check).delay + + def test_run_checks(self, db_request): + _delay = pretend.call_recorder(lambda *args: None) + db_request.task = lambda x: pretend.stub(delay=_delay) + service = DatabaseMalwareCheckService.create_service(None, db_request) + checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"] + service.run_checks(checks) + assert _delay.calls == [ + pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187") + ] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py new file mode 100644 index 000000000000..38f79b201ba2 --- /dev/null +++ b/tests/unit/malware/test_tasks.py @@ -0,0 +1,85 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import celery +import pretend +import pytest + +from sqlalchemy.orm.exc import NoResultFound + +import warehouse.malware.checks as checks + +from warehouse.malware.models import MalwareVerdict +from warehouse.malware.tasks import run_check + +from ...common.db.malware import MalwareCheckFactory +from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory + + +def test_run_check(monkeypatch, db_request): + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + MalwareCheckFactory.create(name="ExampleCheck", state="enabled") + + task = pretend.stub() + run_check(task, db_request, "ExampleCheck", file0.id) + assert db_request.db.query(MalwareVerdict).one() + + +def test_run_check_missing_check_id(monkeypatch, db_session): + exc = NoResultFound("No row was found for one()") + + class FakeMalwareCheck: + def __init__(self, db): + raise exc + + class Task: + @staticmethod + @pretend.call_recorder + def retry(exc): + raise celery.exceptions.Retry + + task = Task() + + checks.FakeMalwareCheck = FakeMalwareCheck + + request = pretend.stub( + db=db_session, + log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + with pytest.raises(celery.exceptions.Retry): + run_check( + task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" + ) + + assert request.log.error.calls == [ + pretend.call( + "Error executing check %s: %s", + "FakeMalwareCheck", + "No row was found for one()", + ) + ] + + assert task.retry.calls == [pretend.call(exc=exc)] + + +def test_run_check_missing_check(db_request): + task = pretend.stub() + with pytest.raises(AttributeError): + run_check( + task, + db_request, + "DoesNotExistCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) diff --git a/tests/unit/malware/test_utils.py b/tests/unit/malware/test_utils.py new file mode 100644 index 000000000000..b995147f6872 --- /dev/null +++ b/tests/unit/malware/test_utils.py @@ -0,0 +1,47 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +from warehouse.malware.models import MalwareCheckState, MalwareCheckType +from warehouse.malware.utils import get_enabled_checks + +from ...common.db.malware import MalwareCheckFactory + + +def test_get_enabled_checks(db_session): + check = MalwareCheckFactory.create( + state=MalwareCheckState.enabled, check_type=MalwareCheckType.event_hook + ) + result = defaultdict(list) + result[check.hooked_object.value].append(check.name) + checks = get_enabled_checks(db_session) + assert checks == result + + +def test_get_enabled_checks_many(db_session): + result = defaultdict(list) + for i in range(10): + check = MalwareCheckFactory.create() + if ( + check.state == MalwareCheckState.enabled + and check.check_type == MalwareCheckType.event_hook + ): + result[check.hooked_object.value].append(check.name) + + checks = get_enabled_checks(db_session) + assert checks == result + + +def test_get_enabled_checks_none(db_session): + checks = get_enabled_checks(db_session) + assert checks == defaultdict(list) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 53fab7d8235a..d65a976bc91b 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -317,6 +317,7 @@ def __init__(self): pretend.call(".email"), pretend.call(".accounts"), pretend.call(".macaroons"), + pretend.call(".malware"), pretend.call(".manage"), pretend.call(".packaging"), pretend.call(".redirects"), diff --git a/warehouse/config.py b/warehouse/config.py index 84a49e19d814..fb7efcbe7cd9 100644 --- a/warehouse/config.py +++ b/warehouse/config.py @@ -203,6 +203,7 @@ def configure(settings=None): maybe_set_compound(settings, "mail", "backend", "MAIL_BACKEND") maybe_set_compound(settings, "metrics", "backend", "METRICS_BACKEND") maybe_set_compound(settings, "breached_passwords", "backend", "BREACHED_PASSWORDS") + maybe_set_compound(settings, "malware_check", "backend", "MALWARE_CHECK_BACKEND") # Add the settings we use when the environment is set to development. if settings["warehouse.env"] == Environment.development: @@ -389,6 +390,9 @@ def configure(settings=None): # Register support for Macaroon based authentication config.include(".macaroons") + # Register support for malware checks + config.include(".malware") + # Register logged-in views config.include(".manage") diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py index 164f68b09175..ee0e36b808aa 100644 --- a/warehouse/malware/__init__.py +++ b/warehouse/malware/__init__.py @@ -9,3 +9,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from warehouse import db +from warehouse.malware import utils +from warehouse.malware.interfaces import IMalwareCheckService + + +@db.listens_for(db.Session, "after_flush") +def determine_malware_checks(config, session, flush_context): + if not session.new: + return + + if not any( + [ + obj.__class__.__name__ + for obj in session.new + if obj.__class__.__name__ in utils.valid_check_types() + ] + ): + return + + malware_checks = session.info.setdefault("warehouse.malware.checks", set()) + enabled_checks = utils.get_enabled_checks(session) + for obj in session.new: + for check_name in enabled_checks.get(obj.__class__.__name__, []): + malware_checks.update([f"{check_name}:{obj.id}"]) + + +@db.listens_for(db.Session, "after_commit") +def queue_malware_checks(config, session): + + malware_checks = session.info.pop("warehouse.malware.checks", set()) + if not malware_checks: + return + + malware_check_factory = config.find_service_factory(IMalwareCheckService) + + malware_check = malware_check_factory(None, config) + malware_check.run_checks(malware_checks) + + +def includeme(config): + malware_check_class = config.maybe_dotted( + config.registry.settings["malware_check.backend"] + ) + # Register the malware check service + config.register_service_factory( + malware_check_class.create_service, IMalwareCheckService + ) diff --git a/warehouse/malware/checks/__init__.py b/warehouse/malware/checks/__init__.py new file mode 100644 index 000000000000..a627b7d18159 --- /dev/null +++ b/warehouse/malware/checks/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .example import ExampleCheck # noqa diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py new file mode 100644 index 000000000000..b5102bfe6f71 --- /dev/null +++ b/warehouse/malware/checks/base.py @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from warehouse.malware.models import MalwareCheck, MalwareCheckState + + +class MalwareCheckBase: + def __init__(self, db): + self.db = db + self._name = self.__class__.__name__ + self._load_check() + + def run(self, obj_id): + """ + Executes the check. + """ + + def backfill(self, sample=1): + """ + Runs the check across all historical data in PyPI. The sample value represents + the percentage of files to file the check against. By default, it will run the + backfill on the entire corpus. + """ + + def update(self): + """ + Update the check definition in the database. + """ + + def _load_check(self): + self.id = ( + self.db.query(MalwareCheck.id) + .filter(MalwareCheck.name == self._name) + .filter(MalwareCheck.state == MalwareCheckState.enabled) + .one() + ) diff --git a/warehouse/malware/checks/example.py b/warehouse/malware/checks/example.py new file mode 100644 index 000000000000..c55748cdaf44 --- /dev/null +++ b/warehouse/malware/checks/example.py @@ -0,0 +1,43 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.models import ( + MalwareVerdict, + VerdictClassification, + VerdictConfidence, +) + +VERSION = 1 +SHORT_DESCRIPTION = "An example hook-based check" +LONG_DESCRIPTION = """The purpose of this check is to demonstrate the implementation \ +of a hook-based check. This check will generate verdicts if enabled.""" + + +class ExampleCheck(MalwareCheckBase): + + version = VERSION + short_description = SHORT_DESCRIPTION + long_description = LONG_DESCRIPTION + + def __init__(self, db): + super().__init__(db) + + def run(self, file_id): + verdict = MalwareVerdict( + check_id=self.id, + file_id=file_id, + classification=VerdictClassification.benign, + confidence=VerdictConfidence.High, + message="Nothing to see here!", + ) + self.db.add(verdict) diff --git a/warehouse/malware/interfaces.py b/warehouse/malware/interfaces.py new file mode 100644 index 000000000000..f179aa374d55 --- /dev/null +++ b/warehouse/malware/interfaces.py @@ -0,0 +1,26 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from zope.interface import Interface + + +class IMalwareCheckService(Interface): + def create_service(context, request): + """ + Create the service, given the context and request for which it is being + created for. + """ + + def run_checks(checks): + """ + Run a given set of Checks + """ diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index 0c51e3006991..257e7bfa2bd5 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -46,6 +46,13 @@ class MalwareCheckState(enum.Enum): wiped_out = "wiped_out" +class MalwareCheckObjectType(enum.Enum): + + File = "File" + Release = "Release" + Project = "Project" + + class VerdictClassification(enum.Enum): threat = "threat" @@ -74,9 +81,12 @@ class MalwareCheck(db.Model): Enum(MalwareCheckType, values_callable=lambda x: [e.value for e in x]), nullable=False, ) - # This field contains the same content as the ProjectEvent and UserEvent "tag" - # fields. - hook_name = Column(String, nullable=True) + # This field contains the object name that check operates on, e.g. + # Project, File, Release + hooked_object = Column( + Enum(MalwareCheckObjectType, values_callable=lambda x: [e.value for e in x]), + nullable=True, + ) state = Column( Enum(MalwareCheckState, values_callable=lambda x: [e.value for e in x]), nullable=False, diff --git a/warehouse/malware/services.py b/warehouse/malware/services.py new file mode 100644 index 000000000000..f2f454b964e2 --- /dev/null +++ b/warehouse/malware/services.py @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from zope.interface import implementer + +from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.tasks import run_check + + +@implementer(IMalwareCheckService) +class PrinterMalwareCheckService: + def __init__(self, executor): + self.executor = executor + + @classmethod + def create_service(cls, context, request): + return cls(print) + + def run_checks(self, checks): + for check in checks: + self.executor(check) + + +@implementer(IMalwareCheckService) +class DatabaseMalwareCheckService: + def __init__(self, executor): + self.executor = executor + + @classmethod + def create_service(cls, context, request): + return cls(request.task(run_check).delay) + + def run_checks(self, checks): + for check_info in checks: + check_name, obj_id = check_info.split(":") + self.executor(check_name, obj_id) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py new file mode 100644 index 000000000000..ade9f7a9ae73 --- /dev/null +++ b/warehouse/malware/tasks.py @@ -0,0 +1,26 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import warehouse.malware.checks as checks + +from warehouse.tasks import task + + +@task(bind=True, ignore_result=True, acks_late=True) +def run_check(task, request, check_name, obj_id): + try: + check = getattr(checks, check_name)(request.db) + check.run(obj_id) + except Exception as exc: + request.log.error("Error executing check %s: %s", check_name, str(exc)) + raise task.retry(exc=exc) diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py new file mode 100644 index 000000000000..b0a5bf5b49e7 --- /dev/null +++ b/warehouse/malware/utils.py @@ -0,0 +1,42 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + +from collections import defaultdict + +from warehouse.malware.models import ( + MalwareCheck, + MalwareCheckObjectType, + MalwareCheckState, + MalwareCheckType, +) + + +@functools.lru_cache() +def valid_check_types(): + return set([t.value for t in MalwareCheckObjectType]) + + +def get_enabled_checks(session): + checks = ( + session.query(MalwareCheck.name, MalwareCheck.hooked_object) + .filter(MalwareCheck.check_type == MalwareCheckType.event_hook) + .filter(MalwareCheck.state == MalwareCheckState.enabled) + .all() + ) + results = defaultdict(list) + + for check_name, object_type in checks: + results[object_type.value].append(check_name) + + return results diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py index 569cc0f100b1..e74a9ddabe94 100644 --- a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -31,6 +31,10 @@ "enabled", "evaluation", "disabled", "wiped_out", name="malwarecheckstate" ) +MalwareCheckObjectTypes = sa.Enum( + "File", "Release", "Project", name="malwarecheckobjecttype" +) + VerdictClassifications = sa.Enum( "threat", "indeterminate", "benign", name="verdictclassification" ) @@ -51,7 +55,7 @@ def upgrade(): sa.Column("short_description", sa.String(length=128), nullable=False), sa.Column("long_description", sa.Text(), nullable=False), sa.Column("check_type", MalwareCheckTypes, nullable=False), - sa.Column("hook_name", sa.String(), nullable=True), + sa.Column("hooked_object", MalwareCheckObjectTypes, nullable=True), sa.Column( "state", MalwareCheckStates, server_default="disabled", nullable=False, ), @@ -106,5 +110,6 @@ def downgrade(): op.drop_table("malware_checks") MalwareCheckTypes.drop(op.get_bind()) MalwareCheckStates.drop(op.get_bind()) + MalwareCheckObjectTypes.drop(op.get_bind()) VerdictClassifications.drop(op.get_bind()) VerdictConfidences.drop(op.get_bind()) From 7fd996404d4d4a25c3653cb471d136c8d5139a14 Mon Sep 17 00:00:00 2001 From: Cristina Date: Tue, 7 Jan 2020 07:20:31 -0800 Subject: [PATCH 04/16] Add malware check syncing mechanism (#7190) * Add malware check syncing mechanism * Code review changes. --- bin/release | 3 + tests/unit/cli/test_malware.py | 36 ++++ tests/unit/malware/test_checks.py | 45 +++++ tests/unit/malware/test_tasks.py | 258 +++++++++++++++++++++++----- tests/unit/malware/test_utils.py | 74 +++++--- warehouse/cli/malware.py | 34 ++++ warehouse/malware/checks/example.py | 14 +- warehouse/malware/tasks.py | 62 +++++++ warehouse/malware/utils.py | 12 ++ 9 files changed, 464 insertions(+), 74 deletions(-) create mode 100644 tests/unit/cli/test_malware.py create mode 100644 tests/unit/malware/test_checks.py create mode 100644 warehouse/cli/malware.py diff --git a/bin/release b/bin/release index 0edce3b489f7..1759c65fe3d4 100755 --- a/bin/release +++ b/bin/release @@ -5,3 +5,6 @@ set -eo pipefail # Migrate our database to the latest revision. python -m warehouse db upgrade head + +# Insert/upgrade malware checks. +python -m warehouse malware sync-checks diff --git a/tests/unit/cli/test_malware.py b/tests/unit/cli/test_malware.py new file mode 100644 index 000000000000..69613bf4ace1 --- /dev/null +++ b/tests/unit/cli/test_malware.py @@ -0,0 +1,36 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pretend + +from warehouse.cli.malware import sync_checks +from warehouse.malware.tasks import sync_checks as _sync_checks + + +class TestCLIMalware: + def test_sync_checks(self, cli): + request = pretend.stub() + task = pretend.stub( + get_request=pretend.call_recorder(lambda *a, **kw: request), + run=pretend.call_recorder(lambda *a, **kw: None), + ) + config = pretend.stub(task=pretend.call_recorder(lambda *a, **kw: task)) + + result = cli.invoke(sync_checks, obj=config) + + assert result.exit_code == 0 + assert config.task.calls == [ + pretend.call(_sync_checks), + pretend.call(_sync_checks), + ] + assert task.get_request.calls == [pretend.call()] + assert task.run.calls == [pretend.call(request)] diff --git a/tests/unit/malware/test_checks.py b/tests/unit/malware/test_checks.py new file mode 100644 index 000000000000..2ce63c624965 --- /dev/null +++ b/tests/unit/malware/test_checks.py @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import warehouse.malware.checks as checks + +from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.utils import get_check_fields + + +def test_checks_subclass_base(): + checks_from_module = inspect.getmembers(checks, inspect.isclass) + + subclasses_of_malware_base = { + cls.__name__: cls for cls in MalwareCheckBase.__subclasses__() + } + + assert len(checks_from_module) == len(subclasses_of_malware_base) + + for check_name, check in checks_from_module: + assert subclasses_of_malware_base[check_name] == check + + +def test_checks_fields(): + checks_from_module = inspect.getmembers(checks, inspect.isclass) + + for check_name, check in checks_from_module: + elems = inspect.getmembers(check, lambda a: not (inspect.isroutine(a))) + inspection_fields = {"name": check_name} + for elem_name, value in elems: + if not elem_name.startswith("__"): + inspection_fields[elem_name] = value + fields = get_check_fields(check) + + assert inspection_fields == fields diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 38f79b201ba2..1057af6855a5 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -18,68 +18,240 @@ import warehouse.malware.checks as checks -from warehouse.malware.models import MalwareVerdict -from warehouse.malware.tasks import run_check +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict +from warehouse.malware.tasks import run_check, sync_checks from ...common.db.malware import MalwareCheckFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory -def test_run_check(monkeypatch, db_request): - project = ProjectFactory.create(name="foo") - release = ReleaseFactory.create(project=project) - file0 = FileFactory.create(release=release, filename="foo.bar") - MalwareCheckFactory.create(name="ExampleCheck", state="enabled") +class TestRunCheck: + def test_success(self, monkeypatch, db_request): + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) - task = pretend.stub() - run_check(task, db_request, "ExampleCheck", file0.id) - assert db_request.db.query(MalwareVerdict).one() + task = pretend.stub() + run_check(task, db_request, "ExampleCheck", file0.id) + assert db_request.db.query(MalwareVerdict).one() + def test_missing_check_id(self, monkeypatch, db_session): + exc = NoResultFound("No row was found for one()") -def test_run_check_missing_check_id(monkeypatch, db_session): - exc = NoResultFound("No row was found for one()") + class FakeMalwareCheck: + def __init__(self, db): + raise exc - class FakeMalwareCheck: - def __init__(self, db): - raise exc + checks.FakeMalwareCheck = FakeMalwareCheck - class Task: - @staticmethod - @pretend.call_recorder - def retry(exc): - raise celery.exceptions.Retry + class Task: + @staticmethod + @pretend.call_recorder + def retry(exc): + raise celery.exceptions.Retry - task = Task() + task = Task() - checks.FakeMalwareCheck = FakeMalwareCheck + request = pretend.stub( + db=db_session, + log=pretend.stub( + error=pretend.call_recorder(lambda *args, **kwargs: None), + ), + ) + + with pytest.raises(celery.exceptions.Retry): + run_check( + task, + request, + "FakeMalwareCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) + + assert request.log.error.calls == [ + pretend.call( + "Error executing check %s: %s", + "FakeMalwareCheck", + "No row was found for one()", + ) + ] + + assert task.retry.calls == [pretend.call(exc=exc)] + + del checks.FakeMalwareCheck + + def test_missing_check(self, db_request): + task = pretend.stub() + with pytest.raises(AttributeError): + run_check( + task, + db_request, + "DoesNotExistCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) + + +class TestSyncChecks: + def test_no_updates(self, db_session): + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.disabled + ) + + task = pretend.stub() + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + sync_checks(task, request) + + assert request.log.info.calls == [ + pretend.call("1 malware checks found in codebase."), + pretend.call("ExampleCheck is unmodified."), + ] - request = pretend.stub( - db=db_session, - log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),), + @pytest.mark.parametrize( + ("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled] ) + def test_upgrade_check(self, monkeypatch, db_session, final_state): + MalwareCheckFactory.create(name="ExampleCheck", state=final_state) + + class ExampleCheck: + version = 2 + short_description = "This is a short description." + long_description = "This is a longer description." + check_type = "scheduled" + + monkeypatch.setattr(checks, "ExampleCheck", ExampleCheck) + + task = pretend.stub() + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + sync_checks(task, request) + + assert request.log.info.calls == [ + pretend.call("1 malware checks found in codebase."), + pretend.call("Updating existing ExampleCheck."), + ] + db_checks = ( + db_session.query(MalwareCheck) + .filter(MalwareCheck.name == "ExampleCheck") + .all() + ) + + assert len(db_checks) == 2 - with pytest.raises(celery.exceptions.Retry): - run_check( - task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" + if final_state == MalwareCheckState.disabled: + assert ( + db_checks[0].state == db_checks[1].state == MalwareCheckState.disabled + ) + + else: + for c in db_checks: + if c.state == final_state: + assert c.version == 2 + else: + assert c.version == 1 + + def test_one_new_check(self, db_session): + task = pretend.stub() + + class FakeMalwareCheck: + version = 1 + short_description = "This is a short description." + long_description = "This is a longer description." + check_type = "scheduled" + + checks.FakeMalwareCheck = FakeMalwareCheck + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.evaluation + ) + + sync_checks(task, request) + + assert request.log.info.calls == [ + pretend.call("2 malware checks found in codebase."), + pretend.call("ExampleCheck is unmodified."), + pretend.call("Adding new FakeMalwareCheck to the database."), + ] + assert db_session.query(MalwareCheck).count() == 2 + + new_check = ( + db_session.query(MalwareCheck) + .filter(MalwareCheck.name == "FakeMalwareCheck") + .one() ) - assert request.log.error.calls == [ - pretend.call( - "Error executing check %s: %s", - "FakeMalwareCheck", - "No row was found for one()", + assert new_check.state == MalwareCheckState.disabled + + del checks.FakeMalwareCheck + + def test_too_many_db_checks(self, db_session): + task = pretend.stub() + + MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + MalwareCheckFactory.create( + name="AnotherCheck", state=MalwareCheckState.disabled + ) + MalwareCheckFactory.create( + name="AnotherCheck", state=MalwareCheckState.evaluation, version=2 + ) + + request = pretend.stub( + db=db_session, + log=pretend.stub( + info=pretend.call_recorder(lambda *args, **kwargs: None), + error=pretend.call_recorder(lambda *args, **kwargs: None), + ), ) - ] - assert task.retry.calls == [pretend.call(exc=exc)] + with pytest.raises(Exception): + sync_checks(task, request) + assert request.log.info.calls == [ + pretend.call("1 malware checks found in codebase."), + ] -def test_run_check_missing_check(db_request): - task = pretend.stub() - with pytest.raises(AttributeError): - run_check( - task, - db_request, - "DoesNotExistCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + assert request.log.error.calls == [ + pretend.call( + "Found 2 active checks in the db, but only 1 checks in code. Please \ +manually move superfluous checks to the wiped_out state in the check admin: \ +AnotherCheck" + ), + ] + + def test_only_wiped_out(self, db_session): + task = pretend.stub() + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.wiped_out + ) + request = pretend.stub( + db=db_session, + log=pretend.stub( + info=pretend.call_recorder(lambda *args, **kwargs: None), + error=pretend.call_recorder(lambda *args, **kwargs: None), + ), ) + + sync_checks(task, request) + + assert request.log.info.calls == [ + pretend.call("1 malware checks found in codebase."), + ] + + assert request.log.error.calls == [ + pretend.call( + "ExampleCheck is wiped_out and cannot be synced. Please remove check \ +from codebase." + ), + ] diff --git a/tests/unit/malware/test_utils.py b/tests/unit/malware/test_utils.py index b995147f6872..0f6ea532debf 100644 --- a/tests/unit/malware/test_utils.py +++ b/tests/unit/malware/test_utils.py @@ -12,36 +12,64 @@ from collections import defaultdict +import pytest + from warehouse.malware.models import MalwareCheckState, MalwareCheckType -from warehouse.malware.utils import get_enabled_checks +from warehouse.malware.utils import get_check_fields, get_enabled_checks from ...common.db.malware import MalwareCheckFactory -def test_get_enabled_checks(db_session): - check = MalwareCheckFactory.create( - state=MalwareCheckState.enabled, check_type=MalwareCheckType.event_hook - ) - result = defaultdict(list) - result[check.hooked_object.value].append(check.name) - checks = get_enabled_checks(db_session) - assert checks == result +class TestGetEnabledChecks: + def test_one(self, db_session): + check = MalwareCheckFactory.create( + state=MalwareCheckState.enabled, check_type=MalwareCheckType.event_hook + ) + result = defaultdict(list) + result[check.hooked_object.value].append(check.name) + checks = get_enabled_checks(db_session) + assert checks == result + + def test_many(self, db_session): + result = defaultdict(list) + for i in range(10): + check = MalwareCheckFactory.create() + if ( + check.state == MalwareCheckState.enabled + and check.check_type == MalwareCheckType.event_hook + ): + result[check.hooked_object.value].append(check.name) + + checks = get_enabled_checks(db_session) + assert checks == result + + def test_none(self, db_session): + checks = get_enabled_checks(db_session) + assert checks == defaultdict(list) -def test_get_enabled_checks_many(db_session): - result = defaultdict(list) - for i in range(10): - check = MalwareCheckFactory.create() - if ( - check.state == MalwareCheckState.enabled - and check.check_type == MalwareCheckType.event_hook - ): - result[check.hooked_object.value].append(check.name) +class TestGetCheckFields: + def test_success(self): + class MySampleCheck: + version = 6 + foo = "bar" + short_description = "This is the description" + long_description = "This is the description" + check_type = "scheduled" - checks = get_enabled_checks(db_session) - assert checks == result + result = get_check_fields(MySampleCheck) + assert result == { + "name": "MySampleCheck", + "version": 6, + "short_description": "This is the description", + "long_description": "This is the description", + "check_type": "scheduled", + } + def test_failure(self): + class MySampleCheck: + version = 1 + status = True -def test_get_enabled_checks_none(db_session): - checks = get_enabled_checks(db_session) - assert checks == defaultdict(list) + with pytest.raises(AttributeError): + get_check_fields(MySampleCheck) diff --git a/warehouse/cli/malware.py b/warehouse/cli/malware.py new file mode 100644 index 000000000000..ad08f557ebaf --- /dev/null +++ b/warehouse/cli/malware.py @@ -0,0 +1,34 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import click + +from warehouse.cli import warehouse +from warehouse.malware.tasks import sync_checks as _sync_checks + + +@warehouse.group() # pragma: no branch +def malware(): + """ + Manage the Warehouse Malware Checks. + """ + + +@malware.command() +@click.pass_obj +def sync_checks(config): + """ + Sync the Warehouse database with the malware checks in malware/checks. + """ + + request = config.task(_sync_checks).get_request() + config.task(_sync_checks).run(request) diff --git a/warehouse/malware/checks/example.py b/warehouse/malware/checks/example.py index c55748cdaf44..519edecfd4df 100644 --- a/warehouse/malware/checks/example.py +++ b/warehouse/malware/checks/example.py @@ -17,17 +17,15 @@ VerdictConfidence, ) -VERSION = 1 -SHORT_DESCRIPTION = "An example hook-based check" -LONG_DESCRIPTION = """The purpose of this check is to demonstrate the implementation \ -of a hook-based check. This check will generate verdicts if enabled.""" - class ExampleCheck(MalwareCheckBase): - version = VERSION - short_description = SHORT_DESCRIPTION - long_description = LONG_DESCRIPTION + version = 1 + short_description = "An example hook-based check" + long_description = """The purpose of this check is to demonstrate the \ +implementation of a hook-based check. This check will generate verdicts if enabled.""" + check_type = "event_hook" + hooked_object = "File" def __init__(self, db): super().__init__(db) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index ade9f7a9ae73..1548d28e66d7 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -10,9 +10,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import warehouse.malware.checks as checks +from warehouse.malware.models import MalwareCheck, MalwareCheckState +from warehouse.malware.utils import get_check_fields from warehouse.tasks import task @@ -24,3 +27,62 @@ def run_check(task, request, check_name, obj_id): except Exception as exc: request.log.error("Error executing check %s: %s", check_name, str(exc)) raise task.retry(exc=exc) + + +@task(bind=True, ignore_result=True, acks_late=True) +def sync_checks(task, request): + code_checks = inspect.getmembers(checks, inspect.isclass) + request.log.info("%d malware checks found in codebase." % len(code_checks)) + + all_checks = request.db.query(MalwareCheck).all() + active_checks = {} + wiped_out_checks = {} + for check in all_checks: + if not check.is_stale: + if check.state == MalwareCheckState.wiped_out: + wiped_out_checks[check.name] = check + else: + active_checks[check.name] = check + + if len(active_checks) > len(code_checks): + code_check_names = set([name for name, cls in code_checks]) + missing = ", ".join(set(active_checks.keys()) - code_check_names) + request.log.error( + "Found %d active checks in the db, but only %d checks in \ +code. Please manually move superfluous checks to the wiped_out state \ +in the check admin: %s" + % (len(active_checks), len(code_checks), missing) + ) + raise Exception("Mismatch between number of db checks and code checks.") + + for check_name, check_class in code_checks: + check = getattr(checks, check_name) + + if wiped_out_checks.get(check_name): + request.log.error( + "%s is wiped_out and cannot be synced. Please remove check from \ +codebase." + % check_name + ) + continue + + db_check = active_checks.get(check_name) + if db_check: + if check.version == db_check.version: + request.log.info("%s is unmodified." % check_name) + continue + + request.log.info("Updating existing %s." % check_name) + fields = get_check_fields(check) + + # Migrate the check state to the newest check. + # Then mark the old check state as disabled. + if db_check.state != MalwareCheckState.disabled: + fields["state"] = db_check.state.value + db_check.state = MalwareCheckState.disabled + + request.db.add(MalwareCheck(**fields)) + else: + request.log.info("Adding new %s to the database." % check_name) + fields = get_check_fields(check) + request.db.add(MalwareCheck(**fields)) diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py index b0a5bf5b49e7..23af5b0a578e 100644 --- a/warehouse/malware/utils.py +++ b/warehouse/malware/utils.py @@ -27,6 +27,18 @@ def valid_check_types(): return set([t.value for t in MalwareCheckObjectType]) +def get_check_fields(check): + result = {"name": check.__name__} + required_fields = ["short_description", "long_description", "version", "check_type"] + for field in required_fields: + result[field] = getattr(check, field) + + if result["check_type"] == "event_hook": + result["hooked_object"] = check.hooked_object + + return result + + def get_enabled_checks(session): checks = ( session.query(MalwareCheck.name, MalwareCheck.hooked_object) From d4dbeed7c9c4f8dd51775fd52df0f62afc713827 Mon Sep 17 00:00:00 2001 From: Cristina Date: Tue, 7 Jan 2020 12:07:32 -0800 Subject: [PATCH 05/16] Refactor MalwareCheckBase. Fixes #7091. (#7196) * Refactor MalwareCheckBase. Fixes #7091. Add Foreign Keys in MalwareVerdicts for other types of objects (Releases, Projects). * Change verdict dict to kwargs. --- warehouse/malware/checks/base.py | 24 ++++++++++++------- warehouse/malware/checks/example.py | 12 +++------- warehouse/malware/models.py | 6 ++++- ...1ff3d24c22_add_malware_detection_tables.py | 6 ++++- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py index b5102bfe6f71..6c94937abbde 100644 --- a/warehouse/malware/checks/base.py +++ b/warehouse/malware/checks/base.py @@ -10,18 +10,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from warehouse.malware.models import MalwareCheck, MalwareCheckState +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict class MalwareCheckBase: def __init__(self, db): self.db = db self._name = self.__class__.__name__ - self._load_check() + self._load_check_id() + self._verdicts = [] + + def add_verdict(self, **kwargs): + self._verdicts.append(MalwareVerdict(check_id=self.id, **kwargs)) def run(self, obj_id): """ - Executes the check. + Runs the check and inserts returned verdicts. + """ + self.scan(obj_id) + self.db.add_all(self._verdicts) + + def scan(self, obj_id): + """ + Scans the object and returns a verdict. """ def backfill(self, sample=1): @@ -31,12 +42,7 @@ def backfill(self, sample=1): backfill on the entire corpus. """ - def update(self): - """ - Update the check definition in the database. - """ - - def _load_check(self): + def _load_check_id(self): self.id = ( self.db.query(MalwareCheck.id) .filter(MalwareCheck.name == self._name) diff --git a/warehouse/malware/checks/example.py b/warehouse/malware/checks/example.py index 519edecfd4df..22b91906ffa3 100644 --- a/warehouse/malware/checks/example.py +++ b/warehouse/malware/checks/example.py @@ -11,11 +11,7 @@ # limitations under the License. from warehouse.malware.checks.base import MalwareCheckBase -from warehouse.malware.models import ( - MalwareVerdict, - VerdictClassification, - VerdictConfidence, -) +from warehouse.malware.models import VerdictClassification, VerdictConfidence class ExampleCheck(MalwareCheckBase): @@ -30,12 +26,10 @@ class ExampleCheck(MalwareCheckBase): def __init__(self, db): super().__init__(db) - def run(self, file_id): - verdict = MalwareVerdict( - check_id=self.id, + def scan(self, file_id): + self.add_verdict( file_id=file_id, classification=VerdictClassification.benign, confidence=VerdictConfidence.High, message="Nothing to see here!", ) - self.db.add(verdict) diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index 257e7bfa2bd5..3e9aa388a701 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -115,7 +115,9 @@ class MalwareVerdict(db.Model): nullable=False, index=True, ) - file_id = Column(ForeignKey("release_files.id"), nullable=False) + file_id = Column(ForeignKey("release_files.id"), nullable=True) + release_id = Column(ForeignKey("releases.id"), nullable=True) + project_id = Column(ForeignKey("projects.id"), nullable=True) classification = Column( Enum(VerdictClassification, values_callable=lambda x: [e.value for e in x]), nullable=False, @@ -135,3 +137,5 @@ class MalwareVerdict(db.Model): check = orm.relationship("MalwareCheck", foreign_keys=[check_id], lazy=True) release_file = orm.relationship("File", foreign_keys=[file_id], lazy=True) + release = orm.relationship("Release", foreign_keys=[release_id], lazy=True) + project = orm.relationship("Project", foreign_keys=[project_id], lazy=True) diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py index e74a9ddabe94..6e23aeb243f8 100644 --- a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -77,7 +77,9 @@ def upgrade(): "run_date", sa.DateTime(), server_default=sa.text("now()"), nullable=False ), sa.Column("check_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.Column("file_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("file_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("release_id", postgresql.UUID(as_uuid=True), nullable=True), sa.Column("classification", VerdictClassifications, nullable=False,), sa.Column("confidence", VerdictConfidences, nullable=False,), sa.Column("message", sa.Text(), nullable=True), @@ -94,6 +96,8 @@ def upgrade(): ["check_id"], ["malware_checks.id"], onupdate="CASCADE", ondelete="CASCADE" ), sa.ForeignKeyConstraint(["file_id"], ["release_files.id"]), + sa.ForeignKeyConstraint(["release_id"], ["releases.id"]), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"]), sa.PrimaryKeyConstraint("id"), ) op.create_index( From 046dbc13e0212dab7309dca5bcbad6c0768f638f Mon Sep 17 00:00:00 2001 From: Cristina Date: Wed, 8 Jan 2020 07:37:48 -0800 Subject: [PATCH 06/16] Add wipe-out functionality (#7202) * Add wipe-out functionality Related: #7133 * Call list explicitly --- tests/common/db/malware.py | 21 ++++++++-- tests/unit/admin/views/test_checks.py | 22 +++++++++-- tests/unit/malware/test_tasks.py | 56 ++++++++++++++++++++++++++- warehouse/admin/views/checks.py | 3 ++ warehouse/malware/tasks.py | 22 ++++++++++- 5 files changed, 114 insertions(+), 10 deletions(-) diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py index 7b01dc4723d6..8c365f4abb66 100644 --- a/tests/common/db/malware.py +++ b/tests/common/db/malware.py @@ -20,9 +20,13 @@ MalwareCheckObjectType, MalwareCheckState, MalwareCheckType, + MalwareVerdict, + VerdictClassification, + VerdictConfidence, ) from .base import WarehouseFactory +from .packaging import FileFactory class MalwareCheckFactory(WarehouseFactory): @@ -33,9 +37,20 @@ class Meta: version = 1 short_description = factory.fuzzy.FuzzyText(length=80) long_description = factory.fuzzy.FuzzyText(length=300) - check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType]) - hooked_object = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckObjectType]) - state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState]) + check_type = factory.fuzzy.FuzzyChoice(list(MalwareCheckType)) + hooked_object = factory.fuzzy.FuzzyChoice(list(MalwareCheckObjectType)) + state = factory.fuzzy.FuzzyChoice(list(MalwareCheckState)) created = factory.fuzzy.FuzzyNaiveDateTime( datetime.datetime.utcnow() - datetime.timedelta(days=7) ) + + +class MalwareVerdictFactory(WarehouseFactory): + class Meta: + model = MalwareVerdict + + check = factory.SubFactory(MalwareCheckFactory) + release_file = factory.SubFactory(FileFactory) + classification = factory.fuzzy.FuzzyChoice(list(VerdictClassification)) + confidence = factory.fuzzy.FuzzyChoice(list(VerdictConfidence)) + message = factory.fuzzy.FuzzyText(length=80) diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 55a45b2d5ac0..601e26e79858 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -67,17 +67,25 @@ def test_get_check_not_found(self, db_request): class TestChangeCheckState: - def test_change_to_enabled(self, db_request): + @pytest.mark.parametrize( + ("final_state"), [MalwareCheckState.disabled, MalwareCheckState.wiped_out] + ) + def test_change_to_valid_state(self, db_request, final_state): check = MalwareCheckFactory.create( name="MyCheck", state=MalwareCheckState.disabled ) - db_request.POST = {"id": check.id, "check_state": "enabled"} + db_request.POST = {"id": check.id, "check_state": final_state.value} db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( flash=pretend.call_recorder(lambda *a, **kw: None) ) + wipe_out_recorder = pretend.stub( + delay=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.task = pretend.call_recorder(lambda *a, **kw: wipe_out_recorder) + db_request.route_path = pretend.call_recorder( lambda *a, **kw: "/admin/checks/MyCheck/change_state" ) @@ -85,9 +93,15 @@ def test_change_to_enabled(self, db_request): views.change_check_state(db_request) assert db_request.session.flash.calls == [ - pretend.call("Changed 'MyCheck' check to 'enabled'!", queue="success") + pretend.call( + "Changed 'MyCheck' check to '%s'!" % final_state.value, queue="success" + ) ] - assert check.state == MalwareCheckState.enabled + + assert check.state == final_state + + if final_state == MalwareCheckState.wiped_out: + assert wipe_out_recorder.delay.calls == [pretend.call("MyCheck")] def test_change_to_invalid_state(self, db_request): check = MalwareCheckFactory.create(name="MyCheck") diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 1057af6855a5..5c89cc5fd562 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -19,9 +19,9 @@ import warehouse.malware.checks as checks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict -from warehouse.malware.tasks import run_check, sync_checks +from warehouse.malware.tasks import remove_verdicts, run_check, sync_checks -from ...common.db.malware import MalwareCheckFactory +from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory @@ -255,3 +255,55 @@ def test_only_wiped_out(self, db_session): from codebase." ), ] + + +class TestRemoveVerdicts: + def test_no_verdicts(self, db_session): + check = MalwareCheckFactory.create() + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + task = pretend.stub() + remove_verdicts(task, request, check.name) + + assert request.log.info.calls == [ + pretend.call( + "Removing 0 malware verdicts associated with %s version 1." % check.name + ), + ] + + @pytest.mark.parametrize(("check_with_verdicts"), [True, False]) + def test_many_verdicts(self, db_session, check_with_verdicts): + check0 = MalwareCheckFactory.create() + check1 = MalwareCheckFactory.create() + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + num_verdicts = 10 + + for i in range(num_verdicts): + MalwareVerdictFactory.create(check=check1, release_file=file0) + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + task = pretend.stub() + + if check_with_verdicts: + wiped_out_check = check1 + else: + wiped_out_check = check0 + num_verdicts = 0 + + remove_verdicts(task, request, wiped_out_check.name) + + assert request.log.info.calls == [ + pretend.call( + "Removing %d malware verdicts associated with %s version 1." + % (num_verdicts, wiped_out_check.name) + ), + ] diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index e3d38a88a0c0..f1ad86685d94 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -15,6 +15,7 @@ from sqlalchemy.orm.exc import NoResultFound from warehouse.malware.models import MalwareCheck, MalwareCheckState +from warehouse.malware.tasks import remove_verdicts @view_config( @@ -80,6 +81,8 @@ def change_check_state(request): except (AttributeError, KeyError): request.session.flash("Invalid check state provided.", queue="error") else: + if check.state == MalwareCheckState.wiped_out: + request.task(remove_verdicts).delay(check.name) request.session.flash( f"Changed {check.name!r} check to {check.state.value!r}!", queue="success" ) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index 1548d28e66d7..0d2f570c436f 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -14,7 +14,7 @@ import warehouse.malware.checks as checks -from warehouse.malware.models import MalwareCheck, MalwareCheckState +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict from warehouse.malware.utils import get_check_fields from warehouse.tasks import task @@ -86,3 +86,23 @@ def sync_checks(task, request): request.log.info("Adding new %s to the database." % check_name) fields = get_check_fields(check) request.db.add(MalwareCheck(**fields)) + + +@task(bind=True, ignore_result=True, acks_late=True) +def remove_verdicts(task, request, check_name): + check_ids = ( + request.db.query(MalwareCheck.id, MalwareCheck.version) + .filter(MalwareCheck.name == check_name) + .all() + ) + + for check_id, check_version in check_ids: + query = request.db.query(MalwareVerdict).filter( + MalwareVerdict.check_id == check_id + ) + num_verdicts = query.count() + request.log.info( + "Removing %d malware verdicts associated with %s version %d." + % (num_verdicts, check_name, check_version) + ) + query.delete(synchronize_session=False) From 328b0eeb5e16cccb83a62328130705455ebccf0b Mon Sep 17 00:00:00 2001 From: Cristina Date: Fri, 10 Jan 2020 14:46:11 -0800 Subject: [PATCH 07/16] Add rudimentary verdicts view. Progress on #6062. (#7207) * Add rudimentary verdicts view. Progress on #6062. Also, add some better testing logic for wiped_out condition. * Code review changes. - Conditionally show fields that are populated - JSON pretty formatting * Fix unit test bug. - Use `get` instead of `filter` to look up verdict by pkey. * simplify unit tests for verdicts view --- tests/common/db/malware.py | 6 ++ tests/common/db/packaging.py | 1 + tests/unit/admin/test_routes.py | 4 + tests/unit/admin/views/test_verdicts.py | 63 +++++++++++++ tests/unit/malware/test_tasks.py | 9 +- warehouse/admin/routes.py | 4 + warehouse/admin/templates/admin/base.html | 5 + .../admin/malware/verdicts/detail.html | 80 ++++++++++++++++ .../admin/malware/verdicts/index.html | 93 +++++++++++++++++++ .../admin/malware/verdicts/object_link.html | 21 +++++ warehouse/admin/views/verdicts.py | 61 ++++++++++++ warehouse/malware/tasks.py | 7 +- 12 files changed, 350 insertions(+), 4 deletions(-) create mode 100644 tests/unit/admin/views/test_verdicts.py create mode 100644 warehouse/admin/templates/admin/malware/verdicts/detail.html create mode 100644 warehouse/admin/templates/admin/malware/verdicts/index.html create mode 100644 warehouse/admin/templates/admin/malware/verdicts/object_link.html create mode 100644 warehouse/admin/views/verdicts.py diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py index 8c365f4abb66..b6e1bf387b90 100644 --- a/tests/common/db/malware.py +++ b/tests/common/db/malware.py @@ -51,6 +51,12 @@ class Meta: check = factory.SubFactory(MalwareCheckFactory) release_file = factory.SubFactory(FileFactory) + release = None + project = None + manually_reviewed = True + administrator_verdict = factory.fuzzy.FuzzyChoice(list(VerdictClassification)) classification = factory.fuzzy.FuzzyChoice(list(VerdictClassification)) confidence = factory.fuzzy.FuzzyChoice(list(VerdictConfidence)) message = factory.fuzzy.FuzzyText(length=80) + full_report_link = None + details = None diff --git a/tests/common/db/packaging.py b/tests/common/db/packaging.py index 2c55099d0242..ada7eca60831 100644 --- a/tests/common/db/packaging.py +++ b/tests/common/db/packaging.py @@ -83,6 +83,7 @@ class Meta: release = factory.SubFactory(ReleaseFactory) python_version = "source" + filename = factory.fuzzy.FuzzyText(length=12) md5_digest = factory.LazyAttribute( lambda o: hashlib.md5(o.filename.encode("utf8")).hexdigest() ) diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index e10962ac54dc..425d50167529 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -132,4 +132,8 @@ def test_includeme(): "/admin/checks/{check_name}/change_state", domain=warehouse, ), + pretend.call("admin.verdicts.list", "/admin/verdicts/", domain=warehouse), + pretend.call( + "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse + ), ] diff --git a/tests/unit/admin/views/test_verdicts.py b/tests/unit/admin/views/test_verdicts.py new file mode 100644 index 000000000000..7d28820ca9cf --- /dev/null +++ b/tests/unit/admin/views/test_verdicts.py @@ -0,0 +1,63 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +from random import randint + +import pretend +import pytest + +from pyramid.httpexceptions import HTTPBadRequest, HTTPNotFound + +from warehouse.admin.views import verdicts as views + +from ....common.db.malware import MalwareVerdictFactory + + +class TestListVerdicts: + def test_none(self, db_request): + assert views.get_verdicts(db_request) == {"verdicts": []} + + def test_some(self, db_request): + verdicts = [MalwareVerdictFactory.create() for _ in range(10)] + + assert views.get_verdicts(db_request) == {"verdicts": verdicts} + + def test_some_with_multipage(self, db_request): + verdicts = [MalwareVerdictFactory.create() for _ in range(60)] + + db_request.GET["page"] = "2" + + assert views.get_verdicts(db_request) == {"verdicts": verdicts[25:50]} + + def test_with_invalid_page(self): + request = pretend.stub(params={"page": "not an integer"}) + + with pytest.raises(HTTPBadRequest): + views.get_verdicts(request) + + +class TestGetVerdict: + def test_found(self, db_request): + verdicts = [MalwareVerdictFactory.create() for _ in range(10)] + index = randint(0, 9) + lookup_id = verdicts[index].id + db_request.matchdict["verdict_id"] = lookup_id + + assert views.get_verdict(db_request) == {"verdict": verdicts[index]} + + def test_not_found(self, db_request): + db_request.matchdict["verdict_id"] = uuid.uuid4() + + with pytest.raises(HTTPNotFound): + views.get_verdict(db_request) diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 5c89cc5fd562..234d05dba5bb 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -266,13 +266,14 @@ def test_no_verdicts(self, db_session): log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) task = pretend.stub() - remove_verdicts(task, request, check.name) + removed = remove_verdicts(task, request, check.name) assert request.log.info.calls == [ pretend.call( "Removing 0 malware verdicts associated with %s version 1." % check.name ), ] + assert removed == 0 @pytest.mark.parametrize(("check_with_verdicts"), [True, False]) def test_many_verdicts(self, db_session, check_with_verdicts): @@ -286,6 +287,8 @@ def test_many_verdicts(self, db_session, check_with_verdicts): for i in range(num_verdicts): MalwareVerdictFactory.create(check=check1, release_file=file0) + assert db_session.query(MalwareVerdict).count() == num_verdicts + request = pretend.stub( db=db_session, log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), @@ -299,7 +302,7 @@ def test_many_verdicts(self, db_session, check_with_verdicts): wiped_out_check = check0 num_verdicts = 0 - remove_verdicts(task, request, wiped_out_check.name) + removed = remove_verdicts(task, request, wiped_out_check.name) assert request.log.info.calls == [ pretend.call( @@ -307,3 +310,5 @@ def test_many_verdicts(self, db_session, check_with_verdicts): % (num_verdicts, wiped_out_check.name) ), ] + + assert removed == num_verdicts diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 4180df8188a6..284932c29190 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -139,3 +139,7 @@ def includeme(config): "/admin/checks/{check_name}/change_state", domain=warehouse, ) + config.add_route("admin.verdicts.list", "/admin/verdicts/", domain=warehouse) + config.add_route( + "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse + ) diff --git a/warehouse/admin/templates/admin/base.html b/warehouse/admin/templates/admin/base.html index cc74b4675b81..fffa7ccfcec5 100644 --- a/warehouse/admin/templates/admin/base.html +++ b/warehouse/admin/templates/admin/base.html @@ -130,6 +130,11 @@ Checks +
  • + + Verdicts + +
  • diff --git a/warehouse/admin/templates/admin/malware/verdicts/detail.html b/warehouse/admin/templates/admin/malware/verdicts/detail.html new file mode 100644 index 000000000000..7702943e8692 --- /dev/null +++ b/warehouse/admin/templates/admin/malware/verdicts/detail.html @@ -0,0 +1,80 @@ +{# + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. +-#} +{% extends "admin/base.html" %} + +{% block title %}Verdict {{ verdict.id }}{% endblock %} + +{% block breadcrumb %} +
  • Verdicts
  • +
  • {{ verdict.id }}
  • +{% endblock %} + +{% block content %} +
    +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {% if verdict.manually_reviewed %} + + + + + {% endif %} + {% if verdict.full_report_link %} + + + + + {% endif %} + {% if verdict.details %} + + + + + {% endif %} +
    Message{{ verdict.message }}
    Run Date{{ verdict.run_date }}
    Check + + {{ verdict.check.name }} v{{ verdict.check.version }} + +
    Object{% include 'object_link.html' %}
    Verdict Classification{{ verdict.classification.value }}
    Verdict Confidence{{ verdict.confidence.value }}
    Manually Reviewed{{ verdict.manually_reviewed }}
    Administrator Verdict{{ verdict.administrator_verdict }}
    Full Report Link{{ verdict.full_report_link }}
    Details
    {{ verdict.details|tojson(indent=4) }}
    +
    +
    +{% endblock %} diff --git a/warehouse/admin/templates/admin/malware/verdicts/index.html b/warehouse/admin/templates/admin/malware/verdicts/index.html new file mode 100644 index 000000000000..d6ab7ef6b028 --- /dev/null +++ b/warehouse/admin/templates/admin/malware/verdicts/index.html @@ -0,0 +1,93 @@ +{# + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. +-#} +{% extends "admin/base.html" %} + +{% import "admin/utils/pagination.html" as pagination %} + +{% block title %}Malware Verdicts{% endblock %} + +{% block breadcrumb %} +
  • Verdicts
  • +{% endblock %} + +{% block content %} +
    +
    + + + + + + + + + {% for verdict in verdicts %} + + + + + + + + {% else %} + + + + {% endfor %} +
    ObjectCheckClassificationConfidenceDetail
    {% include 'object_link.html' %} + + {{ verdict.check.name }} v{{ verdict.check.version }} + + + + + {% if verdict.classification.value == 'indeterminate' %} + + {% elif verdict.classification.value == 'threat' %} + + + {% endif %} + + + + + {% if verdict.confidence.value == 'medium' %} + + {% elif verdict.confidence.value == 'high' %} + + + {% endif %} + + + + Detail + +
    +
    + No verdicts! +
    +
    + +
    +
    +{% endblock content %} diff --git a/warehouse/admin/templates/admin/malware/verdicts/object_link.html b/warehouse/admin/templates/admin/malware/verdicts/object_link.html new file mode 100644 index 000000000000..c31678ce419c --- /dev/null +++ b/warehouse/admin/templates/admin/malware/verdicts/object_link.html @@ -0,0 +1,21 @@ +{# + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. +-#} + +{% if verdict.project %} + {{ verdict.project.name }} +{% elif verdict.release %} + {{ verdict.release.project.name }}-{{ verdict.release.version }} +{% else %} + {{ verdict.release_file.filename}} +{% endif %} diff --git a/warehouse/admin/views/verdicts.py b/warehouse/admin/views/verdicts.py new file mode 100644 index 000000000000..bd9c2eae68ca --- /dev/null +++ b/warehouse/admin/views/verdicts.py @@ -0,0 +1,61 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paginate_sqlalchemy import SqlalchemyOrmPage as SQLAlchemyORMPage +from pyramid.httpexceptions import HTTPBadRequest, HTTPNotFound +from pyramid.view import view_config + +from warehouse.malware.models import MalwareVerdict +from warehouse.utils.paginate import paginate_url_factory + + +@view_config( + route_name="admin.verdicts.list", + renderer="admin/malware/verdicts/index.html", + permission="moderator", + request_method="GET", + uses_session=True, +) +def get_verdicts(request): + try: + page_num = int(request.params.get("page", 1)) + except ValueError: + raise HTTPBadRequest("'page' must be an integer.") from None + + verdicts_query = request.db.query(MalwareVerdict).order_by( + MalwareVerdict.run_date.desc() + ) + + verdicts = SQLAlchemyORMPage( + verdicts_query, + page=page_num, + items_per_page=25, + url_maker=paginate_url_factory(request), + ) + + return {"verdicts": verdicts} + + +@view_config( + route_name="admin.verdicts.detail", + renderer="admin/malware/verdicts/detail.html", + permission="moderator", + request_method="GET", + uses_session=True, +) +def get_verdict(request): + verdict = request.db.query(MalwareVerdict).get(request.matchdict["verdict_id"]) + + if verdict: + return {"verdict": verdict} + + raise HTTPNotFound diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index 0d2f570c436f..a70b7ea344d6 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -95,7 +95,7 @@ def remove_verdicts(task, request, check_name): .filter(MalwareCheck.name == check_name) .all() ) - + total_deleted = 0 for check_id, check_version in check_ids: query = request.db.query(MalwareVerdict).filter( MalwareVerdict.check_id == check_id @@ -105,4 +105,7 @@ def remove_verdicts(task, request, check_name): "Removing %d malware verdicts associated with %s version %d." % (num_verdicts, check_name, check_version) ) - query.delete(synchronize_session=False) + total_deleted += query.delete(synchronize_session=False) + + # This returned value is only relevant for testing. + return total_deleted From a9572b9955339bbbfb8188e5bb319643179ad8ca Mon Sep 17 00:00:00 2001 From: "Ernest W. Durbin III" Date: Mon, 13 Jan 2020 15:34:43 -0500 Subject: [PATCH 08/16] introduce malware queue (#7227) * introduce malware queue * correct syntax, apparently list of tuples documented doesn't work. --- Procfile | 1 + tests/unit/test_tasks.py | 7 +++++-- warehouse/tasks.py | 7 +++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/Procfile b/Procfile index 2c79f62a1e98..d30253824174 100644 --- a/Procfile +++ b/Procfile @@ -2,4 +2,5 @@ release: bin/release web: bin/start-web python -m gunicorn.app.wsgiapp -c gunicorn.conf.py warehouse.wsgi:application web-uploads: bin/start-web python -m gunicorn.app.wsgiapp -c gunicorn-uploads.conf.py warehouse.wsgi:application worker: bin/start-worker celery -A warehouse worker -Q default -l info --max-tasks-per-child 32 +worker-malware: bin/start-worker celery -A warehouse worker -Q malware -l info --max-tasks-per-child 32 worker-beat: bin/start-worker celery -A warehouse beat -S redbeat.RedBeatScheduler -l info diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py index d0207088254f..cbe28996d8ff 100644 --- a/tests/unit/test_tasks.py +++ b/tests/unit/test_tasks.py @@ -501,8 +501,11 @@ def test_includeme(env, ssl, broker_url, expected_url, transport_options): "task_serializer": "json", "accept_content": ["json", "msgpack"], "task_queue_ha_policy": "all", - "task_queues": (Queue("default", routing_key="task.#"),), - "task_routes": ([]), + "task_queues": ( + Queue("default", routing_key="task.#"), + Queue("malware", routing_key="malware.#"), + ), + "task_routes": {"warehouse.malware.tasks.*": {"queue": "malware"}}, "REDBEAT_REDIS_URL": (config.registry.settings["celery.scheduler_url"]), }.items(): assert app.conf[key] == value diff --git a/warehouse/tasks.py b/warehouse/tasks.py index 53bf5d16d62d..a7b834b3b535 100644 --- a/warehouse/tasks.py +++ b/warehouse/tasks.py @@ -195,8 +195,11 @@ def includeme(config): task_default_queue="default", task_default_routing_key="task.default", task_queue_ha_policy="all", - task_queues=(Queue("default", routing_key="task.#"),), - task_routes=([]), + task_queues=( + Queue("default", routing_key="task.#"), + Queue("malware", routing_key="malware.#"), + ), + task_routes={"warehouse.malware.tasks.*": {"queue": "malware"}}, task_serializer="json", worker_disable_rate_limits=True, REDBEAT_REDIS_URL=s["celery.scheduler_url"], From 91b749cc5e5db396d90d5e8d16e2573e5568215f Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 16 Jan 2020 08:53:10 -0800 Subject: [PATCH 09/16] Add backfill functionality to check admin #7094 (#7232) * Add backfill functionality to check admin #7094 - Add backfill task - Change lookup of checks to check_name instead of id - Load checks that are also in "evaluation" state * Add unit tests for backfill. - Log number of runs executed by backfill - Perform basic validation on sample_rate input - Clean up other testing logic. * Remove superfluous 'all()' * Code review changes. - Set backfill size to a fix number, not configurable via web ui. - Backfill task enqueues run_check tasks - Only retry if `check.run` fails, not if loading the check fails. - Use exponential backoff for retries. * Update warehouse/admin/templates/admin/malware/checks/detail.html Co-Authored-By: Ernest W. Durbin III Co-authored-by: Ernest W. Durbin III --- tests/unit/admin/test_routes.py | 5 + tests/unit/admin/views/test_checks.py | 73 ++++++++++-- tests/unit/malware/test_tasks.py | 112 ++++++++++++------ warehouse/admin/routes.py | 5 + .../admin/malware/checks/detail.html | 17 ++- warehouse/admin/views/checks.py | 78 +++++++++--- warehouse/malware/checks/base.py | 6 +- warehouse/malware/tasks.py | 22 +++- 8 files changed, 247 insertions(+), 71 deletions(-) diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 425d50167529..28538ad12dac 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -132,6 +132,11 @@ def test_includeme(): "/admin/checks/{check_name}/change_state", domain=warehouse, ), + pretend.call( + "admin.checks.run_backfill", + "/admin/checks/{check_name}/run_backfill", + domain=warehouse, + ), pretend.call("admin.verdicts.list", "/admin/verdicts/", domain=warehouse), pretend.call( "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 601e26e79858..273bb8cd3fe3 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -10,8 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import uuid - import pretend import pytest @@ -67,6 +65,12 @@ def test_get_check_not_found(self, db_request): class TestChangeCheckState: + def test_no_check_state(self, db_request): + check = MalwareCheckFactory.create() + db_request.matchdict["check_name"] = check.name + with pytest.raises(HTTPNotFound): + views.change_check_state(db_request) + @pytest.mark.parametrize( ("final_state"), [MalwareCheckState.disabled, MalwareCheckState.wiped_out] ) @@ -75,7 +79,7 @@ def test_change_to_valid_state(self, db_request, final_state): name="MyCheck", state=MalwareCheckState.disabled ) - db_request.POST = {"id": check.id, "check_state": final_state.value} + db_request.POST = {"check_state": final_state.value} db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -107,7 +111,7 @@ def test_change_to_invalid_state(self, db_request): check = MalwareCheckFactory.create(name="MyCheck") initial_state = check.state invalid_check_state = "cancelled" - db_request.POST = {"id": check.id, "check_state": invalid_check_state} + db_request.POST = {"check_state": invalid_check_state} db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -124,13 +128,62 @@ def test_change_to_invalid_state(self, db_request): ] assert check.state == initial_state - def test_check_not_found(self, db_request): - db_request.POST = {"id": uuid.uuid4(), "check_state": "enabled"} - db_request.matchdict["check_name"] = "DoesNotExist" + +class TestRunBackfill: + @pytest.mark.parametrize( + ("check_state", "message"), + [ + ( + MalwareCheckState.disabled, + "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + ), + ( + MalwareCheckState.wiped_out, + "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + ), + ], + ) + def test_invalid_backfill_parameters(self, db_request, check_state, message): + check = MalwareCheckFactory.create(state=check_state) + db_request.matchdict["check_name"] = check.name + + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/DoesNotExist/change_state" + lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name ) - with pytest.raises(HTTPNotFound): - views.change_check_state(db_request) + views.run_backfill(db_request) + + assert db_request.session.flash.calls == [pretend.call(message, queue="error")] + + def test_sucess(self, db_request): + check = MalwareCheckFactory.create(state=MalwareCheckState.enabled) + db_request.matchdict["check_name"] = check.name + + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + + db_request.route_path = pretend.call_recorder( + lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + ) + + backfill_recorder = pretend.stub( + delay=pretend.call_recorder(lambda *a, **kw: None) + ) + + db_request.task = pretend.call_recorder(lambda *a, **kw: backfill_recorder) + + views.run_backfill(db_request) + + assert db_request.session.flash.calls == [ + pretend.call( + "Running %s on 10000 %ss!" % (check.name, check.hooked_object.value), + queue="success", + ) + ] + + assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 234d05dba5bb..91ce94966438 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -19,76 +19,112 @@ import warehouse.malware.checks as checks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict -from warehouse.malware.tasks import remove_verdicts, run_check, sync_checks +from warehouse.malware.tasks import backfill, remove_verdicts, run_check, sync_checks from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory class TestRunCheck: - def test_success(self, monkeypatch, db_request): - project = ProjectFactory.create(name="foo") - release = ReleaseFactory.create(project=project) - file0 = FileFactory.create(release=release, filename="foo.bar") + def test_success(self, db_request): + file0 = FileFactory.create() MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) - task = pretend.stub() run_check(task, db_request, "ExampleCheck", file0.id) + assert db_request.db.query(MalwareVerdict).one() - def test_missing_check_id(self, monkeypatch, db_session): - exc = NoResultFound("No row was found for one()") + def test_disabled_check(self, db_request): + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.disabled + ) + task = pretend.stub() - class FakeMalwareCheck: - def __init__(self, db): - raise exc + with pytest.raises(NoResultFound): + run_check( + task, + db_request, + "ExampleCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) - checks.FakeMalwareCheck = FakeMalwareCheck + def test_missing_check(self, db_request): + task = pretend.stub() + with pytest.raises(AttributeError): + run_check( + task, + db_request, + "DoesNotExistCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) - class Task: - @staticmethod - @pretend.call_recorder - def retry(exc): - raise celery.exceptions.Retry + def test_retry(self, db_session, monkeypatch): + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.evaluation + ) - task = Task() + exc = Exception("Scan failed") + def scan(self, file_id): + raise exc + + monkeypatch.setattr(checks.ExampleCheck, "scan", scan) + + task = pretend.stub( + retry=pretend.call_recorder(pretend.raiser(celery.exceptions.Retry)), + ) request = pretend.stub( db=db_session, - log=pretend.stub( - error=pretend.call_recorder(lambda *args, **kwargs: None), - ), + log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)), ) with pytest.raises(celery.exceptions.Retry): run_check( - task, - request, - "FakeMalwareCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + task, request, "ExampleCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" ) assert request.log.error.calls == [ - pretend.call( - "Error executing check %s: %s", - "FakeMalwareCheck", - "No row was found for one()", - ) + pretend.call("Error executing check ExampleCheck: Scan failed") ] assert task.retry.calls == [pretend.call(exc=exc)] - del checks.FakeMalwareCheck - def test_missing_check(self, db_request): +class TestBackfill: + def test_invalid_check_name(self, db_request): task = pretend.stub() with pytest.raises(AttributeError): - run_check( - task, - db_request, - "DoesNotExistCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", - ) + backfill(task, db_request, "DoesNotExist", 1) + + @pytest.mark.parametrize( + ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)], + ) + def test_run(self, db_session, num_objects, num_runs): + files = [] + for i in range(num_objects): + files.append(FileFactory.create()) + + MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + enqueue_recorder = pretend.stub( + delay=pretend.call_recorder(lambda *a, **kw: None) + ) + task = pretend.call_recorder(lambda *args, **kwargs: enqueue_recorder) + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + task=task, + ) + + backfill(task, request, "ExampleCheck", num_runs) + + assert request.log.info.calls == [ + pretend.call("Running backfill on %d Files." % num_runs), + ] + + assert enqueue_recorder.delay.calls == [ + pretend.call("ExampleCheck", files[i].id) for i in range(num_runs) + ] class TestSyncChecks: diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 284932c29190..2788f51519bf 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -139,6 +139,11 @@ def includeme(config): "/admin/checks/{check_name}/change_state", domain=warehouse, ) + config.add_route( + "admin.checks.run_backfill", + "/admin/checks/{check_name}/run_backfill", + domain=warehouse, + ) config.add_route("admin.verdicts.list", "/admin/verdicts/", domain=warehouse) config.add_route( "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse diff --git a/warehouse/admin/templates/admin/malware/checks/detail.html b/warehouse/admin/templates/admin/malware/checks/detail.html index 2cf80f20cc8a..77914488bb51 100644 --- a/warehouse/admin/templates/admin/malware/checks/detail.html +++ b/warehouse/admin/templates/admin/malware/checks/detail.html @@ -48,11 +48,10 @@

    Revision History

    Change State

    -
    - {% for state in states %}
    +
    +
    +

    Run Evaluation

    +
    +
    + +
    +

    Run this check against 10,000 {{ check.hooked_object.value }}s, selected at random. This is used to evaluate the efficacy of a check.

    +
    + +
    +
    +
    +
    {% endblock %} diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index f1ad86685d94..7fd3fa9aacf0 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -15,7 +15,7 @@ from sqlalchemy.orm.exc import NoResultFound from warehouse.malware.models import MalwareCheck, MalwareCheckState -from warehouse.malware.tasks import remove_verdicts +from warehouse.malware.tasks import backfill, remove_verdicts @view_config( @@ -26,7 +26,7 @@ uses_session=True, ) def get_checks(request): - all_checks = request.db.query(MalwareCheck).all() + all_checks = request.db.query(MalwareCheck) active_checks = [] for check in all_checks: if not check.is_stale: @@ -43,19 +43,49 @@ def get_checks(request): uses_session=True, ) def get_check(request): - query = ( + check = get_check_by_name(request.db, request.matchdict["check_name"]) + + all_checks = ( request.db.query(MalwareCheck) .filter(MalwareCheck.name == request.matchdict["check_name"]) .order_by(MalwareCheck.version.desc()) + .all() ) - try: - # Throw an exception if and only if no results are returned. - newest = query.limit(1).one() - except NoResultFound: - raise HTTPNotFound + return {"check": check, "checks": all_checks, "states": MalwareCheckState} - return {"check": newest, "checks": query.all(), "states": MalwareCheckState} + +@view_config( + route_name="admin.checks.run_backfill", + permission="admin", + request_method="POST", + uses_session=True, + require_methods=False, + require_csrf=True, +) +def run_backfill(request): + check = get_check_by_name(request.db, request.matchdict["check_name"]) + num_objects = 10000 + + if check.state not in (MalwareCheckState.enabled, MalwareCheckState.evaluation): + request.session.flash( + f"Check must be in 'enabled' or 'evaluation' state to run a backfill.", + queue="error", + ) + return HTTPSeeOther( + request.route_path("admin.checks.detail", check_name=check.name) + ) + + request.session.flash( + f"Running {check.name} on {num_objects} {check.hooked_object.value}s!", + queue="success", + ) + + request.task(backfill).delay(check.name, num_objects) + + return HTTPSeeOther( + request.route_path("admin.checks.detail", check_name=check.name) + ) @view_config( @@ -67,18 +97,16 @@ def get_check(request): require_csrf=True, ) def change_check_state(request): + check = get_check_by_name(request.db, request.matchdict["check_name"]) + try: - check = ( - request.db.query(MalwareCheck) - .filter(MalwareCheck.id == request.POST["id"]) - .one() - ) - except NoResultFound: + check_state = request.POST["check_state"] + except KeyError: raise HTTPNotFound try: - check.state = getattr(MalwareCheckState, request.POST["check_state"]) - except (AttributeError, KeyError): + check.state = getattr(MalwareCheckState, check_state) + except AttributeError: request.session.flash("Invalid check state provided.", queue="error") else: if check.state == MalwareCheckState.wiped_out: @@ -90,3 +118,19 @@ def change_check_state(request): return HTTPSeeOther( request.route_path("admin.checks.detail", check_name=check.name) ) + + +def get_check_by_name(db, check_name): + try: + # Throw an exception if and only if no results are returned. + newest = ( + db.query(MalwareCheck) + .filter(MalwareCheck.name == check_name) + .order_by(MalwareCheck.version.desc()) + .limit(1) + .one() + ) + except NoResultFound: + raise HTTPNotFound + + return newest diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py index 6c94937abbde..72406954c982 100644 --- a/warehouse/malware/checks/base.py +++ b/warehouse/malware/checks/base.py @@ -46,6 +46,10 @@ def _load_check_id(self): self.id = ( self.db.query(MalwareCheck.id) .filter(MalwareCheck.name == self._name) - .filter(MalwareCheck.state == MalwareCheckState.enabled) + .filter( + MalwareCheck.state.in_( + [MalwareCheckState.enabled, MalwareCheckState.evaluation] + ) + ) .one() ) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index a70b7ea344d6..5028ed00e7ef 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -13,22 +13,38 @@ import inspect import warehouse.malware.checks as checks +import warehouse.packaging.models as packaging_models from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict from warehouse.malware.utils import get_check_fields from warehouse.tasks import task -@task(bind=True, ignore_result=True, acks_late=True) +@task(bind=True, ignore_result=True, acks_late=True, retry_backoff=True) def run_check(task, request, check_name, obj_id): + check = getattr(checks, check_name)(request.db) try: - check = getattr(checks, check_name)(request.db) check.run(obj_id) except Exception as exc: - request.log.error("Error executing check %s: %s", check_name, str(exc)) + request.log.error("Error executing check %s: %s" % (check_name, str(exc))) raise task.retry(exc=exc) +@task(bind=True, ignore_result=True, acks_late=True) +def backfill(task, request, check_name, num_objects): + """ + Runs a backfill on a fixed number of objects. + """ + check = getattr(checks, check_name)(request.db) + target_object = getattr(packaging_models, check.hooked_object) + query = request.db.query(target_object.id).limit(num_objects) + + request.log.info("Running backfill on %d %ss." % (num_objects, check.hooked_object)) + + for (elem_id,) in query: + request.task(run_check).delay(check_name, elem_id) + + @task(bind=True, ignore_result=True, acks_late=True) def sync_checks(task, request): code_checks = inspect.getmembers(checks, inspect.isclass) From 16baaa789422e2a9ebd15498b9bbf4bbd8530fc9 Mon Sep 17 00:00:00 2001 From: Cristina Date: Fri, 17 Jan 2020 11:30:29 -0800 Subject: [PATCH 10/16] Refactor testing logic #7098 (#7257) - Add `schedule` field to MalwareCheck model #7096 - Move ExampleCheck into tests/common/ to remove test dependency from prod code - Rename functions and classes to differentiate between "hooked" and "scheduled" checks --- tests/common/checks/__init__.py | 14 ++ .../common/checks/hooked.py | 8 +- tests/common/checks/scheduled.py | 37 ++++ tests/common/db/malware.py | 1 + tests/unit/malware/test_checks.py | 19 +- tests/unit/malware/test_init.py | 25 +-- tests/unit/malware/test_tasks.py | 180 +++++++++++------- tests/unit/malware/test_utils.py | 64 ++++--- warehouse/malware/__init__.py | 2 +- warehouse/malware/checks/__init__.py | 2 - warehouse/malware/models.py | 4 +- warehouse/malware/utils.py | 6 +- ...1ff3d24c22_add_malware_detection_tables.py | 1 + 13 files changed, 244 insertions(+), 119 deletions(-) create mode 100644 tests/common/checks/__init__.py rename warehouse/malware/checks/example.py => tests/common/checks/hooked.py (87%) create mode 100644 tests/common/checks/scheduled.py diff --git a/tests/common/checks/__init__.py b/tests/common/checks/__init__.py new file mode 100644 index 000000000000..dfd77b961075 --- /dev/null +++ b/tests/common/checks/__init__.py @@ -0,0 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hooked import ExampleHookedCheck # noqa +from .scheduled import ExampleScheduledCheck # noqa diff --git a/warehouse/malware/checks/example.py b/tests/common/checks/hooked.py similarity index 87% rename from warehouse/malware/checks/example.py rename to tests/common/checks/hooked.py index 22b91906ffa3..6bd6f2e512e4 100644 --- a/warehouse/malware/checks/example.py +++ b/tests/common/checks/hooked.py @@ -14,19 +14,19 @@ from warehouse.malware.models import VerdictClassification, VerdictConfidence -class ExampleCheck(MalwareCheckBase): +class ExampleHookedCheck(MalwareCheckBase): version = 1 short_description = "An example hook-based check" - long_description = """The purpose of this check is to demonstrate the \ -implementation of a hook-based check. This check will generate verdicts if enabled.""" + long_description = "The purpose of this check is to test the \ +implementation of a hook-based check. This check will generate verdicts if enabled." check_type = "event_hook" hooked_object = "File" def __init__(self, db): super().__init__(db) - def scan(self, file_id): + def scan(self, file_id=None): self.add_verdict( file_id=file_id, classification=VerdictClassification.benign, diff --git a/tests/common/checks/scheduled.py b/tests/common/checks/scheduled.py new file mode 100644 index 000000000000..128ce102a83b --- /dev/null +++ b/tests/common/checks/scheduled.py @@ -0,0 +1,37 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.models import VerdictClassification, VerdictConfidence +from warehouse.packaging.models import Project + + +class ExampleScheduledCheck(MalwareCheckBase): + + version = 1 + short_description = "An example scheduled check" + long_description = "The purpose of this check is to test the \ +implementation of a scheduled check. This check will generate verdicts if enabled." + check_type = "scheduled" + schedule = {"minute": "0", "hour": "*/8"} + + def __init__(self, db): + super().__init__(db) + + def scan(self): + project = self.db.query(Project).first() + self.add_verdict( + project_id=project.id, + classification=VerdictClassification.benign, + confidence=VerdictConfidence.High, + message="Nothing to see here!", + ) diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py index b6e1bf387b90..4e41a0c23865 100644 --- a/tests/common/db/malware.py +++ b/tests/common/db/malware.py @@ -39,6 +39,7 @@ class Meta: long_description = factory.fuzzy.FuzzyText(length=300) check_type = factory.fuzzy.FuzzyChoice(list(MalwareCheckType)) hooked_object = factory.fuzzy.FuzzyChoice(list(MalwareCheckObjectType)) + schedule = {"minute": "*/10"} state = factory.fuzzy.FuzzyChoice(list(MalwareCheckState)) created = factory.fuzzy.FuzzyNaiveDateTime( datetime.datetime.utcnow() - datetime.timedelta(days=7) diff --git a/tests/unit/malware/test_checks.py b/tests/unit/malware/test_checks.py index 2ce63c624965..9427972ef5e8 100644 --- a/tests/unit/malware/test_checks.py +++ b/tests/unit/malware/test_checks.py @@ -12,26 +12,35 @@ import inspect -import warehouse.malware.checks as checks +import pytest + +import warehouse.malware.checks as prod_checks from warehouse.malware.checks.base import MalwareCheckBase from warehouse.malware.utils import get_check_fields +from ...common import checks as test_checks + def test_checks_subclass_base(): - checks_from_module = inspect.getmembers(checks, inspect.isclass) + prod_checks_from_module = inspect.getmembers(prod_checks, inspect.isclass) + test_checks_from_module = inspect.getmembers(test_checks, inspect.isclass) + all_checks = prod_checks_from_module + test_checks_from_module subclasses_of_malware_base = { cls.__name__: cls for cls in MalwareCheckBase.__subclasses__() } - assert len(checks_from_module) == len(subclasses_of_malware_base) + assert len(all_checks) == len(subclasses_of_malware_base) - for check_name, check in checks_from_module: + for check_name, check in all_checks: assert subclasses_of_malware_base[check_name] == check -def test_checks_fields(): +@pytest.mark.parametrize( + ("checks"), [prod_checks, test_checks], +) +def test_checks_fields(checks): checks_from_module = inspect.getmembers(checks, inspect.isclass) for check_name, check in checks_from_module: diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py index 5db87f2c396e..4d2888aef6a1 100644 --- a/tests/unit/malware/test_init.py +++ b/tests/unit/malware/test_init.py @@ -18,15 +18,16 @@ from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService +from ...common import checks as test_checks from ...common.db.accounts import UserFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory def test_determine_malware_checks_no_checks(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): return defaultdict(list) - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -39,13 +40,13 @@ def get_enabled_checks(session): def test_determine_malware_checks_nothing_new(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): result = defaultdict(list) result["File"] = ["Check1", "Check2"] result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -58,13 +59,13 @@ def get_enabled_checks(session): def test_determine_malware_checks_unsupported_object(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): result = defaultdict(list) result["File"] = ["Check1", "Check2"] result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) user = UserFactory.create() @@ -75,13 +76,13 @@ def get_enabled_checks(session): def test_determine_malware_checks_file_only(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): result = defaultdict(list) result["File"] = ["Check1", "Check2"] result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -95,13 +96,13 @@ def get_enabled_checks(session): def test_determine_malware_checks_file_and_release(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): result = defaultdict(list) result["File"] = ["Check1", "Check2"] result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -149,7 +150,9 @@ def test_enqueue_malware_checks_no_checks(app_config): assert "warehouse.malware.checks" not in session.info -def test_includeme(): +def test_includeme(monkeypatch): + monkeypatch.setattr(malware, "checks", test_checks) + malware_check_class = pretend.stub( create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub()) ) diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 91ce94966438..8fc427e35010 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -16,42 +16,47 @@ from sqlalchemy.orm.exc import NoResultFound -import warehouse.malware.checks as checks - +from warehouse.malware import tasks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict -from warehouse.malware.tasks import backfill, remove_verdicts, run_check, sync_checks +from ...common import checks as test_checks from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory class TestRunCheck: - def test_success(self, db_request): + def test_success(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) file0 = FileFactory.create() - MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.enabled + ) task = pretend.stub() - run_check(task, db_request, "ExampleCheck", file0.id) + tasks.run_check(task, db_request, "ExampleHookedCheck", file0.id) assert db_request.db.query(MalwareVerdict).one() - def test_disabled_check(self, db_request): + def test_disabled_check(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.disabled + name="ExampleHookedCheck", state=MalwareCheckState.disabled ) + task = pretend.stub() with pytest.raises(NoResultFound): - run_check( + tasks.run_check( task, db_request, - "ExampleCheck", + "ExampleHookedCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7", ) - def test_missing_check(self, db_request): + def test_missing_check(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) task = pretend.stub() with pytest.raises(AttributeError): - run_check( + tasks.run_check( task, db_request, "DoesNotExistCheck", @@ -59,16 +64,17 @@ def test_missing_check(self, db_request): ) def test_retry(self, db_session, monkeypatch): - MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.evaluation - ) - exc = Exception("Scan failed") def scan(self, file_id): raise exc - monkeypatch.setattr(checks.ExampleCheck, "scan", scan) + monkeypatch.setattr(tasks, "checks", test_checks) + monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "scan", scan) + + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.evaluation + ) task = pretend.stub( retry=pretend.call_recorder(pretend.raiser(celery.exceptions.Retry)), @@ -79,32 +85,40 @@ def scan(self, file_id): ) with pytest.raises(celery.exceptions.Retry): - run_check( - task, request, "ExampleCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" + tasks.run_check( + task, + request, + "ExampleHookedCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", ) assert request.log.error.calls == [ - pretend.call("Error executing check ExampleCheck: Scan failed") + pretend.call("Error executing check ExampleHookedCheck: Scan failed") ] assert task.retry.calls == [pretend.call(exc=exc)] class TestBackfill: - def test_invalid_check_name(self, db_request): + def test_invalid_check_name(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) task = pretend.stub() with pytest.raises(AttributeError): - backfill(task, db_request, "DoesNotExist", 1) + tasks.backfill(task, db_request, "DoesNotExist", 1) @pytest.mark.parametrize( ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)], ) - def test_run(self, db_session, num_objects, num_runs): + def test_run(self, db_session, num_objects, num_runs, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) files = [] for i in range(num_objects): files.append(FileFactory.create()) - MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.enabled + ) + enqueue_recorder = pretend.stub( delay=pretend.call_recorder(lambda *a, **kw: None) ) @@ -116,21 +130,29 @@ def test_run(self, db_session, num_objects, num_runs): task=task, ) - backfill(task, request, "ExampleCheck", num_runs) + tasks.backfill(task, request, "ExampleHookedCheck", num_runs) assert request.log.info.calls == [ pretend.call("Running backfill on %d Files." % num_runs), ] assert enqueue_recorder.delay.calls == [ - pretend.call("ExampleCheck", files[i].id) for i in range(num_runs) + pretend.call("ExampleHookedCheck", files[i].id) for i in range(num_runs) ] class TestSyncChecks: - def test_no_updates(self, db_session): + def test_no_updates(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + monkeypatch.setattr(tasks.checks.ExampleScheduledCheck, "version", 2) + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.disabled + ) MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.disabled + name="ExampleScheduledCheck", state=MalwareCheckState.disabled + ) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.enabled, version=2 ) task = pretend.stub() @@ -140,26 +162,25 @@ def test_no_updates(self, db_session): log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("1 malware checks found in codebase."), - pretend.call("ExampleCheck is unmodified."), + pretend.call("2 malware checks found in codebase."), + pretend.call("ExampleHookedCheck is unmodified."), + pretend.call("ExampleScheduledCheck is unmodified."), ] @pytest.mark.parametrize( ("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled] ) def test_upgrade_check(self, monkeypatch, db_session, final_state): - MalwareCheckFactory.create(name="ExampleCheck", state=final_state) + monkeypatch.setattr(tasks, "checks", test_checks) + monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "version", 2) - class ExampleCheck: - version = 2 - short_description = "This is a short description." - long_description = "This is a longer description." - check_type = "scheduled" - - monkeypatch.setattr(checks, "ExampleCheck", ExampleCheck) + MalwareCheckFactory.create(name="ExampleHookedCheck", state=final_state) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.disabled + ) task = pretend.stub() request = pretend.stub( @@ -167,15 +188,16 @@ class ExampleCheck: log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("1 malware checks found in codebase."), - pretend.call("Updating existing ExampleCheck."), + pretend.call("2 malware checks found in codebase."), + pretend.call("Updating existing ExampleHookedCheck."), + pretend.call("ExampleScheduledCheck is unmodified."), ] db_checks = ( db_session.query(MalwareCheck) - .filter(MalwareCheck.name == "ExampleCheck") + .filter(MalwareCheck.name == "ExampleHookedCheck") .all() ) @@ -193,7 +215,16 @@ class ExampleCheck: else: assert c.version == 1 - def test_one_new_check(self, db_session): + def test_one_new_check(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.disabled + ) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.disabled + ) + task = pretend.stub() class FakeMalwareCheck: @@ -201,26 +232,24 @@ class FakeMalwareCheck: short_description = "This is a short description." long_description = "This is a longer description." check_type = "scheduled" + schedule = {"minute": "0", "hour": "*/8"} - checks.FakeMalwareCheck = FakeMalwareCheck + tasks.checks.FakeMalwareCheck = FakeMalwareCheck request = pretend.stub( db=db_session, log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) - MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.evaluation - ) - - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("2 malware checks found in codebase."), - pretend.call("ExampleCheck is unmodified."), + pretend.call("3 malware checks found in codebase."), + pretend.call("ExampleHookedCheck is unmodified."), + pretend.call("ExampleScheduledCheck is unmodified."), pretend.call("Adding new FakeMalwareCheck to the database."), ] - assert db_session.query(MalwareCheck).count() == 2 + assert db_session.query(MalwareCheck).count() == 3 new_check = ( db_session.query(MalwareCheck) @@ -230,19 +259,23 @@ class FakeMalwareCheck: assert new_check.state == MalwareCheckState.disabled - del checks.FakeMalwareCheck + del tasks.checks.FakeMalwareCheck - def test_too_many_db_checks(self, db_session): - task = pretend.stub() + def test_too_many_db_checks(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) - MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) MalwareCheckFactory.create( - name="AnotherCheck", state=MalwareCheckState.disabled + name="ExampleHookedCheck", state=MalwareCheckState.enabled + ) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.enabled ) MalwareCheckFactory.create( name="AnotherCheck", state=MalwareCheckState.evaluation, version=2 ) + task = pretend.stub() + request = pretend.stub( db=db_session, log=pretend.stub( @@ -252,25 +285,30 @@ def test_too_many_db_checks(self, db_session): ) with pytest.raises(Exception): - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("1 malware checks found in codebase."), + pretend.call("2 malware checks found in codebase."), ] assert request.log.error.calls == [ pretend.call( - "Found 2 active checks in the db, but only 1 checks in code. Please \ + "Found 3 active checks in the db, but only 2 checks in code. Please \ manually move superfluous checks to the wiped_out state in the check admin: \ AnotherCheck" ), ] - def test_only_wiped_out(self, db_session): - task = pretend.stub() + def test_only_wiped_out(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.wiped_out + ) MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.wiped_out + name="ExampleScheduledCheck", state=MalwareCheckState.wiped_out ) + + task = pretend.stub() request = pretend.stub( db=db_session, log=pretend.stub( @@ -279,15 +317,19 @@ def test_only_wiped_out(self, db_session): ), ) - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("1 malware checks found in codebase."), + pretend.call("2 malware checks found in codebase."), ] assert request.log.error.calls == [ pretend.call( - "ExampleCheck is wiped_out and cannot be synced. Please remove check \ + "ExampleHookedCheck is wiped_out and cannot be synced. Please remove check \ +from codebase." + ), + pretend.call( + "ExampleScheduledCheck is wiped_out and cannot be synced. Please remove check \ from codebase." ), ] @@ -302,7 +344,7 @@ def test_no_verdicts(self, db_session): log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) task = pretend.stub() - removed = remove_verdicts(task, request, check.name) + removed = tasks.remove_verdicts(task, request, check.name) assert request.log.info.calls == [ pretend.call( @@ -338,7 +380,7 @@ def test_many_verdicts(self, db_session, check_with_verdicts): wiped_out_check = check0 num_verdicts = 0 - removed = remove_verdicts(task, request, wiped_out_check.name) + removed = tasks.remove_verdicts(task, request, wiped_out_check.name) assert request.log.info.calls == [ pretend.call( diff --git a/tests/unit/malware/test_utils.py b/tests/unit/malware/test_utils.py index 0f6ea532debf..c3cc7093ec6a 100644 --- a/tests/unit/malware/test_utils.py +++ b/tests/unit/malware/test_utils.py @@ -15,8 +15,9 @@ import pytest from warehouse.malware.models import MalwareCheckState, MalwareCheckType -from warehouse.malware.utils import get_check_fields, get_enabled_checks +from warehouse.malware.utils import get_check_fields, get_enabled_hooked_checks +from ...common.checks import ExampleHookedCheck, ExampleScheduledCheck from ...common.db.malware import MalwareCheckFactory @@ -27,7 +28,7 @@ def test_one(self, db_session): ) result = defaultdict(list) result[check.hooked_object.value].append(check.name) - checks = get_enabled_checks(db_session) + checks = get_enabled_hooked_checks(db_session) assert checks == result def test_many(self, db_session): @@ -40,36 +41,49 @@ def test_many(self, db_session): ): result[check.hooked_object.value].append(check.name) - checks = get_enabled_checks(db_session) + checks = get_enabled_hooked_checks(db_session) assert checks == result def test_none(self, db_session): - checks = get_enabled_checks(db_session) + checks = get_enabled_hooked_checks(db_session) assert checks == defaultdict(list) class TestGetCheckFields: - def test_success(self): - class MySampleCheck: - version = 6 - foo = "bar" - short_description = "This is the description" - long_description = "This is the description" - check_type = "scheduled" + @pytest.mark.parametrize( + ("check", "result"), + [ + ( + ExampleHookedCheck, + { + "name": "ExampleHookedCheck", + "version": 1, + "short_description": "An example hook-based check", + "long_description": "The purpose of this check is to test the \ +implementation of a hook-based check. This check will generate verdicts if enabled.", + "check_type": "event_hook", + "hooked_object": "File", + }, + ), + ( + ExampleScheduledCheck, + { + "name": "ExampleScheduledCheck", + "version": 1, + "short_description": "An example scheduled check", + "long_description": "The purpose of this check is to test the \ +implementation of a scheduled check. This check will generate verdicts if enabled.", + "check_type": "scheduled", + "schedule": {"minute": "0", "hour": "*/8"}, + }, + ), + ], + ) + def test_success(self, check, result): + assert get_check_fields(check) == result - result = get_check_fields(MySampleCheck) - assert result == { - "name": "MySampleCheck", - "version": 6, - "short_description": "This is the description", - "long_description": "This is the description", - "check_type": "scheduled", - } - - def test_failure(self): - class MySampleCheck: - version = 1 - status = True + def test_failure(self, monkeypatch): + monkeypatch.delattr(ExampleScheduledCheck, "schedule") with pytest.raises(AttributeError): - get_check_fields(MySampleCheck) + get_check_fields(ExampleScheduledCheck) diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py index ee0e36b808aa..f54a9e89b4f5 100644 --- a/warehouse/malware/__init__.py +++ b/warehouse/malware/__init__.py @@ -30,7 +30,7 @@ def determine_malware_checks(config, session, flush_context): return malware_checks = session.info.setdefault("warehouse.malware.checks", set()) - enabled_checks = utils.get_enabled_checks(session) + enabled_checks = utils.get_enabled_hooked_checks(session) for obj in session.new: for check_name in enabled_checks.get(obj.__class__.__name__, []): malware_checks.update([f"{check_name}:{obj.id}"]) diff --git a/warehouse/malware/checks/__init__.py b/warehouse/malware/checks/__init__.py index a627b7d18159..164f68b09175 100644 --- a/warehouse/malware/checks/__init__.py +++ b/warehouse/malware/checks/__init__.py @@ -9,5 +9,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .example import ExampleCheck # noqa diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index 3e9aa388a701..0464ba0d47ce 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -81,12 +81,14 @@ class MalwareCheck(db.Model): Enum(MalwareCheckType, values_callable=lambda x: [e.value for e in x]), nullable=False, ) - # This field contains the object name that check operates on, e.g. + # The object name that hooked-based checks operate on, e.g. # Project, File, Release hooked_object = Column( Enum(MalwareCheckObjectType, values_callable=lambda x: [e.value for e in x]), nullable=True, ) + # The run schedule for schedule-based checks. + schedule = Column(JSONB, nullable=True) state = Column( Enum(MalwareCheckState, values_callable=lambda x: [e.value for e in x]), nullable=False, diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py index 23af5b0a578e..6139c3e248a7 100644 --- a/warehouse/malware/utils.py +++ b/warehouse/malware/utils.py @@ -29,6 +29,7 @@ def valid_check_types(): def get_check_fields(check): result = {"name": check.__name__} + required_fields = ["short_description", "long_description", "version", "check_type"] for field in required_fields: result[field] = getattr(check, field) @@ -36,10 +37,13 @@ def get_check_fields(check): if result["check_type"] == "event_hook": result["hooked_object"] = check.hooked_object + if result["check_type"] == "scheduled": + result["schedule"] = check.schedule + return result -def get_enabled_checks(session): +def get_enabled_hooked_checks(session): checks = ( session.query(MalwareCheck.name, MalwareCheck.hooked_object) .filter(MalwareCheck.check_type == MalwareCheckType.event_hook) diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py index 6e23aeb243f8..622660fd042f 100644 --- a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -56,6 +56,7 @@ def upgrade(): sa.Column("long_description", sa.Text(), nullable=False), sa.Column("check_type", MalwareCheckTypes, nullable=False), sa.Column("hooked_object", MalwareCheckObjectTypes, nullable=True), + sa.Column("schedule", postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column( "state", MalwareCheckStates, server_default="disabled", nullable=False, ), From 13f20878eb6e9b3c24e6a1d0df23e9195607bdf2 Mon Sep 17 00:00:00 2001 From: William Woodruff Date: Mon, 27 Jan 2020 10:16:56 -0500 Subject: [PATCH 11/16] Event-based Malware check (#7249) * requirements: Introduce yara * [WIP] malware/check: SetupPatternCheck In progress. Introduces SetupPatternCheck, an implementation of an event-based check that scans the `setup.py`s of release files for suspicious patterns. * malware/checks: Give MalwareCheckBase.run/scan args, kwargs * malware: Add check preparation Fiddle with the check/run signature a bit more. * malware/checks: Unpack file path correctly * docker-compose: Override FILES_BACKEND for worker The worker needs to be able to see the "files" virtual host during development so that malware checks can fetch their underlying release files. * [WIP] malware/checks: setup.py extraction * malware/checks: setup_patterns: Fix enum, seek * malware/checks: setup_patterns: Apply YARA rules Each rule match becomes a verdict. * malware/checks: setup_patterns: Prefer get over filter * warehouse/{admin,malware}: Consistent enum names Also enforce uniqueness for enum values. * warehouse/{admin,malware}: More enum changes * tests: Update admin, malware tests * tests: Fix enum, more test fixes * tests: Add prepare tests * malware/changes: base: Unpack id correctly * tests: Begin adding SetupPatternCheck tests * malware/checks: setup_patterns: Fix enum * tests: More SetupPatternCheck tests * warehouse/malware: setup_patterns: Fix enums * tests: More SetupPatternCheck tests * tests: Add license header * malware/checks: setup_patterns: Add TODO * tests: More SetupPatternCheck tests * tests: More SetupPatternCheck tests * tests: Complete extraction tests for SetupPatternCheck * tests: Fix test * malware/checks: Add docstring for prepare * malware/checks: blacken * malware/checks: Document, expand YARA rules * tests, warehouse: Restructure utilities * malware: Order some enums, reduce SetupPatternCheck verdicts * malware/models: Add missing __lt__ * malware/checks: Always embed the model object in the prepared arguments Use it instead of performing a DB request in the check itself. * malware/checks: Avoid raw bytes * malware/changes: Remove unused import * tests: Fixup malware tests * warehouse/malware: blacken * tests: Fill in malware coverage * tests, warehouse: Add a benign verdict for SetupPatternCheck * tests: blacken --- docker-compose.yml | 1 + requirements/main.in | 1 + requirements/main.txt | 14 ++ tests/common/checks/hooked.py | 8 +- tests/unit/admin/views/test_checks.py | 12 +- tests/unit/malware/checks/__init__.py | 11 ++ .../malware/checks/setup_patterns/__init__.py | 11 ++ .../checks/setup_patterns/test_check.py | 145 +++++++++++++++ tests/unit/malware/checks/test_utils.py | 93 ++++++++++ tests/unit/malware/test_checks.py | 36 +++- tests/unit/malware/test_init.py | 2 +- tests/unit/malware/test_models.py | 40 ++++ tests/unit/malware/test_tasks.py | 101 +++++----- tests/unit/malware/test_utils.py | 6 +- warehouse/admin/views/checks.py | 8 +- warehouse/malware/checks/__init__.py | 2 + warehouse/malware/checks/base.py | 28 ++- .../malware/checks/setup_patterns/__init__.py | 13 ++ .../malware/checks/setup_patterns/check.py | 108 +++++++++++ .../checks/setup_patterns/setup_py_rules.yara | 174 ++++++++++++++++++ warehouse/malware/checks/utils.py | 80 ++++++++ warehouse/malware/models.py | 48 ++++- warehouse/malware/tasks.py | 9 +- warehouse/malware/utils.py | 4 +- 24 files changed, 863 insertions(+), 92 deletions(-) create mode 100644 tests/unit/malware/checks/__init__.py create mode 100644 tests/unit/malware/checks/setup_patterns/__init__.py create mode 100644 tests/unit/malware/checks/setup_patterns/test_check.py create mode 100644 tests/unit/malware/checks/test_utils.py create mode 100644 tests/unit/malware/test_models.py create mode 100644 warehouse/malware/checks/setup_patterns/__init__.py create mode 100644 warehouse/malware/checks/setup_patterns/check.py create mode 100644 warehouse/malware/checks/setup_patterns/setup_py_rules.yara create mode 100644 warehouse/malware/checks/utils.py diff --git a/docker-compose.yml b/docker-compose.yml index aa8bdc135974..864278cc7822 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -93,6 +93,7 @@ services: env_file: dev/environment environment: C_FORCE_ROOT: "1" + FILES_BACKEND: "warehouse.packaging.services.LocalFileStorage path=/var/opt/warehouse/packages/ url=http://files:9001/packages/{path}" links: - db - redis diff --git a/requirements/main.in b/requirements/main.in index 622573ca5cbc..cb4d15376eff 100644 --- a/requirements/main.in +++ b/requirements/main.in @@ -55,5 +55,6 @@ typeguard webauthn whitenoise WTForms>=2.0.0 +yara-python zope.sqlalchemy zxcvbn diff --git a/requirements/main.txt b/requirements/main.txt index 73e640e8061b..dbae4743ecde 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -594,6 +594,20 @@ wired==0.2.1 \ wtforms==2.2.1 \ --hash=sha256:0cdbac3e7f6878086c334aa25dc5a33869a3954e9d1e015130d65a69309b3b61 \ --hash=sha256:e3ee092c827582c50877cdbd49e9ce6d2c5c1f6561f849b3b068c1b8029626f1 +yara-python==3.11.0 \ + --hash=sha256:105d851e050b32951ee577148c7f1b18c0a7c64432fef8159069191d522fba86 \ + --hash=sha256:1d35c7f606465015de02143dfa4e1ad2f4ee85fdb5d5af756b51b2bac62ac7bc \ + --hash=sha256:24cd492d6bf8ecedb128f5b02886770be9df03bd1b84ab06a978d45bb1a8ff92 \ + --hash=sha256:58cfc837e7769811afbfb19b1db952ec01e50cdbf9df576fb587e1e343694526 \ + --hash=sha256:5b8d708751a66d1507d819218d06baccdf5527c147c2bd3062f087e2f367a17d \ + --hash=sha256:6f90bb264470235549e1bb4e355fa82895409cd46f27aceecaddfbf55e66ed71 \ + --hash=sha256:70d39c2238c5854e7cd8f11595317dc4d89417e88035d8acca24bcc58a93150f \ + --hash=sha256:8d255349d69d833bca604b4215bdf499c87357172512273feb934f6442b8e6b2 \ + --hash=sha256:8e44f9600607cb1d74a0f26df5d0a1c06ea54f4601206124f47f1bbb58e6a374 \ + --hash=sha256:9e4fafc327e3a343c545dcf5f173fa8bc712aebffe5f034d205c0bac1f1c5df6 \ + --hash=sha256:c919ee656139ed46a0056e8a3de179bbc98d42a2be6fb85c95b1e2ec65396b34 \ + --hash=sha256:e4124414d3cff9a10669569a89f585f81c8114b283ab48b2e756e0347a89de0a \ + --hash=sha256:f104f0bb21a0867f22e750bb4e05de629ec9f37facc84daf963385a86371b0d9 zipp==2.1.0 \ --hash=sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28 \ --hash=sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98 diff --git a/tests/common/checks/hooked.py b/tests/common/checks/hooked.py index 6bd6f2e512e4..8a3e16a3cbf9 100644 --- a/tests/common/checks/hooked.py +++ b/tests/common/checks/hooked.py @@ -26,10 +26,14 @@ class ExampleHookedCheck(MalwareCheckBase): def __init__(self, db): super().__init__(db) - def scan(self, file_id=None): + def scan(self, **kwargs): + file_id = kwargs.get("obj_id") + if file_id is None: + return + self.add_verdict( file_id=file_id, - classification=VerdictClassification.benign, + classification=VerdictClassification.Benign, confidence=VerdictConfidence.High, message="Nothing to see here!", ) diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 273bb8cd3fe3..c8fa6512aeaa 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -72,11 +72,11 @@ def test_no_check_state(self, db_request): views.change_check_state(db_request) @pytest.mark.parametrize( - ("final_state"), [MalwareCheckState.disabled, MalwareCheckState.wiped_out] + ("final_state"), [MalwareCheckState.Disabled, MalwareCheckState.WipedOut] ) def test_change_to_valid_state(self, db_request, final_state): check = MalwareCheckFactory.create( - name="MyCheck", state=MalwareCheckState.disabled + name="MyCheck", state=MalwareCheckState.Disabled ) db_request.POST = {"check_state": final_state.value} @@ -104,7 +104,7 @@ def test_change_to_valid_state(self, db_request, final_state): assert check.state == final_state - if final_state == MalwareCheckState.wiped_out: + if final_state == MalwareCheckState.WipedOut: assert wipe_out_recorder.delay.calls == [pretend.call("MyCheck")] def test_change_to_invalid_state(self, db_request): @@ -134,11 +134,11 @@ class TestRunBackfill: ("check_state", "message"), [ ( - MalwareCheckState.disabled, + MalwareCheckState.Disabled, "Check must be in 'enabled' or 'evaluation' state to run a backfill.", ), ( - MalwareCheckState.wiped_out, + MalwareCheckState.WipedOut, "Check must be in 'enabled' or 'evaluation' state to run a backfill.", ), ], @@ -160,7 +160,7 @@ def test_invalid_backfill_parameters(self, db_request, check_state, message): assert db_request.session.flash.calls == [pretend.call(message, queue="error")] def test_sucess(self, db_request): - check = MalwareCheckFactory.create(state=MalwareCheckState.enabled) + check = MalwareCheckFactory.create(state=MalwareCheckState.Enabled) db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( diff --git a/tests/unit/malware/checks/__init__.py b/tests/unit/malware/checks/__init__.py new file mode 100644 index 000000000000..164f68b09175 --- /dev/null +++ b/tests/unit/malware/checks/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/malware/checks/setup_patterns/__init__.py b/tests/unit/malware/checks/setup_patterns/__init__.py new file mode 100644 index 000000000000..164f68b09175 --- /dev/null +++ b/tests/unit/malware/checks/setup_patterns/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/malware/checks/setup_patterns/test_check.py b/tests/unit/malware/checks/setup_patterns/test_check.py new file mode 100644 index 000000000000..0dbd5c19d06f --- /dev/null +++ b/tests/unit/malware/checks/setup_patterns/test_check.py @@ -0,0 +1,145 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pretend +import pytest +import yara + +from warehouse.malware.checks.setup_patterns import check as c +from warehouse.malware.models import ( + MalwareCheckState, + VerdictClassification, + VerdictConfidence, +) + +from .....common.db.malware import MalwareCheckFactory +from .....common.db.packaging import FileFactory + + +def test_initializes(db_session): + check_model = MalwareCheckFactory.create( + name="SetupPatternCheck", state=MalwareCheckState.Enabled + ) + check = c.SetupPatternCheck(db_session) + + assert check.id == check_model.id + assert isinstance(check._yara_rules, yara.Rules) + + +@pytest.mark.parametrize( + ("obj", "file_url"), [(None, pretend.stub()), (pretend.stub(), None)] +) +def test_scan_missing_kwargs(db_session, obj, file_url): + MalwareCheckFactory.create( + name="SetupPatternCheck", state=MalwareCheckState.Enabled + ) + check = c.SetupPatternCheck(db_session) + check.scan(obj=obj, file_url=file_url) + + assert check._verdicts == [] + + +def test_scan_non_sdist(db_session): + MalwareCheckFactory.create( + name="SetupPatternCheck", state=MalwareCheckState.Enabled + ) + check = c.SetupPatternCheck(db_session) + + file = FileFactory.create(packagetype="bdist_wheel") + + check.scan(obj=file, file_url=pretend.stub()) + + assert check._verdicts == [] + + +def test_scan_no_setup_contents(db_session, monkeypatch): + monkeypatch.setattr( + c, "fetch_url_content", pretend.call_recorder(lambda *a: pretend.stub()) + ) + monkeypatch.setattr( + c, "extract_file_content", pretend.call_recorder(lambda *a: None) + ) + + MalwareCheckFactory.create( + name="SetupPatternCheck", state=MalwareCheckState.Enabled + ) + check = c.SetupPatternCheck(db_session) + + file = FileFactory.create(packagetype="sdist") + + check.scan(obj=file, file_url=pretend.stub()) + + assert len(check._verdicts) == 1 + assert check._verdicts[0].check_id == check.id + assert check._verdicts[0].file_id == file.id + assert check._verdicts[0].classification == VerdictClassification.Indeterminate + assert check._verdicts[0].confidence == VerdictConfidence.High + assert ( + check._verdicts[0].message + == "sdist does not contain a suitable setup.py for analysis" + ) + + +def test_scan_benign_contents(db_session, monkeypatch): + monkeypatch.setattr( + c, "fetch_url_content", pretend.call_recorder(lambda *a: pretend.stub()) + ) + monkeypatch.setattr( + c, + "extract_file_content", + pretend.call_recorder(lambda *a: b"this is a benign string"), + ) + + MalwareCheckFactory.create( + name="SetupPatternCheck", state=MalwareCheckState.Enabled + ) + check = c.SetupPatternCheck(db_session) + + file = FileFactory.create(packagetype="sdist") + + check.scan(obj=file, file_url=pretend.stub()) + + assert len(check._verdicts) == 1 + assert check._verdicts[0].check_id == check.id + assert check._verdicts[0].file_id == file.id + assert check._verdicts[0].classification == VerdictClassification.Benign + assert check._verdicts[0].confidence == VerdictConfidence.Low + assert check._verdicts[0].message == "No malicious patterns found in setup.py" + + +def test_scan_matched_content(db_session, monkeypatch): + monkeypatch.setattr( + c, "fetch_url_content", pretend.call_recorder(lambda *a: pretend.stub()) + ) + monkeypatch.setattr( + c, + "extract_file_content", + pretend.call_recorder( + lambda *a: b"this looks suspicious: os.system('cat /etc/passwd')" + ), + ) + + MalwareCheckFactory.create( + name="SetupPatternCheck", state=MalwareCheckState.Enabled + ) + check = c.SetupPatternCheck(db_session) + + file = FileFactory.create(packagetype="sdist") + + check.scan(obj=file, file_url=pretend.stub()) + + assert len(check._verdicts) == 1 + assert check._verdicts[0].check_id == check.id + assert check._verdicts[0].file_id == file.id + assert check._verdicts[0].classification == VerdictClassification.Threat + assert check._verdicts[0].confidence == VerdictConfidence.High + assert check._verdicts[0].message == "process_spawn_in_setup" diff --git a/tests/unit/malware/checks/test_utils.py b/tests/unit/malware/checks/test_utils.py new file mode 100644 index 000000000000..bdf91c70a697 --- /dev/null +++ b/tests/unit/malware/checks/test_utils.py @@ -0,0 +1,93 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import tarfile +import zipfile + +import pretend + +from warehouse.malware.checks import utils + + +def test_fetch_url_content(monkeypatch): + response = pretend.stub( + raise_for_status=pretend.call_recorder(lambda: None), content=b"fake content" + ) + requests = pretend.stub(get=pretend.call_recorder(lambda url: response)) + + monkeypatch.setattr(utils, "requests", requests) + + io = utils.fetch_url_content("hxxp://fake_url.com") + + assert requests.get.calls == [pretend.call("hxxp://fake_url.com")] + assert response.raise_for_status.calls == [pretend.call()] + assert io.getvalue() == b"fake content" + + +def test_extract_file_contents_zip(): + zipbuf = io.BytesIO() + with zipfile.ZipFile(zipbuf, mode="w") as zipobj: + zipobj.writestr("toplevelgetsskipped", b"nothing to see here") + zipobj.writestr("foo/setup.py", b"these are some contents") + zipbuf.seek(0) + + assert utils.extract_file_content(zipbuf, "setup.py") == b"these are some contents" + + +def test_extract_file_contents_zip_no_file(): + zipbuf = io.BytesIO() + with zipfile.ZipFile(zipbuf, mode="w") as zipobj: + zipobj.writestr("foo/notsetup.py", b"these are some contents") + zipbuf.seek(0) + + assert utils.extract_file_content(zipbuf, "setup.py") is None + + +def test_extract_file_contents_tar(): + tarbuf = io.BytesIO() + with tarfile.open(fileobj=tarbuf, mode="w:gz") as tarobj: + contents = io.BytesIO(b"these are some contents") + member = tarfile.TarInfo(name="foo/setup.py") + member.size = len(contents.getbuffer()) + tarobj.addfile(member, fileobj=contents) + + contents = io.BytesIO(b"nothing to see here") + member = tarfile.TarInfo(name="toplevelgetsskipped") + member.size = len(contents.getbuffer()) + tarobj.addfile(member, fileobj=contents) + tarbuf.seek(0) + + assert utils.extract_file_content(tarbuf, "setup.py") == b"these are some contents" + + +def test_extract_file_contents_tar_empty(): + tarbuf = io.BytesIO(b"invalid tar contents") + + assert utils.extract_file_content(tarbuf, "setup.py") is None + + +def test_extract_file_contents_tar_no_file(): + tarbuf = io.BytesIO() + with tarfile.open(fileobj=tarbuf, mode="w:gz") as tarobj: + contents = io.BytesIO(b"these are some contents") + member = tarfile.TarInfo(name="foo/notsetup.py") + member.size = len(contents.getbuffer()) + tarobj.addfile(member, fileobj=contents) + + contents = io.BytesIO(b"nothing to see here") + member = tarfile.TarInfo(name="toplevelgetsskipped") + member.size = len(contents.getbuffer()) + tarobj.addfile(member, fileobj=contents) + tarbuf.seek(0) + + assert utils.extract_file_content(tarbuf, "setup.py") is None diff --git a/tests/unit/malware/test_checks.py b/tests/unit/malware/test_checks.py index 9427972ef5e8..8398604a7869 100644 --- a/tests/unit/malware/test_checks.py +++ b/tests/unit/malware/test_checks.py @@ -12,6 +12,7 @@ import inspect +import pretend import pytest import warehouse.malware.checks as prod_checks @@ -20,6 +21,7 @@ from warehouse.malware.utils import get_check_fields from ...common import checks as test_checks +from ...common.db.packaging import FileFactory def test_checks_subclass_base(): @@ -37,9 +39,7 @@ def test_checks_subclass_base(): assert subclasses_of_malware_base[check_name] == check -@pytest.mark.parametrize( - ("checks"), [prod_checks, test_checks], -) +@pytest.mark.parametrize(("checks"), [prod_checks, test_checks]) def test_checks_fields(checks): checks_from_module = inspect.getmembers(checks, inspect.isclass) @@ -47,8 +47,36 @@ def test_checks_fields(checks): elems = inspect.getmembers(check, lambda a: not (inspect.isroutine(a))) inspection_fields = {"name": check_name} for elem_name, value in elems: - if not elem_name.startswith("__"): + # Skip both dunder and "private" (_-prefixed) attributes + if not elem_name.startswith("_"): inspection_fields[elem_name] = value fields = get_check_fields(check) assert inspection_fields == fields + + +def test_base_prepare_file_hooked(db_session): + file = FileFactory.create() + request = pretend.stub( + db=db_session, route_url=pretend.call_recorder(lambda *a, **kw: "fake_url") + ) + + kwargs = test_checks.ExampleHookedCheck.prepare(request, file.id) + + assert request.route_url.calls == [pretend.call("packaging.file", path=file.path)] + assert "file_url" in kwargs + assert kwargs["file_url"] == "fake_url" + + +def test_base_prepare_nonfile_hooked(db_session): + file = FileFactory.create() + request = pretend.stub( + db=db_session, route_url=pretend.call_recorder(lambda *a, **kw: "fake_url") + ) + + class FakeProjectCheck(MalwareCheckBase): + hooked_object = "Project" + + kwargs = FakeProjectCheck.prepare(request, file.id) + assert request.route_url.calls == [] + assert "file_url" not in kwargs diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py index 4d2888aef6a1..c8e1f3d7acf4 100644 --- a/tests/unit/malware/test_init.py +++ b/tests/unit/malware/test_init.py @@ -170,5 +170,5 @@ def test_includeme(monkeypatch): malware.includeme(config) assert config.register_service_factory.calls == [ - pretend.call(malware_check_class.create_service, IMalwareCheckService), + pretend.call(malware_check_class.create_service, IMalwareCheckService) ] diff --git a/tests/unit/malware/test_models.py b/tests/unit/malware/test_models.py new file mode 100644 index 000000000000..43f303357a24 --- /dev/null +++ b/tests/unit/malware/test_models.py @@ -0,0 +1,40 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from warehouse.malware.models import VerdictClassification, VerdictConfidence + + +def test_classification_orderable(): + assert ( + VerdictClassification.Benign + < VerdictClassification.Indeterminate + < VerdictClassification.Threat + ) + assert ( + max( + [ + VerdictClassification.Benign, + VerdictClassification.Indeterminate, + VerdictClassification.Threat, + ] + ) + == VerdictClassification.Threat + ) + + +def test_confidence_orderable(): + assert VerdictConfidence.Low < VerdictConfidence.Medium < VerdictConfidence.High + assert ( + max([VerdictConfidence.Low, VerdictConfidence.Medium, VerdictConfidence.High]) + == VerdictConfidence.High + ) diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 8fc427e35010..de278f4ad825 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -26,71 +26,68 @@ class TestRunCheck: def test_success(self, db_request, monkeypatch): + db_request.route_url = pretend.call_recorder(lambda *a, **kw: "fake_route") + monkeypatch.setattr(tasks, "checks", test_checks) file0 = FileFactory.create() MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.enabled + name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) task = pretend.stub() tasks.run_check(task, db_request, "ExampleHookedCheck", file0.id) + assert db_request.route_url.calls == [ + pretend.call("packaging.file", path=file0.path) + ] assert db_request.db.query(MalwareVerdict).one() def test_disabled_check(self, db_request, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.disabled + name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) - task = pretend.stub() + file = FileFactory.create() + with pytest.raises(NoResultFound): - tasks.run_check( - task, - db_request, - "ExampleHookedCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", - ) + tasks.run_check(task, db_request, "ExampleHookedCheck", file.id) def test_missing_check(self, db_request, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) task = pretend.stub() + + file = FileFactory.create() + with pytest.raises(AttributeError): - tasks.run_check( - task, - db_request, - "DoesNotExistCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", - ) + tasks.run_check(task, db_request, "DoesNotExistCheck", file.id) def test_retry(self, db_session, monkeypatch): exc = Exception("Scan failed") - def scan(self, file_id): + def scan(self, **kwargs): raise exc monkeypatch.setattr(tasks, "checks", test_checks) monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "scan", scan) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.evaluation + name="ExampleHookedCheck", state=MalwareCheckState.Evaluation ) task = pretend.stub( - retry=pretend.call_recorder(pretend.raiser(celery.exceptions.Retry)), + retry=pretend.call_recorder(pretend.raiser(celery.exceptions.Retry)) ) request = pretend.stub( db=db_session, log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)), + route_url=pretend.call_recorder(lambda *a, **kw: pretend.stub()), ) + file = FileFactory.create() + with pytest.raises(celery.exceptions.Retry): - tasks.run_check( - task, - request, - "ExampleHookedCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", - ) + tasks.run_check(task, request, "ExampleHookedCheck", file.id) assert request.log.error.calls == [ pretend.call("Error executing check ExampleHookedCheck: Scan failed") @@ -107,7 +104,7 @@ def test_invalid_check_name(self, db_request, monkeypatch): tasks.backfill(task, db_request, "DoesNotExist", 1) @pytest.mark.parametrize( - ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)], + ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)] ) def test_run(self, db_session, num_objects, num_runs, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) @@ -116,7 +113,7 @@ def test_run(self, db_session, num_objects, num_runs, monkeypatch): files.append(FileFactory.create()) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.enabled + name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) enqueue_recorder = pretend.stub( @@ -133,7 +130,7 @@ def test_run(self, db_session, num_objects, num_runs, monkeypatch): tasks.backfill(task, request, "ExampleHookedCheck", num_runs) assert request.log.info.calls == [ - pretend.call("Running backfill on %d Files." % num_runs), + pretend.call("Running backfill on %d Files." % num_runs) ] assert enqueue_recorder.delay.calls == [ @@ -146,20 +143,20 @@ def test_no_updates(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) monkeypatch.setattr(tasks.checks.ExampleScheduledCheck, "version", 2) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.disabled + name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) MalwareCheckFactory.create( - name="ExampleScheduledCheck", state=MalwareCheckState.disabled + name="ExampleScheduledCheck", state=MalwareCheckState.Disabled ) MalwareCheckFactory.create( - name="ExampleScheduledCheck", state=MalwareCheckState.enabled, version=2 + name="ExampleScheduledCheck", state=MalwareCheckState.Enabled, version=2 ) task = pretend.stub() request = pretend.stub( db=db_session, - log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), ) tasks.sync_checks(task, request) @@ -171,7 +168,7 @@ def test_no_updates(self, db_session, monkeypatch): ] @pytest.mark.parametrize( - ("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled] + ("final_state"), [MalwareCheckState.Enabled, MalwareCheckState.Disabled] ) def test_upgrade_check(self, monkeypatch, db_session, final_state): monkeypatch.setattr(tasks, "checks", test_checks) @@ -179,13 +176,13 @@ def test_upgrade_check(self, monkeypatch, db_session, final_state): MalwareCheckFactory.create(name="ExampleHookedCheck", state=final_state) MalwareCheckFactory.create( - name="ExampleScheduledCheck", state=MalwareCheckState.disabled + name="ExampleScheduledCheck", state=MalwareCheckState.Disabled ) task = pretend.stub() request = pretend.stub( db=db_session, - log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), ) tasks.sync_checks(task, request) @@ -203,9 +200,9 @@ def test_upgrade_check(self, monkeypatch, db_session, final_state): assert len(db_checks) == 2 - if final_state == MalwareCheckState.disabled: + if final_state == MalwareCheckState.Disabled: assert ( - db_checks[0].state == db_checks[1].state == MalwareCheckState.disabled + db_checks[0].state == db_checks[1].state == MalwareCheckState.Disabled ) else: @@ -219,10 +216,10 @@ def test_one_new_check(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.disabled + name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) MalwareCheckFactory.create( - name="ExampleScheduledCheck", state=MalwareCheckState.disabled + name="ExampleScheduledCheck", state=MalwareCheckState.Disabled ) task = pretend.stub() @@ -238,7 +235,7 @@ class FakeMalwareCheck: request = pretend.stub( db=db_session, - log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), ) tasks.sync_checks(task, request) @@ -257,7 +254,7 @@ class FakeMalwareCheck: .one() ) - assert new_check.state == MalwareCheckState.disabled + assert new_check.state == MalwareCheckState.Disabled del tasks.checks.FakeMalwareCheck @@ -265,13 +262,13 @@ def test_too_many_db_checks(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.enabled + name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) MalwareCheckFactory.create( - name="ExampleScheduledCheck", state=MalwareCheckState.enabled + name="ExampleScheduledCheck", state=MalwareCheckState.Enabled ) MalwareCheckFactory.create( - name="AnotherCheck", state=MalwareCheckState.evaluation, version=2 + name="AnotherCheck", state=MalwareCheckState.Evaluation, version=2 ) task = pretend.stub() @@ -288,7 +285,7 @@ def test_too_many_db_checks(self, db_session, monkeypatch): tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("2 malware checks found in codebase."), + pretend.call("2 malware checks found in codebase.") ] assert request.log.error.calls == [ @@ -296,16 +293,16 @@ def test_too_many_db_checks(self, db_session, monkeypatch): "Found 3 active checks in the db, but only 2 checks in code. Please \ manually move superfluous checks to the wiped_out state in the check admin: \ AnotherCheck" - ), + ) ] def test_only_wiped_out(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.wiped_out + name="ExampleHookedCheck", state=MalwareCheckState.WipedOut ) MalwareCheckFactory.create( - name="ExampleScheduledCheck", state=MalwareCheckState.wiped_out + name="ExampleScheduledCheck", state=MalwareCheckState.WipedOut ) task = pretend.stub() @@ -320,7 +317,7 @@ def test_only_wiped_out(self, db_session, monkeypatch): tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("2 malware checks found in codebase."), + pretend.call("2 malware checks found in codebase.") ] assert request.log.error.calls == [ @@ -341,7 +338,7 @@ def test_no_verdicts(self, db_session): request = pretend.stub( db=db_session, - log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), ) task = pretend.stub() removed = tasks.remove_verdicts(task, request, check.name) @@ -349,7 +346,7 @@ def test_no_verdicts(self, db_session): assert request.log.info.calls == [ pretend.call( "Removing 0 malware verdicts associated with %s version 1." % check.name - ), + ) ] assert removed == 0 @@ -369,7 +366,7 @@ def test_many_verdicts(self, db_session, check_with_verdicts): request = pretend.stub( db=db_session, - log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), ) task = pretend.stub() @@ -386,7 +383,7 @@ def test_many_verdicts(self, db_session, check_with_verdicts): pretend.call( "Removing %d malware verdicts associated with %s version 1." % (num_verdicts, wiped_out_check.name) - ), + ) ] assert removed == num_verdicts diff --git a/tests/unit/malware/test_utils.py b/tests/unit/malware/test_utils.py index c3cc7093ec6a..a7e0fae8b66d 100644 --- a/tests/unit/malware/test_utils.py +++ b/tests/unit/malware/test_utils.py @@ -24,7 +24,7 @@ class TestGetEnabledChecks: def test_one(self, db_session): check = MalwareCheckFactory.create( - state=MalwareCheckState.enabled, check_type=MalwareCheckType.event_hook + state=MalwareCheckState.Enabled, check_type=MalwareCheckType.EventHook ) result = defaultdict(list) result[check.hooked_object.value].append(check.name) @@ -36,8 +36,8 @@ def test_many(self, db_session): for i in range(10): check = MalwareCheckFactory.create() if ( - check.state == MalwareCheckState.enabled - and check.check_type == MalwareCheckType.event_hook + check.state == MalwareCheckState.Enabled + and check.check_type == MalwareCheckType.EventHook ): result[check.hooked_object.value].append(check.name) diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index 7fd3fa9aacf0..b23465fd4a49 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -67,7 +67,7 @@ def run_backfill(request): check = get_check_by_name(request.db, request.matchdict["check_name"]) num_objects = 10000 - if check.state not in (MalwareCheckState.enabled, MalwareCheckState.evaluation): + if check.state not in (MalwareCheckState.Enabled, MalwareCheckState.Evaluation): request.session.flash( f"Check must be in 'enabled' or 'evaluation' state to run a backfill.", queue="error", @@ -105,11 +105,11 @@ def change_check_state(request): raise HTTPNotFound try: - check.state = getattr(MalwareCheckState, check_state) - except AttributeError: + check.state = MalwareCheckState(check_state) + except ValueError: request.session.flash("Invalid check state provided.", queue="error") else: - if check.state == MalwareCheckState.wiped_out: + if check.state == MalwareCheckState.WipedOut: request.task(remove_verdicts).delay(check.name) request.session.flash( f"Changed {check.name!r} check to {check.state.value!r}!", queue="success" diff --git a/warehouse/malware/checks/__init__.py b/warehouse/malware/checks/__init__.py index 164f68b09175..fa0607f15913 100644 --- a/warehouse/malware/checks/__init__.py +++ b/warehouse/malware/checks/__init__.py @@ -9,3 +9,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .setup_patterns import SetupPatternCheck # noqa diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py index 72406954c982..67810a367203 100644 --- a/warehouse/malware/checks/base.py +++ b/warehouse/malware/checks/base.py @@ -11,6 +11,7 @@ # limitations under the License. from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict +from warehouse.packaging import models class MalwareCheckBase: @@ -20,17 +21,34 @@ def __init__(self, db): self._load_check_id() self._verdicts = [] + @classmethod + def prepare(cls, request, obj_id): + """ + Prepares some context for scanning the given object. + """ + kwargs = {} + + model = getattr(models, cls.hooked_object) + kwargs["obj"] = request.db.query(model).get(obj_id) + + if cls.hooked_object == "File": + kwargs["file_url"] = request.route_url( + "packaging.file", path=kwargs["obj"].path + ) + + return kwargs + def add_verdict(self, **kwargs): self._verdicts.append(MalwareVerdict(check_id=self.id, **kwargs)) - def run(self, obj_id): + def run(self, **kwargs): """ Runs the check and inserts returned verdicts. """ - self.scan(obj_id) + self.scan(**kwargs) self.db.add_all(self._verdicts) - def scan(self, obj_id): + def scan(self, **kwargs): """ Scans the object and returns a verdict. """ @@ -43,12 +61,12 @@ def backfill(self, sample=1): """ def _load_check_id(self): - self.id = ( + (self.id,) = ( self.db.query(MalwareCheck.id) .filter(MalwareCheck.name == self._name) .filter( MalwareCheck.state.in_( - [MalwareCheckState.enabled, MalwareCheckState.evaluation] + [MalwareCheckState.Enabled, MalwareCheckState.Evaluation] ) ) .one() diff --git a/warehouse/malware/checks/setup_patterns/__init__.py b/warehouse/malware/checks/setup_patterns/__init__.py new file mode 100644 index 000000000000..dec05468533a --- /dev/null +++ b/warehouse/malware/checks/setup_patterns/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .check import SetupPatternCheck # noqa diff --git a/warehouse/malware/checks/setup_patterns/check.py b/warehouse/malware/checks/setup_patterns/check.py new file mode 100644 index 000000000000..c2127bb292db --- /dev/null +++ b/warehouse/malware/checks/setup_patterns/check.py @@ -0,0 +1,108 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from textwrap import dedent + +import yara + +from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.checks.utils import extract_file_content, fetch_url_content +from warehouse.malware.models import VerdictClassification, VerdictConfidence + + +class SetupPatternCheck(MalwareCheckBase): + _yara_rule_file = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "setup_py_rules.yara" + ) + + version = 1 + short_description = "A check for common malicious patterns in setup.py" + long_description = dedent( + """ + This check uses YARA to search for common malicious patterns in the setup.py + files of uploaded release archives. + """ + ) + check_type = "event_hook" + hooked_object = "File" + + def __init__(self, db): + super().__init__(db) + self._yara_rules = self._load_yara_rules() + + def _load_yara_rules(self): + return yara.compile(filepath=self._yara_rule_file) + + def scan(self, **kwargs): + file = kwargs.get("obj") + file_url = kwargs.get("file_url") + if file is None or file_url is None: + # TODO: Maybe raise here, since the absence of these + # arguments is a use/user error. + return + + if file.packagetype != "sdist": + # Per PEP 491: bdists do not contain setup.py. + # This check only scans dists that contain setup.py, so + # we have nothing to perform. + return + + archive_stream = fetch_url_content(file_url) + setup_py_contents = extract_file_content(archive_stream, "setup.py") + if setup_py_contents is None: + self.add_verdict( + file_id=file.id, + classification=VerdictClassification.Indeterminate, + confidence=VerdictConfidence.High, + message="sdist does not contain a suitable setup.py for analysis", + ) + return + + matches = self._yara_rules.match(data=setup_py_contents) + if len(matches) > 0: + # We reduce N matches into a single verdict by taking the maximum + # classification and confidence. + classification = max( + VerdictClassification(m.meta["classification"]) for m in matches + ) + confidence = max(VerdictConfidence(m.meta["confidence"]) for m in matches) + message = ":".join(m.rule for m in matches) + + details = {} + for match in matches: + details[match.rule] = { + "classification": match.meta["classification"], + "confidence": match.meta["confidence"], + # NOTE: We could include the raw bytes here (s[2]), + # but we'd have to serialize/encode it to make JSON happy. + # It probably suffices to include the offset and identifier + # for triage purposes. + "strings": [[s[0], s[1]] for s in match.strings], + } + + self.add_verdict( + file_id=file.id, + classification=classification, + confidence=confidence, + message=message, + details=details, + ) + else: + # No matches? Report a low-confidence benign verdict. + self.add_verdict( + file_id=file.id, + classification=VerdictClassification.Benign, + confidence=VerdictConfidence.Low, + message="No malicious patterns found in setup.py", + ) diff --git a/warehouse/malware/checks/setup_patterns/setup_py_rules.yara b/warehouse/malware/checks/setup_patterns/setup_py_rules.yara new file mode 100644 index 000000000000..cf67b1377db3 --- /dev/null +++ b/warehouse/malware/checks/setup_patterns/setup_py_rules.yara @@ -0,0 +1,174 @@ +/* Patterns that indicate or suggest an attempt to spawn a process + * using various routines in the `os` module. + * + * These indicators are classified as "threat" to reflect the low + * probability that their presence is legitimate. + */ +rule process_spawn_in_setup { + meta: + confidence = "high" + classification = "threat" + + strings: + // NOTE(ww): We don't detect `import os as ...` + $from_os_import = /from os import / + + // Bare calls to suspicious os methods. + $bare_system = "system" + $bare_exec = /exec.*/ + $bare_spawn = /spawn.*/ + $bare_posix_spawn = /posix_spawn.*/ + $bare_popen = /popen.*/ + + // Fully qualified calls to suspicious os methods. + $fq_system = "os.system(" + $fq_exec = /os\.exec.*/ + $fq_spawn = /os\.spawn.*/ + $fq_posix_spawn = /os\.posix_spawn.*/ + $fq_popen = /os\.popen.*/ + + condition: + (1 of ($fq_*)) or ($from_os_import and (1 of ($bare_*))) +} + +/* Patterns that indicate or suggest an attempt to spawn a process + * using various routines and objects in the `subprocess` module. + * + * These indicators are classified as "threat" to reflect the low + * probability that their presence is legitimate. + */ +rule subprocess_in_setup { + meta: + confidence = "high" + classification = "threat" + + strings: + // NOTE(ww): We don't detect `import subprocess as ...` + $from_subprocess_import = /from subprocess import / + + // Bare calls to suspicious subprocess methods/objects + $bare_run = "run" + $bare_Popen = "Popen" + $bare_call = "call" + $bare_check_call = "check_call" + $bare_check_output = "check_output" + + // Fully qualified calls to suspicious subprocess methods/objects + $fq_run = "subprocess.run" + $fq_Popen = "subprocess.Popen" + $fq_call = "subprocess.call" + $fq_check_call = "subprocess.check_call" + $fq_check_output = "subprocess.check_output" + + condition: + (1 of ($fq_*) or ($from_subprocess_import and (1 of ($bare_*)))) +} + +/* Patterns that indicate or suggest an attempt to access a network resource. + * + * These indicators are classified as "indeterminate" to reflect that some + * legitimate use cases may exist. + */ +rule networking_in_setup { + meta: + confidence = "high" + classification = "indeterminate" + + strings: + // These modules contain frequently-used routines for making network requests + // Other candidates: poplib, imaplib, nntplib, smtplib, telnetlib + $from_socket_import = /from socket(\..+)? import/ + $from_socketserver_import = /from socketserver(\..+)? import/ + $from_ssl_import = /from ssl(\..+)? import/ + $from_ftplib_import = /from ftplib(\..+)? import/ + $from_http_import = /from http(\..+)? import/ + $from_urllib_import = /from urllib(\..+)? import/ + $from_xmlrpc_sub_import = /from xmlrpc(\..+)? import/ + + $import_socket = /import socket(\..+)?/ + $import_socketserver = /import socketserver(\..+)?/ + $import_ssl = /import ssl(\..+)?/ + $import_ftplib = /import ftplib(\..+)?/ + $import_http = /import http(\..+)?/ + $import_http_sub = /import http(\..+)?/ + $import_urllib = /import urllib(\..+)?/ + $import_urllib_sub = /import urllib(\..+)?/ + $import_xmlrpc = /import xmlrpc(\..+)?/ + $import_xmlrpc_sub = /import xmlrpc(\..+)?/ + + condition: + any of them +} + +/* Patterns that indicate or suggest an attempt to deserialize data. + * + * These indicators are clasified as "indeterminate" to reflect that some + * legitimate use cases may exist. + */ +rule deserialization_in_setup { + meta: + confidence = "high" + classification = "indeterminate" + + strings: + // These modules contain frequently-used routines for obfuscating data + // Other candidates: uu, quopri + $from_pickle_import = /from pickle(\..+)? import/ + $from_base64_import = /from base64(\..+)? import/ + $from_binhex_import = /from binhex(\..+)? import/ + + $import_pickle = /import pickle(\..+)?/ + $import_base64 = /import base64(\..+)?/ + $import_binhex = /import binhex(\..+)?/ + + condition: + any of them +} + +/* Patterns that indicate or suggest an attempt to perform metaprogramming. + * + * These indicators are clasified as "indeterminate" to reflect that some + * legitimate use cases may exist. + */ +rule metaprogramming_in_setup { + meta: + confidence = "high" + classification = "indeterminate" + + strings: + // The inspect module contains routines that can be used to obfuscate accesses + $from_inspect_import = /from inspect(\..+)? import/ + $import_inspect = /import inspect(\..+)?/ + + // The compileall module contains routines that can be used to smuggle Python code + $from_compileall_import = /from compileall(\..+)? import/ + $import_compileall = /import compileall(\..+)?/ + + // The py_compile module contains routines that can be used to smuggle Python code + $from_py_compile_import = /from py_compile(\..+)? import/ + $import_py_compile = /import py_compile(\..+)?/ + + // compile can be used to smuggle Python code into exec or eval. + $compile_call = /compile\(/ + + // dir can be used to obfuscate accesses of attributes + $dir_call = /dir\(/ + + // eval can be used to evaluate smuggled code + $eval_call = /eval\(/ + + // exec can be used to evaluate smuggled code + $exec_call = /exec\(/ + + // getattr can be used to obfuscate accesses of attributes + $getattr_call = /getattr\(/ + + // globals can be used to obfuscate accesses of attributes + $globals_call = /globals\(/ + + // locals can be used to obfuscate accesses of attributes + $locals_call = /locals\(/ + + condition: + any of them +} diff --git a/warehouse/malware/checks/utils.py b/warehouse/malware/checks/utils.py new file mode 100644 index 000000000000..5ddda01ccc7c --- /dev/null +++ b/warehouse/malware/checks/utils.py @@ -0,0 +1,80 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import pathlib +import tarfile +import zipfile + +import requests + + +def fetch_url_content(url): + """ + Retrieves the contents of the given (presumed CDN) URL as a BytesIO. + + Performs no error checking; exceptions are handled in the check harness + as part of check retrying behavior. + """ + response = requests.get(url) + response.raise_for_status() + return io.BytesIO(response.content) + + +def extract_file_content(archive_stream, file_path): + """ + Retrieves the content of the given path from the given archive stream + (presumed to be a dist) as bytes. + + Handling of the given path is a little special: since the dist format(s) + don't enforce any naming convention for the base archive directory, + the path is interpreted as {base}/{file_path}. Thus, a call like this: + + extract_file_content(stream, "setup.py") + + will extract and return the contents of {base}/setup.py where {base} + is frequently (but not guaranteed to be) something like $name-$version. + + Returns None on any sort of failure. + """ + if zipfile.is_zipfile(archive_stream): + with zipfile.ZipFile(archive_stream) as zipobj: + for name in zipobj.namelist(): + path_parts = pathlib.Path(name).parts + if len(path_parts) >= 2: + tail = pathlib.Path(*path_parts[1:]) + if str(tail) == file_path: + return zipobj.read(name) + return None + else: + # NOTE: is_zipfile doesn't rewind the fileobj it's given. + archive_stream.seek(0) + + # NOTE: We don't need to perform a sanity check on + # the (presumed) tarfile's compression here, since we're + # extracting from a stream that's already gone through + # upload validation. + # See _is_valid_dist_file in forklift/legacy.py. + try: + with tarfile.open(fileobj=archive_stream) as tarobj: + member = tarobj.next() + while member: + path_parts = pathlib.Path(member.name).parts + if len(path_parts) >= 2: + tail = pathlib.Path(*path_parts[1:]) + if str(tail) == file_path: + return tarobj.extractfile(member).read() + + member = tarobj.next() + return None + except tarfile.TarError: + return None diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index 0464ba0d47ce..a1ab490fc45e 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -11,6 +11,7 @@ # limitations under the License. import enum +import functools from citext import CIText from sqlalchemy import ( @@ -32,20 +33,23 @@ from warehouse.utils.attrs import make_repr +@enum.unique class MalwareCheckType(enum.Enum): - event_hook = "event_hook" - scheduled = "scheduled" + EventHook = "event_hook" + Scheduled = "scheduled" +@enum.unique class MalwareCheckState(enum.Enum): - enabled = "enabled" - evaluation = "evaluation" - disabled = "disabled" - wiped_out = "wiped_out" + Enabled = "enabled" + Evaluation = "evaluation" + Disabled = "disabled" + WipedOut = "wiped_out" +@enum.unique class MalwareCheckObjectType(enum.Enum): File = "File" @@ -53,19 +57,45 @@ class MalwareCheckObjectType(enum.Enum): Project = "Project" +@enum.unique +@functools.total_ordering class VerdictClassification(enum.Enum): + """ + An enumeration of classification markers for malware verdicts. - threat = "threat" - indeterminate = "indeterminate" - benign = "benign" + Note that the order of declaration is important: it provides + the appropriate ordering behavior when finding the minimum + and maximum classifications for a set of verdicts. + """ + Benign = "benign" + Indeterminate = "indeterminate" + Threat = "threat" + def __lt__(self, other): + members = list(self.__class__) + return members.index(self) < members.index(other) + + +@enum.unique +@functools.total_ordering class VerdictConfidence(enum.Enum): + """ + An enumeration of confidence markers for malware verdicts. + + Note that the order of declaration is important: it provides + the appropriate ordering behavior when finding the minimum + and maximum confidences for a set of verdicts. + """ Low = "low" Medium = "medium" High = "high" + def __lt__(self, other): + members = list(self.__class__) + return members.index(self) < members.index(other) + class MalwareCheck(db.Model): diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index 5028ed00e7ef..dc0274cbee98 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -24,7 +24,8 @@ def run_check(task, request, check_name, obj_id): check = getattr(checks, check_name)(request.db) try: - check.run(obj_id) + kwargs = check.prepare(request, obj_id) + check.run(obj_id=obj_id, **kwargs) except Exception as exc: request.log.error("Error executing check %s: %s" % (check_name, str(exc))) raise task.retry(exc=exc) @@ -55,7 +56,7 @@ def sync_checks(task, request): wiped_out_checks = {} for check in all_checks: if not check.is_stale: - if check.state == MalwareCheckState.wiped_out: + if check.state == MalwareCheckState.WipedOut: wiped_out_checks[check.name] = check else: active_checks[check.name] = check @@ -93,9 +94,9 @@ def sync_checks(task, request): # Migrate the check state to the newest check. # Then mark the old check state as disabled. - if db_check.state != MalwareCheckState.disabled: + if db_check.state != MalwareCheckState.Disabled: fields["state"] = db_check.state.value - db_check.state = MalwareCheckState.disabled + db_check.state = MalwareCheckState.Disabled request.db.add(MalwareCheck(**fields)) else: diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py index 6139c3e248a7..879d9e8a3a7b 100644 --- a/warehouse/malware/utils.py +++ b/warehouse/malware/utils.py @@ -46,8 +46,8 @@ def get_check_fields(check): def get_enabled_hooked_checks(session): checks = ( session.query(MalwareCheck.name, MalwareCheck.hooked_object) - .filter(MalwareCheck.check_type == MalwareCheckType.event_hook) - .filter(MalwareCheck.state == MalwareCheckState.enabled) + .filter(MalwareCheck.check_type == MalwareCheckType.EventHook) + .filter(MalwareCheck.state == MalwareCheckState.Enabled) .all() ) results = defaultdict(list) From f4e2e1ece25a7991974990b1f50bb857cd7120bb Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 30 Jan 2020 13:13:28 -0800 Subject: [PATCH 12/16] Implement scheduled checks #7093 (#7271) * Implement scheduled checks #7093 - Rename `run_backfill` to `run_evaluation` in admin malware view - Modify `run` and `scan` method signatures to accept `**kwargs` - Extend `run_check` to accomodate scheduled check functionality * Reduce unit test flakiness * Code review changes. Also replace `check.hooked_object` with `check.hooked_object.value` in check detail template. * tests, warehouse: enum fixes * Fix lint error Co-authored-by: William Woodruff --- tests/common/checks/scheduled.py | 4 +- tests/unit/admin/test_routes.py | 4 +- tests/unit/admin/views/test_checks.py | 59 ++++++++++----- tests/unit/malware/test_init.py | 10 +++ tests/unit/malware/test_tasks.py | 71 ++++++++++++++----- warehouse/admin/routes.py | 4 +- .../admin/malware/checks/detail.html | 18 ++++- .../templates/admin/malware/checks/index.html | 4 +- warehouse/admin/views/checks.py | 35 +++++---- warehouse/malware/__init__.py | 16 +++++ warehouse/malware/checks/base.py | 10 +-- warehouse/malware/tasks.py | 29 ++++++-- 12 files changed, 197 insertions(+), 67 deletions(-) diff --git a/tests/common/checks/scheduled.py b/tests/common/checks/scheduled.py index 128ce102a83b..d5c80962e45c 100644 --- a/tests/common/checks/scheduled.py +++ b/tests/common/checks/scheduled.py @@ -27,11 +27,11 @@ class ExampleScheduledCheck(MalwareCheckBase): def __init__(self, db): super().__init__(db) - def scan(self): + def scan(self, **kwargs): project = self.db.query(Project).first() self.add_verdict( project_id=project.id, - classification=VerdictClassification.benign, + classification=VerdictClassification.Benign, confidence=VerdictConfidence.High, message="Nothing to see here!", ) diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 28538ad12dac..451b6ad5a9e0 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -133,8 +133,8 @@ def test_includeme(): domain=warehouse, ), pretend.call( - "admin.checks.run_backfill", - "/admin/checks/{check_name}/run_backfill", + "admin.checks.run_evaluation", + "/admin/checks/{check_name}/run_evaluation", domain=warehouse, ), pretend.call("admin.verdicts.list", "/admin/verdicts/", domain=warehouse), diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index c8fa6512aeaa..0c3baeb03369 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -16,7 +16,8 @@ from pyramid.httpexceptions import HTTPNotFound from warehouse.admin.views import checks as views -from warehouse.malware.models import MalwareCheckState +from warehouse.malware.models import MalwareCheckState, MalwareCheckType +from warehouse.malware.tasks import backfill, run_check from ....common.db.malware import MalwareCheckFactory @@ -46,6 +47,7 @@ def test_get_check(self, db_request): "check": check, "checks": [check], "states": MalwareCheckState, + "evaluation_run_size": 10000, } def test_get_check_many_versions(self, db_request): @@ -56,6 +58,7 @@ def test_get_check_many_versions(self, db_request): "check": check2, "checks": [check2, check1], "states": MalwareCheckState, + "evaluation_run_size": 10000, } def test_get_check_not_found(self, db_request): @@ -129,17 +132,17 @@ def test_change_to_invalid_state(self, db_request): assert check.state == initial_state -class TestRunBackfill: +class TestRunEvaluation: @pytest.mark.parametrize( ("check_state", "message"), [ ( MalwareCheckState.Disabled, - "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + "Check must be in 'enabled' or 'evaluation' state to manually execute.", ), ( MalwareCheckState.WipedOut, - "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + "Check must be in 'enabled' or 'evaluation' state to manually execute.", ), ], ) @@ -152,15 +155,21 @@ def test_invalid_backfill_parameters(self, db_request, check_state, message): ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + lambda *a, **kw: "/admin/checks/%s/run_evaluation" % check.name ) - views.run_backfill(db_request) + views.run_evaluation(db_request) assert db_request.session.flash.calls == [pretend.call(message, queue="error")] - def test_sucess(self, db_request): - check = MalwareCheckFactory.create(state=MalwareCheckState.Enabled) + @pytest.mark.parametrize( + ("check_type"), [MalwareCheckType.EventHook, MalwareCheckType.Scheduled] + ) + def test_success(self, db_request, check_type): + + check = MalwareCheckFactory.create( + check_type=check_type, state=MalwareCheckState.Enabled + ) db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -168,7 +177,7 @@ def test_sucess(self, db_request): ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + lambda *a, **kw: "/admin/checks/%s/run_evaluation" % check.name ) backfill_recorder = pretend.stub( @@ -177,13 +186,25 @@ def test_sucess(self, db_request): db_request.task = pretend.call_recorder(lambda *a, **kw: backfill_recorder) - views.run_backfill(db_request) - - assert db_request.session.flash.calls == [ - pretend.call( - "Running %s on 10000 %ss!" % (check.name, check.hooked_object.value), - queue="success", - ) - ] - - assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] + views.run_evaluation(db_request) + + if check_type == MalwareCheckType.EventHook: + assert db_request.session.flash.calls == [ + pretend.call( + "Running %s on 10000 %ss!" + % (check.name, check.hooked_object.value), + queue="success", + ) + ] + assert db_request.task.calls == [pretend.call(backfill)] + assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] + elif check_type == MalwareCheckType.Scheduled: + assert db_request.session.flash.calls == [ + pretend.call("Running %s now!" % check.name, queue="success",) + ] + assert db_request.task.calls == [pretend.call(run_check)] + assert backfill_recorder.delay.calls == [ + pretend.call(check.name, manually_triggered=True) + ] + else: + raise Exception("Invalid check type: %s" % check_type) diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py index c8e1f3d7acf4..32628118398d 100644 --- a/tests/unit/malware/test_init.py +++ b/tests/unit/malware/test_init.py @@ -14,9 +14,12 @@ import pretend +from celery.schedules import crontab + from warehouse import malware from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.tasks import run_check from ...common import checks as test_checks from ...common.db.accounts import UserFactory @@ -165,6 +168,7 @@ def test_includeme(monkeypatch): registry=pretend.stub( settings={"malware_check.backend": "TestMalwareCheckService"} ), + add_periodic_task=pretend.call_recorder(lambda *a, **kw: None), ) malware.includeme(config) @@ -172,3 +176,9 @@ def test_includeme(monkeypatch): assert config.register_service_factory.calls == [ pretend.call(malware_check_class.create_service, IMalwareCheckService) ] + + assert config.add_periodic_task.calls == [ + pretend.call( + crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",) + ) + ] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index de278f4ad825..6a333c31b018 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -14,8 +14,6 @@ import pretend import pytest -from sqlalchemy.orm.exc import NoResultFound - from warehouse.malware import tasks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict @@ -34,45 +32,86 @@ def test_success(self, db_request, monkeypatch): name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) task = pretend.stub() - tasks.run_check(task, db_request, "ExampleHookedCheck", file0.id) + tasks.run_check(task, db_request, "ExampleHookedCheck", obj_id=file0.id) assert db_request.route_url.calls == [ pretend.call("packaging.file", path=file0.path) ] assert db_request.db.query(MalwareVerdict).one() - def test_disabled_check(self, db_request, monkeypatch): + @pytest.mark.parametrize(("manually_triggered"), [True, False]) + def test_evaluation_run(self, db_session, monkeypatch, manually_triggered): + monkeypatch.setattr(tasks, "checks", test_checks) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.Evaluation + ) + ProjectFactory.create() + task = pretend.stub() + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + ) + + tasks.run_check( + task, + request, + "ExampleScheduledCheck", + manually_triggered=manually_triggered, + ) + + if manually_triggered: + assert db_session.query(MalwareVerdict).one() + else: + assert request.log.info.calls == [ + pretend.call( + "ExampleScheduledCheck is in the `evaluation` state and must be \ +manually triggered to run." + ) + ] + assert db_session.query(MalwareVerdict).all() == [] + + def test_disabled_check(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) task = pretend.stub() + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + ) file = FileFactory.create() - with pytest.raises(NoResultFound): - tasks.run_check(task, db_request, "ExampleHookedCheck", file.id) + tasks.run_check( + task, request, "ExampleHookedCheck", obj_id=file.id, + ) + + assert request.log.info.calls == [ + pretend.call("Check ExampleHookedCheck isn't active. Aborting.") + ] def test_missing_check(self, db_request, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) task = pretend.stub() - file = FileFactory.create() - with pytest.raises(AttributeError): - tasks.run_check(task, db_request, "DoesNotExistCheck", file.id) + tasks.run_check( + task, db_request, "DoesNotExistCheck", + ) def test_retry(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) exc = Exception("Scan failed") def scan(self, **kwargs): raise exc - monkeypatch.setattr(tasks, "checks", test_checks) monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "scan", scan) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.Evaluation + name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) task = pretend.stub( @@ -87,7 +126,7 @@ def scan(self, **kwargs): file = FileFactory.create() with pytest.raises(celery.exceptions.Retry): - tasks.run_check(task, request, "ExampleHookedCheck", file.id) + tasks.run_check(task, request, "ExampleHookedCheck", obj_id=file.id) assert request.log.error.calls == [ pretend.call("Error executing check ExampleHookedCheck: Scan failed") @@ -108,9 +147,8 @@ def test_invalid_check_name(self, db_request, monkeypatch): ) def test_run(self, db_session, num_objects, num_runs, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) - files = [] for i in range(num_objects): - files.append(FileFactory.create()) + FileFactory.create() MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Enabled @@ -133,15 +171,14 @@ def test_run(self, db_session, num_objects, num_runs, monkeypatch): pretend.call("Running backfill on %d Files." % num_runs) ] - assert enqueue_recorder.delay.calls == [ - pretend.call("ExampleHookedCheck", files[i].id) for i in range(num_runs) - ] + assert len(enqueue_recorder.delay.calls) == num_runs class TestSyncChecks: def test_no_updates(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) monkeypatch.setattr(tasks.checks.ExampleScheduledCheck, "version", 2) + MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 2788f51519bf..2b8ca93a3541 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -140,8 +140,8 @@ def includeme(config): domain=warehouse, ) config.add_route( - "admin.checks.run_backfill", - "/admin/checks/{check_name}/run_backfill", + "admin.checks.run_evaluation", + "/admin/checks/{check_name}/run_evaluation", domain=warehouse, ) config.add_route("admin.verdicts.list", "/admin/verdicts/", domain=warehouse) diff --git a/warehouse/admin/templates/admin/malware/checks/detail.html b/warehouse/admin/templates/admin/malware/checks/detail.html index 77914488bb51..c52a5f354530 100644 --- a/warehouse/admin/templates/admin/malware/checks/detail.html +++ b/warehouse/admin/templates/admin/malware/checks/detail.html @@ -30,12 +30,22 @@

    Revision History

    Version State + {% if check.check_type.value == "event_hook" %} + Hooked Object + {% else %} + Schedule + {% endif %} Created {% for c in checks %} {{ c.version }} {{ c.state.value }} + {% if check.check_type.value == "event_hook" %} + {{ c.hooked_object.value }} + {% else %} +
    {{ c.schedule }}
    + {% endif %} {{ c.created }} {% endfor %} @@ -69,10 +79,14 @@

    Change State

    Run Evaluation

    -
    +
    -

    Run this check against 10,000 {{ check.hooked_object.value }}s, selected at random. This is used to evaluate the efficacy of a check.

    + {% if check.check_type.value == "event_hook" %} +

    Run this check against {{ evaluation_run_size }} {{ check.hooked_object.value }}s, selected at random. This is used to evaluate the efficacy of a check.

    + {% else %} +

    Execute this check now.

    + {% endif %}
    diff --git a/warehouse/admin/templates/admin/malware/checks/index.html b/warehouse/admin/templates/admin/malware/checks/index.html index 5717849e2579..2601bfe0f62a 100644 --- a/warehouse/admin/templates/admin/malware/checks/index.html +++ b/warehouse/admin/templates/admin/malware/checks/index.html @@ -26,7 +26,7 @@ Check Name State - Revisions + Type Last Modified Description @@ -38,7 +38,7 @@ {{ check.state.value }} - {{ check.version }} + {{ check.check_type.value }} {{ check.created }} {{ check.short_description }} diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index b23465fd4a49..0cd1761c8456 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -14,8 +14,10 @@ from pyramid.view import view_config from sqlalchemy.orm.exc import NoResultFound -from warehouse.malware.models import MalwareCheck, MalwareCheckState -from warehouse.malware.tasks import backfill, remove_verdicts +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType +from warehouse.malware.tasks import backfill, remove_verdicts, run_check + +EVALUATION_RUN_SIZE = 10000 @view_config( @@ -52,36 +54,45 @@ def get_check(request): .all() ) - return {"check": check, "checks": all_checks, "states": MalwareCheckState} + return { + "check": check, + "checks": all_checks, + "states": MalwareCheckState, + "evaluation_run_size": EVALUATION_RUN_SIZE, + } @view_config( - route_name="admin.checks.run_backfill", + route_name="admin.checks.run_evaluation", permission="admin", request_method="POST", uses_session=True, require_methods=False, require_csrf=True, ) -def run_backfill(request): +def run_evaluation(request): check = get_check_by_name(request.db, request.matchdict["check_name"]) - num_objects = 10000 if check.state not in (MalwareCheckState.Enabled, MalwareCheckState.Evaluation): request.session.flash( - f"Check must be in 'enabled' or 'evaluation' state to run a backfill.", + f"Check must be in 'enabled' or 'evaluation' state to manually execute.", queue="error", ) return HTTPSeeOther( request.route_path("admin.checks.detail", check_name=check.name) ) - request.session.flash( - f"Running {check.name} on {num_objects} {check.hooked_object.value}s!", - queue="success", - ) + if check.check_type == MalwareCheckType.EventHook: + request.session.flash( + f"Running {check.name} on {EVALUATION_RUN_SIZE} {check.hooked_object.value}s\ +!", + queue="success", + ) + request.task(backfill).delay(check.name, EVALUATION_RUN_SIZE) - request.task(backfill).delay(check.name, num_objects) + else: + request.session.flash(f"Running {check.name} now!", queue="success") + request.task(run_check).delay(check.name, manually_triggered=True) return HTTPSeeOther( request.route_path("admin.checks.detail", check_name=check.name) diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py index f54a9e89b4f5..a1b76b82f2b5 100644 --- a/warehouse/malware/__init__.py +++ b/warehouse/malware/__init__.py @@ -10,9 +10,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect + +from celery.schedules import crontab + +import warehouse.malware.checks as checks + from warehouse import db from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.tasks import run_check @db.listens_for(db.Session, "after_flush") @@ -57,3 +64,12 @@ def includeme(config): config.register_service_factory( malware_check_class.create_service, IMalwareCheckService ) + + # Add scheduled tasks for every scheduled Malware Check. + all_checks = inspect.getmembers(checks, inspect.isclass) + for check_obj in all_checks: + check = check_obj[1] + if check.check_type == "scheduled": + config.add_periodic_task( + crontab(**check.schedule), run_check, args=(check_obj[0],) + ) diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py index 67810a367203..44f8230a1c2c 100644 --- a/warehouse/malware/checks/base.py +++ b/warehouse/malware/checks/base.py @@ -18,7 +18,7 @@ class MalwareCheckBase: def __init__(self, db): self.db = db self._name = self.__class__.__name__ - self._load_check_id() + self._load_check_fields() self._verdicts = [] @classmethod @@ -26,7 +26,7 @@ def prepare(cls, request, obj_id): """ Prepares some context for scanning the given object. """ - kwargs = {} + kwargs = {"obj_id": obj_id} model = getattr(models, cls.hooked_object) kwargs["obj"] = request.db.query(model).get(obj_id) @@ -60,9 +60,9 @@ def backfill(self, sample=1): backfill on the entire corpus. """ - def _load_check_id(self): - (self.id,) = ( - self.db.query(MalwareCheck.id) + def _load_check_fields(self): + self.id, self.state = ( + self.db.query(MalwareCheck.id, MalwareCheck.state) .filter(MalwareCheck.name == self._name) .filter( MalwareCheck.state.in_( diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index dc0274cbee98..1fbe46b7a638 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -12,6 +12,8 @@ import inspect +from sqlalchemy.orm.exc import NoResultFound + import warehouse.malware.checks as checks import warehouse.packaging.models as packaging_models @@ -21,11 +23,30 @@ @task(bind=True, ignore_result=True, acks_late=True, retry_backoff=True) -def run_check(task, request, check_name, obj_id): - check = getattr(checks, check_name)(request.db) +def run_check(task, request, check_name, obj_id=None, manually_triggered=False): try: + check = getattr(checks, check_name)(request.db) + except NoResultFound: + request.log.info("Check %s isn't active. Aborting." % check_name) + return + + # Don't run scheduled checks if they are in evaluation mode, unless manually + # triggered. + if check.state == MalwareCheckState.Evaluation and not manually_triggered: + request.log.info( + "%s is in the `evaluation` state and must be manually triggered to run." + % check_name + ) + return + + kwargs = {} + + # Hooked checks require `obj_id`s. + if obj_id is not None: kwargs = check.prepare(request, obj_id) - check.run(obj_id=obj_id, **kwargs) + + try: + check.run(**kwargs) except Exception as exc: request.log.error("Error executing check %s: %s" % (check_name, str(exc))) raise task.retry(exc=exc) @@ -43,7 +64,7 @@ def backfill(task, request, check_name, num_objects): request.log.info("Running backfill on %d %ss." % (num_objects, check.hooked_object)) for (elem_id,) in query: - request.task(run_check).delay(check_name, elem_id) + request.task(run_check).delay(check_name, elem_id, manually_triggered=True) @task(bind=True, ignore_result=True, acks_late=True) From 6fe43a312e6056b1105dd323ca9d14412d9bf3f7 Mon Sep 17 00:00:00 2001 From: Cristina Date: Mon, 3 Feb 2020 08:27:51 -0800 Subject: [PATCH 13/16] Add verdicts view filtering capabilities #6062. (#7322) * Add verdicts view filtering capabilities #6062. * Code review changes. - Refactor tests to be parametrized. - Pass `_query` to `route_path` in template. - Remove `is None` from filter query, it adds nothing. --- tests/unit/admin/views/test_verdicts.py | 163 ++++++++++++++++-- warehouse/admin/templates/admin/base.html | 2 +- .../admin/malware/verdicts/index.html | 43 ++++- warehouse/admin/views/verdicts.py | 67 +++++-- 4 files changed, 249 insertions(+), 26 deletions(-) diff --git a/tests/unit/admin/views/test_verdicts.py b/tests/unit/admin/views/test_verdicts.py index 7d28820ca9cf..1492ce853f6e 100644 --- a/tests/unit/admin/views/test_verdicts.py +++ b/tests/unit/admin/views/test_verdicts.py @@ -14,37 +14,176 @@ from random import randint -import pretend import pytest from pyramid.httpexceptions import HTTPBadRequest, HTTPNotFound from warehouse.admin.views import verdicts as views +from warehouse.malware.models import VerdictClassification, VerdictConfidence -from ....common.db.malware import MalwareVerdictFactory +from ....common.db.malware import MalwareCheckFactory, MalwareVerdictFactory class TestListVerdicts: def test_none(self, db_request): - assert views.get_verdicts(db_request) == {"verdicts": []} + assert views.get_verdicts(db_request) == { + "verdicts": [], + "check_names": set(), + "classifications": set(["threat", "indeterminate", "benign"]), + "confidences": set(["low", "medium", "high"]), + } def test_some(self, db_request): - verdicts = [MalwareVerdictFactory.create() for _ in range(10)] + check = MalwareCheckFactory.create() + verdicts = [MalwareVerdictFactory.create(check=check) for _ in range(10)] - assert views.get_verdicts(db_request) == {"verdicts": verdicts} + assert views.get_verdicts(db_request) == { + "verdicts": verdicts, + "check_names": set([check.name]), + "classifications": set(["threat", "indeterminate", "benign"]), + "confidences": set(["low", "medium", "high"]), + } def test_some_with_multipage(self, db_request): - verdicts = [MalwareVerdictFactory.create() for _ in range(60)] + check1 = MalwareCheckFactory.create() + check2 = MalwareCheckFactory.create() + verdicts = [MalwareVerdictFactory.create(check=check2) for _ in range(60)] db_request.GET["page"] = "2" - assert views.get_verdicts(db_request) == {"verdicts": verdicts[25:50]} - - def test_with_invalid_page(self): - request = pretend.stub(params={"page": "not an integer"}) - + assert views.get_verdicts(db_request) == { + "verdicts": verdicts[25:50], + "check_names": set([check1.name, check2.name]), + "classifications": set(["threat", "indeterminate", "benign"]), + "confidences": set(["low", "medium", "high"]), + } + + @pytest.mark.parametrize( + "check_name", ["check0", "check1", ""], + ) + def test_check_name_filter(self, db_request, check_name): + result_verdicts, all_verdicts = [], [] + for i in range(3): + check = MalwareCheckFactory.create(name="check%d" % i) + verdicts = [MalwareVerdictFactory.create(check=check) for _ in range(5)] + all_verdicts.extend(verdicts) + if check.name == check_name: + result_verdicts = verdicts + + # Emptry string + if not result_verdicts: + result_verdicts = all_verdicts + + response = { + "verdicts": result_verdicts, + "check_names": set(["check0", "check1", "check2"]), + "classifications": set(["threat", "indeterminate", "benign"]), + "confidences": set(["low", "medium", "high"]), + } + + db_request.GET["check_name"] = check_name + assert views.get_verdicts(db_request) == response + + @pytest.mark.parametrize( + "classification", ["benign", "indeterminate", "threat", ""], + ) + def test_classification_filter(self, db_request, classification): + check1 = MalwareCheckFactory.create() + result_verdicts, all_verdicts = [], [] + for c in VerdictClassification: + verdicts = [ + MalwareVerdictFactory.create(check=check1, classification=c) + for _ in range(5) + ] + all_verdicts.extend(verdicts) + if c.value == classification: + result_verdicts = verdicts + + # Emptry string + if not result_verdicts: + result_verdicts = all_verdicts + + db_request.GET["classification"] = classification + response = { + "verdicts": result_verdicts, + "check_names": set([check1.name]), + "classifications": set(["threat", "indeterminate", "benign"]), + "confidences": set(["low", "medium", "high"]), + } + assert views.get_verdicts(db_request) == response + + @pytest.mark.parametrize( + "confidence", ["low", "medium", "high", ""], + ) + def test_confidence_filter(self, db_request, confidence): + check1 = MalwareCheckFactory.create() + result_verdicts, all_verdicts = [], [] + for c in VerdictConfidence: + verdicts = [ + MalwareVerdictFactory.create(check=check1, confidence=c) + for _ in range(5) + ] + all_verdicts.extend(verdicts) + if c.value == confidence: + result_verdicts = verdicts + + # Emptry string + if not result_verdicts: + result_verdicts = all_verdicts + + response = { + "verdicts": result_verdicts, + "check_names": set([check1.name]), + "classifications": set(["threat", "indeterminate", "benign"]), + "confidences": set(["low", "medium", "high"]), + } + + db_request.GET["confidence"] = confidence + assert views.get_verdicts(db_request) == response + + @pytest.mark.parametrize( + "manually_reviewed", [1, 0], + ) + def test_manually_reviewed_filter(self, db_request, manually_reviewed): + check1 = MalwareCheckFactory.create() + result_verdicts = [ + MalwareVerdictFactory.create( + check=check1, manually_reviewed=bool(manually_reviewed) + ) + for _ in range(5) + ] + + # Create other verdicts to ensure filter works properly + for _ in range(10): + MalwareVerdictFactory.create( + check=check1, manually_reviewed=not bool(manually_reviewed) + ) + + db_request.GET["manually_reviewed"] = str(manually_reviewed) + + response = { + "verdicts": result_verdicts, + "check_names": set([check1.name]), + "classifications": set(["threat", "indeterminate", "benign"]), + "confidences": set(["low", "medium", "high"]), + } + + assert views.get_verdicts(db_request) == response + + @pytest.mark.parametrize( + "invalid_param", + [ + ("page", "invalid"), + ("check_name", "NotACheck"), + ("confidence", "NotAConfidence"), + ("classification", "NotAClassification"), + ("manually_reviewed", "False"), + ], + ) + def test_errors(self, db_request, invalid_param): + db_request.GET[invalid_param[0]] = invalid_param[1] with pytest.raises(HTTPBadRequest): - views.get_verdicts(request) + views.get_verdicts(db_request) class TestGetVerdict: diff --git a/warehouse/admin/templates/admin/base.html b/warehouse/admin/templates/admin/base.html index fffa7ccfcec5..18be097d6576 100644 --- a/warehouse/admin/templates/admin/base.html +++ b/warehouse/admin/templates/admin/base.html @@ -131,7 +131,7 @@
  • - + Verdicts
  • diff --git a/warehouse/admin/templates/admin/malware/verdicts/index.html b/warehouse/admin/templates/admin/malware/verdicts/index.html index d6ab7ef6b028..443dd295aa15 100644 --- a/warehouse/admin/templates/admin/malware/verdicts/index.html +++ b/warehouse/admin/templates/admin/malware/verdicts/index.html @@ -17,12 +17,53 @@ {% block title %}Malware Verdicts{% endblock %} +{% set check_name = request.params.get("check_name") %} +{% set classification = request.params.get("classification") %} +{% set confidence = request.params.get("confidence") %} +{% set manually_reviewed = request.params.get("manually_reviewed") %} + {% block breadcrumb %}
  • Verdicts
  • {% endblock %} {% block content %}
    +
    + +
    + +
    +
    + +
    +
    + +
    +
    + +
    + + +
    @@ -30,7 +71,7 @@ - + {% for verdict in verdicts %} diff --git a/warehouse/admin/views/verdicts.py b/warehouse/admin/views/verdicts.py index bd9c2eae68ca..ab8dec6dd514 100644 --- a/warehouse/admin/views/verdicts.py +++ b/warehouse/admin/views/verdicts.py @@ -14,7 +14,12 @@ from pyramid.httpexceptions import HTTPBadRequest, HTTPNotFound from pyramid.view import view_config -from warehouse.malware.models import MalwareVerdict +from warehouse.malware.models import ( + MalwareCheck, + MalwareVerdict, + VerdictClassification, + VerdictConfidence, +) from warehouse.utils.paginate import paginate_url_factory @@ -26,23 +31,23 @@ uses_session=True, ) def get_verdicts(request): - try: - page_num = int(request.params.get("page", 1)) - except ValueError: - raise HTTPBadRequest("'page' must be an integer.") from None - - verdicts_query = request.db.query(MalwareVerdict).order_by( - MalwareVerdict.run_date.desc() + result = {} + result["check_names"] = set( + [name for (name,) in request.db.query(MalwareCheck.name)] ) + result["classifications"] = set([c.value for c in VerdictClassification]) + result["confidences"] = set([c.value for c in VerdictConfidence]) + + validate_fields(request, result) - verdicts = SQLAlchemyORMPage( - verdicts_query, - page=page_num, + result["verdicts"] = SQLAlchemyORMPage( + generate_query(request.db, request.params), + page=int(request.params.get("page", 1)), items_per_page=25, url_maker=paginate_url_factory(request), ) - return {"verdicts": verdicts} + return result @view_config( @@ -59,3 +64,41 @@ def get_verdict(request): return {"verdict": verdict} raise HTTPNotFound + + +def validate_fields(request, validators): + try: + int(request.params.get("page", 1)) + except ValueError: + raise HTTPBadRequest("'page' must be an integer.") from None + + validators = {**validators, **{"manually_revieweds": set(["0", "1"])}} + + for key, possible_values in validators.items(): + # Remove the trailing 's' + value = request.params.get(key[:-1]) + additional_values = set([None, ""]) + if value not in possible_values | additional_values: + raise HTTPBadRequest( + "Invalid value for '%s': %s." % (key[:-1], value) + ) from None + + +def generate_query(db, params): + """ + Returns an SQLAlchemy query wth request params applied as filters. + """ + query = db.query(MalwareVerdict) + if params.get("check_name"): + query = query.join(MalwareCheck) + query = query.filter(MalwareCheck.name == params["check_name"]) + if params.get("confidence"): + query = query.filter(MalwareVerdict.confidence == params["confidence"]) + if params.get("classification"): + query = query.filter(MalwareVerdict.classification == params["classification"]) + if params.get("manually_reviewed"): + query = query.filter( + MalwareVerdict.manually_reviewed == bool(int(params["manually_reviewed"])) + ) + + return query.order_by(MalwareVerdict.run_date.desc()) From 1d3791bc76824e7edbfb13c888ff583cf156479f Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 6 Feb 2020 08:54:44 -0800 Subject: [PATCH 14/16] Add verdict administrator review. Fixes #6062. (#7339) * Add verdict administrator review. Fixes #6062. - Add new `admin.verdicts.review` endpoint - Change layout of verdict list and detail view and add forms - Change sort order of the MalwareChecks, and update the tests * Code review changes. - Rename MalwareVerdict field `administrator_verdict` to `reviewer_verdict`. - Change verdict review permission from `admin` to `moderator`. --- tests/common/db/malware.py | 2 +- tests/unit/admin/test_routes.py | 5 ++ tests/unit/admin/views/test_checks.py | 10 +++- tests/unit/admin/views/test_verdicts.py | 48 ++++++++++++++++++- warehouse/admin/routes.py | 3 ++ .../admin/malware/verdicts/detail.html | 39 +++++++++------ .../admin/malware/verdicts/index.html | 18 +++++-- warehouse/admin/views/checks.py | 2 + warehouse/admin/views/verdicts.py | 36 +++++++++++++- warehouse/malware/models.py | 2 +- ...1ff3d24c22_add_malware_detection_tables.py | 2 +- 11 files changed, 142 insertions(+), 25 deletions(-) diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py index 4e41a0c23865..32b57576862c 100644 --- a/tests/common/db/malware.py +++ b/tests/common/db/malware.py @@ -55,7 +55,7 @@ class Meta: release = None project = None manually_reviewed = True - administrator_verdict = factory.fuzzy.FuzzyChoice(list(VerdictClassification)) + reviewer_verdict = factory.fuzzy.FuzzyChoice(list(VerdictClassification)) classification = factory.fuzzy.FuzzyChoice(list(VerdictClassification)) confidence = factory.fuzzy.FuzzyChoice(list(VerdictConfidence)) message = factory.fuzzy.FuzzyText(length=80) diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 451b6ad5a9e0..c03b772e45a7 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -141,4 +141,9 @@ def test_includeme(): pretend.call( "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse ), + pretend.call( + "admin.verdicts.review", + "/admin/verdicts/{verdict_id}/review", + domain=warehouse, + ), ] diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 0c3baeb03369..3edf31232e91 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -28,7 +28,10 @@ def test_get_checks_none(self, db_request): def test_get_checks(self, db_request): checks = [MalwareCheckFactory.create() for _ in range(10)] - assert views.get_checks(db_request) == {"checks": checks} + result = views.get_checks(db_request)["checks"] + assert len(result) == len(checks) + for r in result: + assert r in checks def test_get_checks_different_versions(self, db_request): checks = [MalwareCheckFactory.create() for _ in range(5)] @@ -36,7 +39,10 @@ def test_get_checks_different_versions(self, db_request): MalwareCheckFactory.create(name="MyCheck", version=i) for i in range(1, 6) ] checks.append(checks_same[-1]) - assert views.get_checks(db_request) == {"checks": checks} + result = views.get_checks(db_request)["checks"] + assert len(result) == len(checks) + for r in result: + assert r in checks class TestGetCheck: diff --git a/tests/unit/admin/views/test_verdicts.py b/tests/unit/admin/views/test_verdicts.py index 1492ce853f6e..51ee3587e496 100644 --- a/tests/unit/admin/views/test_verdicts.py +++ b/tests/unit/admin/views/test_verdicts.py @@ -14,6 +14,7 @@ from random import randint +import pretend import pytest from pyramid.httpexceptions import HTTPBadRequest, HTTPNotFound @@ -193,10 +194,55 @@ def test_found(self, db_request): lookup_id = verdicts[index].id db_request.matchdict["verdict_id"] = lookup_id - assert views.get_verdict(db_request) == {"verdict": verdicts[index]} + assert views.get_verdict(db_request) == { + "verdict": verdicts[index], + "classifications": ["Benign", "Indeterminate", "Threat"], + } def test_not_found(self, db_request): db_request.matchdict["verdict_id"] = uuid.uuid4() with pytest.raises(HTTPNotFound): views.get_verdict(db_request) + + +class TestReviewVerdict: + @pytest.mark.parametrize( + "manually_reviewed, reviewer_verdict", + [ + (False, None), # unreviewed verdict + (True, VerdictClassification.Threat), # previously reviewed + ], + ) + def test_set_classification(self, db_request, manually_reviewed, reviewer_verdict): + verdict = MalwareVerdictFactory.create( + manually_reviewed=manually_reviewed, reviewer_verdict=reviewer_verdict, + ) + + db_request.matchdict["verdict_id"] = verdict.id + db_request.POST = {"classification": "Benign"} + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + + db_request.route_path = pretend.call_recorder( + lambda *a, **kw: "/admin/verdicts/%s/review" % verdict.id + ) + + views.review_verdict(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Verdict %s marked as reviewed." % verdict.id, queue="success") + ] + + assert verdict.manually_reviewed + assert verdict.reviewer_verdict == VerdictClassification.Benign + + @pytest.mark.parametrize("post_params", [{}, {"classification": "Nope"}]) + def test_errors(self, db_request, post_params): + verdict = MalwareVerdictFactory.create() + db_request.matchdict["verdict_id"] = verdict.id + db_request.POST = post_params + + with pytest.raises(HTTPBadRequest): + views.review_verdict(db_request) diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 2b8ca93a3541..adaaf393afee 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -148,3 +148,6 @@ def includeme(config): config.add_route( "admin.verdicts.detail", "/admin/verdicts/{verdict_id}", domain=warehouse ) + config.add_route( + "admin.verdicts.review", "/admin/verdicts/{verdict_id}/review", domain=warehouse + ) diff --git a/warehouse/admin/templates/admin/malware/verdicts/detail.html b/warehouse/admin/templates/admin/malware/verdicts/detail.html index 7702943e8692..b31abf3e7221 100644 --- a/warehouse/admin/templates/admin/malware/verdicts/detail.html +++ b/warehouse/admin/templates/admin/malware/verdicts/detail.html @@ -13,7 +13,7 @@ -#} {% extends "admin/base.html" %} -{% block title %}Verdict {{ verdict.id }}{% endblock %} +{% block title %}Verdict Details{% endblock %} {% block breadcrumb %}
  • Verdicts
  • @@ -44,24 +44,20 @@
    + {% if verdict.manually_reviewed %} - - - - - - + + + {% endif %} - - + + - {% if verdict.manually_reviewed %} - - + + - {% endif %} {% if verdict.full_report_link %} @@ -70,10 +66,25 @@ {% endif %} {% if verdict.details %} - + {% endif %} + + + +
    Check Classification ConfidenceDetailReview
    Object {% include 'object_link.html' %}
    Verdict Classification{{ verdict.classification.value }}
    Verdict Confidence{{ verdict.confidence.value }}Reviewer Verdict{{ verdict.reviewer_verdict.value }}
    Manually Reviewed{{ verdict.manually_reviewed }}Check Verdict{{ verdict.classification.value }}
    Administrator Verdict{{ verdict.administrator_verdict }}Confidence{{ verdict.confidence.value }}
    Full Report Link
    DetailsAdditional Details
    {{ verdict.details|tojson(indent=4) }}
    {{ "Change" if verdict.manually_reviewed else "Set"}} Verdict +
    + + + +
    +
    diff --git a/warehouse/admin/templates/admin/malware/verdicts/index.html b/warehouse/admin/templates/admin/malware/verdicts/index.html index 443dd295aa15..f37073ee1c36 100644 --- a/warehouse/admin/templates/admin/malware/verdicts/index.html +++ b/warehouse/admin/templates/admin/malware/verdicts/index.html @@ -67,6 +67,7 @@
    + @@ -75,6 +76,11 @@ {% for verdict in verdicts %} + {% else %} diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index 0cd1761c8456..6b203d88a20a 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -34,6 +34,8 @@ def get_checks(request): if not check.is_stale: active_checks.append(check) + active_checks.sort(key=lambda check: check.created, reverse=True) + return {"checks": active_checks} diff --git a/warehouse/admin/views/verdicts.py b/warehouse/admin/views/verdicts.py index ab8dec6dd514..e55bc1322b36 100644 --- a/warehouse/admin/views/verdicts.py +++ b/warehouse/admin/views/verdicts.py @@ -11,7 +11,7 @@ # limitations under the License. from paginate_sqlalchemy import SqlalchemyOrmPage as SQLAlchemyORMPage -from pyramid.httpexceptions import HTTPBadRequest, HTTPNotFound +from pyramid.httpexceptions import HTTPBadRequest, HTTPNotFound, HTTPSeeOther from pyramid.view import view_config from warehouse.malware.models import ( @@ -61,11 +61,43 @@ def get_verdict(request): verdict = request.db.query(MalwareVerdict).get(request.matchdict["verdict_id"]) if verdict: - return {"verdict": verdict} + return { + "verdict": verdict, + "classifications": list(VerdictClassification.__members__.keys()), + } raise HTTPNotFound +@view_config( + route_name="admin.verdicts.review", + permission="moderator", + request_method="POST", + uses_session=True, + require_methods=False, + require_csrf=True, +) +def review_verdict(request): + verdict = request.db.query(MalwareVerdict).get(request.matchdict["verdict_id"]) + + try: + classification = getattr(VerdictClassification, request.POST["classification"]) + except (KeyError, AttributeError): + raise HTTPBadRequest("Invalid verdict classification.") from None + + verdict.manually_reviewed = True + verdict.reviewer_verdict = classification + + request.session.flash( + "Verdict %s marked as reviewed." % verdict.id, queue="success" + ) + + # If no query params are provided (e.g. request originating from + # admins.verdicts.detail view), then route to the default list view + query = request.GET or {"classification": "threat", "manually_reviewed": "0"} + return HTTPSeeOther(request.route_path("admin.verdicts.list", _query=query)) + + def validate_fields(request, validators): try: int(request.params.get("page", 1)) diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index a1ab490fc45e..e4f56fcf851a 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -161,7 +161,7 @@ class MalwareVerdict(db.Model): message = Column(Text, nullable=True) details = Column(JSONB, nullable=True) manually_reviewed = Column(Boolean, nullable=False, server_default=sql.false()) - administrator_verdict = Column( + reviewer_verdict = Column( Enum(VerdictClassification, values_callable=lambda x: [e.value for e in x]), nullable=True, ) diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py index 622660fd042f..1eb2da246e58 100644 --- a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -91,7 +91,7 @@ def upgrade(): server_default=sa.text("false"), nullable=False, ), - sa.Column("administrator_verdict", VerdictClassifications, nullable=True,), + sa.Column("reviewer_verdict", VerdictClassifications, nullable=True,), sa.Column("full_report_link", sa.String(), nullable=True), sa.ForeignKeyConstraint( ["check_id"], ["malware_checks.id"], onupdate="CASCADE", ondelete="CASCADE" From c1370bfc35122046b082ddef02ee2282a04adc09 Mon Sep 17 00:00:00 2001 From: Cristina Date: Fri, 7 Feb 2020 13:27:30 -0800 Subject: [PATCH 15/16] Misc cleanup and TODOs on malware checks. (#7355) * Misc cleanup and TODOs on malware checks. - Change backfill function to invoke `IMalwareCheckService` interface - Add support for `kwargs to `IMalwareCheckService` interface - Rename variable from reserved word `file` to `release_file` - Add `FatalCheckException` for non-retryable exceptions - Replace `MALWARE_CHECK_BACKEND` in dev/environment * Make `IMalwareService` the entrypoint for `run_check` - Add `run_scheduled_check` task that invokes this interface. - Remove useless utility method - Move `FatalCheckException` into warehouse/malware/errors.py. --- dev/environment | 3 +- tests/common/checks/hooked.py | 3 +- tests/unit/admin/views/test_checks.py | 4 +- .../checks/setup_patterns/test_check.py | 5 +- tests/unit/malware/test_init.py | 17 ++-- tests/unit/malware/test_services.py | 40 +++++++-- tests/unit/malware/test_tasks.py | 86 +++++++++++++++++-- warehouse/admin/views/checks.py | 4 +- warehouse/malware/__init__.py | 11 +-- .../malware/checks/setup_patterns/check.py | 19 ++-- warehouse/malware/errors.py | 15 ++++ warehouse/malware/interfaces.py | 2 +- warehouse/malware/services.py | 16 ++-- warehouse/malware/tasks.py | 19 +++- warehouse/malware/utils.py | 14 +-- 15 files changed, 190 insertions(+), 68 deletions(-) create mode 100644 warehouse/malware/errors.py diff --git a/dev/environment b/dev/environment index ec7eeae6d2f3..5d9fe6cc5af6 100644 --- a/dev/environment +++ b/dev/environment @@ -29,8 +29,7 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService -#TODO: change this to PrinterMalwareCheckService before deploy -MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService +MALWARE_CHECK_BACKEND=warehouse.malware.services.PrinterMalwareCheckService METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog diff --git a/tests/common/checks/hooked.py b/tests/common/checks/hooked.py index 8a3e16a3cbf9..2aa72a1bb8ae 100644 --- a/tests/common/checks/hooked.py +++ b/tests/common/checks/hooked.py @@ -11,6 +11,7 @@ # limitations under the License. from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.errors import FatalCheckException from warehouse.malware.models import VerdictClassification, VerdictConfidence @@ -29,7 +30,7 @@ def __init__(self, db): def scan(self, **kwargs): file_id = kwargs.get("obj_id") if file_id is None: - return + raise FatalCheckException("Missing required kwarg `obj_id`") self.add_verdict( file_id=file_id, diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 3edf31232e91..8aafc0a68319 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -17,7 +17,7 @@ from warehouse.admin.views import checks as views from warehouse.malware.models import MalwareCheckState, MalwareCheckType -from warehouse.malware.tasks import backfill, run_check +from warehouse.malware.tasks import backfill, run_scheduled_check from ....common.db.malware import MalwareCheckFactory @@ -208,7 +208,7 @@ def test_success(self, db_request, check_type): assert db_request.session.flash.calls == [ pretend.call("Running %s now!" % check.name, queue="success",) ] - assert db_request.task.calls == [pretend.call(run_check)] + assert db_request.task.calls == [pretend.call(run_scheduled_check)] assert backfill_recorder.delay.calls == [ pretend.call(check.name, manually_triggered=True) ] diff --git a/tests/unit/malware/checks/setup_patterns/test_check.py b/tests/unit/malware/checks/setup_patterns/test_check.py index 0dbd5c19d06f..c25556cf988d 100644 --- a/tests/unit/malware/checks/setup_patterns/test_check.py +++ b/tests/unit/malware/checks/setup_patterns/test_check.py @@ -43,9 +43,8 @@ def test_scan_missing_kwargs(db_session, obj, file_url): name="SetupPatternCheck", state=MalwareCheckState.Enabled ) check = c.SetupPatternCheck(db_session) - check.scan(obj=obj, file_url=file_url) - - assert check._verdicts == [] + with pytest.raises(c.FatalCheckException): + check.scan(obj=obj, file_url=file_url) def test_scan_non_sdist(db_session): diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py index 32628118398d..fc642a23cd41 100644 --- a/tests/unit/malware/test_init.py +++ b/tests/unit/malware/test_init.py @@ -17,9 +17,8 @@ from celery.schedules import crontab from warehouse import malware -from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService -from warehouse.malware.tasks import run_check +from warehouse.malware.tasks import run_scheduled_check from ...common import checks as test_checks from ...common.db.accounts import UserFactory @@ -30,7 +29,7 @@ def test_determine_malware_checks_no_checks(monkeypatch, db_request): def get_enabled_hooked_checks(session): return defaultdict(list) - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -49,7 +48,7 @@ def get_enabled_hooked_checks(session): result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -68,7 +67,7 @@ def get_enabled_hooked_checks(session): result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) user = UserFactory.create() @@ -85,7 +84,7 @@ def get_enabled_hooked_checks(session): result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -105,7 +104,7 @@ def get_enabled_hooked_checks(session): result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -179,6 +178,8 @@ def test_includeme(monkeypatch): assert config.add_periodic_task.calls == [ pretend.call( - crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",) + crontab(minute="0", hour="*/8"), + run_scheduled_check, + args=("ExampleScheduledCheck",), ) ] diff --git a/tests/unit/malware/test_services.py b/tests/unit/malware/test_services.py index 7a9cb636f720..3348bb894d81 100644 --- a/tests/unit/malware/test_services.py +++ b/tests/unit/malware/test_services.py @@ -11,6 +11,7 @@ # limitations under the License. import pretend +import pytest from zope.interface.verify import verifyClass @@ -31,13 +32,14 @@ def test_create_service(self): service = PrinterMalwareCheckService.create_service(None, request) assert service.executor == print - def test_run_checks(self, capfd): + @pytest.mark.parametrize(("kwargs"), [{}, {"manually_triggered": True}]) + def test_run_checks(self, capfd, kwargs): request = pretend.stub() service = PrinterMalwareCheckService.create_service(None, request) checks = ["one", "two", "three"] - service.run_checks(checks) + service.run_checks(checks, **kwargs) out, err = capfd.readouterr() - assert out == "one\ntwo\nthree\n" + assert out == "".join(["%s %s\n" % (check, kwargs) for check in checks]) class TestDatabaseMalwareService: @@ -50,12 +52,36 @@ def test_create_service(self, db_request): service = DatabaseMalwareCheckService.create_service(None, db_request) assert service.executor == db_request.task(run_check).delay - def test_run_checks(self, db_request): - _delay = pretend.call_recorder(lambda *args: None) + def test_run_hooked_check(self, db_request): + _delay = pretend.call_recorder(lambda *args, **kwargs: None) + db_request.task = lambda x: pretend.stub(delay=_delay) + service = DatabaseMalwareCheckService.create_service(None, db_request) + checks = [ + "MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187", + "AnotherCheck:44f57b0e-c5b0-47c5-8713-341cf392efe2", + "FinalCheck:e8518a15-8f01-430e-8f5b-87644007c9c0", + ] + service.run_checks(checks) + assert _delay.calls == [ + pretend.call("MyTestCheck", obj_id="ba70267f-fabf-496f-9ac2-d237a983b187"), + pretend.call("AnotherCheck", obj_id="44f57b0e-c5b0-47c5-8713-341cf392efe2"), + pretend.call("FinalCheck", obj_id="e8518a15-8f01-430e-8f5b-87644007c9c0"), + ] + + def test_run_scheduled_check(self, db_request): + _delay = pretend.call_recorder(lambda *args, **kwargs: None) db_request.task = lambda x: pretend.stub(delay=_delay) service = DatabaseMalwareCheckService.create_service(None, db_request) - checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"] + checks = ["MyTestScheduledCheck"] service.run_checks(checks) + assert _delay.calls == [pretend.call("MyTestScheduledCheck")] + + def test_run_triggered_check(self, db_request): + _delay = pretend.call_recorder(lambda *args, **kwargs: None) + db_request.task = lambda x: pretend.stub(delay=_delay) + service = DatabaseMalwareCheckService.create_service(None, db_request) + checks = ["MyTriggeredCheck"] + service.run_checks(checks, manually_triggered=True) assert _delay.calls == [ - pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187") + pretend.call("MyTriggeredCheck", manually_triggered=True) ] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 6a333c31b018..ec8f5e8dac48 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -16,6 +16,7 @@ from warehouse.malware import tasks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict +from warehouse.malware.services import PrinterMalwareCheckService from ...common import checks as test_checks from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory @@ -101,6 +102,28 @@ def test_missing_check(self, db_request, monkeypatch): task, db_request, "DoesNotExistCheck", ) + def test_missing_obj_id(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + task = pretend.stub() + + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.Enabled + ) + task = pretend.stub() + + request = pretend.stub( + db=db_session, + log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)), + ) + + tasks.run_check(task, request, "ExampleHookedCheck") + + assert request.log.error.calls == [ + pretend.call( + "Fatal exception: ExampleHookedCheck: Missing required kwarg `obj_id`" + ) + ] + def test_retry(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) exc = Exception("Scan failed") @@ -135,6 +158,37 @@ def scan(self, **kwargs): assert task.retry.calls == [pretend.call(exc=exc)] +class TestRunScheduledCheck: + def test_invalid_check_name(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + task = pretend.stub() + with pytest.raises(AttributeError): + tasks.run_scheduled_check(task, db_request, "DoesNotExist") + + def test_run_check(self, db_session, capfd, monkeypatch): + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.Enabled + ) + + request = pretend.stub( + db=db_session, + find_service_factory=pretend.call_recorder( + lambda interface: PrinterMalwareCheckService.create_service + ), + ) + + task = pretend.stub() + + tasks.run_scheduled_check(task, request, "ExampleScheduledCheck") + + assert request.find_service_factory.calls == [ + pretend.call(tasks.IMalwareCheckService) + ] + + out, err = capfd.readouterr() + assert out == "ExampleScheduledCheck {'manually_triggered': False}\n" + + class TestBackfill: def test_invalid_check_name(self, db_request, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) @@ -145,33 +199,47 @@ def test_invalid_check_name(self, db_request, monkeypatch): @pytest.mark.parametrize( ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)] ) - def test_run(self, db_session, num_objects, num_runs, monkeypatch): + def test_run(self, db_session, capfd, num_objects, num_runs, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) + + ids = [] for i in range(num_objects): - FileFactory.create() + ids.append(FileFactory.create().id) MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) - enqueue_recorder = pretend.stub( - delay=pretend.call_recorder(lambda *a, **kw: None) - ) - task = pretend.call_recorder(lambda *args, **kwargs: enqueue_recorder) - request = pretend.stub( db=db_session, log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), - task=task, + find_service_factory=pretend.call_recorder( + lambda interface: PrinterMalwareCheckService.create_service + ), ) + task = pretend.stub() + tasks.backfill(task, request, "ExampleHookedCheck", num_runs) assert request.log.info.calls == [ pretend.call("Running backfill on %d Files." % num_runs) ] - assert len(enqueue_recorder.delay.calls) == num_runs + assert request.find_service_factory.calls == [ + pretend.call(tasks.IMalwareCheckService) + ] + + out, err = capfd.readouterr() + num_output_lines = 0 + for file_id in ids: + logged_output = "ExampleHookedCheck:%s %s\n" % ( + file_id, + {"manually_triggered": True}, + ) + num_output_lines += 1 if logged_output in out else 0 + + assert num_output_lines == num_runs class TestSyncChecks: diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index 6b203d88a20a..817c1fd66b7a 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -15,7 +15,7 @@ from sqlalchemy.orm.exc import NoResultFound from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType -from warehouse.malware.tasks import backfill, remove_verdicts, run_check +from warehouse.malware.tasks import backfill, remove_verdicts, run_scheduled_check EVALUATION_RUN_SIZE = 10000 @@ -94,7 +94,7 @@ def run_evaluation(request): else: request.session.flash(f"Running {check.name} now!", queue="success") - request.task(run_check).delay(check.name, manually_triggered=True) + request.task(run_scheduled_check).delay(check.name, manually_triggered=True) return HTTPSeeOther( request.route_path("admin.checks.detail", check_name=check.name) diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py index a1b76b82f2b5..2336c9f11dce 100644 --- a/warehouse/malware/__init__.py +++ b/warehouse/malware/__init__.py @@ -17,9 +17,10 @@ import warehouse.malware.checks as checks from warehouse import db -from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService -from warehouse.malware.tasks import run_check +from warehouse.malware.models import MalwareCheckObjectType +from warehouse.malware.tasks import run_scheduled_check +from warehouse.malware.utils import get_enabled_hooked_checks @db.listens_for(db.Session, "after_flush") @@ -31,13 +32,13 @@ def determine_malware_checks(config, session, flush_context): [ obj.__class__.__name__ for obj in session.new - if obj.__class__.__name__ in utils.valid_check_types() + if obj.__class__.__name__ in MalwareCheckObjectType.__members__ ] ): return malware_checks = session.info.setdefault("warehouse.malware.checks", set()) - enabled_checks = utils.get_enabled_hooked_checks(session) + enabled_checks = get_enabled_hooked_checks(session) for obj in session.new: for check_name in enabled_checks.get(obj.__class__.__name__, []): malware_checks.update([f"{check_name}:{obj.id}"]) @@ -71,5 +72,5 @@ def includeme(config): check = check_obj[1] if check.check_type == "scheduled": config.add_periodic_task( - crontab(**check.schedule), run_check, args=(check_obj[0],) + crontab(**check.schedule), run_scheduled_check, args=(check_obj[0],) ) diff --git a/warehouse/malware/checks/setup_patterns/check.py b/warehouse/malware/checks/setup_patterns/check.py index c2127bb292db..2a92a36ed9a7 100644 --- a/warehouse/malware/checks/setup_patterns/check.py +++ b/warehouse/malware/checks/setup_patterns/check.py @@ -18,6 +18,7 @@ from warehouse.malware.checks.base import MalwareCheckBase from warehouse.malware.checks.utils import extract_file_content, fetch_url_content +from warehouse.malware.errors import FatalCheckException from warehouse.malware.models import VerdictClassification, VerdictConfidence @@ -45,14 +46,14 @@ def _load_yara_rules(self): return yara.compile(filepath=self._yara_rule_file) def scan(self, **kwargs): - file = kwargs.get("obj") + release_file = kwargs.get("obj") file_url = kwargs.get("file_url") - if file is None or file_url is None: - # TODO: Maybe raise here, since the absence of these - # arguments is a use/user error. - return + if release_file is None or file_url is None: + raise FatalCheckException( + "Release file or file url is None, indicating user error." + ) - if file.packagetype != "sdist": + if release_file.packagetype != "sdist": # Per PEP 491: bdists do not contain setup.py. # This check only scans dists that contain setup.py, so # we have nothing to perform. @@ -62,7 +63,7 @@ def scan(self, **kwargs): setup_py_contents = extract_file_content(archive_stream, "setup.py") if setup_py_contents is None: self.add_verdict( - file_id=file.id, + file_id=release_file.id, classification=VerdictClassification.Indeterminate, confidence=VerdictConfidence.High, message="sdist does not contain a suitable setup.py for analysis", @@ -92,7 +93,7 @@ def scan(self, **kwargs): } self.add_verdict( - file_id=file.id, + file_id=release_file.id, classification=classification, confidence=confidence, message=message, @@ -101,7 +102,7 @@ def scan(self, **kwargs): else: # No matches? Report a low-confidence benign verdict. self.add_verdict( - file_id=file.id, + file_id=release_file.id, classification=VerdictClassification.Benign, confidence=VerdictConfidence.Low, message="No malicious patterns found in setup.py", diff --git a/warehouse/malware/errors.py b/warehouse/malware/errors.py new file mode 100644 index 000000000000..837c079cef18 --- /dev/null +++ b/warehouse/malware/errors.py @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class FatalCheckException(Exception): + pass diff --git a/warehouse/malware/interfaces.py b/warehouse/malware/interfaces.py index f179aa374d55..482907735f33 100644 --- a/warehouse/malware/interfaces.py +++ b/warehouse/malware/interfaces.py @@ -20,7 +20,7 @@ def create_service(context, request): created for. """ - def run_checks(checks): + def run_checks(checks, **kwargs): """ Run a given set of Checks """ diff --git a/warehouse/malware/services.py b/warehouse/malware/services.py index f2f454b964e2..5566250bbad8 100644 --- a/warehouse/malware/services.py +++ b/warehouse/malware/services.py @@ -25,9 +25,9 @@ def __init__(self, executor): def create_service(cls, context, request): return cls(print) - def run_checks(self, checks): + def run_checks(self, checks, **kwargs): for check in checks: - self.executor(check) + self.executor(check, kwargs) @implementer(IMalwareCheckService) @@ -39,7 +39,13 @@ def __init__(self, executor): def create_service(cls, context, request): return cls(request.task(run_check).delay) - def run_checks(self, checks): + def run_checks(self, checks, **kwargs): for check_info in checks: - check_name, obj_id = check_info.split(":") - self.executor(check_name, obj_id) + # Hooked checks + if ":" in check_info: + check_name, obj_id = check_info.split(":") + kwargs["obj_id"] = obj_id + # Scheduled checks + else: + check_name = check_info + self.executor(check_name, **kwargs) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index 1fbe46b7a638..ac27f1577b0b 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -17,6 +17,8 @@ import warehouse.malware.checks as checks import warehouse.packaging.models as packaging_models +from warehouse.malware.errors import FatalCheckException +from warehouse.malware.interfaces import IMalwareCheckService from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict from warehouse.malware.utils import get_check_fields from warehouse.tasks import task @@ -47,11 +49,21 @@ def run_check(task, request, check_name, obj_id=None, manually_triggered=False): try: check.run(**kwargs) + except FatalCheckException as exc: + request.log.error("Fatal exception: %s: %s" % (check_name, str(exc))) + return except Exception as exc: request.log.error("Error executing check %s: %s" % (check_name, str(exc))) raise task.retry(exc=exc) +@task(bind=True, ignore_result=True, acks_late=True) +def run_scheduled_check(task, request, check_name, manually_triggered=False): + malware_check_service = request.find_service_factory(IMalwareCheckService) + malware_check = malware_check_service(None, request) + malware_check.run_checks([check_name], manually_triggered=manually_triggered) + + @task(bind=True, ignore_result=True, acks_late=True) def backfill(task, request, check_name, num_objects): """ @@ -63,8 +75,13 @@ def backfill(task, request, check_name, num_objects): request.log.info("Running backfill on %d %ss." % (num_objects, check.hooked_object)) + runs = set() for (elem_id,) in query: - request.task(run_check).delay(check_name, elem_id, manually_triggered=True) + runs.update([f"{check_name}:{elem_id}"]) + + malware_check_service = request.find_service_factory(IMalwareCheckService) + malware_check = malware_check_service(None, request) + malware_check.run_checks(runs, manually_triggered=True) @task(bind=True, ignore_result=True, acks_late=True) diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py index 879d9e8a3a7b..95ebf4b796b5 100644 --- a/warehouse/malware/utils.py +++ b/warehouse/malware/utils.py @@ -10,21 +10,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools - from collections import defaultdict -from warehouse.malware.models import ( - MalwareCheck, - MalwareCheckObjectType, - MalwareCheckState, - MalwareCheckType, -) - - -@functools.lru_cache() -def valid_check_types(): - return set([t.value for t in MalwareCheckObjectType]) +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType def get_check_fields(check): From f614bad9dcec98d095c6c1efa9bdebb61e78a930 Mon Sep 17 00:00:00 2001 From: William Woodruff Date: Tue, 11 Feb 2020 14:48:08 -0500 Subject: [PATCH 16/16] malware/checks: PackageTurnover skeleton (#7321) * malware/checks: PackageTurnover skeleton * malware/checks: PackageTurnover: Add NOTE * malware/checks: PackageTurnoverCheck: more work * tests: blacken * malware/checks: More PackageTurnoverCheck work * malware/checks: Blacken * malware/checks: Blacken * package_turnover: Promote from indeterminate to threat * tests: Begin adding package_turnover tests * tests: Add remaining package_turnover tests * tests: Drop unused imports * warehouse: Drop (ww) from NOTE * checks/package_turnover: Drop NOTE --- .../checks/package_turnover/__init__.py | 11 ++ .../checks/package_turnover/test_check.py | 177 ++++++++++++++++++ tests/unit/malware/test_tasks.py | 8 +- warehouse/malware/checks/__init__.py | 1 + .../checks/package_turnover/__init__.py | 13 ++ .../malware/checks/package_turnover/check.py | 112 +++++++++++ 6 files changed, 316 insertions(+), 6 deletions(-) create mode 100644 tests/unit/malware/checks/package_turnover/__init__.py create mode 100644 tests/unit/malware/checks/package_turnover/test_check.py create mode 100644 warehouse/malware/checks/package_turnover/__init__.py create mode 100644 warehouse/malware/checks/package_turnover/check.py diff --git a/tests/unit/malware/checks/package_turnover/__init__.py b/tests/unit/malware/checks/package_turnover/__init__.py new file mode 100644 index 000000000000..164f68b09175 --- /dev/null +++ b/tests/unit/malware/checks/package_turnover/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/malware/checks/package_turnover/test_check.py b/tests/unit/malware/checks/package_turnover/test_check.py new file mode 100644 index 000000000000..4215b89d0d66 --- /dev/null +++ b/tests/unit/malware/checks/package_turnover/test_check.py @@ -0,0 +1,177 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pretend + +from warehouse.malware.checks.package_turnover import check as c +from warehouse.malware.models import ( + MalwareCheckState, + VerdictClassification, + VerdictConfidence, +) + +from .....common.db.accounts import UserFactory +from .....common.db.malware import MalwareCheckFactory +from .....common.db.packaging import ProjectFactory, ReleaseFactory + + +def test_initializes(db_session): + check_model = MalwareCheckFactory.create( + name="PackageTurnoverCheck", state=MalwareCheckState.Enabled, + ) + check = c.PackageTurnoverCheck(db_session) + + assert check.id == check_model.id + + +def test_user_posture_verdicts(db_session): + user = UserFactory.create() + project = pretend.stub(users=[user], id=pretend.stub()) + + MalwareCheckFactory.create( + name="PackageTurnoverCheck", state=MalwareCheckState.Enabled, + ) + check = c.PackageTurnoverCheck(db_session) + + user.record_event( + tag="account:two_factor:method_removed", ip_address="0.0.0.0", additional={} + ) + + check.user_posture_verdicts(project) + assert len(check._verdicts) == 1 + assert check._verdicts[0].check_id == check.id + assert check._verdicts[0].project_id == project.id + assert check._verdicts[0].classification == VerdictClassification.Threat + assert check._verdicts[0].confidence == VerdictConfidence.High + assert ( + check._verdicts[0].message + == "User with control over this package has disabled 2FA" + ) + + +def test_user_posture_verdicts_hasnt_removed_2fa(db_session): + user = UserFactory.create() + project = pretend.stub(users=[user], id=pretend.stub()) + + MalwareCheckFactory.create( + name="PackageTurnoverCheck", state=MalwareCheckState.Enabled, + ) + check = c.PackageTurnoverCheck(db_session) + + check.user_posture_verdicts(project) + assert len(check._verdicts) == 0 + + +def test_user_posture_verdicts_has_2fa(db_session): + user = UserFactory.create(totp_secret=b"fake secret") + project = pretend.stub(users=[user], id=pretend.stub()) + + MalwareCheckFactory.create( + name="PackageTurnoverCheck", state=MalwareCheckState.Enabled, + ) + check = c.PackageTurnoverCheck(db_session) + + user.record_event( + tag="account:two_factor:method_removed", ip_address="0.0.0.0", additional={} + ) + + check.user_posture_verdicts(project) + assert len(check._verdicts) == 0 + + +def test_user_turnover_verdicts(db_session): + user = UserFactory.create() + project = ProjectFactory.create(users=[user]) + + project.record_event( + tag="project:role:add", + ip_address="0.0.0.0", + additional={"target_user": user.username}, + ) + + MalwareCheckFactory.create( + name="PackageTurnoverCheck", state=MalwareCheckState.Enabled, + ) + check = c.PackageTurnoverCheck(db_session) + + check.user_turnover_verdicts(project) + assert len(check._verdicts) == 1 + assert check._verdicts[0].check_id == check.id + assert check._verdicts[0].project_id == project.id + assert check._verdicts[0].classification == VerdictClassification.Threat + assert check._verdicts[0].confidence == VerdictConfidence.High + assert ( + check._verdicts[0].message + == "Suspicious user turnover; all current maintainers are new" + ) + + +def test_user_turnover_verdicts_no_turnover(db_session): + user = UserFactory.create() + project = ProjectFactory.create(users=[user]) + + MalwareCheckFactory.create( + name="PackageTurnoverCheck", state=MalwareCheckState.Enabled, + ) + check = c.PackageTurnoverCheck(db_session) + + check.user_turnover_verdicts(project) + assert len(check._verdicts) == 0 + + +def test_scan(db_session, monkeypatch): + user = UserFactory.create() + project = ProjectFactory.create(users=[user]) + + for _ in range(3): + ReleaseFactory.create(project=project) + + MalwareCheckFactory.create( + name="PackageTurnoverCheck", state=MalwareCheckState.Enabled, + ) + check = c.PackageTurnoverCheck(db_session) + + monkeypatch.setattr( + check, "user_posture_verdicts", pretend.call_recorder(lambda project: None) + ) + monkeypatch.setattr( + check, "user_turnover_verdicts", pretend.call_recorder(lambda project: None) + ) + + check.scan() + + # Each verdict rendering method is only called once per project, + # thanks to deduplication. + assert check.user_posture_verdicts.calls == [pretend.call(project)] + assert check.user_turnover_verdicts.calls == [pretend.call(project)] + + +def test_scan_too_few_releases(db_session, monkeypatch): + user = UserFactory.create() + project = ProjectFactory.create(users=[user]) + ReleaseFactory.create(project=project) + + MalwareCheckFactory.create( + name="PackageTurnoverCheck", state=MalwareCheckState.Enabled, + ) + check = c.PackageTurnoverCheck(db_session) + + monkeypatch.setattr( + check, "user_posture_verdicts", pretend.call_recorder(lambda project: None) + ) + monkeypatch.setattr( + check, "user_turnover_verdicts", pretend.call_recorder(lambda project: None) + ) + + check.scan() + assert check.user_posture_verdicts.calls == [] + assert check.user_turnover_verdicts.calls == [] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index ec8f5e8dac48..f2e77afe25f5 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -85,9 +85,7 @@ def test_disabled_check(self, db_session, monkeypatch): file = FileFactory.create() - tasks.run_check( - task, request, "ExampleHookedCheck", obj_id=file.id, - ) + tasks.run_check(task, request, "ExampleHookedCheck", obj_id=file.id) assert request.log.info.calls == [ pretend.call("Check ExampleHookedCheck isn't active. Aborting.") @@ -98,9 +96,7 @@ def test_missing_check(self, db_request, monkeypatch): task = pretend.stub() with pytest.raises(AttributeError): - tasks.run_check( - task, db_request, "DoesNotExistCheck", - ) + tasks.run_check(task, db_request, "DoesNotExistCheck") def test_missing_obj_id(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) diff --git a/warehouse/malware/checks/__init__.py b/warehouse/malware/checks/__init__.py index fa0607f15913..ea686b348b5e 100644 --- a/warehouse/malware/checks/__init__.py +++ b/warehouse/malware/checks/__init__.py @@ -10,4 +10,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .package_turnover import PackageTurnoverCheck # noqa from .setup_patterns import SetupPatternCheck # noqa diff --git a/warehouse/malware/checks/package_turnover/__init__.py b/warehouse/malware/checks/package_turnover/__init__.py new file mode 100644 index 000000000000..e3d7d35259ee --- /dev/null +++ b/warehouse/malware/checks/package_turnover/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .check import PackageTurnoverCheck # noqa diff --git a/warehouse/malware/checks/package_turnover/check.py b/warehouse/malware/checks/package_turnover/check.py new file mode 100644 index 000000000000..91fdb513958a --- /dev/null +++ b/warehouse/malware/checks/package_turnover/check.py @@ -0,0 +1,112 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +from textwrap import dedent + +from warehouse.accounts.models import UserEvent +from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.models import ( + MalwareVerdict, + VerdictClassification, + VerdictConfidence, +) +from warehouse.packaging.models import ProjectEvent, Release + + +class PackageTurnoverCheck(MalwareCheckBase): + version = 1 + short_description = "A check for unusual changes in package ownership" + long_description = dedent( + """ + This check looks at recently uploaded releases and determines + whether their owners have recently changed or decreased the security + of their accounts (e.g., by disabling 2FA). + """ + ) + check_type = "scheduled" + schedule = {"minute": 0, "hour": 0} + + def __init__(self, db): + super().__init__(db) + self._scan_interval = datetime.utcnow() - timedelta(hours=24) + + def user_posture_verdicts(self, project): + for user in project.users: + has_removed_2fa_method = self.db.query( + self.db.query(UserEvent) + .filter(UserEvent.user_id == user.id) + .filter(UserEvent.time >= self._scan_interval) + .filter(UserEvent.tag == "account:two_factor:method_removed") + .exists() + ).scalar() + + if has_removed_2fa_method and not user.has_two_factor: + self.add_verdict( + project_id=project.id, + classification=VerdictClassification.Threat, + confidence=VerdictConfidence.High, + message="User with control over this package has disabled 2FA", + ) + + def user_turnover_verdicts(self, project): + # NOTE: This could probably be more involved to check for the case + # where someone adds themself, removes the real maintainers, pushes a malicious + # release, then reverts the ownership to the original maintainers and removes + # themself again. + recent_role_adds = ( + self.db.query(ProjectEvent.additional) + .filter(ProjectEvent.project_id == project.id) + .filter(ProjectEvent.time >= self._scan_interval) + .filter(ProjectEvent.tag == "project:role:add") + .all() + ) + + added_users = {role_add["target_user"] for role_add, in recent_role_adds} + current_users = {user.username for user in project.users} + + if added_users == current_users: + self.add_verdict( + project_id=project.id, + classification=VerdictClassification.Threat, + confidence=VerdictConfidence.High, + message="Suspicious user turnover; all current maintainers are new", + ) + + def scan(self, **kwargs): + prior_verdicts = ( + self.db.query(MalwareVerdict.release_id).filter( + MalwareVerdict.check_id == self.id + ) + ).subquery() + + releases = ( + self.db.query(Release) + .filter(Release.created >= self._scan_interval) + .filter(~Release.id.in_(prior_verdicts)) + .all() + ) + + visited_project_ids = set() + for release in releases: + # Skip projects for which this is the first release, + # since we need a baseline to compare against + if len(release.project.releases) < 2: + continue + + if release.project.id in visited_project_ids: + continue + + visited_project_ids.add(release.project.id) + + self.user_posture_verdicts(release.project) + self.user_turnover_verdicts(release.project)
    Investigate Object Check Classification
    + + Detail + + {% include 'object_link.html' %} @@ -104,9 +110,15 @@ - - Detail - + {% if verdict.manually_reviewed %} + Marked as {{ verdict.reviewer_verdict.value }} + {% else %} +
    + + + +
    + {% endif %}