Skip to content

Commit f450823

Browse files
committed
http: migrate to fsspec & aiohttp
1 parent b8b6ace commit f450823

File tree

5 files changed

+127
-277
lines changed

5 files changed

+127
-277
lines changed

dvc/fs/fsspec_wrapper.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,27 @@ def find(self, path_info, detail=False, prefix=None):
233233
yield from self._strip_buckets(files, detail=detail)
234234

235235

236+
# pylint: disable=arguments-differ
237+
class NoDirectoriesMixin:
238+
def isdir(self, *args, **kwargs):
239+
return False
240+
241+
def isfile(self, *args, **kwargs):
242+
return True
243+
244+
def find(self, *args, **kwargs):
245+
raise NotImplementedError
246+
247+
def walk(self, *args, **kwargs):
248+
raise NotImplementedError
249+
250+
def walk_files(self, *args, **kwargs):
251+
raise NotImplementedError
252+
253+
def ls(self, *args, **kwargs):
254+
raise NotImplementedError
255+
256+
236257
_LOCAL_FS = LocalFileSystem()
237258

238259

dvc/fs/http.py

Lines changed: 84 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,12 @@
1-
import logging
2-
import os.path
31
import threading
4-
from typing import Optional
5-
from urllib.parse import urlparse
62

7-
from funcy import cached_property, memoize, wrap_prop, wrap_with
3+
from funcy import cached_property, memoize, wrap_with
84

95
from dvc import prompt
10-
from dvc.exceptions import DvcException, HTTPError
116
from dvc.path_info import HTTPURLInfo
12-
from dvc.progress import Tqdm
137
from dvc.scheme import Schemes
148

15-
from .base import BaseFileSystem
16-
17-
logger = logging.getLogger(__name__)
9+
from .fsspec_wrapper import CallbackMixin, FSSpecWrapper, NoDirectoriesMixin
1810

1911

2012
@wrap_with(threading.Lock())
@@ -26,179 +18,98 @@ def ask_password(host, user):
2618
)
2719

2820

29-
class HTTPFileSystem(BaseFileSystem): # pylint:disable=abstract-method
21+
def make_context(ssl_verify):
22+
if isinstance(ssl_verify, bool) or ssl_verify is None:
23+
return ssl_verify
24+
25+
# If this is a path, then we will create an
26+
# SSL context for it, and load the given certificate.
27+
import ssl
28+
29+
context = ssl.create_default_context()
30+
context.load_verify_locations(ssl_verify)
31+
return context
32+
33+
34+
# pylint: disable=abstract-method
35+
class HTTPFileSystem(CallbackMixin, NoDirectoriesMixin, FSSpecWrapper):
3036
scheme = Schemes.HTTP
3137
PATH_CLS = HTTPURLInfo
32-
PARAM_CHECKSUM = "etag"
38+
PARAM_CHECKSUM = "checksum"
39+
REQUIRES = {"aiohttp": "aiohttp", "aiohttp-retry": "aiohttp_retry"}
3340
CAN_TRAVERSE = False
34-
REQUIRES = {"requests": "requests"}
3541

3642
SESSION_RETRIES = 5
3743
SESSION_BACKOFF_FACTOR = 0.1
3844
REQUEST_TIMEOUT = 60
39-
CHUNK_SIZE = 2 ** 16
40-
41-
def __init__(self, **config):
42-
super().__init__(**config)
43-
44-
self.user = config.get("user", None)
45-
46-
self.auth = config.get("auth", None)
47-
self.custom_auth_header = config.get("custom_auth_header", None)
48-
self.password = config.get("password", None)
49-
self.ask_password = config.get("ask_password", False)
50-
self.headers = {}
51-
self.ssl_verify = config.get("ssl_verify", True)
52-
self.method = config.get("method", "POST")
53-
54-
def _auth_method(self, url):
55-
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
56-
57-
if self.auth:
58-
if self.ask_password and self.password is None:
59-
self.password = ask_password(urlparse(url).hostname, self.user)
60-
if self.auth == "basic":
61-
return HTTPBasicAuth(self.user, self.password)
62-
if self.auth == "digest":
63-
return HTTPDigestAuth(self.user, self.password)
64-
if self.auth == "custom" and self.custom_auth_header:
65-
self.headers.update({self.custom_auth_header: self.password})
66-
return None
67-
68-
@wrap_prop(threading.Lock())
69-
@cached_property
70-
def _session(self):
71-
import requests
72-
from requests.adapters import HTTPAdapter
73-
from urllib3.util.retry import Retry
7445

