diff --git a/dvc/fs/utils.py b/dvc/fs/utils.py index c585e6fa4c..6ef75a31cb 100644 --- a/dvc/fs/utils.py +++ b/dvc/fs/utils.py @@ -1,6 +1,7 @@ import logging import os -from typing import TYPE_CHECKING +from io import BytesIO +from typing import TYPE_CHECKING, BinaryIO, Union from .local import LocalFileSystem @@ -18,9 +19,25 @@ def transfer( to_fs: "BaseFileSystem", to_info: "DvcPath", move: bool = False, + content: Union[bytes, BinaryIO] = None, ) -> None: use_move = isinstance(from_fs, type(to_fs)) and move try: + if content: + if isinstance(content, bytes): + fobj: BinaryIO = BytesIO(content) + size = len(content) + else: + fobj = content + size = from_fs.getsize(from_info) + + desc = ( + from_info.name + if isinstance(from_info, from_fs.PATH_CLS) + else from_info + ) + return to_fs.upload_fobj(fobj, to_info, size=size, desc=desc) + if use_move: return to_fs.move(from_info, to_info) @@ -33,7 +50,8 @@ def transfer( return from_fs.download_file(from_info, to_info) with from_fs.open(from_info, mode="rb") as fobj: - return to_fs.upload_fobj(fobj, to_info) + size = from_fs.getsize(from_info) + return to_fs.upload_fobj(fobj, to_info, size=size) except OSError as exc: # If the target file already exists, we are going to simply # ignore the exception (#4992). diff --git a/dvc/objects/db/reference.py b/dvc/objects/db/reference.py index 0c1654ac03..36e2cf8320 100644 --- a/dvc/objects/db/reference.py +++ b/dvc/objects/db/reference.py @@ -1,6 +1,4 @@ -import io import logging -import os from typing import TYPE_CHECKING, Dict from dvc.scheme import Schemes @@ -60,25 +58,19 @@ def _add_file( hash_info: "HashInfo", move: bool = False, ): + from dvc import fs + self.makedirs(to_info.parent) if hash_info.isdir: return super()._add_file( - from_fs, from_info, to_info, hash_info, move + from_fs, from_info, to_info, hash_info, move=move ) + ref_file = ReferenceHashFile(from_info, from_fs, hash_info) self._obj_cache[hash_info] = ref_file - ref_fobj = io.BytesIO(ref_file.to_bytes()) - ref_fobj.seek(0) - try: - self.fs.upload_fobj(ref_fobj, to_info) - except OSError as exc: - if isinstance(exc, FileExistsError) or ( - os.name == "nt" - and exc.__context__ - and isinstance(exc.__context__, FileExistsError) - ): - logger.debug("'%s' file already exists, skipping", to_info) - else: - raise + content = ref_file.to_bytes() + fs.utils.transfer( + from_fs, from_info, self.fs, to_info, move=move, content=content + ) if from_fs.scheme != Schemes.LOCAL: self._fs_cache[ReferenceHashFile.config_tuple(from_fs)] = from_fs diff --git a/dvc/objects/stage.py b/dvc/objects/stage.py index c86fc1b485..0342ce5bbd 100644 --- a/dvc/objects/stage.py +++ b/dvc/objects/stage.py @@ -28,18 +28,17 @@ def _upload_file(path_info, fs, odb, upload_odb): + from dvc.fs.utils import transfer from dvc.utils import tmp_fname from dvc.utils.stream import HashedStreamReader tmp_info = upload_odb.path_info / tmp_fname() with fs.open(path_info, mode="rb", chunk_size=fs.CHUNK_SIZE) as stream: - stream = HashedStreamReader(stream) - upload_odb.fs.upload_fobj( - stream, tmp_info, desc=path_info.name, size=fs.getsize(path_info) - ) + wrapped = HashedStreamReader(stream) + transfer(fs, path_info, upload_odb.fs, tmp_info, content=wrapped) - odb.add(tmp_info, upload_odb.fs, stream.hash_info) - return path_info, odb.get(stream.hash_info) + odb.add(tmp_info, upload_odb.fs, wrapped.hash_info) + return path_info, odb.get(wrapped.hash_info) def _get_file_hash(path_info, fs, name):