|
6 | 6 |
|
7 | 7 | # Imports
|
8 | 8 | import abc
|
9 |
| -from typing import Iterator |
| 9 | +import logging |
| 10 | +import tempfile |
| 11 | +from contextlib import contextmanager |
| 12 | +from typing import IO, Iterator |
| 13 | +from urllib import parse |
| 14 | + |
| 15 | +from tuf import exceptions |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
10 | 18 |
|
11 | 19 |
|
12 | 20 | # Classes
|
@@ -40,3 +48,51 @@ def fetch(self, url: str, required_length: int) -> Iterator[bytes]:
|
40 | 48 | A bytes iterator
|
41 | 49 | """
|
42 | 50 | raise NotImplementedError # pragma: no cover
|
| 51 | + |
| 52 | + @contextmanager |
| 53 | + def download_file(self, url: str, required_length: int) -> Iterator[IO]: |
| 54 | + """Opens a connection to 'url' and downloads the content |
| 55 | + up to 'required_length'. |
| 56 | +
|
| 57 | + Args: |
| 58 | + url: a URL string that represents the location of the file. |
| 59 | + required_length: an integer value representing the length of |
| 60 | + the file or an upper boundary. |
| 61 | +
|
| 62 | + Raises: |
| 63 | + DownloadLengthMismatchError: a mismatch of observed vs expected |
| 64 | + lengths while downloading the file. |
| 65 | +
|
| 66 | + Yields: |
| 67 | + A file object that points to the contents of 'url'. |
| 68 | + """ |
| 69 | + # 'url.replace('\\', '/')' is needed for compatibility with |
| 70 | + # Windows-based systems, because they might use back-slashes in place |
| 71 | + # of forward-slashes. This converts it to the common format. |
| 72 | + # unquote() replaces %xx escapes in a url with their single-character |
| 73 | + # equivalent. A back-slash may beencoded as %5c in the url, which |
| 74 | + # should also be replaced with a forward slash. |
| 75 | + url = parse.unquote(url).replace("\\", "/") |
| 76 | + logger.debug("Downloading: %s", url) |
| 77 | + |
| 78 | + number_of_bytes_received = 0 |
| 79 | + |
| 80 | + with tempfile.TemporaryFile() as temp_file: |
| 81 | + chunks = self.fetch(url, required_length) |
| 82 | + for chunk in chunks: |
| 83 | + temp_file.write(chunk) |
| 84 | + number_of_bytes_received += len(chunk) |
| 85 | + if number_of_bytes_received > required_length: |
| 86 | + raise exceptions.DownloadLengthMismatchError( |
| 87 | + required_length, number_of_bytes_received |
| 88 | + ) |
| 89 | + temp_file.seek(0) |
| 90 | + yield temp_file |
| 91 | + |
| 92 | + def download_bytes(self, url: str, required_length: int) -> bytes: |
| 93 | + """Download bytes from given url |
| 94 | +
|
| 95 | + Returns the downloaded bytes, otherwise like download_file() |
| 96 | + """ |
| 97 | + with self.download_file(url, required_length) as dl_file: |
| 98 | + return dl_file.read() |
0 commit comments