Skip to content

http: migrate to fsspec & aiohttp #6525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dvc/fs/fsspec_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,27 @@ def find(self, path_info, detail=False, prefix=None):
yield from self._strip_buckets(files, detail=detail)


# pylint: disable=arguments-differ
class NoDirectoriesMixin:
def isdir(self, *args, **kwargs):
return False

def isfile(self, *args, **kwargs):
return True

def find(self, *args, **kwargs):
raise NotImplementedError

def walk(self, *args, **kwargs):
raise NotImplementedError

def walk_files(self, *args, **kwargs):
raise NotImplementedError

def ls(self, *args, **kwargs):
raise NotImplementedError


_LOCAL_FS = LocalFileSystem()


Expand Down
257 changes: 84 additions & 173 deletions dvc/fs/http.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
import logging
import os.path
import threading
from typing import Optional
from urllib.parse import urlparse

from funcy import cached_property, memoize, wrap_prop, wrap_with
from funcy import cached_property, memoize, wrap_with

from dvc import prompt
from dvc.exceptions import DvcException, HTTPError
from dvc.path_info import HTTPURLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes

from .base import BaseFileSystem

logger = logging.getLogger(__name__)
from .fsspec_wrapper import CallbackMixin, FSSpecWrapper, NoDirectoriesMixin


@wrap_with(threading.Lock())
Expand All @@ -26,179 +18,98 @@ def ask_password(host, user):
)


class HTTPFileSystem(BaseFileSystem): # pylint:disable=abstract-method
def make_context(ssl_verify):
if isinstance(ssl_verify, bool) or ssl_verify is None:
return ssl_verify

# If this is a path, then we will create an
# SSL context for it, and load the given certificate.
import ssl

context = ssl.create_default_context()
context.load_verify_locations(ssl_verify)
return context


# pylint: disable=abstract-method
class HTTPFileSystem(CallbackMixin, NoDirectoriesMixin, FSSpecWrapper):
scheme = Schemes.HTTP
PATH_CLS = HTTPURLInfo
PARAM_CHECKSUM = "etag"
PARAM_CHECKSUM = "checksum"
REQUIRES = {"aiohttp": "aiohttp", "aiohttp-retry": "aiohttp_retry"}
CAN_TRAVERSE = False
REQUIRES = {"requests": "requests"}

SESSION_RETRIES = 5
SESSION_BACKOFF_FACTOR = 0.1
REQUEST_TIMEOUT = 60
CHUNK_SIZE = 2 ** 16

def __init__(self, **config):
super().__init__(**config)

self.user = config.get("user", None)

self.auth = config.get("auth", None)
self.custom_auth_header = config.get("custom_auth_header", None)
self.password = config.get("password", None)
self.ask_password = config.get("ask_password", False)
self.headers = {}
self.ssl_verify = config.get("ssl_verify", True)
self.method = config.get("method", "POST")

def _auth_method(self, url):
from requests.auth import HTTPBasicAuth, HTTPDigestAuth

if self.auth:
if self.ask_password and self.password is None:
self.password = ask_password(urlparse(url).hostname, self.user)
if self.auth == "basic":
return HTTPBasicAuth(self.user, self.password)
if self.auth == "digest":
return HTTPDigestAuth(self.user, self.password)
if self.auth == "custom" and self.custom_auth_header:
self.headers.update({self.custom_auth_header: self.password})
return None

@wrap_prop(threading.Lock())
@cached_property
def _session(self):
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

session = requests.Session()
def _prepare_credentials(self, **config):
import aiohttp
from fsspec.asyn import fsspec_loop

from dvc.config import ConfigError

credentials = {}
client_args = credentials.setdefault("client_args", {})

if config.get("auth"):
user = config.get("user")
password = config.get("password")
custom_auth_header = config.get("custom_auth_header")

if password is None and config.get("ask_password"):
password = ask_password(config.get("url"), user or "custom")

auth_method = config["auth"]
if auth_method == "basic":
if user is None or password is None:
raise ConfigError(
"HTTP 'basic' authentication require both "
"'user' and 'password'"
)

client_args["auth"] = aiohttp.BasicAuth(user, password)
elif auth_method == "custom":
if custom_auth_header is None or password is None:
raise ConfigError(
"HTTP 'custom' authentication require both "
"'custom_auth_header' and 'password'"
)
credentials["headers"] = {custom_auth_header: password}
else:
raise NotImplementedError(
f"Auth method {auth_method!r} is not supported."
)

if "ssl_verify" in config:
with fsspec_loop():
client_args["connector"] = aiohttp.TCPConnector(
ssl=make_context(config["ssl_verify"])
)

credentials["get_client"] = self.get_client
self.upload_method = config.get("method", "POST")
return credentials

