From 85b26c2063ae0f523514a61219848511f15426a3 Mon Sep 17 00:00:00 2001 From: pared Date: Thu, 21 Mar 2019 09:48:14 +0100 Subject: [PATCH] run: add --outs-persist option --- dvc/command/run.py | 16 +++++++++++ dvc/output/__init__.py | 39 ++++++++++++++++++++++---- dvc/output/base.py | 16 ++++++++++- dvc/output/hdfs.py | 17 ++++++++++-- dvc/output/local.py | 17 ++++++++++-- dvc/output/s3.py | 17 ++++++++++-- dvc/output/ssh.py | 17 ++++++++++-- dvc/remote/base.py | 4 +++ dvc/remote/local.py | 41 +++++++++++++++++++++++++++ dvc/repo/__init__.py | 8 ++---- dvc/repo/run.py | 8 ++++++ dvc/repo/unprotect.py | 49 -------------------------------- dvc/stage.py | 63 ++++++++++++++++++++++++++++++++---------- tests/test_output.py | 2 +- tests/test_run.py | 54 +++++++++++++++++++++++++++++++++++- 15 files changed, 282 insertions(+), 86 deletions(-) delete mode 100644 dvc/repo/unprotect.py diff --git a/dvc/command/run.py b/dvc/command/run.py index cc9bc636e1..2ed9e68620 100644 --- a/dvc/command/run.py +++ b/dvc/command/run.py @@ -43,6 +43,8 @@ def run(self): ignore_build_cache=self.args.ignore_build_cache, remove_outs=self.args.remove_outs, no_commit=self.args.no_commit, + outs_persist=self.args.outs_persist, + outs_persist_no_cache=self.args.outs_persist_no_cache, ) except DvcException: logger.error("failed to run command") @@ -175,6 +177,20 @@ def add_parser(subparsers, parent_parser): default=False, help="Don't put files/directories into cache.", ) + run_parser.add_argument( + "--outs-persist", + action="append", + default=[], + help="Declare output file or directory that will not be " + "removed upon repro.", + ) + run_parser.add_argument( + "--outs-persist-no-cache", + action="append", + default=[], + help="Declare output file or directory that will not be " + "removed upon repro (do not put into DVC cache).", + ) run_parser.add_argument( "command", nargs=argparse.REMAINDER, help="Command to execute." ) diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index ede71b5622..4f7058782b 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -48,23 +48,38 @@ schema.Optional(RemoteHDFS.PARAM_CHECKSUM): schema.Or(str, None), schema.Optional(OutputBase.PARAM_CACHE): bool, schema.Optional(OutputBase.PARAM_METRIC): OutputBase.METRIC_SCHEMA, + schema.Optional(OutputBase.PARAM_PERSIST): bool, } -def _get(stage, p, info, cache, metric): +def _get(stage, p, info, cache, metric, persist): parsed = urlparse(p) if parsed.scheme == "remote": name = Config.SECTION_REMOTE_FMT.format(parsed.netloc) sect = stage.repo.config.config[name] remote = Remote(stage.repo, sect) return OUTS_MAP[remote.scheme]( - stage, p, info, cache=cache, remote=remote, metric=metric + stage, + p, + info, + cache=cache, + remote=remote, + metric=metric, + persist=persist, ) for o in OUTS: if o.supported(p): return o(stage, p, info, cache=cache, remote=None, metric=metric) - return OutputLOCAL(stage, p, info, cache=cache, remote=None, metric=metric) + return OutputLOCAL( + stage, + p, + info, + cache=cache, + remote=None, + metric=metric, + persist=persist, + ) def loadd_from(stage, d_list): @@ -73,12 +88,24 @@ def loadd_from(stage, d_list): p = d.pop(OutputBase.PARAM_PATH) cache = d.pop(OutputBase.PARAM_CACHE, True) metric = d.pop(OutputBase.PARAM_METRIC, False) - ret.append(_get(stage, p, info=d, cache=cache, metric=metric)) + persist = d.pop(OutputBase.PARAM_PERSIST, False) + ret.append( + _get(stage, p, info=d, cache=cache, metric=metric, persist=persist) + ) return ret -def loads_from(stage, s_list, use_cache=True, metric=False): +def loads_from(stage, s_list, use_cache=True, metric=False, persist=False): ret = [] for s in s_list: - ret.append(_get(stage, s, info={}, cache=use_cache, metric=metric)) + ret.append( + _get( + stage, + s, + info={}, + cache=use_cache, + metric=metric, + persist=persist, + ) + ) return ret diff --git a/dvc/output/base.py b/dvc/output/base.py index 53dfe6f7e2..1aa7e14884 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -36,6 +36,7 @@ class OutputBase(object): PARAM_METRIC = "metric" PARAM_METRIC_TYPE = "type" PARAM_METRIC_XPATH = "xpath" + PARAM_PERSIST = "persist" METRIC_SCHEMA = Or( None, @@ -50,7 +51,14 @@ class OutputBase(object): IsNotFileOrDirError = OutputIsNotFileOrDirError def __init__( - self, stage, path, info=None, remote=None, cache=True, metric=False + self, + stage, + path, + info=None, + remote=None, + cache=True, + metric=False, + persist=False, ): self.stage = stage self.repo = stage.repo @@ -59,6 +67,7 @@ def __init__( self.remote = remote or self.REMOTE(self.repo, {}) self.use_cache = False if self.IS_DEPENDENCY else cache self.metric = False if self.IS_DEPENDENCY else metric + self.persist = persist if ( self.use_cache @@ -186,6 +195,7 @@ def dumpd(self): del self.metric[self.PARAM_METRIC_XPATH] ret[self.PARAM_METRIC] = self.metric + ret[self.PARAM_PERSIST] = self.persist return ret @@ -231,3 +241,7 @@ def get_files_number(self): return 1 return 0 + + def unprotect(self): + if self.exists: + self.remote.unprotect(self.path_info) diff --git a/dvc/output/hdfs.py b/dvc/output/hdfs.py index 9bf8f6a4a8..4e6b7b8b3e 100644 --- a/dvc/output/hdfs.py +++ b/dvc/output/hdfs.py @@ -11,10 +11,23 @@ class OutputHDFS(OutputBase): REMOTE = RemoteHDFS def __init__( - self, stage, path, info=None, remote=None, cache=True, metric=False + self, + stage, + path, + info=None, + remote=None, + cache=True, + metric=False, + persist=False, ): super(OutputHDFS, self).__init__( - stage, path, info=info, remote=remote, cache=cache, metric=metric + stage, + path, + info=info, + remote=remote, + cache=cache, + metric=metric, + persist=persist, ) if remote: path = posixpath.join(remote.url, urlparse(path).path.lstrip("/")) diff --git a/dvc/output/local.py b/dvc/output/local.py index 0bfe28d8e2..70728eb049 100644 --- a/dvc/output/local.py +++ b/dvc/output/local.py @@ -14,10 +14,23 @@ class OutputLOCAL(OutputBase): REMOTE = RemoteLOCAL def __init__( - self, stage, path, info=None, remote=None, cache=True, metric=False + self, + stage, + path, + info=None, + remote=None, + cache=True, + metric=False, + persist=False, ): super(OutputLOCAL, self).__init__( - stage, path, info, remote=remote, cache=cache, metric=metric + stage, + path, + info, + remote=remote, + cache=cache, + metric=metric, + persist=persist, ) if remote: p = os.path.join( diff --git a/dvc/output/s3.py b/dvc/output/s3.py index 48546b1a4c..6eebef7b0d 100644 --- a/dvc/output/s3.py +++ b/dvc/output/s3.py @@ -11,10 +11,23 @@ class OutputS3(OutputBase): REMOTE = RemoteS3 def __init__( - self, stage, path, info=None, remote=None, cache=True, metric=False + self, + stage, + path, + info=None, + remote=None, + cache=True, + metric=False, + persist=False, ): super(OutputS3, self).__init__( - stage, path, info=info, remote=remote, cache=cache, metric=metric + stage, + path, + info=info, + remote=remote, + cache=cache, + metric=metric, + persist=persist, ) bucket = remote.bucket if remote else urlparse(path).netloc path = urlparse(path).path.lstrip("/") diff --git a/dvc/output/ssh.py b/dvc/output/ssh.py index fba970a59b..c0ed25dfa7 100644 --- a/dvc/output/ssh.py +++ b/dvc/output/ssh.py @@ -12,10 +12,23 @@ class OutputSSH(OutputBase): REMOTE = RemoteSSH def __init__( - self, stage, path, info=None, remote=None, cache=True, metric=False + self, + stage, + path, + info=None, + remote=None, + cache=True, + metric=False, + persist=False, ): super(OutputSSH, self).__init__( - stage, path, info=info, remote=remote, cache=cache, metric=metric + stage, + path, + info=info, + remote=remote, + cache=cache, + metric=metric, + persist=persist, ) parsed = urlparse(path) host = remote.host if remote else parsed.hostname diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 42f069e49b..7144219685 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -362,3 +362,7 @@ def checkout(self, output, force=False, progress_callback=None): self.do_checkout( output, force=force, progress_callback=progress_callback ) + + @staticmethod + def unprotect(path_info): + pass diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 99d6d79fb3..adb847a045 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -743,3 +743,44 @@ def _log_missing_caches(self, checksum_info_dict): "nor on remote. Missing cache files: {}".format(missing_desc) ) logger.warning(msg) + + @staticmethod + def _unprotect_file(path): + if System.is_symlink(path) or System.is_hardlink(path): + logger.debug("Unprotecting '{}'".format(path)) + tmp = os.path.join(os.path.dirname(path), "." + str(uuid.uuid4())) + + # The operations order is important here - if some application + # would access the file during the process of copyfile then it + # would get only the part of file. So, at first, the file should be + # copied with the temporary name, and then original file should be + # replaced by new. + copyfile(path, tmp) + remove(path) + os.rename(tmp, path) + + else: + logger.debug( + "Skipping copying for '{}', since it is not " + "a symlink or a hardlink.".format(path) + ) + + os.chmod(path, os.stat(path).st_mode | stat.S_IWRITE) + + @staticmethod + def _unprotect_dir(path): + for path in walk_files(path): + RemoteLOCAL._unprotect_file(path) + + @staticmethod + def unprotect(path_info): + path = path_info["path"] + if not os.path.exists(path): + raise DvcException( + "can't unprotect non-existing data '{}'".format(path) + ) + + if os.path.isdir(path): + RemoteLOCAL._unprotect_dir(path) + else: + RemoteLOCAL._unprotect_file(path) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 24ad3ea00a..f758609fef 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -101,11 +101,9 @@ def init(root_dir=os.curdir, no_scm=False, force=False): init(root_dir=root_dir, no_scm=no_scm, force=force) return Repo(root_dir) - @staticmethod - def unprotect(target): - from dvc.repo.unprotect import unprotect - - return unprotect(target) + def unprotect(self, target): + path_info = {"schema": "local", "path": target} + return self.cache.local.unprotect(path_info) def _ignore(self): flist = [ diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 24cc8283d8..7f1759fdcc 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -20,6 +20,8 @@ def run( ignore_build_cache=False, remove_outs=False, no_commit=False, + outs_persist=None, + outs_persist_no_cache=None, ): from dvc.stage import Stage @@ -33,6 +35,10 @@ def run( metrics = [] if metrics_no_cache is None: metrics_no_cache = [] + if outs_persist is None: + outs_persist = [] + if outs_persist_no_cache is None: + outs_persist_no_cache = [] with self.state: stage = Stage.create( @@ -49,6 +55,8 @@ def run( overwrite=overwrite, ignore_build_cache=ignore_build_cache, remove_outs=remove_outs, + outs_persist=outs_persist, + outs_persist_no_cache=outs_persist_no_cache, ) if stage is None: diff --git a/dvc/repo/unprotect.py b/dvc/repo/unprotect.py deleted file mode 100644 index 5220e471f8..0000000000 --- a/dvc/repo/unprotect.py +++ /dev/null @@ -1,49 +0,0 @@ -from __future__ import unicode_literals - -import os -import stat -import uuid - -import dvc.logger as logger -from dvc.system import System -from dvc.utils import copyfile, remove, walk_files -from dvc.exceptions import DvcException - - -def _unprotect_file(path): - if System.is_symlink(path) or System.is_hardlink(path): - logger.debug("Unprotecting '{}'".format(path)) - tmp = os.path.join(os.path.dirname(path), "." + str(uuid.uuid4())) - - # The operations order is important here - if some application would - # access the file during the process of copyfile then it would get - # only the part of file. So, at first, the file should be copied with - # the temporary name, and then original file should be replaced by new. - copyfile(path, tmp) - remove(path) - os.rename(tmp, path) - - else: - logger.debug( - "Skipping copying for '{}', since it is not " - "a symlink or a hardlink.".format(path) - ) - - os.chmod(path, os.stat(path).st_mode | stat.S_IWRITE) - - -def _unprotect_dir(path): - for path in walk_files(path): - _unprotect_file(path) - - -def unprotect(path): - if not os.path.exists(path): - raise DvcException( - "can't unprotect non-existing data '{}'".format(path) - ) - - if os.path.isdir(path): - _unprotect_dir(path) - else: - _unprotect_file(path) diff --git a/dvc/stage.py b/dvc/stage.py index 9c1be9339e..ca338bd683 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -262,18 +262,19 @@ def changed(self): def remove_outs(self, ignore_remove=False): """Used mainly for `dvc remove --outs` and :func:`Stage.reproduce`.""" for out in self.outs: - logger.debug( - "Removing output '{out}' of '{stage}'.".format( - out=out, stage=self.relpath + if out.persist: + out.unprotect() + else: + logger.debug( + "Removing output '{out}' of '{stage}'.".format( + out=out, stage=self.relpath + ) ) - ) - out.remove(ignore_remove=ignore_remove) + out.remove(ignore_remove=ignore_remove) def unprotect_outs(self): for out in self.outs: - if out.scheme != "local" or not out.exists: - continue - self.repo.unprotect(out.path) + out.unprotect() def remove(self): self.remove_outs(ignore_remove=True) @@ -407,6 +408,8 @@ def create( ignore_build_cache=False, remove_outs=False, validate_state=True, + outs_persist=None, + outs_persist_no_cache=None, ): if outs is None: outs = [] @@ -418,6 +421,10 @@ def create( metrics = [] if metrics_no_cache is None: metrics_no_cache = [] + if outs_persist is None: + outs_persist = [] + if outs_persist_no_cache is None: + outs_persist_no_cache = [] # Backward compatibility for `cwd` option if wdir is None and cwd is not None: @@ -434,13 +441,14 @@ def create( stage = Stage(repo=repo, wdir=wdir, cmd=cmd, locked=locked) - stage.outs = output.loads_from(stage, outs, use_cache=True) - stage.outs += output.loads_from( - stage, metrics, use_cache=True, metric=True - ) - stage.outs += output.loads_from(stage, outs_no_cache, use_cache=False) - stage.outs += output.loads_from( - stage, metrics_no_cache, use_cache=False, metric=True + Stage._fill_stage_outputs( + stage, + outs, + outs_no_cache, + metrics, + metrics_no_cache, + outs_persist, + outs_persist_no_cache, ) stage.deps = dependency.loads_from(stage, deps) @@ -487,6 +495,31 @@ def create( return stage + @staticmethod + def _fill_stage_outputs( + stage, + outs, + outs_no_cache, + metrics, + metrics_no_cache, + outs_persist, + outs_persist_no_cache, + ): + stage.outs = output.loads_from(stage, outs, use_cache=True) + stage.outs += output.loads_from( + stage, metrics, use_cache=True, metric=True + ) + stage.outs += output.loads_from( + stage, outs_persist, use_cache=True, persist=True + ) + stage.outs += output.loads_from(stage, outs_no_cache, use_cache=False) + stage.outs += output.loads_from( + stage, metrics_no_cache, use_cache=False, metric=True + ) + stage.outs += output.loads_from( + stage, outs_persist_no_cache, use_cache=False, persist=True + ) + @staticmethod def _check_dvc_filename(fname): if not Stage.is_valid_filename(fname): diff --git a/tests/test_output.py b/tests/test_output.py index 9ecb9d0db2..cf96dd0318 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -23,7 +23,7 @@ class TestOutScheme(TestDvc): } def _get(self, path): - return _get(Stage(self.dvc), path, None, None, None) + return _get(Stage(self.dvc), path, None, None, None, None) def test(self): for path, cls in self.TESTS.items(): diff --git a/tests/test_run.py b/tests/test_run.py index 1cd56d28af..8c4e21a2a3 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -8,7 +8,8 @@ import yaml from dvc.main import main -from dvc.utils import file_md5 +from dvc.output import OutputBase +from dvc.utils import file_md5, load_stage_file from dvc.system import System from dvc.stage import Stage, StagePathNotFoundError, StagePathNotDirectoryError from dvc.stage import StageFileBadNameError, MissingDep @@ -768,3 +769,54 @@ def test(self): self.assertEqual(ret, 0) self.assertTrue(os.path.isfile(fname)) self.assertEqual(len(os.listdir(self.dvc.cache.local.cache_dir)), 1) + + +class TestRunPersist(TestDvc): + @property + def outs_command(self): + raise NotImplementedError + + def _test(self): + file = "file.txt" + file_content = "content" + stage_file = file + Stage.STAGE_FILE_SUFFIX + + ret = main( + [ + "run", + self.outs_command, + file, + "echo {} >> {}".format(file_content, file), + ] + ) + self.assertEqual(0, ret) + + stage_file_content = load_stage_file(stage_file) + self.assertEqual( + True, stage_file_content["outs"][0][OutputBase.PARAM_PERSIST] + ) + + ret = main(["repro", stage_file]) + self.assertEqual(0, ret) + + with open(file, "r") as fobj: + lines = fobj.readlines() + self.assertEqual(2, len(lines)) + + +class TestRunPersistOuts(TestRunPersist): + @property + def outs_command(self): + return "--outs-persist" + + def test(self): + self._test() + + +class TestRunPersistOutsNoCache(TestRunPersist): + @property + def outs_command(self): + return "--outs-persist-no-cache" + + def test(self): + self._test()