Skip to content

Commit 69f489d

Browse files
authored
refactor: add type hints in version_scanner.py (#1581)
* part of #1539
1 parent ca08460 commit 69f489d

File tree

2 files changed

+58
-42
lines changed

2 files changed

+58
-42
lines changed

cve_bin_tool/util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class ProductInfo(NamedTuple):
6565
version: str
6666

6767

68+
class ScanInfo(NamedTuple):
69+
product_info: ProductInfo
70+
file_path: str
71+
72+
6873
class VersionInfo(NamedTuple):
6974
start_including: str
7075
start_excluding: str

cve_bin_tool/version_scanner.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
# Copyright (C) 2021 Intel Corporation
22
# SPDX-License-Identifier: GPL-3.0-or-later
3+
from __future__ import annotations
34

45
import json
56
import os
67
import subprocess
78
import sys
9+
from logging import Logger
810
from re import MULTILINE, compile, search
9-
from typing import List
11+
from typing import Iterator
1012

1113
import defusedxml.ElementTree as ET
1214

15+
from cve_bin_tool.checkers import Checker
1316
from cve_bin_tool.cvedb import CVEDB
1417
from cve_bin_tool.egg_updater import IS_DEVELOP, update_egg
1518
from cve_bin_tool.error_handler import ErrorMode
16-
from cve_bin_tool.extractor import Extractor
19+
from cve_bin_tool.extractor import Extractor, TempDirExtractorContext
1720
from cve_bin_tool.file import is_binary
1821
from cve_bin_tool.log import LOGGER
1922
from cve_bin_tool.strings import Strings
20-
from cve_bin_tool.util import DirWalk, ProductInfo, inpath
23+
from cve_bin_tool.util import DirWalk, ProductInfo, ScanInfo, inpath
2124

2225
if sys.version_info >= (3, 8):
2326
from importlib import metadata as importlib_metadata
@@ -36,12 +39,12 @@ class VersionScanner:
3639

3740
def __init__(
3841
self,
39-
should_extract=False,
40-
exclude_folders=[],
41-
checkers=None,
42-
logger=None,
43-
error_mode=ErrorMode.TruncTrace,
44-
score=0,
42+
should_extract: bool = False,
43+
exclude_folders: list[str] = [],
44+
checkers: dict[str, type[Checker]] | None = None,
45+
logger: Logger | None = None,
46+
error_mode: ErrorMode = ErrorMode.TruncTrace,
47+
score: int = 0,
4548
):
4649
self.logger = logger or LOGGER.getChild(self.__class__.__name__)
4750
# Update egg if installed in development mode
@@ -62,13 +65,13 @@ def __init__(
6265
)
6366
).walk
6467
self.should_extract = should_extract
65-
self.file_stack = []
68+
self.file_stack: list[str] = []
6669
self.error_mode = error_mode
6770
self.cve_db = CVEDB()
6871
# self.logger.info("Checkers loaded: %s" % (", ".join(self.checkers.keys())))
6972

