Skip to content

Commit 9426ade

Browse files
committed
fs.upload: add support for optional callback
Also removes _upload methods in other fses, and replaces them with fsspec-compatible put_file which supports fsspec callback even if they are not fsspec based filesystem.
1 parent 6b582bd commit 9426ade

File tree

11 files changed

+130
-105
lines changed

11 files changed

+130
-105
lines changed

dvc/fs/base.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import contextlib
12
import logging
23
from concurrent.futures import ThreadPoolExecutor, as_completed
34
from functools import partial, partialmethod
45
from multiprocessing import cpu_count
56
from typing import Any, ClassVar, Dict, FrozenSet, Optional
67

8+
from funcy import cached_property
9+
710
from dvc.exceptions import DvcException
811
from dvc.path_info import URLInfo
912
from dvc.progress import Tqdm
13+
from dvc.ui import ui
1014
from dvc.utils import tmp_fname
1115
from dvc.utils.fs import makedirs, move
1216

@@ -217,8 +221,22 @@ def is_dir_hash(cls, hash_):
217221
return False
218222
return hash_.endswith(cls.CHECKSUM_DIR_SUFFIX)
219223

220-
def upload(self, from_info, to_info, name=None, no_progress_bar=False):
221-
if not hasattr(self, "_upload"):
224+
@cached_property
225+
def _local_fs(self):
226+
from dvc.fs import LocalFileSystem
227+
228+
return LocalFileSystem()
229+
230+
def upload(
231+
self,
232+
from_info,
233+
to_info,
234+
name=None,
235+
callback=None,
236+
no_progress_bar=False,
237+
**pbar_kw,
238+
):
239+
if not hasattr(self, "put_file"):
222240
raise RemoteActionNotImplemented("upload", self.scheme)
223241

224242
if to_info.scheme != self.scheme:
@@ -231,12 +249,23 @@ def upload(self, from_info, to_info, name=None, no_progress_bar=False):
231249

232250
name = name or from_info.name
233251

234-
self._upload( # noqa, pylint: disable=no-member
235-
from_info.fspath,
236-
to_info,
237-
name=name,
238-
no_progress_bar=no_progress_bar,
239-
)
252+
stack = contextlib.ExitStack()
253+
fs = self._local_fs
254+
if not callback:
255+
pbar = ui.progress(
256+
desc=name,
257+
disable=no_progress_bar,
258+
bytes=True,
259+
total=-1,
260+
**pbar_kw,
261+
)
262+
stack.enter_context(pbar)
263+
callback = pbar.as_callback(fs, from_info)
264+
265+
with stack:
266+
src = from_info.fspath
267+
# pylint: disable=no-member
268+
return self.put_file(src, to_info, callback=callback)
240269

241270
def upload_fobj(
242271
self, fobj, to_info, no_progress_bar=False, size=None, **pbar_args

dvc/fs/fsspec_wrapper.py

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from funcy import cached_property
66

7-
from dvc.progress import Tqdm
7+
from dvc.progress import DEFAULT_CALLBACK, Tqdm
88

99
from .base import BaseFileSystem
1010
from .local import LocalFileSystem
@@ -130,26 +130,19 @@ def makedirs(self, path_info, **kwargs):
130130
self._with_bucket(path_info), exist_ok=kwargs.pop("exist_ok", True)
131131
)
132132

133+
def put_file(
134+
self, from_file, to_info, callback=DEFAULT_CALLBACK, **kwargs
135+
):
136+
self.fs.put_file(
137+
from_file, self._with_bucket(to_info), callback=callback, **kwargs
138+
)
139+
self.fs.invalidate_cache(self._with_bucket(to_info.parent))
140+
133141
def _upload_fobj(self, fobj, to_info, size=None):
134142
self.makedirs(to_info.parent)
135143
with self.open(to_info, "wb") as fdest:
136144
shutil.copyfileobj(fobj, fdest, length=fdest.blocksize)
137145