async def get_client(self, **kwargs):
from aiohttp_retry import ExponentialRetry, RetryClient

kwargs["retry_options"] = ExponentialRetry(
attempts=self.SESSION_RETRIES,
factor=self.SESSION_BACKOFF_FACTOR,
max_timeout=self.REQUEST_TIMEOUT,
)

session.verify = self.ssl_verify
return RetryClient(**kwargs)

retries = Retry(
total=self.SESSION_RETRIES,
backoff_factor=self.SESSION_BACKOFF_FACTOR,
@cached_property
def fs(self):
from fsspec.implementations.http import (
HTTPFileSystem as _HTTPFileSystem,
)

session.mount("http://", HTTPAdapter(max_retries=retries))
session.mount("https://", HTTPAdapter(max_retries=retries))

return session

def request(self, method, url, **kwargs):
import requests

kwargs.setdefault("allow_redirects", True)
kwargs.setdefault("timeout", self.REQUEST_TIMEOUT)

try:
res = self._session.request(
method,
url,
auth=self._auth_method(url),
headers=self.headers,
**kwargs,
)

redirect_no_location = (
kwargs["allow_redirects"]
and res.status_code in (301, 302)
and "location" not in res.headers
)

if redirect_no_location:
# AWS s3 doesn't like to add a location header to its redirects
# from https://s3.amazonaws.com/<bucket name>/* type URLs.
# This should be treated as an error
raise requests.exceptions.RequestException

return res

except requests.exceptions.RequestException:
raise DvcException(f"could not perform a {method} request")

def _head(self, url):
response = self.request("HEAD", url)
if response.ok:
return response

# Sometimes servers are configured to forbid HEAD requests
# Context: https://github.com/iterative/dvc/issues/4131
with self.request("GET", url, stream=True) as r:
if r.ok:
return r

return response

def exists(self, path_info) -> bool:
res = self._head(path_info.url)
if res.status_code == 404:
return False
if bool(res):
return True
raise HTTPError(res.status_code, res.reason)

def info(self, path_info):
resp = self._head(path_info.url)
etag = resp.headers.get("ETag") or resp.headers.get("Content-MD5")
size = self._content_length(resp)
return {"etag": etag, "size": size, "type": "file"}

def _upload_fobj(self, fobj, to_info, **kwargs):
def chunks(fobj):
while True:
chunk = fobj.read(self.CHUNK_SIZE)
if not chunk:
break
yield chunk

response = self.request(self.method, to_info.url, data=chunks(fobj))
if response.status_code not in (200, 201):
raise HTTPError(response.status_code, response.reason)

def _download(self, from_info, to_file, name=None, no_progress_bar=False):
response = self.request("GET", from_info.url, stream=True)
if response.status_code != 200:
raise HTTPError(response.status_code, response.reason)
with open(to_file, "wb") as fd:
with Tqdm.wrapattr(
fd,
"write",
total=None
if no_progress_bar
else self._content_length(response),
leave=False,
desc=from_info.url if name is None else name,
disable=no_progress_bar,
) as fd_wrapped:
for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE):
fd_wrapped.write(chunk)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
):
with open(from_file, "rb") as fobj:
self.upload_fobj(
fobj,
to_info,
size=None if no_progress_bar else os.path.getsize(from_file),
no_progress_bar=no_progress_bar,
desc=name or to_info.url,
)

def open(self, path_info, mode: str = "r", encoding: str = None, **kwargs):
from dvc.utils.http import open_url

return open_url(
path_info.url,
mode=mode,
encoding=encoding,
auth=self._auth_method(path_info),
**kwargs,
)
return _HTTPFileSystem(timeout=self.REQUEST_TIMEOUT)

@staticmethod
def _content_length(response) -> Optional[int]:
res = response.headers.get("Content-Length")
return int(res) if res else None
def _entry_hook(self, entry):
entry["checksum"] = entry.get("ETag") or entry.get("Content-MD5")
return entry
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def run(self):
"typing_extensions>=3.7.4; python_version < '3.10'",
# until https://github.com/python/typing/issues/865 is fixed for python3.10
"typing_extensions==3.10.0.0; python_version >= '3.10'",
"fsspec>=2021.8.1",
"fsspec[http]>=2021.8.1",
"aiohttp-retry==2.4.5",
"diskcache>=5.2.1",
]

Expand Down
1 change: 0 additions & 1 deletion tests/func/test_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def test_fs_getsize(dvc, cloud):
pytest.lazy_fixture("gs"),
pytest.lazy_fixture("gdrive"),
pytest.lazy_fixture("hdfs"),
pytest.lazy_fixture("http"),
pytest.lazy_fixture("local_cloud"),
pytest.lazy_fixture("oss"),
pytest.lazy_fixture("s3"),
Expand Down
Loading