7073
@classmethod
71-
def load_checkers(cls):
74+
def load_checkers(cls) -> dict[str, type[Checker]]:
7275
"""Loads CVE checkers"""
7376
checkers = dict(
7477
map(
@@ -79,12 +82,12 @@ def load_checkers(cls):
7982
return checkers
8083

8184
@classmethod
82-
def available_checkers(cls):
85+
def available_checkers(cls) -> list[str]:
8386
checkers = importlib_metadata.entry_points()[cls.CHECKER_ENTRYPOINT]
8487
checker_list = [item.name for item in checkers]
8588
return checker_list
8689

87-
def remove_skiplist(self, skips):
90+
def remove_skiplist(self, skips: list[str]) -> None:
8891
# Take out any checkers that are on the skip list
8992
# (string of comma-delimited checker names)
9093
skiplist = skips
@@ -95,20 +98,21 @@ def remove_skiplist(self, skips):
9598
else:
9699
self.logger.error(f"Checker {skipme} is not a valid checker name")
97100

98-
def print_checkers(self):
101+
def print_checkers(self) -> None:
99102
self.logger.info(f'Checkers: {", ".join(self.checkers.keys())}')
100103

101-
def number_of_checkers(self):
104+
def number_of_checkers(self) -> int:
102105
return len(self.checkers)
103106

104-
def is_executable(self, filename):
107+
def is_executable(self, filename: str) -> tuple[bool, str | None]:
105108
"""check if file is an ELF binary file"""
106109

107-
output = None
110+
output: str | None = None
108111
if inpath("file"):
109112
# use system file if available (for performance reasons)
110-
output = subprocess.check_output(["file", filename])
111-
output = output.decode(sys.stdout.encoding)
113+
output = subprocess.check_output(["file", filename]).decode(
114+
sys.stdout.encoding
115+
)
112116

113117
if "cannot open" in output:
114118
self.logger.warning(f"Unopenable file {filename} cannot be scanned")
@@ -133,7 +137,7 @@ def is_executable(self, filename):
133137

134138
return True, output
135139

136-
def parse_strings(self, filename):
140+
def parse_strings(self, filename: str) -> str:
137141
"""parse binary file's strings"""
138142

139143
if inpath("strings"):
@@ -145,7 +149,7 @@ def parse_strings(self, filename):
145149
lines = s.parse()
146150
return lines
147151

148-
def scan_file(self, filename):
152+
def scan_file(self, filename: str) -> Iterator[ScanInfo]:
149153
"""Scans a file to see if it contains any of the target libraries,
150154
and whether any of those contain CVEs"""
151155

@@ -185,7 +189,9 @@ def scan_file(self, filename):
185189

186190
yield from self.run_checkers(filename, lines)
187191

188-
def find_java_vendor(self, product, version):
192+
def find_java_vendor(
193+
self, product: str, version: str
194+
) -> tuple[ProductInfo, str] | tuple[None, None]:
189195
"""Find vendor for Java product"""
190196
vendor_package_pair = self.cve_db.get_vendor_product_pairs(product)
191197
# If no match, try alternative product name.
@@ -205,7 +211,7 @@ def find_java_vendor(self, product, version):
205211
return ProductInfo(vendor, product, version), file_path
206212
return None, None
207213

208-
def run_java_checker(self, filename: str) -> None:
214+
def run_java_checker(self, filename: str) -> Iterator[ScanInfo]:
209215
"""Process maven pom.xml file and extract product and dependency details"""
210216
tree = ET.parse(filename)
211217
# Find root element
@@ -231,7 +237,7 @@ def run_java_checker(self, filename: str) -> None:
231237
if product is not None and version is not None:
232238
product_info, file_path = self.find_java_vendor(product, version)
233239
if file_path is not None:
234-
yield product_info, file_path
240+
yield ScanInfo(product_info, file_path)
235241

236242
# Scan for any dependencies referenced in file
237243
dependencies = root.find(schema + "dependencies")
@@ -249,16 +255,16 @@ def run_java_checker(self, filename: str) -> None:
249255
product.text, version
250256
)
251257
if file_path is not None:
252-
yield product_info, file_path
258+
yield ScanInfo(product_info, file_path)
253259

254260
self.logger.debug(f"Done scanning file: {filename}")
255261

256-
def find_js_vendor(self, product: str, version: str) -> List[List[str]]:
262+
def find_js_vendor(self, product: str, version: str) -> list[ScanInfo] | None:
257263
"""Find vendor for Javascript product"""
258264
if version == "*":
259265
return None
260266
vendor_package_pair = self.cve_db.get_vendor_product_pairs(product)
261-
vendorlist: List[List[str]] = []
267+
vendorlist: list[ScanInfo] = []
262268
if vendor_package_pair != []:
263269
# To handle multiple vendors, return all combinations of product/vendor mappings
264270
for v in vendor_package_pair:
@@ -268,20 +274,21 @@ def find_js_vendor(self, product: str, version: str) -> List[List[str]]:
268274
if "^" in version:
269275
version = version[1:]
270276
self.logger.debug(f"{file_path} {product} {version} by {vendor}")
271-
vendorlist.append([ProductInfo(vendor, product, version), file_path])
277+
vendorlist.append(
278+
ScanInfo(ProductInfo(vendor, product, version), file_path)
279+
)
272280
return vendorlist if len(vendorlist) > 0 else None
273281
return None
274282

275-
def run_js_checker(self, filename: str) -> None:
283+
def run_js_checker(self, filename: str) -> Iterator[ScanInfo]:
276284
"""Process package-lock.json file and extract product and dependency details"""
277285
fh = open(filename)
278286
data = json.load(fh)
279287
product = data["name"]
280288
version = data["version"]
281289
vendor = self.find_js_vendor(product, version)
282290
if vendor is not None:
283-
for v in vendor:
284-
yield v[0], v[1] # product_info, file_path
291+
yield from vendor
285292
# Now process dependencies
286293
for i in data["dependencies"]:
287294
# To handle @actions/<product>: lines, extract product name from line
@@ -299,20 +306,20 @@ def run_js_checker(self, filename: str) -> None:
299306
version = data["dependencies"][i]
300307
vendor = self.find_js_vendor(product, version)
301308
if vendor is not None:
302-
for v in vendor:
303-
yield v[0], v[1] # product_info, file_path
309+
yield from vendor
304310
if "requires" in data["dependencies"][i]:
305311
for r in data["dependencies"][i]["requires"]:
306312
# To handle @actions/<product>: lines, extract product name from line
307313
product = r.split("/")[1] if "/" in r else r
308314
version = data["dependencies"][i]["requires"][r]
309315
vendor = self.find_js_vendor(product, version)
310316
if vendor is not None:
311-
for v in vendor:
312-
yield v[0], v[1] # product_info, file_path
317+
yield from vendor
313318
self.logger.debug(f"Done scanning file: {filename}")
314319

315-
def run_python_package_checkers(self, filename, lines):
320+
def run_python_package_checkers(
321+
self, filename: str, lines: str
322+
) -> Iterator[ScanInfo]:
316323
"""
317324
This generator runs only for python packages.
318325
There are no actual checkers.
@@ -331,15 +338,15 @@ def run_python_package_checkers(self, filename, lines):
331338

332339
self.logger.info(f"{file_path} is {product} {version}")
333340

334-
yield ProductInfo(vendor, product, version), file_path
341+
yield ScanInfo(ProductInfo(vendor, product, version), file_path)
335342

336343
# There are packages with a METADATA file in them containing different data from what the tool expects
337344
except AttributeError:
338345
self.logger.debug(f"{filename} is an invalid METADATA/PKG-INFO")
339346

340347
self.logger.debug(f"Done scanning file: {filename}")
341348

342-
def run_checkers(self, filename, lines):
349+
def run_checkers(self, filename: str, lines: str) -> Iterator[ScanInfo]:
343350
# tko
344351
for (dummy_checker_name, checker) in self.checkers.items():
345352
checker = checker()
@@ -370,12 +377,14 @@ def run_checkers(self, filename, lines):
370377
f'{file_path} {result["is_or_contains"]} {dummy_checker_name} {version}'
371378
)
372379
for vendor, product in checker.VENDOR_PRODUCT:
373-
yield ProductInfo(vendor, product, version), file_path
380+
yield ScanInfo(
381+
ProductInfo(vendor, product, version), file_path
382+
)
374383

375384
self.logger.debug(f"Done scanning file: {filename}")
376385

377386
@staticmethod
378-
def clean_file_path(filepath):
387+
def clean_file_path(filepath: str) -> str:
379388
"""Returns a cleaner filepath by removing temp path from filepath"""
380389

381390
# we'll recieve a filepath similar to
@@ -387,7 +396,9 @@ def clean_file_path(filepath):
387396
start_point = filepath.find("extracted") + 9
388397
return filepath[start_point:]
389398

390-
def scan_and_or_extract_file(self, ectx, filepath):
399+
def scan_and_or_extract_file(
400+
self, ectx: TempDirExtractorContext, filepath: str
401+
) -> Iterator[ScanInfo]:
391402
"""Runs extraction if possible and desired otherwise scans."""
392403
# Scan the file
393404
yield from self.scan_file(filepath)
@@ -404,7 +415,7 @@ def scan_and_or_extract_file(self, ectx, filepath):
404415
yield from self.scan_and_or_extract_file(ectx, filename)
405416
self.file_stack.pop()
406417

407-
def recursive_scan(self, scan_path):
418+
def recursive_scan(self, scan_path: str) -> Iterator[ScanInfo]:
408419
with Extractor(logger=self.logger, error_mode=self.error_mode) as ectx:
409420
if os.path.isdir(scan_path):
410421
for filepath in self.walker([scan_path]):

0 commit comments

Comments
 (0)