138-
def _upload(
139-
self, from_file, to_info, name=None, no_progress_bar=False, **kwargs
140-
):
141-
self.makedirs(to_info.parent)
142-
size = os.path.getsize(from_file)
143-
with open(from_file, "rb") as fobj:
144-
self.upload_fobj(
145-
fobj,
146-
to_info,
147-
size=size,
148-
desc=name,
149-
no_progress_bar=no_progress_bar,
150-
)
151-
self.fs.invalidate_cache(self._with_bucket(to_info.parent))
152-
153146
def _download(
154147
self, from_info, to_file, name=None, no_progress_bar=False, **pbar_args
155148
):
@@ -261,23 +254,6 @@ class CallbackMixin:
261254
"""Use the native ``get_file()``/``put_file()`` APIs
262255
if the target filesystem supports callbacks."""
263256

264-
def _upload(
265-
self, from_file, to_info, name=None, no_progress_bar=False, **pbar_args
266-
):
267-
with Tqdm(
268-
desc=name,
269-
disable=no_progress_bar,
270-
bytes=True,
271-
total=-1,
272-
**pbar_args,
273-
) as pbar:
274-
self.fs.put_file(
275-
os.fspath(from_file),
276-
self._with_bucket(to_info),
277-
callback=pbar.as_callback(_LOCAL_FS, from_file),
278-
)
279-
self.fs.invalidate_cache(self._with_bucket(to_info.parent))
280-
281257
def _download(
282258
self, from_info, to_file, name=None, no_progress_bar=False, **pbar_args
283259
):

dvc/fs/hdfs.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import closing, contextmanager
1010

1111
from dvc.hash_info import HashInfo
12-
from dvc.progress import Tqdm
12+
from dvc.progress import DEFAULT_CALLBACK, Tqdm
1313
from dvc.scheme import Schemes
1414
from dvc.utils import fix_env, tmp_fname
1515

@@ -249,22 +249,22 @@ def _upload_fobj(self, fobj, to_info, **kwargs):
249249
with hdfs.open_output_stream(to_info.path) as fdest:
250250
shutil.copyfileobj(fobj, fdest, self.BLOCK_SIZE)
251251