75-
session = requests.Session()
46+
def _prepare_credentials(self, **config):
47+
import aiohttp
48+
from fsspec.asyn import fsspec_loop
49+
50+
from dvc.config import ConfigError
51+
52+
credentials = {}
53+
client_args = credentials.setdefault("client_args", {})
54+
55+
if config.get("auth"):
56+
user = config.get("user")
57+
password = config.get("password")
58+
custom_auth_header = config.get("custom_auth_header")
59+
60+
if password is None and config.get("ask_password"):
61+
password = ask_password(config.get("url"), user or "custom")
62+
63+
auth_method = config["auth"]
64+
if auth_method == "basic":
65+
if user is None or password is None:
66+
raise ConfigError(
67+
"HTTP 'basic' authentication require both "
68+
"'user' and 'password'"
69+
)
70+
71+
client_args["auth"] = aiohttp.BasicAuth(user, password)
72+
elif auth_method == "custom":
73+
if custom_auth_header is None or password is None:
74+
raise ConfigError(
75+
"HTTP 'custom' authentication require both "
76+
"'custom_auth_header' and 'password'"
77+
)
78+
credentials["headers"] = {custom_auth_header: password}
79+
else:
80+
raise NotImplementedError(
81+
f"Auth method {auth_method!r} is not supported."
82+
)
83+
84+
if "ssl_verify" in config:
85+
with fsspec_loop():
86+
client_args["connector"] = aiohttp.TCPConnector(
87+
ssl=make_context(config["ssl_verify"])
88+
)
89+
90+
credentials["get_client"] = self.get_client
91+
self.upload_method = config.get("method", "POST")
92+
return credentials
93+
94+
async def get_client(self, **kwargs):
95+
from aiohttp_retry import ExponentialRetry, RetryClient
96+
97+
kwargs["retry_options"] = ExponentialRetry(
98+
attempts=self.SESSION_RETRIES,
99+
factor=self.SESSION_BACKOFF_FACTOR,
100+
max_timeout=self.REQUEST_TIMEOUT,
101+
)
76102

77-
session.verify = self.ssl_verify
103+
return RetryClient(**kwargs)
78104

