diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index ac189f4f635..55ec7f38268 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -7,6 +7,7 @@ from urllib.error import HTTPError, URLError from urllib.parse import urlparse from urllib.request import urlopen, Request +import warnings import pytest @@ -45,6 +46,31 @@ def inner_wrapper(request, *args, **kwargs): urlopen = limit_requests_per_time()(urlopen) +def resolve_redirects(max_redirects=3): + def outer_wrapper(fn): + def inner_wrapper(request, *args, **kwargs): + url = initial_url = request.full_url if isinstance(request, Request) else request + + for _ in range(max_redirects + 1): + response = fn(request, *args, **kwargs) + + if response.url == url or response.url is None: + if url != initial_url: + warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.") + return response + + url = response.url + else: + raise RecursionError(f"Request to {initial_url} exceeded {max_redirects} redirects.") + + return inner_wrapper + + return outer_wrapper + + +urlopen = resolve_redirects()(urlopen) + + @contextlib.contextmanager def log_download_attempts( urls_and_md5s=None, @@ -117,19 +143,22 @@ def assert_server_response_ok(): raise AssertionError("The request timed out.") from error except HTTPError as error: raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error + except RecursionError as error: + raise AssertionError(str(error)) from error -def assert_url_is_accessible(url): +def assert_url_is_accessible(url, timeout=5.0): request = Request(url, headers=dict(method="HEAD")) with assert_server_response_ok(): - urlopen(request, timeout=5.0) + urlopen(request, timeout=timeout) -def assert_file_downloads_correctly(url, md5): +def assert_file_downloads_correctly(url, md5, timeout=5.0): with get_tmp_dir() as root: file = path.join(root, path.basename(url)) with assert_server_response_ok(): - with urlopen(url, timeout=5.0) as response, open(file, "wb") as fh: + with open(file, "wb") as fh: + response = urlopen(url, timeout=timeout) fh.write(response.read()) assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"