1
- import logging
2
- import os .path
3
1
import threading
4
- from typing import Optional
5
- from urllib .parse import urlparse
6
2
7
- from funcy import cached_property , memoize , wrap_prop , wrap_with
3
+ from funcy import cached_property , memoize , wrap_with
8
4
9
5
from dvc import prompt
10
- from dvc .exceptions import DvcException , HTTPError
11
6
from dvc .path_info import HTTPURLInfo
12
- from dvc .progress import Tqdm
13
7
from dvc .scheme import Schemes
14
8
15
- from .base import BaseFileSystem
16
-
17
- logger = logging .getLogger (__name__ )
9
+ from .fsspec_wrapper import CallbackMixin , FSSpecWrapper , NoDirectoriesMixin
18
10
19
11
20
12
@wrap_with (threading .Lock ())
@@ -26,179 +18,98 @@ def ask_password(host, user):
26
18
)
27
19
28
20
29
- class HTTPFileSystem (BaseFileSystem ): # pylint:disable=abstract-method
21
+ def make_context (ssl_verify ):
22
+ if isinstance (ssl_verify , bool ) or ssl_verify is None :
23
+ return ssl_verify
24
+
25
+ # If this is a path, then we will create an
26
+ # SSL context for it, and load the given certificate.
27
+ import ssl
28
+
29
+ context = ssl .create_default_context ()
30
+ context .load_verify_locations (ssl_verify )
31
+ return context
32
+
33
+
34
+ # pylint: disable=abstract-method
35
+ class HTTPFileSystem (CallbackMixin , NoDirectoriesMixin , FSSpecWrapper ):
30
36
scheme = Schemes .HTTP
31
37
PATH_CLS = HTTPURLInfo
32
- PARAM_CHECKSUM = "etag"
38
+ PARAM_CHECKSUM = "checksum"
39
+ REQUIRES = {"aiohttp" : "aiohttp" , "aiohttp-retry" : "aiohttp_retry" }
33
40
CAN_TRAVERSE = False
34
- REQUIRES = {"requests" : "requests" }
35
41
36
42
SESSION_RETRIES = 5
37
43
SESSION_BACKOFF_FACTOR = 0.1
38
44
REQUEST_TIMEOUT = 60
39
- CHUNK_SIZE = 2 ** 16
40
-
41
- def __init__ (self , ** config ):
42
- super ().__init__ (** config )
43
-
44
- self .user = config .get ("user" , None )
45
-
46
- self .auth = config .get ("auth" , None )
47
- self .custom_auth_header = config .get ("custom_auth_header" , None )
48
- self .password = config .get ("password" , None )
49
- self .ask_password = config .get ("ask_password" , False )
50
- self .headers = {}
51
- self .ssl_verify = config .get ("ssl_verify" , True )
52
- self .method = config .get ("method" , "POST" )
53
-
54
- def _auth_method (self , url ):
55
- from requests .auth import HTTPBasicAuth , HTTPDigestAuth
56
-
57
- if self .auth :
58
- if self .ask_password and self .password is None :
59
- self .password = ask_password (urlparse (url ).hostname , self .user )
60
- if self .auth == "basic" :
61
- return HTTPBasicAuth (self .user , self .password )
62
- if self .auth == "digest" :
63
- return HTTPDigestAuth (self .user , self .password )
64
- if self .auth == "custom" and self .custom_auth_header :
65
- self .headers .update ({self .custom_auth_header : self .password })
66
- return None
67
-
68
- @wrap_prop (threading .Lock ())
69
- @cached_property
70
- def _session (self ):
71
- import requests
72
- from requests .adapters import HTTPAdapter
73
- from urllib3 .util .retry import Retry
74
45
75
- session = requests .Session ()
46
+ def _prepare_credentials (self , ** config ):
47
+ import aiohttp
48
+ from fsspec .asyn import fsspec_loop
49
+
50
+ from dvc .config import ConfigError
51
+
52
+ credentials = {}
53
+ client_args = credentials .setdefault ("client_args" , {})
54
+
55
+ if config .get ("auth" ):
56
+ user = config .get ("user" )
57
+ password = config .get ("password" )
58
+ custom_auth_header = config .get ("custom_auth_header" )
59
+
60
+ if password is None and config .get ("ask_password" ):
61
+ password = ask_password (config .get ("url" ), user or "custom" )
62
+
63
+ auth_method = config ["auth" ]
64
+ if auth_method == "basic" :
65
+ if user is None or password is None :
66
+ raise ConfigError (
67
+ "HTTP 'basic' authentication require both "
68
+ "'user' and 'password'"
69
+ )
70
+
71
+ client_args ["auth" ] = aiohttp .BasicAuth (user , password )
72
+ elif auth_method == "custom" :
73
+ if custom_auth_header is None or password is None :
74
+ raise ConfigError (
75
+ "HTTP 'custom' authentication require both "
76
+ "'custom_auth_header' and 'password'"
77
+ )
78
+ credentials ["headers" ] = {custom_auth_header : password }
79
+ else :
80
+ raise NotImplementedError (
81
+ f"Auth method { auth_method !r} is not supported."
82
+ )
83
+
84
+ if "ssl_verify" in config :
85
+ with fsspec_loop ():
86
+ client_args ["connector" ] = aiohttp .TCPConnector (
87
+ ssl = make_context (config ["ssl_verify" ])
88
+ )
89
+
90
+ credentials ["get_client" ] = self .get_client
91
+ self .upload_method = config .get ("method" , "POST" )
92
+ return credentials
93
+
94
+ async def get_client (self , ** kwargs ):
95
+ from aiohttp_retry import ExponentialRetry , RetryClient
96
+
97
+ kwargs ["retry_options" ] = ExponentialRetry (
98
+ attempts = self .SESSION_RETRIES ,
99
+ factor = self .SESSION_BACKOFF_FACTOR ,
100
+ max_timeout = self .REQUEST_TIMEOUT ,
101
+ )
76
102
77
- session . verify = self . ssl_verify
103
+ return RetryClient ( ** kwargs )
78
104
79
- retries = Retry (
80
- total = self .SESSION_RETRIES ,
81
- backoff_factor = self .SESSION_BACKOFF_FACTOR ,
105
+ @cached_property
106
+ def fs (self ):
107
+ from fsspec .implementations .http import (
108
+ HTTPFileSystem as _HTTPFileSystem ,
82
109
)
83
110
84
- session .mount ("http://" , HTTPAdapter (max_retries = retries ))
85
- session .mount ("https://" , HTTPAdapter (max_retries = retries ))
86
-
87
- return session
88
-
89
- def request (self , method , url , ** kwargs ):
90
- import requests
91
-
92
- kwargs .setdefault ("allow_redirects" , True )
93
- kwargs .setdefault ("timeout" , self .REQUEST_TIMEOUT )
94
-
95
- try :
96
- res = self ._session .request (
97
- method ,
98
- url ,
99
- auth = self ._auth_method (url ),
100
- headers = self .headers ,
101
- ** kwargs ,
102
- )
103
-
104
- redirect_no_location = (
105
- kwargs ["allow_redirects" ]
106
- and res .status_code in (301 , 302 )
107
- and "location" not in res .headers
108
- )
109
-
110
- if redirect_no_location :
111
- # AWS s3 doesn't like to add a location header to its redirects
112
- # from https://s3.amazonaws.com/<bucket name>/* type URLs.
113
- # This should be treated as an error
114
- raise requests .exceptions .RequestException
115
-
116
- return res
117
-
118
- except requests .exceptions .RequestException :
119
- raise DvcException (f"could not perform a { method } request" )
120
-
121
- def _head (self , url ):
122
- response = self .request ("HEAD" , url )
123
- if response .ok :
124
- return response
125
-
126
- # Sometimes servers are configured to forbid HEAD requests
127
- # Context: https://github.com/iterative/dvc/issues/4131
128
- with self .request ("GET" , url , stream = True ) as r :
129
- if r .ok :
130
- return r
131
-
132
- return response
133
-
134
- def exists (self , path_info ) -> bool :
135
- res = self ._head (path_info .url )
136
- if res .status_code == 404 :
137
- return False
138
- if bool (res ):
139
- return True
140
- raise HTTPError (res .status_code , res .reason )
141
-
142
- def info (self , path_info ):
143
- resp = self ._head (path_info .url )
144
- etag = resp .headers .get ("ETag" ) or resp .headers .get ("Content-MD5" )
145
- size = self ._content_length (resp )
146
- return {"etag" : etag , "size" : size , "type" : "file" }
147
-
148
- def _upload_fobj (self , fobj , to_info , ** kwargs ):
149
- def chunks (fobj ):
150
- while True :
151
- chunk = fobj .read (self .CHUNK_SIZE )
152
- if not chunk :
153
- break
154
- yield chunk
155
-
156
- response = self .request (self .method , to_info .url , data = chunks (fobj ))
157
- if response .status_code not in (200 , 201 ):
158
- raise HTTPError (response .status_code , response .reason )
159
-
160
- def _download (self , from_info , to_file , name = None , no_progress_bar = False ):
161
- response = self .request ("GET" , from_info .url , stream = True )
162
- if response .status_code != 200 :
163
- raise HTTPError (response .status_code , response .reason )
164
- with open (to_file , "wb" ) as fd :
165
- with Tqdm .wrapattr (
166
- fd ,
167
- "write" ,
168
- total = None
169
- if no_progress_bar
170
- else self ._content_length (response ),
171
- leave = False ,
172
- desc = from_info .url if name is None else name ,
173
- disable = no_progress_bar ,
174
- ) as fd_wrapped :
175
- for chunk in response .iter_content (chunk_size = self .CHUNK_SIZE ):
176
- fd_wrapped .write (chunk )
177
-
178
- def _upload (
179
- self , from_file , to_info , name = None , no_progress_bar = False , ** _kwargs
180
- ):
181
- with open (from_file , "rb" ) as fobj :
182
- self .upload_fobj (
183
- fobj ,
184
- to_info ,
185
- size = None if no_progress_bar else os .path .getsize (from_file ),
186
- no_progress_bar = no_progress_bar ,
187
- desc = name or to_info .url ,
188
- )
189
-
190
- def open (self , path_info , mode : str = "r" , encoding : str = None , ** kwargs ):
191
- from dvc .utils .http import open_url
192
-
193
- return open_url (
194
- path_info .url ,
195
- mode = mode ,
196
- encoding = encoding ,
197
- auth = self ._auth_method (path_info ),
198
- ** kwargs ,
199
- )
111
+ return _HTTPFileSystem (timeout = self .REQUEST_TIMEOUT )
200
112
201
- @staticmethod
202
- def _content_length (response ) -> Optional [int ]:
203
- res = response .headers .get ("Content-Length" )
204
- return int (res ) if res else None
113
+ def _entry_hook (self , entry ):
114
+ entry ["checksum" ] = entry .get ("ETag" ) or entry .get ("Content-MD5" )
115
+ return entry
0 commit comments