79-
retries = Retry(
80-
total=self.SESSION_RETRIES,
81-
backoff_factor=self.SESSION_BACKOFF_FACTOR,
105+
@cached_property
106+
def fs(self):
107+
from fsspec.implementations.http import (
108+
HTTPFileSystem as _HTTPFileSystem,
82109
)
83110

84-
session.mount("http://", HTTPAdapter(max_retries=retries))
85-
session.mount("https://", HTTPAdapter(max_retries=retries))
86-
87-
return session
88-
89-
def request(self, method, url, **kwargs):
90-
import requests
91-
92-
kwargs.setdefault("allow_redirects", True)
93-
kwargs.setdefault("timeout", self.REQUEST_TIMEOUT)
94-
95-
try:
96-
res = self._session.request(
97-
method,
98-
url,
99-
auth=self._auth_method(url),
100-
headers=self.headers,
101-
**kwargs,
102-
)
103-
104-
redirect_no_location = (
105-
kwargs["allow_redirects"]
106-
and res.status_code in (301, 302)
107-
and "location" not in res.headers
108-
)
109-
110-
if redirect_no_location:
111-
# AWS s3 doesn't like to add a location header to its redirects
112-
# from https://s3.amazonaws.com/<bucket name>/* type URLs.
113-
# This should be treated as an error
114-
raise requests.exceptions.RequestException
115-
116-
return res
117-
118-
except requests.exceptions.RequestException:
119-
raise DvcException(f"could not perform a {method} request")
120-
121-
def _head(self, url):
122-
response = self.request("HEAD", url)
123-
if response.ok:
124-
return response
125-
126-
# Sometimes servers are configured to forbid HEAD requests
127-
# Context: https://github.com/iterative/dvc/issues/4131
128-
with self.request("GET", url, stream=True) as r:
129-
if r.ok:
130-
return r
131-
132-
return response
133-
134-
def exists(self, path_info) -> bool:
135-
res = self._head(path_info.url)
136-
if res.status_code == 404:
137-
return False
138-
if bool(res):
139-
return True
140-
raise HTTPError(res.status_code, res.reason)
141-
142-
def info(self, path_info):
143-
resp = self._head(path_info.url)
144-
etag = resp.headers.get("ETag") or resp.headers.get("Content-MD5")
145-
size = self._content_length(resp)
146-
return {"etag": etag, "size": size, "type": "file"}
147-
148-
def _upload_fobj(self, fobj, to_info, **kwargs):
149-
def chunks(fobj):
150-
while True:
151-
chunk = fobj.read(self.CHUNK_SIZE)
152-
if not chunk:
153-
break
154-
yield chunk
155-
156-
response = self.request(self.method, to_info.url, data=chunks(fobj))
157-
if response.status_code not in (200, 201):
158-
raise HTTPError(response.status_code, response.reason)
159-
160-
def _download(self, from_info, to_file, name=None, no_progress_bar=False):
161-
response = self.request("GET", from_info.url, stream=True)
162-
if response.status_code != 200:
163-
raise HTTPError(response.status_code, response.reason)
164-
with open(to_file, "wb") as fd:
165-
with Tqdm.wrapattr(
166-
fd,
167-
"write",
168-
total=None
169-
if no_progress_bar
170-
else self._content_length(response),
171-
leave=False,
172-
desc=from_info.url if name is None else name,
173-
disable=no_progress_bar,
174-
) as fd_wrapped:
175-
for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE):
176-
fd_wrapped.write(chunk)
177-
178-
def _upload(
179-
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
180-
):
181-
with open(from_file, "rb") as fobj:
182-
self.upload_fobj(
183-
fobj,
184-
to_info,
185-
size=None if no_progress_bar else os.path.getsize(from_file),
186-
no_progress_bar=no_progress_bar,
187-
desc=name or to_info.url,
188-
)
189-
190-
def open(self, path_info, mode: str = "r", encoding: str = None, **kwargs):
191-
from dvc.utils.http import open_url
192-
193-
return open_url(
194-
path_info.url,
195-
mode=mode,
196-
encoding=encoding,
197-
auth=self._auth_method(path_info),
198-
**kwargs,
199-
)
111+
return _HTTPFileSystem(timeout=self.REQUEST_TIMEOUT)
200112

201-
@staticmethod
202-
def _content_length(response) -> Optional[int]:
203-
res = response.headers.get("Content-Length")
204-
return int(res) if res else None
113+
def _entry_hook(self, entry):
114+
entry["checksum"] = entry.get("ETag") or entry.get("Content-MD5")
115+
return entry

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def run(self):
9090
"typing_extensions>=3.7.4; python_version < '3.10'",
9191
# until https://github.com/python/typing/issues/865 is fixed for python3.10
9292
"typing_extensions==3.10.0.0; python_version >= '3.10'",
93-
"fsspec>=2021.8.1",
93+
"fsspec[http]>=2021.8.1",
94+
"aiohttp-retry==2.4.5",
9495
"diskcache>=5.2.1",
9596
]
9697

tests/func/test_fs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def test_fs_getsize(dvc, cloud):
260260
pytest.lazy_fixture("gs"),
261261
pytest.lazy_fixture("gdrive"),
262262
pytest.lazy_fixture("hdfs"),
263-
pytest.lazy_fixture("http"),
264263
pytest.lazy_fixture("local_cloud"),
265264
pytest.lazy_fixture("oss"),
266265
pytest.lazy_fixture("s3"),

0 commit comments

Comments
 (0)