diff --git a/dvc/fs/fsspec_wrapper.py b/dvc/fs/fsspec_wrapper.py index cef7c6e517..e784b799b9 100644 --- a/dvc/fs/fsspec_wrapper.py +++ b/dvc/fs/fsspec_wrapper.py @@ -233,6 +233,27 @@ def find(self, path_info, detail=False, prefix=None): yield from self._strip_buckets(files, detail=detail) +# pylint: disable=arguments-differ +class NoDirectoriesMixin: + def isdir(self, *args, **kwargs): + return False + + def isfile(self, *args, **kwargs): + return True + + def find(self, *args, **kwargs): + raise NotImplementedError + + def walk(self, *args, **kwargs): + raise NotImplementedError + + def walk_files(self, *args, **kwargs): + raise NotImplementedError + + def ls(self, *args, **kwargs): + raise NotImplementedError + + _LOCAL_FS = LocalFileSystem() diff --git a/dvc/fs/http.py b/dvc/fs/http.py index 06579fcaef..c024100c45 100644 --- a/dvc/fs/http.py +++ b/dvc/fs/http.py @@ -1,20 +1,12 @@ -import logging -import os.path import threading -from typing import Optional -from urllib.parse import urlparse -from funcy import cached_property, memoize, wrap_prop, wrap_with +from funcy import cached_property, memoize, wrap_with from dvc import prompt -from dvc.exceptions import DvcException, HTTPError from dvc.path_info import HTTPURLInfo -from dvc.progress import Tqdm from dvc.scheme import Schemes -from .base import BaseFileSystem - -logger = logging.getLogger(__name__) +from .fsspec_wrapper import CallbackMixin, FSSpecWrapper, NoDirectoriesMixin @wrap_with(threading.Lock()) @@ -26,179 +18,98 @@ def ask_password(host, user): ) -class HTTPFileSystem(BaseFileSystem): # pylint:disable=abstract-method +def make_context(ssl_verify): + if isinstance(ssl_verify, bool) or ssl_verify is None: + return ssl_verify + + # If this is a path, then we will create an + # SSL context for it, and load the given certificate. + import ssl + + context = ssl.create_default_context() + context.load_verify_locations(ssl_verify) + return context + + +# pylint: disable=abstract-method +class HTTPFileSystem(CallbackMixin, NoDirectoriesMixin, FSSpecWrapper): scheme = Schemes.HTTP PATH_CLS = HTTPURLInfo - PARAM_CHECKSUM = "etag" + PARAM_CHECKSUM = "checksum" + REQUIRES = {"aiohttp": "aiohttp", "aiohttp-retry": "aiohttp_retry"} CAN_TRAVERSE = False - REQUIRES = {"requests": "requests"} SESSION_RETRIES = 5 SESSION_BACKOFF_FACTOR = 0.1 REQUEST_TIMEOUT = 60 - CHUNK_SIZE = 2 ** 16 - - def __init__(self, **config): - super().__init__(**config) - - self.user = config.get("user", None) - - self.auth = config.get("auth", None) - self.custom_auth_header = config.get("custom_auth_header", None) - self.password = config.get("password", None) - self.ask_password = config.get("ask_password", False) - self.headers = {} - self.ssl_verify = config.get("ssl_verify", True) - self.method = config.get("method", "POST") - - def _auth_method(self, url): - from requests.auth import HTTPBasicAuth, HTTPDigestAuth - - if self.auth: - if self.ask_password and self.password is None: - self.password = ask_password(urlparse(url).hostname, self.user) - if self.auth == "basic": - return HTTPBasicAuth(self.user, self.password) - if self.auth == "digest": - return HTTPDigestAuth(self.user, self.password) - if self.auth == "custom" and self.custom_auth_header: - self.headers.update({self.custom_auth_header: self.password}) - return None - - @wrap_prop(threading.Lock()) - @cached_property - def _session(self): - import requests - from requests.adapters import HTTPAdapter - from urllib3.util.retry import Retry - session = requests.Session() + def _prepare_credentials(self, **config): + import aiohttp + from fsspec.asyn import fsspec_loop + + from dvc.config import ConfigError + + credentials = {} + client_args = credentials.setdefault("client_args", {}) + + if config.get("auth"): + user = config.get("user") + password = config.get("password") + custom_auth_header = config.get("custom_auth_header") + + if password is None and config.get("ask_password"): + password = ask_password(config.get("url"), user or "custom") + + auth_method = config["auth"] + if auth_method == "basic": + if user is None or password is None: + raise ConfigError( + "HTTP 'basic' authentication require both " + "'user' and 'password'" + ) + + client_args["auth"] = aiohttp.BasicAuth(user, password) + elif auth_method == "custom": + if custom_auth_header is None or password is None: + raise ConfigError( + "HTTP 'custom' authentication require both " + "'custom_auth_header' and 'password'" + ) + credentials["headers"] = {custom_auth_header: password} + else: + raise NotImplementedError( + f"Auth method {auth_method!r} is not supported." + ) + + if "ssl_verify" in config: + with fsspec_loop(): + client_args["connector"] = aiohttp.TCPConnector( + ssl=make_context(config["ssl_verify"]) + ) + + credentials["get_client"] = self.get_client + self.upload_method = config.get("method", "POST") + return credentials + + async def get_client(self, **kwargs): + from aiohttp_retry import ExponentialRetry, RetryClient + + kwargs["retry_options"] = ExponentialRetry( + attempts=self.SESSION_RETRIES, + factor=self.SESSION_BACKOFF_FACTOR, + max_timeout=self.REQUEST_TIMEOUT, + ) - session.verify = self.ssl_verify + return RetryClient(**kwargs) - retries = Retry( - total=self.SESSION_RETRIES, - backoff_factor=self.SESSION_BACKOFF_FACTOR, + @cached_property + def fs(self): + from fsspec.implementations.http import ( + HTTPFileSystem as _HTTPFileSystem, ) - session.mount("http://", HTTPAdapter(max_retries=retries)) - session.mount("https://", HTTPAdapter(max_retries=retries)) - - return session - - def request(self, method, url, **kwargs): - import requests - - kwargs.setdefault("allow_redirects", True) - kwargs.setdefault("timeout", self.REQUEST_TIMEOUT) - - try: - res = self._session.request( - method, - url, - auth=self._auth_method(url), - headers=self.headers, - **kwargs, - ) - - redirect_no_location = ( - kwargs["allow_redirects"] - and res.status_code in (301, 302) - and "location" not in res.headers - ) - - if redirect_no_location: - # AWS s3 doesn't like to add a location header to its redirects - # from https://s3.amazonaws.com//* type URLs. - # This should be treated as an error - raise requests.exceptions.RequestException - - return res - - except requests.exceptions.RequestException: - raise DvcException(f"could not perform a {method} request") - - def _head(self, url): - response = self.request("HEAD", url) - if response.ok: - return response - - # Sometimes servers are configured to forbid HEAD requests - # Context: https://github.com/iterative/dvc/issues/4131 - with self.request("GET", url, stream=True) as r: - if r.ok: - return r - - return response - - def exists(self, path_info) -> bool: - res = self._head(path_info.url) - if res.status_code == 404: - return False - if bool(res): - return True - raise HTTPError(res.status_code, res.reason) - - def info(self, path_info): - resp = self._head(path_info.url) - etag = resp.headers.get("ETag") or resp.headers.get("Content-MD5") - size = self._content_length(resp) - return {"etag": etag, "size": size, "type": "file"} - - def _upload_fobj(self, fobj, to_info, **kwargs): - def chunks(fobj): - while True: - chunk = fobj.read(self.CHUNK_SIZE) - if not chunk: - break - yield chunk - - response = self.request(self.method, to_info.url, data=chunks(fobj)) - if response.status_code not in (200, 201): - raise HTTPError(response.status_code, response.reason) - - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - response = self.request("GET", from_info.url, stream=True) - if response.status_code != 200: - raise HTTPError(response.status_code, response.reason) - with open(to_file, "wb") as fd: - with Tqdm.wrapattr( - fd, - "write", - total=None - if no_progress_bar - else self._content_length(response), - leave=False, - desc=from_info.url if name is None else name, - disable=no_progress_bar, - ) as fd_wrapped: - for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE): - fd_wrapped.write(chunk) - - def _upload( - self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs - ): - with open(from_file, "rb") as fobj: - self.upload_fobj( - fobj, - to_info, - size=None if no_progress_bar else os.path.getsize(from_file), - no_progress_bar=no_progress_bar, - desc=name or to_info.url, - ) - - def open(self, path_info, mode: str = "r", encoding: str = None, **kwargs): - from dvc.utils.http import open_url - - return open_url( - path_info.url, - mode=mode, - encoding=encoding, - auth=self._auth_method(path_info), - **kwargs, - ) + return _HTTPFileSystem(timeout=self.REQUEST_TIMEOUT) - @staticmethod - def _content_length(response) -> Optional[int]: - res = response.headers.get("Content-Length") - return int(res) if res else None + def _entry_hook(self, entry): + entry["checksum"] = entry.get("ETag") or entry.get("Content-MD5") + return entry diff --git a/setup.py b/setup.py index 0da595dfbe..555f244b8b 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,8 @@ def run(self): "typing_extensions>=3.7.4; python_version < '3.10'", # until https://github.com/python/typing/issues/865 is fixed for python3.10 "typing_extensions==3.10.0.0; python_version >= '3.10'", - "fsspec>=2021.8.1", + "fsspec[http]>=2021.8.1", + "aiohttp-retry==2.4.5", "diskcache>=5.2.1", ] diff --git a/tests/func/test_fs.py b/tests/func/test_fs.py index 82b914949a..c8034f2534 100644 --- a/tests/func/test_fs.py +++ b/tests/func/test_fs.py @@ -260,7 +260,6 @@ def test_fs_getsize(dvc, cloud): pytest.lazy_fixture("gs"), pytest.lazy_fixture("gdrive"), pytest.lazy_fixture("hdfs"), - pytest.lazy_fixture("http"), pytest.lazy_fixture("local_cloud"), pytest.lazy_fixture("oss"), pytest.lazy_fixture("s3"), diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index cfefff2a81..d28ad381a6 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -1,17 +1,15 @@ -import io +import ssl import pytest -import requests +from mock import patch -from dvc.exceptions import HTTPError from dvc.fs.http import HTTPFileSystem -from dvc.path_info import URLInfo def test_download_fails_on_error_code(dvc, http): fs = HTTPFileSystem(**http.config) - with pytest.raises(HTTPError): + with pytest.raises(FileNotFoundError): fs._download(http / "missing.txt", "missing.txt") @@ -25,15 +23,13 @@ def test_public_auth_method(dvc): fs = HTTPFileSystem(**config) - assert fs._auth_method(config["url"]) is None + assert "auth" not in fs.fs_args["client_args"] + assert "headers" not in fs.fs_args def test_basic_auth_method(dvc): - from requests.auth import HTTPBasicAuth - user = "username" password = "password" - auth = HTTPBasicAuth(user, password) config = { "url": "http://example.com/", "path_info": "file.html", @@ -44,28 +40,8 @@ def test_basic_auth_method(dvc): fs = HTTPFileSystem(**config) - assert fs._auth_method(config["url"]) == auth - assert isinstance(fs._auth_method(config["url"]), HTTPBasicAuth) - - -def test_digest_auth_method(dvc): - from requests.auth import HTTPDigestAuth - - user = "username" - password = "password" - auth = HTTPDigestAuth(user, password) - config = { - "url": "http://example.com/", - "path_info": "file.html", - "auth": "digest", - "user": user, - "password": password, - } - - fs = HTTPFileSystem(**config) - - assert fs._auth_method(config["url"]) == auth - assert isinstance(fs._auth_method(config["url"]), HTTPDigestAuth) + assert fs.fs_args["client_args"]["auth"].login == user + assert fs.fs_args["client_args"]["auth"].password == password def test_custom_auth_method(dvc): @@ -81,17 +57,9 @@ def test_custom_auth_method(dvc): fs = HTTPFileSystem(**config) - assert fs._auth_method(config["url"]) is None - assert header in fs.headers - assert fs.headers[header] == password - - -def test_ssl_verify_is_enabled_by_default(dvc): - config = {"url": "http://example.com/", "path_info": "file.html"} - - fs = HTTPFileSystem(**config) - - assert fs._session.verify is True + headers = fs.fs_args["headers"] + assert header in headers + assert headers[header] == password def test_ssl_verify_disable(dvc): @@ -102,11 +70,11 @@ def test_ssl_verify_disable(dvc): } fs = HTTPFileSystem(**config) - - assert fs._session.verify is False + assert not fs.fs_args["client_args"]["connector"]._ssl -def test_ssl_verify_custom_cert(dvc): +@patch("ssl.SSLContext.load_verify_locations") +def test_ssl_verify_custom_cert(dvc, mocker): config = { "url": "http://example.com/", "path_info": "file.html", @@ -115,69 +83,19 @@ def test_ssl_verify_custom_cert(dvc): fs = HTTPFileSystem(**config) - assert fs._session.verify == "/path/to/custom/cabundle.pem" + assert isinstance( + fs.fs_args["client_args"]["connector"]._ssl, ssl.SSLContext + ) def test_http_method(dvc): - from requests.auth import HTTPBasicAuth - - user = "username" - password = "password" - auth = HTTPBasicAuth(user, password) config = { "url": "http://example.com/", "path_info": "file.html", - "auth": "basic", - "user": user, - "password": password, - "method": "PUT", } - fs = HTTPFileSystem(**config) - - assert fs._auth_method(config["url"]) == auth - assert fs.method == "PUT" - assert isinstance(fs._auth_method(config["url"]), HTTPBasicAuth) - - -def test_exists(mocker): - res = requests.Response() - # need to add `raw`, as `exists()` fallbacks to a streaming GET requests - # on HEAD request failure. - res.raw = io.StringIO("foo") + fs = HTTPFileSystem(**config, method="PUT") + assert fs.upload_method == "PUT" - fs = HTTPFileSystem() - mocker.patch.object(fs, "request", return_value=res) - - url = URLInfo("https://example.org/file.txt") - - res.status_code = 200 - assert fs.exists(url) is True - - res.status_code = 404 - assert fs.exists(url) is False - - res.status_code = 403 - with pytest.raises(HTTPError): - fs.exists(url) - - -@pytest.mark.parametrize( - "headers, expected_size", [({"Content-Length": "3"}, 3), ({}, None)] -) -def test_content_length(mocker, headers, expected_size): - res = requests.Response() - res.headers.update(headers) - res.status_code = 200 - - fs = HTTPFileSystem() - mocker.patch.object(fs, "request", return_value=res) - - url = URLInfo("https://example.org/file.txt") - - assert fs.info(url) == { - "etag": None, - "size": expected_size, - "type": "file", - } - assert fs._content_length(res) == expected_size + fs = HTTPFileSystem(**config, method="POST") + assert fs.upload_method == "POST"