diff --git a/dvc/cache/base.py b/dvc/cache/base.py index 4ec3683d26..759d66a48d 100644 --- a/dvc/cache/base.py +++ b/dvc/cache/base.py @@ -3,6 +3,7 @@ import json import logging import os +from concurrent import futures from concurrent.futures import ThreadPoolExecutor from copy import copy from typing import Optional @@ -247,6 +248,116 @@ def _save_file(self, path_info, tree, hash_info, save_link=True, **kwargs): self.tree.state.save(cache_info, hash_info) + def _transfer_file_as_whole(self, from_tree, from_info): + from dvc.utils import tmp_fname + + # When we can't use the chunked upload, we have to first download + # and then calculate the hash as if it were a local file and then + # upload it. + local_tree = self.repo.cache.local.tree + local_info = local_tree.path_info / tmp_fname() + + from_tree.download(from_info, local_info) + hash_info = local_tree.get_file_hash(local_info) + + self.tree.upload( + local_info, + self.tree.hash_to_path_info(hash_info.value), + name=from_info.name, + ) + return hash_info + + def _transfer_file_as_chunked(self, from_tree, from_info): + from dvc.utils import tmp_fname + from dvc.utils.stream import HashedStreamReader + + tmp_info = self.tree.path_info / tmp_fname() + with from_tree.open( + from_info, mode="rb", chunk_size=from_tree.CHUNK_SIZE + ) as stream: + stream_reader = HashedStreamReader(stream) + # Since we don't know the hash beforehand, we'll + # upload it to a temporary location and then move + # it. + self.tree.upload_fobj( + stream_reader, + tmp_info, + total=from_tree.getsize(from_info), + desc=from_info.name, + ) + + hash_info = stream_reader.hash_info + self.tree.move(tmp_info, self.tree.hash_to_path_info(hash_info.value)) + return hash_info + + def _transfer_file(self, from_tree, from_info): + try: + hash_info = self._transfer_file_as_chunked(from_tree, from_info) + except RemoteActionNotImplemented: + hash_info = self._transfer_file_as_whole(from_tree, from_info) + + return hash_info + + def _transfer_directory_contents(self, from_tree, from_info, jobs, pbar): + rel_path_infos = {} + from_infos = from_tree.walk_files(from_info) + + def create_tasks(executor, amount): + for entry_info in itertools.islice(from_infos, amount): + pbar.total += 1 + task = executor.submit( + pbar.wrap_fn(self._transfer_file), from_tree, entry_info + ) + rel_path_infos[task] = entry_info.relative_to(from_info) + yield task + + pbar.total = 0 + with ThreadPoolExecutor(max_workers=jobs) as executor: + tasks = set(create_tasks(executor, jobs * 5)) + + while tasks: + done, tasks = futures.wait( + tasks, return_when=futures.FIRST_COMPLETED + ) + tasks.update(create_tasks(executor, len(done))) + for task in done: + yield rel_path_infos.pop(task), task.result() + + def _transfer_directory( + self, from_tree, from_info, jobs, no_progress_bar=False + ): + dir_info = DirInfo() + + with Tqdm(total=1, unit="Files", disable=no_progress_bar) as pbar: + for entry_info, entry_hash in self._transfer_directory_contents( + from_tree, from_info, jobs, pbar + ): + dir_info.trie[entry_info.parts] = entry_hash + + local_cache = self.repo.cache.local + ( + hash_info, + to_info, + ) = local_cache._get_dir_info_hash( # pylint: disable=protected-access + dir_info + ) + + self.tree.upload(to_info, self.tree.hash_to_path_info(hash_info.value)) + return hash_info + + def transfer(self, from_tree, from_info, jobs=None, no_progress_bar=False): + jobs = jobs or min((from_tree.jobs, self.tree.jobs)) + + if from_tree.isdir(from_info): + return self._transfer_directory( + from_tree, + from_info, + jobs=jobs, + no_progress_bar=no_progress_bar, + ) + else: + return self._transfer_file(from_tree, from_info) + def _cache_is_copy(self, path_info): """Checks whether cache uses copies.""" if self.cache_type_confirmed: diff --git a/dvc/command/add.py b/dvc/command/add.py index db112162f9..af359b6af7 100644 --- a/dvc/command/add.py +++ b/dvc/command/add.py @@ -26,6 +26,10 @@ def run(self): external=self.args.external, glob=self.args.glob, desc=self.args.desc, + out=self.args.out, + remote=self.args.remote, + to_remote=self.args.to_remote, + jobs=self.args.jobs, ) except DvcException: @@ -74,6 +78,35 @@ def add_parser(subparsers, parent_parser): help="Specify name of the DVC-file this command will generate.", metavar="", ) + parser.add_argument( + "-o", + "--out", + help="Destination path to put files to.", + metavar="", + ) + parser.add_argument( + "--to-remote", + action="store_true", + default=False, + help="Download it directly to the remote", + ) + parser.add_argument( + "-r", + "--remote", + help="Remote storage to download to", + metavar="", + ) + parser.add_argument( + "-j", + "--jobs", + type=int, + help=( + "Number of jobs to run simultaneously. " + "The default value is 4 * cpu_count(). " + "For SSH remotes, the default is 4. " + ), + metavar="", + ) parser.add_argument( "--desc", type=str, diff --git a/dvc/command/imp_url.py b/dvc/command/imp_url.py index a5b5767f78..b11245da6c 100644 --- a/dvc/command/imp_url.py +++ b/dvc/command/imp_url.py @@ -16,7 +16,10 @@ def run(self): out=self.args.out, fname=self.args.file, no_exec=self.args.no_exec, + remote=self.args.remote, + to_remote=self.args.to_remote, desc=self.args.desc, + jobs=self.args.jobs, ) except DvcException: logger.exception( @@ -68,6 +71,29 @@ def add_parser(subparsers, parent_parser): default=False, help="Only create DVC-file without actually downloading it.", ) + import_parser.add_argument( + "--to-remote", + action="store_true", + default=False, + help="Download it directly to the remote", + ) + import_parser.add_argument( + "-r", + "--remote", + help="Remote storage to download to", + metavar="", + ) + import_parser.add_argument( + "-j", + "--jobs", + type=int, + help=( + "Number of jobs to run simultaneously. " + "The default value is 4 * cpu_count(). " + "For SSH remotes, the default is 4. " + ), + metavar="", + ) import_parser.add_argument( "--desc", type=str, diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index c4dab8e8cd..93539d7621 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -92,6 +92,24 @@ def pull( show_checksums=show_checksums, ) + def transfer(self, source, jobs=None, remote=None, command=None): + """Transfer data items in a cloud-agnostic way. + + Args: + source (str): url for the source location. + jobs (int): number of jobs that can be running simultaneously. + remote (dvc.remote.base.BaseRemote): optional remote to compare + cache to. By default remote from core.remote config option + is used. + command (bool): the command which is benefitting from this function + (to be used for reporting better error messages). + """ + from dvc.tree import get_cloud_tree + + from_tree = get_cloud_tree(self.repo, url=source) + remote = self.get_remote(remote, command) + return remote.transfer(from_tree, from_tree.path_info, jobs=jobs) + def status( self, cache, diff --git a/dvc/istextfile.py b/dvc/istextfile.py index e4c29fdcf1..917a6e3541 100644 --- a/dvc/istextfile.py +++ b/dvc/istextfile.py @@ -7,19 +7,7 @@ TEXT_CHARS = bytes(range(32, 127)) + b"\n\r\t\f\b" -def istextfile(fname, blocksize=512, tree=None): - """ Uses heuristics to guess whether the given file is text or binary, - by reading a single block of bytes from the file. - If more than 30% of the chars in the block are non-text, or there - are NUL ('\x00') bytes in the block, assume this is a binary file. - """ - if tree: - open_func = tree.open - else: - open_func = open - with open_func(fname, "rb") as fobj: - block = fobj.read(blocksize) - +def istextblock(block): if not block: # An empty file is considered a valid text file return True @@ -32,3 +20,18 @@ def istextfile(fname, blocksize=512, tree=None): # occurrences of TEXT_CHARS from the block nontext = block.translate(None, TEXT_CHARS) return float(len(nontext)) / len(block) <= 0.30 + + +def istextfile(fname, blocksize=512, tree=None): + """ Uses heuristics to guess whether the given file is text or binary, + by reading a single block of bytes from the file. + If more than 30% of the chars in the block are non-text, or there + are NUL ('\x00') bytes in the block, assume this is a binary file. + """ + if tree: + open_func = tree.open + else: + open_func = open + with open_func(fname, "rb") as fobj: + block = fobj.read(blocksize) + return istextblock(block) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index eb2d2f86b2..773cc53e2a 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -472,6 +472,11 @@ def pull(self, cache, named_cache, jobs=None, show_checksums=False): return ret + def transfer(self, from_tree, from_info, jobs=None, no_progress_bar=False): + return self.cache.transfer( + from_tree, from_info, jobs=jobs, no_progress_bar=no_progress_bar + ) + @staticmethod def _log_missing_caches(hash_info_dict): missing_caches = [ diff --git a/dvc/repo/add.py b/dvc/repo/add.py index 5e59114f55..ac93f35ad9 100644 --- a/dvc/repo/add.py +++ b/dvc/repo/add.py @@ -6,13 +6,14 @@ from ..exceptions import ( CacheLinkError, + InvalidArgumentError, OutputDuplicationError, OverlappingOutputPathsError, RecursiveAddingWhileUsingFilename, ) from ..progress import Tqdm from ..repo.scm_context import scm_context -from ..utils import LARGE_DIR_SIZE, glob_targets, resolve_paths +from ..utils import LARGE_DIR_SIZE, glob_targets, resolve_output, resolve_paths from . import locked if TYPE_CHECKING: @@ -24,12 +25,13 @@ @locked @scm_context -def add( +def add( # noqa: C901 repo, targets: "TargetType", recursive=False, no_commit=False, fname=None, + to_remote=False, **kwargs, ): from dvc.utils.collections import ensure_list @@ -38,6 +40,28 @@ def add( raise RecursiveAddingWhileUsingFilename() targets = ensure_list(targets) + + invalid_opt = None + if to_remote: + message = "{option} can't be used with --to-remote" + if len(targets) != 1: + invalid_opt = "multiple targets" + elif no_commit: + invalid_opt = "--no-commit option" + elif recursive: + invalid_opt = "--recursive option" + else: + message = "{option} can't be used without --to-remote" + if kwargs.get("out"): + invalid_opt = "--out" + elif kwargs.get("remote"): + invalid_opt = "--remote" + elif kwargs.get("jobs"): + invalid_opt = "--jobs" + + if invalid_opt is not None: + raise InvalidArgumentError(message.format(option=invalid_opt)) + link_failures = [] stages_list = [] num_targets = len(targets) @@ -64,7 +88,12 @@ def add( ) stages = _create_stages( - repo, sub_targets, fname, pbar=pbar, **kwargs, + repo, + sub_targets, + fname, + pbar=pbar, + to_remote=to_remote, + **kwargs, ) try: @@ -89,7 +118,15 @@ def add( ) link_failures.extend( - _process_stages(repo, stages, no_commit, pbar) + _process_stages( + repo, + sub_targets, + stages, + no_commit, + pbar, + to_remote, + **kwargs, + ) ) stages_list += stages @@ -110,12 +147,29 @@ def add( return stages_list -def _process_stages(repo, stages, no_commit, pbar): +def _process_stages( + repo, sub_targets, stages, no_commit, pbar, to_remote, **kwargs +): link_failures = [] from dvc.dvcfile import Dvcfile from ..output.base import OutputDoesNotExistError + if to_remote: + # Already verified in the add() + assert len(stages) == 1 + assert len(sub_targets) == 1 + + [stage] = stages + stage.outs[0].hash_info = repo.cloud.transfer( + sub_targets[0], + jobs=kwargs.get("jobs"), + remote=kwargs.get("remote"), + command="add", + ) + Dvcfile(repo, stage.path).dump(stage) + return link_failures + with Tqdm( total=len(stages), desc="Processing", @@ -162,7 +216,15 @@ def _find_all_targets(repo, target, recursive): def _create_stages( - repo, targets, fname, pbar=None, external=False, glob=False, desc=None, + repo, + targets, + fname, + to_remote=False, + pbar=None, + external=False, + glob=False, + desc=None, + **kwargs, ): from dvc.dvcfile import Dvcfile from dvc.stage import Stage, create_stage, restore_meta @@ -176,6 +238,8 @@ def _create_stages( disable=len(expanded_targets) < LARGE_DIR_SIZE, unit="file", ): + if to_remote: + out = resolve_output(out, kwargs.get("out")) path, wdir, out = resolve_paths(repo, out) stage = create_stage( Stage, diff --git a/dvc/repo/imp_url.py b/dvc/repo/imp_url.py index 71838b1e2f..dc607efec4 100644 --- a/dvc/repo/imp_url.py +++ b/dvc/repo/imp_url.py @@ -4,7 +4,7 @@ from dvc.utils import relpath, resolve_output, resolve_paths from dvc.utils.fs import path_isin -from ..exceptions import OutputDuplicationError +from ..exceptions import InvalidArgumentError, OutputDuplicationError from . import locked @@ -18,6 +18,8 @@ def imp_url( erepo=None, frozen=True, no_exec=False, + remote=None, + to_remote=False, desc=None, jobs=None, ): @@ -27,6 +29,16 @@ def imp_url( out = resolve_output(url, out) path, wdir, out = resolve_paths(self, out) + if to_remote and no_exec: + raise InvalidArgumentError( + "--no-exec can't be combined with --to-remote" + ) + + if not to_remote and remote: + raise InvalidArgumentError( + "--remote can't be used without --to-remote" + ) + # NOTE: when user is importing something from within their own repository if ( erepo is None @@ -61,6 +73,10 @@ def imp_url( if no_exec: stage.ignore_outs() + elif to_remote: + stage.outs[0].hash_info = self.cloud.transfer( + url, jobs=jobs, remote=remote, command="import-url" + ) else: stage.run(jobs=jobs) diff --git a/dvc/tree/base.py b/dvc/tree/base.py index 311b9263ac..7b4ed1be3f 100644 --- a/dvc/tree/base.py +++ b/dvc/tree/base.py @@ -59,6 +59,7 @@ class BaseTree: TRAVERSE_PREFIX_LEN = 3 TRAVERSE_THRESHOLD_SIZE = 500000 CAN_TRAVERSE = True + CHUNK_SIZE = 64 * 1024 * 1024 # 64 MiB SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} PARAM_CHECKSUM: ClassVar[Optional[str]] = None @@ -168,12 +169,12 @@ def dir_mode(self): def cache(self): return getattr(self.repo.cache, self.scheme) - def open(self, path_info, mode: str = "r", encoding: str = None): + def open(self, path_info, mode: str = "r", encoding: str = None, **kwargs): if hasattr(self, "_generate_download_url"): # pylint:disable=no-member func = self._generate_download_url # type: ignore[attr-defined] get_url = partial(func, path_info) - return open_url(get_url, mode=mode, encoding=encoding) + return open_url(get_url, mode=mode, encoding=encoding, **kwargs) raise RemoteActionNotImplemented("open", self.scheme) @@ -238,7 +239,7 @@ def move(self, from_info, to_info, mode=None): def copy(self, from_info, to_info): raise RemoteActionNotImplemented("copy", self.scheme) - def copy_fobj(self, fobj, to_info): + def copy_fobj(self, fobj, to_info, chunk_size=None): raise RemoteActionNotImplemented("copy_fobj", self.scheme) def symlink(self, from_info, to_info): @@ -382,6 +383,9 @@ def upload( file_mode=file_mode, ) + def upload_fobj(self, fobj, to_info, no_progress_bar=False): + raise RemoteActionNotImplemented("upload_fobj", self.scheme) + def download( self, from_info, diff --git a/dvc/tree/dvc.py b/dvc/tree/dvc.py index e2e628acd5..14a094dbc5 100644 --- a/dvc/tree/dvc.py +++ b/dvc/tree/dvc.py @@ -65,8 +65,8 @@ def _get_granular_hash( return file_hash raise FileNotFoundError - def open( - self, path: PathInfo, mode="r", encoding="utf-8", remote=None + def open( # type: ignore + self, path: PathInfo, mode="r", encoding="utf-8", remote=None, **kwargs ): # pylint: disable=arguments-differ try: outs = self._find_outs(path, strict=False) diff --git a/dvc/tree/gdrive.py b/dvc/tree/gdrive.py index 00fbcfbd81..2007f8d156 100644 --- a/dvc/tree/gdrive.py +++ b/dvc/tree/gdrive.py @@ -389,7 +389,7 @@ def _gdrive_download_file( @contextmanager @_gdrive_retry - def open(self, path_info, mode="r", encoding=None): + def open(self, path_info, mode="r", encoding=None, **kwargs): assert mode in {"r", "rt", "rb"} item_id = self._get_item_id(path_info) diff --git a/dvc/tree/hdfs.py b/dvc/tree/hdfs.py index 2c2db68c37..c086539ca7 100644 --- a/dvc/tree/hdfs.py +++ b/dvc/tree/hdfs.py @@ -63,7 +63,7 @@ def close(self): ) @contextmanager - def open(self, path_info, mode="r", encoding=None): + def open(self, path_info, mode="r", encoding=None, **kwargs): assert mode in {"r", "rt", "rb"} try: diff --git a/dvc/tree/local.py b/dvc/tree/local.py index f8aa4e0690..d68b61cd9f 100644 --- a/dvc/tree/local.py +++ b/dvc/tree/local.py @@ -56,7 +56,7 @@ def dvcignore(self): return cls(self, root) @staticmethod - def open(path_info, mode="r", encoding=None): + def open(path_info, mode="r", encoding=None, **kwargs): return open(path_info, mode=mode, encoding=encoding) def exists(self, path_info, use_dvcignore=True): @@ -180,11 +180,11 @@ def copy(self, from_info, to_info): self.remove(tmp_info) raise - def copy_fobj(self, fobj, to_info): + def copy_fobj(self, fobj, to_info, chunk_size=None): self.makedirs(to_info.parent) tmp_info = to_info.parent / tmp_fname("") try: - copy_fobj_to_file(fobj, tmp_info) + copy_fobj_to_file(fobj, tmp_info, chunk_size=chunk_size) os.chmod(tmp_info, self.file_mode) os.rename(tmp_info, to_info) except Exception: @@ -282,6 +282,13 @@ def _upload( self.chmod(tmp_file, file_mode) os.replace(tmp_file, to_info) + def upload_fobj(self, fobj, to_info, no_progress_bar=False, **pbar_args): + from dvc.progress import Tqdm + + with Tqdm(bytes=True, disable=no_progress_bar, **pbar_args) as pbar: + with pbar.wrapattr(fobj, "read") as fobj: + self.copy_fobj(fobj, to_info, chunk_size=self.CHUNK_SIZE) + @staticmethod def _download( from_info, to_file, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/tree/s3.py b/dvc/tree/s3.py index 7b3eb0c10f..1aee7de6de 100644 --- a/dvc/tree/s3.py +++ b/dvc/tree/s3.py @@ -80,7 +80,10 @@ def s3(self): session = boto3.session.Session(**session_opts) return session.resource( - "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl + "s3", + endpoint_url=self.endpoint_url, + use_ssl=self.use_ssl, + config=boto3.session.Config(signature_version="s3v4"), ) @contextmanager @@ -362,6 +365,27 @@ def _upload( from_file, Callback=pbar.update, ExtraArgs=self.extra_args, ) + def upload_fobj(self, fobj, to_info, no_progress_bar=False, **pbar_args): + from boto3.s3.transfer import TransferConfig + + config = TransferConfig( + multipart_threshold=self.CHUNK_SIZE, + multipart_chunksize=self.CHUNK_SIZE, + max_concurrency=1, + use_threads=False, + ) + with self._get_s3() as s3: + with Tqdm( + disable=no_progress_bar, bytes=True, **pbar_args + ) as pbar: + s3.meta.client.upload_fileobj( + fobj, + to_info.bucket, + to_info.path, + Config=config, + Callback=pbar.update, + ) + def _download(self, from_info, to_file, name=None, no_progress_bar=False): with self._get_obj(from_info) as obj: with Tqdm( diff --git a/dvc/tree/ssh/__init__.py b/dvc/tree/ssh/__init__.py index b2391e4c90..0ce4088ae4 100644 --- a/dvc/tree/ssh/__init__.py +++ b/dvc/tree/ssh/__init__.py @@ -2,6 +2,7 @@ import io import logging import os +import shutil import threading from contextlib import closing, contextmanager from urllib.parse import urlparse @@ -147,7 +148,7 @@ def ssh(self, path_info): ) @contextmanager - def open(self, path_info, mode="r", encoding=None): + def open(self, path_info, mode="r", encoding=None, **kwargs): assert mode in {"r", "rt", "rb", "wb"} with self.ssh(path_info) as ssh, closing( @@ -260,6 +261,14 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): no_progress_bar=no_progress_bar, ) + def upload_fobj(self, fobj, to_info, no_progress_bar=False, **pbar_args): + from dvc.progress import Tqdm + + with Tqdm(bytes=True, disable=no_progress_bar, **pbar_args) as pbar: + with pbar.wrapattr(fobj, "read") as fobj: + with self.open(to_info, mode="wb") as fdest: + shutil.copyfileobj(fobj, fdest, length=self.CHUNK_SIZE) + def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs ): diff --git a/dvc/tree/webhdfs.py b/dvc/tree/webhdfs.py index a65568f090..cc102c9428 100644 --- a/dvc/tree/webhdfs.py +++ b/dvc/tree/webhdfs.py @@ -91,7 +91,7 @@ def hdfs_client(self): return client @contextmanager - def open(self, path_info, mode="r", encoding=None): + def open(self, path_info, mode="r", encoding=None, **kwargs): assert mode in {"r", "rt", "rb"} with self.hdfs_client.read( diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index f888b76b17..49122a53ae 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -226,7 +226,7 @@ def fix_env(env=None): return env -def tmp_fname(fname): +def tmp_fname(fname=""): """ Temporary name for a partial download """ from shortuuid import uuid diff --git a/dvc/utils/fs.py b/dvc/utils/fs.py index 4e37601105..578048434b 100644 --- a/dvc/utils/fs.py +++ b/dvc/utils/fs.py @@ -201,10 +201,10 @@ def copyfile(src, dest, no_progress_bar=False, name=None): fdest_wrapped.write(buf) -def copy_fobj_to_file(fsrc, dest): +def copy_fobj_to_file(fsrc, dest, chunk_size=None): """Copy contents of open file object to destination path.""" with open(dest, "wb+") as fdest: - shutil.copyfileobj(fsrc, fdest) + shutil.copyfileobj(fsrc, fdest, length=chunk_size) def walk_files(directory): diff --git a/dvc/utils/http.py b/dvc/utils/http.py index 4472b80912..c372db561d 100644 --- a/dvc/utils/http.py +++ b/dvc/utils/http.py @@ -5,7 +5,7 @@ @contextmanager -def open_url(url, mode="r", encoding=None): +def open_url(url, mode="r", encoding=None, **iter_opts): """Opens an url as a readable stream. Resumes on connection error. @@ -13,7 +13,7 @@ def open_url(url, mode="r", encoding=None): """ assert mode in {"r", "rt", "rb"} - with iter_url(url) as (response, it): + with iter_url(url, **iter_opts) as (response, it): bytes_stream = IterStream(it) if mode == "rb": diff --git a/dvc/utils/stream.py b/dvc/utils/stream.py index c2d3fe33ea..c93f2756da 100644 --- a/dvc/utils/stream.py +++ b/dvc/utils/stream.py @@ -1,5 +1,10 @@ +import hashlib import io +from dvc.hash_info import HashInfo +from dvc.istextfile import istextblock +from dvc.utils import dos2unix + class IterStream(io.RawIOBase): """Wraps an iterator yielding bytes as a file object""" @@ -54,3 +59,30 @@ def peek(self, n): except StopIteration: break return self.leftover[:n] + + +class HashedStreamReader: + + PARAM_CHECKSUM = "md5" + + def __init__(self, fobj): + self.md5 = hashlib.md5() + self.is_text_file = None + self.reader = fobj.read1 if hasattr(fobj, "read1") else fobj.read + + def read(self, n=-1): + chunk = self.reader(n) + if self.is_text_file is None: + self.is_text_file = istextblock(chunk) + + if self.is_text_file: + data = dos2unix(chunk) + else: + data = chunk + self.md5.update(data) + + return chunk + + @property + def hash_info(self): + return HashInfo(self.PARAM_CHECKSUM, self.md5.hexdigest(), nfiles=1) diff --git a/tests/dir_helpers.py b/tests/dir_helpers.py index 34e6dbf467..aa5f14637f 100644 --- a/tests/dir_helpers.py +++ b/tests/dir_helpers.py @@ -238,6 +238,9 @@ def read_text(self, *args, **kwargs): # pylint: disable=signature-differs } return super().read_text(*args, **kwargs) + def hash_to_path_info(self, hash_): + return self / hash_[0:2] / hash_[2:] + def _coerce_filenames(filenames): if isinstance(filenames, (str, bytes, pathlib.PurePath)): diff --git a/tests/func/test_add.py b/tests/func/test_add.py index 554efd8b22..77dcd9e3be 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -16,6 +16,7 @@ from dvc.dvcfile import DVC_FILE_SUFFIX from dvc.exceptions import ( DvcException, + InvalidArgumentError, OutputDuplicationError, OverlappingOutputPathsError, RecursiveAddingWhileUsingFilename, @@ -991,3 +992,32 @@ def test_add_long_fname(tmp_dir, dvc): dvc.add("data") assert (tmp_dir / "data").read_text() == {name: "foo"} + + +def test_add_to_remote(tmp_dir, dvc, local_cloud, local_remote): + local_cloud.gen("foo", "foo") + + url = "remote://upstream/foo" + [stage] = dvc.add(url, to_remote=True) + + assert not (tmp_dir / "foo").exists() + assert (tmp_dir / "foo.dvc").exists() + + assert len(stage.deps) == 0 + assert len(stage.outs) == 1 + + hash_info = stage.outs[0].hash_info + assert local_remote.hash_to_path_info(hash_info.value).read_text() == "foo" + + +@pytest.mark.parametrize( + "invalid_opt, kwargs", + [ + ("multiple targets", {"targets": ["foo", "bar", "baz"]}), + ("--no-commit", {"targets": ["foo"], "no_commit": True}), + ("--recursive", {"targets": ["foo"], "recursive": True},), + ], +) +def test_add_to_remote_invalid_combinations(dvc, invalid_opt, kwargs): + with pytest.raises(InvalidArgumentError, match=invalid_opt): + dvc.add(to_remote=True, **kwargs) diff --git a/tests/func/test_import_url.py b/tests/func/test_import_url.py index a314de0354..925675c6de 100644 --- a/tests/func/test_import_url.py +++ b/tests/func/test_import_url.py @@ -1,3 +1,4 @@ +import json import os import textwrap from uuid import uuid4 @@ -6,6 +7,7 @@ from dvc.cache import Cache from dvc.dependency.base import DependencyDoesNotExistError +from dvc.exceptions import InvalidArgumentError from dvc.main import main from dvc.stage import Stage from dvc.utils.fs import makedirs @@ -245,3 +247,99 @@ def test_import_url_preserve_meta(tmp_dir, dvc): frozen: true """ ) + + +@pytest.mark.parametrize( + "workspace", + [ + pytest.lazy_fixture("local_cloud"), + pytest.lazy_fixture("s3"), + pytest.lazy_fixture("gs"), + pytest.lazy_fixture("hdfs"), + pytest.param( + pytest.lazy_fixture("ssh"), + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + pytest.lazy_fixture("http"), + ], + indirect=True, +) +def test_import_url_to_remote_single_file( + tmp_dir, dvc, workspace, local_remote +): + workspace.gen("foo", "foo") + + url = "remote://workspace/foo" + stage = dvc.imp_url(url, to_remote=True) + + assert not (tmp_dir / "foo").exists() + assert (tmp_dir / "foo.dvc").exists() + + assert len(stage.deps) == 1 + assert stage.deps[0].def_path == url + assert len(stage.outs) == 1 + + hash_info = stage.outs[0].hash_info + assert local_remote.hash_to_path_info(hash_info.value).read_text() == "foo" + + +@pytest.mark.parametrize( + "workspace", + [ + pytest.lazy_fixture("local_cloud"), + pytest.lazy_fixture("s3"), + pytest.lazy_fixture("gs"), + pytest.lazy_fixture("hdfs"), + pytest.param( + pytest.lazy_fixture("ssh"), + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + ], + indirect=True, +) +def test_import_url_to_remote_directory(tmp_dir, dvc, workspace, local_remote): + workspace.gen( + { + "data": { + "foo": "foo", + "bar": "bar", + "sub_dir": {"baz": "sub_dir/baz"}, + } + } + ) + + url = "remote://workspace/data" + stage = dvc.imp_url(url, to_remote=True) + + assert not (tmp_dir / "data").exists() + assert (tmp_dir / "data.dvc").exists() + + assert len(stage.deps) == 1 + assert stage.deps[0].def_path == url + assert len(stage.outs) == 1 + + hash_info = stage.outs[0].hash_info + with open(local_remote.hash_to_path_info(hash_info.value)) as stream: + file_parts = json.load(stream) + + assert len(file_parts) == 3 + assert {file_part["relpath"] for file_part in file_parts} == { + "foo", + "bar", + "sub_dir/baz", + } + + for file_part in file_parts: + assert ( + local_remote.hash_to_path_info(file_part["md5"]).read_text() + == file_part["relpath"] + ) + + +def test_import_url_to_remote_invalid_combinations(dvc): + with pytest.raises(InvalidArgumentError, match="--no-exec"): + dvc.imp_url("s3://bucket/foo", no_exec=True, to_remote=True) diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index 9e52aecd33..8f6f4c680c 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -119,3 +119,14 @@ def test_s3_isdir(tmp_dir, dvc, s3): assert not tree.isdir(s3 / "data" / "foo") assert tree.isdir(s3 / "data") + + +def test_s3_upload_fobj(tmp_dir, dvc, s3): + s3.gen({"data": {"foo": "foo"}}) + tree = S3Tree(dvc, s3.config) + + to_info = s3 / "data" / "bar" + with tree.open(s3 / "data" / "foo", "rb") as stream: + tree.upload_fobj(stream, to_info, 1) + + assert to_info.read_text() == "foo" diff --git a/tests/unit/command/test_add.py b/tests/unit/command/test_add.py index 22b57c6cb9..1c4f5266af 100644 --- a/tests/unit/command/test_add.py +++ b/tests/unit/command/test_add.py @@ -1,3 +1,5 @@ +import logging + from dvc.cli import parse_args from dvc.command.add import CmdAdd @@ -31,5 +33,69 @@ def test_add(mocker, dvc): glob=True, fname="file", external=True, + out=None, + remote=None, + to_remote=False, desc="stage description", + jobs=None, + ) + + +def test_add_to_remote(mocker): + cli_args = parse_args( + [ + "add", + "s3://bucket/foo", + "--to-remote", + "--out", + "bar", + "--remote", + "remote", + ] + ) + assert cli_args.func == CmdAdd + + cmd = cli_args.func(cli_args) + m = mocker.patch.object(cmd.repo, "add", autospec=True) + + assert cmd.run() == 0 + + m.assert_called_once_with( + ["s3://bucket/foo"], + recursive=False, + no_commit=False, + glob=False, + fname=None, + external=False, + out="bar", + remote="remote", + to_remote=True, + desc=None, + jobs=None, ) + + +def test_add_to_remote_invalid_combinations(mocker, caplog): + cli_args = parse_args( + ["add", "s3://bucket/foo", "s3://bucket/bar", "--to-remote"] + ) + assert cli_args.func == CmdAdd + + cmd = cli_args.func(cli_args) + with caplog.at_level(logging.ERROR, logger="dvc"): + assert cmd.run() == 1 + expected_msg = "multiple targets can't be used with --to-remote" + assert expected_msg in caplog.text + + for option, value in ( + ("--remote", "remote"), + ("--out", "bar"), + ("--jobs", "4"), + ): + cli_args = parse_args(["add", "foo", option, value]) + + cmd = cli_args.func(cli_args) + with caplog.at_level(logging.ERROR, logger="dvc"): + assert cmd.run() == 1 + expected_msg = f"{option} can't be used without --to-remote" + assert expected_msg in caplog.text diff --git a/tests/unit/command/test_imp_url.py b/tests/unit/command/test_imp_url.py index 3eb6dcc0d5..ef82af8c78 100644 --- a/tests/unit/command/test_imp_url.py +++ b/tests/unit/command/test_imp_url.py @@ -7,7 +7,17 @@ def test_import_url(mocker): cli_args = parse_args( - ["import-url", "src", "out", "--file", "file", "--desc", "description"] + [ + "import-url", + "src", + "out", + "--file", + "file", + "--jobs", + "4", + "--desc", + "description", + ] ) assert cli_args.func == CmdImportUrl @@ -17,7 +27,14 @@ def test_import_url(mocker): assert cmd.run() == 0 m.assert_called_once_with( - "src", out="out", fname="file", no_exec=False, desc="description" + "src", + out="out", + fname="file", + no_exec=False, + remote=None, + to_remote=False, + desc="description", + jobs=4, ) @@ -57,5 +74,75 @@ def test_import_url_no_exec(mocker): assert cmd.run() == 0 m.assert_called_once_with( - "src", out="out", fname="file", no_exec=True, desc="description" + "src", + out="out", + fname="file", + no_exec=True, + remote=None, + to_remote=False, + desc="description", + jobs=None, + ) + + +def test_import_url_to_remote(mocker): + cli_args = parse_args( + [ + "import-url", + "s3://bucket/foo", + "bar", + "--to-remote", + "--remote", + "remote", + "--desc", + "description", + ] ) + assert cli_args.func == CmdImportUrl + + cmd = cli_args.func(cli_args) + m = mocker.patch.object(cmd.repo, "imp_url", autospec=True) + + assert cmd.run() == 0 + + m.assert_called_once_with( + "s3://bucket/foo", + out="bar", + fname=None, + no_exec=False, + remote="remote", + to_remote=True, + desc="description", + jobs=None, + ) + + +def test_import_url_to_remote_invalid_combination(mocker, caplog): + cli_args = parse_args( + [ + "import-url", + "s3://bucket/foo", + "bar", + "--to-remote", + "--remote", + "remote", + "--no-exec", + ] + ) + assert cli_args.func == CmdImportUrl + + cmd = cli_args.func(cli_args) + with caplog.at_level(logging.ERROR, logger="dvc"): + assert cmd.run() == 1 + expected_msg = "--no-exec can't be combined with --to-remote" + assert expected_msg in caplog.text + + cli_args = parse_args( + ["import-url", "s3://bucket/foo", "bar", "--remote", "remote"] + ) + + cmd = cli_args.func(cli_args) + with caplog.at_level(logging.ERROR, logger="dvc"): + assert cmd.run() == 1 + expected_msg = "--remote can't be used without --to-remote" + assert expected_msg in caplog.text diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py new file mode 100644 index 0000000000..bde29cdb2f --- /dev/null +++ b/tests/unit/utils/test_stream.py @@ -0,0 +1,31 @@ +from dvc.utils import file_md5 +from dvc.utils.stream import HashedStreamReader + + +def test_hashed_stream_reader(tmp_dir): + tmp_dir.gen({"foo": "foo"}) + + foo = tmp_dir / "foo" + with open(foo, "rb") as fobj: + stream_reader = HashedStreamReader(fobj) + assert stream_reader.read(3) == b"foo" + + hex_digest, _ = file_md5(foo) + assert stream_reader.is_text_file + assert hex_digest == stream_reader.hash_info.value + + +def test_hashed_stream_reader_as_chunks(tmp_dir): + tmp_dir.gen({"foo": b"foo \x00" * 16}) + + foo = tmp_dir / "foo" + with open(foo, "rb") as fobj: + stream_reader = HashedStreamReader(fobj) + while True: + chunk = stream_reader.read(16) + if not chunk: + break + + hex_digest, _ = file_md5(foo) + assert not stream_reader.is_text_file + assert hex_digest == stream_reader.hash_info.value