Skip to content

Refactor testing logic #7098 #7257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/common/checks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .hooked import ExampleHookedCheck # noqa
from .scheduled import ExampleScheduledCheck # noqa
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
from warehouse.malware.models import VerdictClassification, VerdictConfidence


class ExampleCheck(MalwareCheckBase):
class ExampleHookedCheck(MalwareCheckBase):

version = 1
short_description = "An example hook-based check"
long_description = """The purpose of this check is to demonstrate the \
implementation of a hook-based check. This check will generate verdicts if enabled."""
long_description = "The purpose of this check is to test the \
implementation of a hook-based check. This check will generate verdicts if enabled."
check_type = "event_hook"
hooked_object = "File"

def __init__(self, db):
super().__init__(db)

def scan(self, file_id):
def scan(self, file_id=None):
self.add_verdict(
file_id=file_id,
classification=VerdictClassification.benign,
Expand Down
37 changes: 37 additions & 0 deletions tests/common/checks/scheduled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.models import VerdictClassification, VerdictConfidence
from warehouse.packaging.models import Project


class ExampleScheduledCheck(MalwareCheckBase):

version = 1
short_description = "An example scheduled check"
long_description = "The purpose of this check is to test the \
implementation of a scheduled check. This check will generate verdicts if enabled."
check_type = "scheduled"
schedule = {"minute": "0", "hour": "*/8"}

def __init__(self, db):
super().__init__(db)

def scan(self):
project = self.db.query(Project).first()
self.add_verdict(
project_id=project.id,
classification=VerdictClassification.benign,
confidence=VerdictConfidence.High,
message="Nothing to see here!",
)
1 change: 1 addition & 0 deletions tests/common/db/malware.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Meta:
long_description = factory.fuzzy.FuzzyText(length=300)
check_type = factory.fuzzy.FuzzyChoice(list(MalwareCheckType))
hooked_object = factory.fuzzy.FuzzyChoice(list(MalwareCheckObjectType))
schedule = {"minute": "*/10"}
state = factory.fuzzy.FuzzyChoice(list(MalwareCheckState))
created = factory.fuzzy.FuzzyNaiveDateTime(
datetime.datetime.utcnow() - datetime.timedelta(days=7)
Expand Down
19 changes: 14 additions & 5 deletions tests/unit/malware/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,35 @@

import inspect

import warehouse.malware.checks as checks
import pytest

import warehouse.malware.checks as prod_checks

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.utils import get_check_fields

from ...common import checks as test_checks


def test_checks_subclass_base():
checks_from_module = inspect.getmembers(checks, inspect.isclass)
prod_checks_from_module = inspect.getmembers(prod_checks, inspect.isclass)
test_checks_from_module = inspect.getmembers(test_checks, inspect.isclass)
all_checks = prod_checks_from_module + test_checks_from_module

subclasses_of_malware_base = {
cls.__name__: cls for cls in MalwareCheckBase.__subclasses__()
}

assert len(checks_from_module) == len(subclasses_of_malware_base)
assert len(all_checks) == len(subclasses_of_malware_base)

for check_name, check in checks_from_module:
for check_name, check in all_checks:
assert subclasses_of_malware_base[check_name] == check


def test_checks_fields():
@pytest.mark.parametrize(
("checks"), [prod_checks, test_checks],
)
def test_checks_fields(checks):
checks_from_module = inspect.getmembers(checks, inspect.isclass)

for check_name, check in checks_from_module:
Expand Down
25 changes: 14 additions & 11 deletions tests/unit/malware/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
from warehouse.malware import utils
from warehouse.malware.interfaces import IMalwareCheckService

from ...common import checks as test_checks
from ...common.db.accounts import UserFactory
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory


def test_determine_malware_checks_no_checks(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
return defaultdict(list)

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -39,13 +40,13 @@ def get_enabled_checks(session):


def test_determine_malware_checks_nothing_new(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -58,13 +59,13 @@ def get_enabled_checks(session):


def test_determine_malware_checks_unsupported_object(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

user = UserFactory.create()

Expand All @@ -75,13 +76,13 @@ def get_enabled_checks(session):


def test_determine_malware_checks_file_only(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -95,13 +96,13 @@ def get_enabled_checks(session):


def test_determine_malware_checks_file_and_release(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand Down Expand Up @@ -149,7 +150,9 @@ def test_enqueue_malware_checks_no_checks(app_config):
assert "warehouse.malware.checks" not in session.info


def test_includeme():
def test_includeme(monkeypatch):
monkeypatch.setattr(malware, "checks", test_checks)

malware_check_class = pretend.stub(
create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub())
)
Expand Down
Loading