252-
def _upload(
253-
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
252+
def put_file(
253+
self, from_file, to_info, callback=DEFAULT_CALLBACK, **kwargs
254254
):
255+
from tqdm.utils import CallbackIOWrapper
256+
255257
with self.hdfs(to_info) as hdfs:
258+
hdfs.create_dir(to_info.parent.path)
259+
256260
tmp_file = tmp_fname(to_info.path)
257-
total = os.path.getsize(from_file)
261+
total: int = os.path.getsize(from_file)
262+
callback.set_size(total)
263+
258264
with open(from_file, "rb") as fobj:
259-
with Tqdm.wrapattr(
260-
fobj,
261-
"read",
262-
desc=name,
263-
total=total,
264-
disable=no_progress_bar,
265-
) as wrapped:
266-
with hdfs.open_output_stream(tmp_file) as sobj:
267-
shutil.copyfileobj(wrapped, sobj, self.BLOCK_SIZE)
265+
wrapped = CallbackIOWrapper(callback.relative_update, fobj)
266+
with hdfs.open_output_stream(tmp_file) as sobj:
267+
shutil.copyfileobj(wrapped, sobj, self.BLOCK_SIZE)
268268
hdfs.move(tmp_file, to_info.path)
269269

270270
def _download(

dvc/fs/local.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dvc.utils import is_exec, tmp_fname
88
from dvc.utils.fs import copy_fobj_to_file, copyfile, makedirs, move, remove
99

10+
from ..progress import DEFAULT_CALLBACK
1011
from .base import BaseFileSystem
1112

1213
logger = logging.getLogger(__name__)
@@ -158,15 +159,12 @@ def reflink(self, from_info, to_info):
158159
def info(self, path_info):
159160
return self.fs.info(path_info)
160161

161-
def _upload(
162-
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
162+
def put_file(
163+
self, from_file, to_info, callback=DEFAULT_CALLBACK, **kwargs
163164
):
164165
makedirs(to_info.parent, exist_ok=True)
165-
166166
tmp_file = tmp_fname(to_info)
167-
copyfile(
168-
from_file, tmp_file, name=name, no_progress_bar=no_progress_bar
169-
)
167+
copyfile(from_file, tmp_file, callback=callback)
170168
os.replace(tmp_file, to_info)
171169

172170
@staticmethod

dvc/fs/s3.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from funcy import cached_property, wrap_prop
77

88
from dvc.path_info import CloudURLInfo
9-
from dvc.progress import Tqdm
9+
from dvc.progress import DEFAULT_CALLBACK, Tqdm
1010
from dvc.scheme import Schemes
1111

1212
from .fsspec_wrapper import ObjectFSWrapper
@@ -204,24 +204,18 @@ def _get_obj(self, path_info):
204204
return bucket.Object(path_info.path)
205205

206206
@_translate_exceptions
207-
def _upload(
208-
self, from_file, to_info, name=None, no_progress_bar=False, **pbar_args
207+
def put_file(
208+
self, from_file, to_info, callback=DEFAULT_CALLBACK, **kwargs
209209
):
210-
total = os.path.getsize(from_file)
211-
with Tqdm(
212-
disable=no_progress_bar,
213-
total=total,
214-
bytes=True,
215-
desc=name,
216-
**pbar_args,
217-
) as pbar:
218-
obj = self._get_obj(to_info)
219-
obj.upload_file(
220-
from_file,
221-
Callback=pbar.update,
222-
ExtraArgs=self.fs_args.get("s3_additional_kwargs"),
223-
Config=self._transfer_config,
224-
)
210+
callback.set_size(os.path.getsize(from_file))
211+
212+
obj = self._get_obj(to_info)
213+
obj.upload_file(
214+
from_file,
215+
Callback=callback.relative_update,
216+
ExtraArgs=self.fs_args.get("s3_additional_kwargs"),
217+
Config=self._transfer_config,
218+
)
225219
self.fs.invalidate_cache(self._with_bucket(to_info.parent))
226220

227221
@_translate_exceptions

dvc/fs/ssh.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dvc.scheme import Schemes
99
from dvc.utils.fs import as_atomic
1010

11+
from ..progress import DEFAULT_CALLBACK
1112
from .fsspec_wrapper import CallbackMixin, FSSpecWrapper
1213

1314
_SSH_TIMEOUT = 60 * 30
@@ -123,6 +124,8 @@ def _upload_fobj(self, fobj, to_info, *args, **kwargs):
123124
with as_atomic(self, to_info) as tmp_file:
124125
super()._upload_fobj(fobj, tmp_file, *args, **kwargs)
125126

126-
def _upload(self, from_file, to_info, *args, **kwargs):
127+
def put_file(
128+
self, from_file, to_info, callback=DEFAULT_CALLBACK, **kwargs
129+
):
127130
with as_atomic(self, to_info) as tmp_file:
128-
super()._upload(from_file, tmp_file, *args, **kwargs)
131+
super().put_file(from_file, tmp_file, callback=callback, **kwargs)

dvc/fs/webhdfs.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from dvc.hash_info import HashInfo
1111
from dvc.path_info import CloudURLInfo
12-
from dvc.progress import Tqdm
12+
from dvc.progress import DEFAULT_CALLBACK, Tqdm
1313
from dvc.scheme import Schemes
1414

1515
from .base import BaseFileSystem
@@ -29,6 +29,15 @@ def update(_, bytes_transfered):
2929
return update
3030

3131

32+
def update_callback(callback, total):
33+
def update(_, bytes_transfered):
34+
if bytes_transfered == -1:
35+
return callback.absolute_update(total)
36+
return callback.relative_update(bytes_transfered)
37+
38+
return update
39+
40+
3241
class WebHDFSFileSystem(BaseFileSystem): # pylint:disable=abstract-method
3342
scheme = Schemes.WEBHDFS
3443
PATH_CLS = CloudURLInfo
@@ -145,19 +154,19 @@ def _upload_fobj(self, fobj, to_info, **kwargs):
145154
with self.hdfs_client.write(to_info.path) as fdest:
146155
shutil.copyfileobj(fobj, fdest)
147156

148-
def _upload(
149-
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
157+
def put_file(
158+
self, from_file, to_info, callback=DEFAULT_CALLBACK, **kwargs
150159
):
151160
total = os.path.getsize(from_file)
152-
with Tqdm(
153-
desc=name, total=total, disable=no_progress_bar, bytes=True
154-
) as pbar:
155-
self.hdfs_client.upload(
156-
to_info.path,
157-
from_file,
158-
overwrite=True,
159-
progress=update_pbar(pbar, total),
160-
)
161+
callback.set_size(total)
162+
163+
self.hdfs_client.makedirs(to_info.parent.path)
164+
return self.hdfs_client.upload(
165+
to_info.path,
166+
from_file,
167+
overwrite=True,
168+
progress=update_callback(callback, total),
169+
)
161170

162171
def _download(
163172
self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs

dvc/progress.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Manages progress bars for DVC repo."""
2-
32
import logging
43
import sys
54
from threading import RLock
@@ -188,3 +187,20 @@ def relative_update(self, inc=1):
188187

189188
def absolute_update(self, value):
190189
self.progress_bar.update_to(value)
190+
191+
192+
def tdqm_or_callback_wrapped(
193+
fobj, method, total, callback=None, **pbar_kwargs
194+
):
195+
if callback:
196+
from funcy import nullcontext
197+
from tqdm.utils import CallbackIOWrapper
198+
199+
callback.set_size(total)
200+
wrapper = CallbackIOWrapper(callback.relative_update, fobj, method)
201+
return nullcontext(wrapper)
202+
203+
return Tqdm.wrapattr(fobj, method, total=total, bytes=True, **pbar_kwargs)
204+
205+
206+
DEFAULT_CALLBACK = fsspec.callbacks.NoOpCallback

dvc/utils/fs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,8 @@ def makedirs(path, exist_ok=False, mode=None):
193193
logger.trace("failed to chmod '%o' '%s'", mode, path, exc_info=True)
194194

195195

196-
def copyfile(src, dest, no_progress_bar=False, name=None):
196+
def copyfile(src, dest, callback=None, no_progress_bar=False, name=None):
197197
"""Copy file with progress bar"""
198-
from dvc.progress import Tqdm
199-
200198
name = name if name else os.path.basename(dest)
201199
total = os.stat(src).st_size
202200

@@ -206,20 +204,22 @@ def copyfile(src, dest, no_progress_bar=False, name=None):
206204
try:
207205
System.reflink(src, dest)
208206
except DvcException:
207+
from dvc.progress import tdqm_or_callback_wrapped
208+
209209
with open(src, "rb") as fsrc, open(dest, "wb+") as fdest:
210-
with Tqdm.wrapattr(
210+
with tdqm_or_callback_wrapped(
211211
fdest,
212212
"write",
213-
desc=name,
213+
total,
214+
callback=callback,
214215
disable=no_progress_bar,
215-
total=total,
216-
bytes=True,
217-
) as fdest_wrapped:
216+
desc=name,
217+
) as wrapped:
218218
while True:
219219
buf = fsrc.read(LOCAL_CHUNK_SIZE)
220220
if not buf:
221221
break
222-
fdest_wrapped.write(buf)
222+
wrapped.write(buf)
223223

224224

225225
def copy_fobj_to_file(fsrc, dest):

0 commit comments

Comments
 (0)