diff --git a/tuf/client_rework/fetcher.py b/tuf/client_rework/fetcher.py new file mode 100644 index 0000000000..8a6cae34d7 --- /dev/null +++ b/tuf/client_rework/fetcher.py @@ -0,0 +1,41 @@ +# Copyright 2021, New York University and the TUF contributors +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Provides an interface for network IO abstraction. +""" + +# Imports +import abc + + +# Classes +class FetcherInterface: + """Defines an interface for abstract network download. + + By providing a concrete implementation of the abstract interface, + users of the framework can plug-in their preferred/customized + network stack. + """ + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def fetch(self, url, required_length): + """Fetches the contents of HTTP/HTTPS url from a remote server. + + Ensures the length of the downloaded data is up to 'required_length'. + + Arguments: + url: A URL string that represents a file location. + required_length: An integer value representing the file length in + bytes. + + Raises: + tuf.exceptions.SlowRetrievalError: A timeout occurs while receiving + data. + tuf.exceptions.FetcherHTTPError: An HTTP error code is received. + + Returns: + A bytes iterator + """ + raise NotImplementedError # pragma: no cover diff --git a/tuf/client_rework/metadata_wrapper.py b/tuf/client_rework/metadata_wrapper.py index 6f182dc336..18f0d6d9aa 100644 --- a/tuf/client_rework/metadata_wrapper.py +++ b/tuf/client_rework/metadata_wrapper.py @@ -9,7 +9,7 @@ from securesystemslib.keys import format_metadata_to_key -import tuf.exceptions +from tuf import exceptions, formats from tuf.api import metadata @@ -64,7 +64,7 @@ def verify(self, keys, threshold): verified += 1 if verified < threshold: - raise tuf.exceptions.InsufficientKeysError + raise exceptions.InsufficientKeysError def persist(self, filename): """ @@ -77,13 +77,13 @@ def expires(self, reference_time=None): TODO """ if reference_time is None: - expires_timestamp = tuf.formats.datetime_to_unix_timestamp( + expires_timestamp = formats.datetime_to_unix_timestamp( self._meta.signed.expires ) reference_time = int(time.time()) if expires_timestamp < reference_time: - raise tuf.exceptions.ExpiredMetadataError + raise exceptions.ExpiredMetadataError class RootWrapper(MetadataWrapper): diff --git a/tuf/client_rework/mirrors_download.py b/tuf/client_rework/mirrors_download.py new file mode 100644 index 0000000000..03f6006795 --- /dev/null +++ b/tuf/client_rework/mirrors_download.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python + +# Copyright 2012 - 2017, New York University and the TUF contributors +# SPDX-License-Identifier: MIT OR Apache-2.0 + +""" + + mirrors.py + + + Konstantin Andrianov. + Derived from original mirrors.py written by Geremy Condra. + + + March 12, 2012. + + + See LICENSE-MIT OR LICENSE for licensing information. + + + Extract a list of mirror urls corresponding to the file type and the location + of the file with respect to the base url. +""" + + +# Help with Python 3 compatibility, where the print statement is a function, an +# implicit relative import is invalid, and the '/' operator performs true +# division. Example: print 'hello world' raises a 'SyntaxError' exception. +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) + +import logging +import os +import tempfile +import timeit +from typing import BinaryIO, Dict, Optional, TextIO + +import securesystemslib +import six +from securesystemslib import exceptions, util + +import tuf +from tuf import formats +from tuf.requests_fetcher import RequestsFetcher + +# See 'log.py' to learn how logging is handled in TUF. +logger = logging.getLogger(__name__) + +# The type of file to be downloaded from a repository. The +# 'get_list_of_mirrors' function supports these file types. +_SUPPORTED_FILE_TYPES = ["meta", "target"] + + +class Mirrors: + """TODO""" + + def __init__( + self, mirrors_dict: Dict, fetcher: Optional["FetcherInterface"] = None + ): + formats.MIRRORDICT_SCHEMA.check_match(mirrors_dict) + self._config = mirrors_dict + + if fetcher is None: + self._fetcher = RequestsFetcher() + else: + self._fetcher = fetcher + + def _get_list_of_mirrors(self, file_type, file_path): + """ + + Get a list of mirror urls from a mirrors dictionary, provided the + type and the path of the file with respect to the base url. + + + file_type: + Type of data needed for download, must correspond to one of the + strings in the list ['meta', 'target']. 'meta' for metadata file + type or 'target' for target file type. It should correspond to + NAME_SCHEMA format. + + file_path: + A relative path to the file that corresponds to RELPATH_SCHEMA + format. Ex: 'http://url_prefix/targets_path/file_path' + + + securesystemslib.exceptions.Error, on unsupported 'file_type'. + + securesystemslib.exceptions.FormatError, on bad argument. + + + List of mirror urls corresponding to the file_type and file_path. + If no match is found, empty list is returned. + """ + + # Checking if all the arguments have appropriate format. + formats.RELPATH_SCHEMA.check_match(file_path) + securesystemslib.formats.NAME_SCHEMA.check_match(file_type) + + # Verify 'file_type' is supported. + if file_type not in _SUPPORTED_FILE_TYPES: + raise exceptions.Error( + "Invalid file_type argument." + " Supported file types: " + repr(_SUPPORTED_FILE_TYPES) + ) + path_key = "metadata_path" if file_type == "meta" else "targets_path" + + list_of_mirrors = [] + for dummy, mirror_info in six.iteritems(self._config): + # Does mirror serve this file type at all? + path = mirror_info.get(path_key) + if path is None: + continue + + # for targets, ensure directory confinement + if path_key == "targets_path": + full_filepath = os.path.join(path, file_path) + confined_target_dirs = mirror_info.get("confined_target_dirs") + # confined_target_dirs is optional and can used to confine the + # client to certain paths on a repository mirror when fetching + # target files. + if ( + confined_target_dirs + and not util.file_in_confined_directories( + full_filepath, confined_target_dirs + ) + ): + continue + + # urllib.quote(string) replaces special characters in string using + # the %xx escape. This is done to avoid parsing issues of the URL + # on the server side. Do *NOT* pass URLs with Unicode characters + # without first encoding the URL as UTF-8. We need a long-term + # solution with #61. http://bugs.python.org/issue1712522 + file_path = six.moves.urllib.parse.quote(file_path) + url = os.path.join(mirror_info["url_prefix"], path, file_path) + + # The above os.path.join() result as well as input file_path may be + # invalid on windows (might contain both separator types), + # see #1077. + # Make sure the URL doesn't contain backward slashes on Windows. + list_of_mirrors.append(url.replace("\\", "/")) + + return list_of_mirrors + + def meta_download(self, filename: str, upper_length: int) -> TextIO: + """ + Download metadata file from the list of metadata mirrors + """ + file_mirrors = self._get_list_of_mirrors("meta", filename) + + file_mirror_errors = {} + for file_mirror in file_mirrors: + try: + temp_obj = self._download_file( + file_mirror, + upper_length, + strict_required_length=False, + ) + + temp_obj.seek(0) + yield temp_obj + + # pylint cannot figure out that we store the exceptions + # in a dictionary to raise them later so we disable + # the warning. This should be reviewed in the future still. + except Exception as exception: # pylint: disable=broad-except + file_mirror_errors[file_mirror] = exception + + finally: + if file_mirror_errors: + raise tuf.exceptions.NoWorkingMirrorError( + file_mirror_errors + ) + + def target_download(self, filename: str, strict_length: int) -> BinaryIO: + """ + Download target file from the list of target mirrors + """ + file_mirrors = self._get_list_of_mirrors("target", filename) + + file_mirror_errors = {} + for file_mirror in file_mirrors: + try: + temp_obj = self._download_file(file_mirror, strict_length) + + temp_obj.seek(0) + yield temp_obj + + # pylint cannot figure out that we store the exceptions + # in a dictionary to raise them later so we disable + # the warning. This should be reviewed in the future still. + except Exception as exception: # pylint: disable=broad-except + file_mirror_errors[file_mirror] = exception + + finally: + if file_mirror_errors: + raise tuf.exceptions.NoWorkingMirrorError( + file_mirror_errors + ) + + def _download_file(self, url, required_length, strict_required_length=True): + """ + + Given the url and length of the desired file, this function opens a + connection to 'url' and downloads the file while ensuring its length + matches 'required_length' if 'STRICT_REQUIRED_LENGH' is True (If False, + the file's length is not checked and a slow retrieval exception is + raised if the downloaded rate falls below the acceptable rate). + + + url: + A URL string that represents the location of the file. + + required_length: + An integer value representing the length of the file. + + strict_required_length: + A Boolean indicator used to signal whether we should perform strict + checking of required_length. True by default. We explicitly set this + to False when we know that we want to turn this off for downloading + the timestamp metadata, which has no signed required_length. + + + A file object is created on disk to store the contents of 'url'. + + + tuf.exceptions.DownloadLengthMismatchError, if there was a + mismatch of observed vs expected lengths while downloading the file. + + securesystemslib.exceptions.FormatError, if any of the arguments are + improperly formatted. + + Any other unforeseen runtime exception. + + + A file object that points to the contents of 'url'. + """ + # Do all of the arguments have the appropriate format? + # Raise 'securesystemslib.exceptions.FormatError' if there is + # a mismatch. + securesystemslib.formats.URL_SCHEMA.check_match(url) + formats.LENGTH_SCHEMA.check_match(required_length) + + # 'url.replace('\\', '/')' is needed for compatibility with + # Windows-based systems, because they might use back-slashes in place + # of forward-slashes. This converts it to the common format. unquote() + # replaces %xx escapes in a url with their single-character equivalent. + # A back-slash may be encoded as %5c in the url, which should also be + # replaced with a forward slash. + url = six.moves.urllib.parse.unquote(url).replace("\\", "/") + msg = f"Downloading: {url}" + logger.info(msg) + + # This is the temporary file that we will return to contain the + # contents of the downloaded file. + temp_file = tempfile.TemporaryFile() + + average_download_speed = 0 + number_of_bytes_received = 0 + + try: + chunks = self._fetcher.fetch(url, required_length) + start_time = timeit.default_timer() + for chunk in chunks: + + stop_time = timeit.default_timer() + temp_file.write(chunk) + + # Measure the average download speed. + number_of_bytes_received += len(chunk) + seconds_spent_receiving = stop_time - start_time + average_download_speed = ( + number_of_bytes_received / seconds_spent_receiving + ) + + if ( + average_download_speed + < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED + ): + logger.debug( + "The average download speed dropped below the minimum" + " average download speed set in tuf.settings.py. " + " Stopping the download!" + ) + break + + logger.debug( + "The average download speed has not dipped below the" + " minimum average download speed set" + " in tuf.settings.py." + ) + + # Does the total number of downloaded bytes match the required + # length? + self._check_downloaded_length( + number_of_bytes_received, + required_length, + strict_required_length=strict_required_length, + average_download_speed=average_download_speed, + ) + + except Exception: + # Close 'temp_file'. Any written data is lost. + temp_file.close() + msg = f"Could not download URL: {url}" + logger.debug(msg) + raise + + else: + return temp_file + + @staticmethod + def _check_downloaded_length( + total_downloaded, + required_length, + strict_required_length=True, + average_download_speed=None, + ): + """ + + A helper function which checks whether the total number of downloaded + bytes matches our expectation. + + + total_downloaded: + The total number of bytes supposedly downloaded for the file in + question. + + required_length: + The total number of bytes expected of the file as seen from its + metadata. The Timestamp role is always downloaded without a known + file length, and the Root role when the client cannot download any + of the required top-level roles. In both cases, 'required_length' + is actually an upper limit on the length of the downloaded file. + + strict_required_length: + A Boolean indicator used to signal whether we should perform strict + checking of required_length. True by default. We explicitly set this + to False when we know that we want to turn this off for downloading + the timestamp metadata, which has no signed required_length. + + average_download_speed: + The average download speed for the downloaded file. + + + None. + + + securesystemslib.exceptions.DownloadLengthMismatchError, if + strict_required_length is True and total_downloaded is not equal + required_length. + + tuf.exceptions.SlowRetrievalError, if the total downloaded was + done in less than the acceptable download speed (as set in + tuf.settings.py). + + + None. + """ + + if total_downloaded == required_length: + msg = ( + f"Downloaded {total_downloaded} bytes out of the" + f" expected {required_length} bytes." + ) + logger.info(msg) + + else: + difference_in_bytes = abs(total_downloaded - required_length) + + # What we downloaded is not equal to the required length, but did + # we ask for strict checking of required length? + if strict_required_length: + msg = ( + f"Downloaded {total_downloaded} bytes, but expected" + f"{required_length} bytes. There is a difference of" + f"{difference_in_bytes} bytes." + ) + logger.info(msg) + + # If the average download speed is below a certain threshold, + # we flag this as a possible slow-retrieval attack. + msg = ( + f"Average download speed: {average_download_speed}\n" + f"Minimum average download speed: " + f"{tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED}" + ) + logger.debug(msg) + + if ( + average_download_speed + < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED + ): + raise tuf.exceptions.SlowRetrievalError( + average_download_speed + ) + + msg = ( + f"Good average download speed: " + f"{average_download_speed} bytes per second" + ) + logger.debug(msg) + + raise tuf.exceptions.DownloadLengthMismatchError( + required_length, total_downloaded + ) + + # We specifically disabled strict checking of required length, + # but we will log a warning anyway. This is useful when we wish + # to download the Timestamp or Root metadata, for which we have + # no signed metadata; so, we must guess a reasonable + # required_length for it. + if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED: + raise tuf.exceptions.SlowRetrievalError(average_download_speed) + + msg = ( + f"Good average download speed: " + f"{average_download_speed} bytes per second" + ) + logger.debug(msg) + + msg = ( + f"Downloaded {total_downloaded} bytes out of an " + f"upper limit of {required_length} bytes." + ) + logger.info(msg) diff --git a/tuf/client_rework/requests_fetcher.py b/tuf/client_rework/requests_fetcher.py new file mode 100644 index 0000000000..37fed1f1ce --- /dev/null +++ b/tuf/client_rework/requests_fetcher.py @@ -0,0 +1,188 @@ +# Copyright 2021, New York University and the TUF contributors +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Provides an implementation of FetcherInterface using the Requests HTTP + library. +""" + +import logging +import time + +# Imports +import requests +import six +import urllib3.exceptions as urllib3_ex + +import tuf +from tuf import exceptions, settings +from tuf.client_rework.fetcher import FetcherInterface + +# Globals +logger = logging.getLogger(__name__) + +# Classess +class RequestsFetcher(FetcherInterface): + """A concrete implementation of FetcherInterface based on the Requests + library. + + Attributes: + _sessions: A dictionary of Requests.Session objects storing a separate + session per scheme+hostname combination. + """ + + def __init__(self): + # http://docs.python-requests.org/en/master/user/advanced/#session-objects: + # + # "The Session object allows you to persist certain parameters across + # requests. It also persists cookies across all requests made from the + # Session instance, and will use urllib3's connection pooling. So if + # you're making several requests to the same host, the underlying TCP + # connection will be reused, which can result in a significant + # performance increase (see HTTP persistent connection)." + # + # NOTE: We use a separate requests.Session per scheme+hostname + # combination, in order to reuse connections to the same hostname to + # improve efficiency, but avoiding sharing state between different + # hosts-scheme combinations to minimize subtle security issues. + # Some cookies may not be HTTP-safe. + self._sessions = {} + + def fetch(self, url, required_length): + """Fetches the contents of HTTP/HTTPS url from a remote server. + + Ensures the length of the downloaded data is up to 'required_length'. + + Arguments: + url: A URL string that represents a file location. + required_length: An integer value representing the file length in + bytes. + + Raises: + exceptions.SlowRetrievalError: A timeout occurs while receiving + data. + exceptions.FetcherHTTPError: An HTTP error code is received. + + Returns: + A bytes iterator + """ + # Get a customized session for each new schema+hostname combination. + session = self._get_session(url) + + # Get the requests.Response object for this URL. + # + # Defer downloading the response body with stream=True. + # Always set the timeout. This timeout value is interpreted + # by requests as: + # - connect timeout (max delay before first byte is received) + # - read (gap) timeout (max delay between bytes received) + response = session.get( + url, stream=True, timeout=settings.SOCKET_TIMEOUT + ) + # Check response status. + try: + response.raise_for_status() + except requests.HTTPError as e: + response.close() + status = e.response.status_code + raise exceptions.FetcherHTTPError(str(e), status) + + # Define a generator function to be returned by fetch. This way the + # caller of fetch can differentiate between connection and actual data + # download and measure download times accordingly. + def chunks(): + try: + bytes_received = 0 + while True: + # We download a fixed chunk of data in every round. This + # is so that we can defend against slow retrieval attacks. + # Furthermore, we do not wish to download an extremely + # large file in one shot. Before beginning the round, sleep + # (if set) for a short amount of time so that the CPU is + # not hogged in the while loop. + if settings.SLEEP_BEFORE_ROUND: + time.sleep(settings.SLEEP_BEFORE_ROUND) + + read_amount = min( + settings.CHUNK_SIZE, + required_length - bytes_received, + ) + + # NOTE: This may not handle some servers adding a + # Content-Encoding header, which may cause urllib3 + # to misbehave: + # https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582 + data = response.raw.read(read_amount) + bytes_received += len(data) + + # We might have no more data to read. Check number of bytes + # downloaded. + msg = ( + f"Downloaded {bytes_received}/{required_length}" + f"bytes." + ) + if not data: + logger.debug(msg) + # Finally, we signal that the download is complete. + break + + yield data + + if bytes_received >= required_length: + break + + except urllib3_ex.ReadTimeoutError as e: + raise exceptions.SlowRetrievalError(str(e)) + + finally: + response.close() + + return chunks() + + def _get_session(self, url): + """Returns a different customized requests.Session per schema+hostname + combination. + """ + # Use a different requests.Session per schema+hostname combination, to + # reuse connections while minimizing subtle security issues. + parsed_url = six.moves.urllib.parse.urlparse(url) + + if not parsed_url.scheme or not parsed_url.hostname: + raise exceptions.URLParsingError( + "Could not get scheme and hostname from URL: " + url + ) + + session_index = parsed_url.scheme + "+" + parsed_url.hostname + + msg = f"""url: {url} + session index: {session_index}""" + logger.debug(msg) + + session = self._sessions.get(session_index) + + if not session: + session = requests.Session() + self._sessions[session_index] = session + + # Attach some default headers to every Session. + requests_user_agent = session.headers["User-Agent"] + # Follows the RFC: https://tools.ietf.org/html/rfc7231#section-5.5.3 + tuf_user_agent = ( + "tuf/" + tuf.__version__ + " " + requests_user_agent + ) + session.headers.update( + { + # Tell the server not to compress or modify anything. + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding#Directives + "Accept-Encoding": "identity", + # The TUF user agent. + "User-Agent": tuf_user_agent, + } + ) + msg = f"Made new session for {session_index}" + + else: + msg = f"Reusing session for {session_index}" + + logger.debug(msg) + + return session diff --git a/tuf/client_rework/updater_rework.py b/tuf/client_rework/updater_rework.py index c7820a4ab4..20d790bbc1 100644 --- a/tuf/client_rework/updater_rework.py +++ b/tuf/client_rework/updater_rework.py @@ -16,8 +16,8 @@ from securesystemslib import hash as sslib_hash from securesystemslib import util as sslib_util -from tuf import download, exceptions, mirrors, requests_fetcher, settings -from tuf.client.fetcher import FetcherInterface +from tuf import exceptions, settings +from tuf.client_rework.mirrors_download import Mirrors from .metadata_wrapper import ( RootWrapper, @@ -50,7 +50,7 @@ def __init__( self, repository_name: str, repository_mirrors: Dict, - fetcher: Optional[FetcherInterface] = None, + fetcher: Optional["FetcherInterface"] = None, ): self._repository_name = repository_name @@ -58,10 +58,7 @@ def __init__( self._consistent_snapshot = False self._metadata = {} - if fetcher is None: - self._fetcher = requests_fetcher.RequestsFetcher() - else: - self._fetcher = fetcher + self._mirrors = Mirrors(repository_mirrors, fetcher) def refresh(self) -> None: """ @@ -147,8 +144,11 @@ def download_target(self, target: Dict, destination_directory: str): The file is saved to the 'destination_directory' argument. """ - try: - for temp_obj in self._mirror_target_download(target): + for temp_obj in self._mirrors.target_download( + target["filepath"], target["fileinfo"]["length"] + ): + + try: self._verify_target_file(temp_obj, target) # break? should we break after first successful download? @@ -156,62 +156,10 @@ def download_target(self, target: Dict, destination_directory: str): destination_directory, target["filepath"] ) sslib_util.persist_temp_file(temp_obj, filepath) - # pylint: disable=try-except-raise - except Exception: - # TODO: do something with exceptions - raise - - def _mirror_meta_download(self, filename: str, upper_length: int) -> TextIO: - """ - Download metadata file from the list of metadata mirrors - """ - file_mirrors = mirrors.get_list_of_mirrors( - "meta", filename, self._mirrors - ) - - file_mirror_errors = {} - for file_mirror in file_mirrors: - try: - temp_obj = download.unsafe_download( - file_mirror, upper_length, self._fetcher - ) - - temp_obj.seek(0) - yield temp_obj - - # pylint: disable=broad-except - except Exception as exception: - file_mirror_errors[file_mirror] = exception - - finally: - if file_mirror_errors: - raise exceptions.NoWorkingMirrorError(file_mirror_errors) - - def _mirror_target_download(self, fileinfo: str) -> BinaryIO: - """ - Download target file from the list of target mirrors - """ - # full_filename = _get_full_name(filename) - file_mirrors = mirrors.get_list_of_mirrors( - "target", fileinfo["filepath"], self._mirrors - ) - - file_mirror_errors = {} - for file_mirror in file_mirrors: - try: - temp_obj = download.safe_download( - file_mirror, fileinfo["fileinfo"]["length"], self._fetcher - ) - - temp_obj.seek(0) - yield temp_obj - # pylint: disable=broad-except - except Exception as exception: - file_mirror_errors[file_mirror] = exception - - finally: - if file_mirror_errors: - raise exceptions.NoWorkingMirrorError(file_mirror_errors) + # pylint: disable=try-except-raise + except Exception: + # TODO: do something with exceptions + raise def _get_full_meta_name( self, role: str, extension: str = ".json", version: int = None @@ -266,7 +214,7 @@ def _load_root(self) -> None: verified_root = None for next_version in range(lower_bound, upper_bound): try: - mirror_download = self._mirror_meta_download( + mirror_download = self._mirrors.meta_download( self._get_relative_meta_name("root", version=next_version), settings.DEFAULT_ROOT_REQUIRED_LENGTH, ) @@ -327,9 +275,10 @@ def _load_timestamp(self) -> None: TODO """ # TODO Check if timestamp exists locally - for temp_obj in self._mirror_meta_download( + for temp_obj in self._mirrors.meta_download( "timestamp.json", settings.DEFAULT_TIMESTAMP_REQUIRED_LENGTH ): + try: verified_tampstamp = self._verify_timestamp(temp_obj) # break? should we break after first successful download? @@ -364,7 +313,8 @@ def _load_snapshot(self) -> None: # Check if exists locally # self.loadLocal('snapshot', snapshotVerifier) - for temp_obj in self._mirror_meta_download("snapshot.json", length): + for temp_obj in self._mirrors.meta_download("snapshot.json", length): + try: verified_snapshot = self._verify_snapshot(temp_obj) # break? should we break after first successful download? @@ -400,9 +350,10 @@ def _load_targets(self, targets_role: str, parent_role: str) -> None: # Check if exists locally # self.loadLocal('snapshot', targetsVerifier) - for temp_obj in self._mirror_meta_download( + for temp_obj in self._mirrors.meta_download( targets_role + ".json", length ): + try: verified_targets = self._verify_targets( temp_obj, targets_role, parent_role