diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 9040e0dffc..d2bb0eac4e 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -72,6 +72,7 @@ def __init__(self, root_dir=None): from dvc.scm.tree import WorkingTree from dvc.repo.tag import Tag from dvc.utils.fs import makedirs + from dvc.stage.cache import StageCache root_dir = self.find_root(root_dir) @@ -104,6 +105,8 @@ def __init__(self, root_dir=None): self.cache = Cache(self) self.cloud = DataCloud(self) + self.stage_cache = StageCache(self.cache.local.cache_dir) + self.metrics = Metrics(self) self.params = Params(self) self.tag = Tag(self) diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index e77934e89e..ac02135154 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -97,12 +97,7 @@ def reproduce( def _reproduce_stages( - G, - stages, - downstream=False, - ignore_build_cache=False, - single_item=False, - **kwargs + G, stages, downstream=False, single_item=False, **kwargs ): r"""Derive the evaluation of the given node for the given graph. @@ -170,7 +165,7 @@ def _reproduce_stages( try: ret = _reproduce_stage(stage, **kwargs) - if len(ret) != 0 and ignore_build_cache: + if len(ret) != 0 and kwargs.get("ignore_build_cache", False): # NOTE: we are walking our pipeline from the top to the # bottom. If one stage is changed, it will be reproduced, # which tells us that we should force reproducing all of diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 63e76b5c0f..14042b00e4 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -68,6 +68,9 @@ def run(self, fname=None, no_exec=False, **kwargs): raise OutputDuplicationError(exc.output, set(exc.stages) - {stage}) if not no_exec: - stage.run(no_commit=kwargs.get("no_commit", False)) + stage.run( + no_commit=kwargs.get("no_commit", False), + ignore_build_cache=kwargs.get("ignore_build_cache", False), + ) dvcfile.dump(stage, update_pipeline=True) return stage diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index 91e8f46cf6..21823b1597 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -494,6 +494,8 @@ def save(self): self.md5 = self._compute_md5() + self.repo.stage_cache.save(self) + @staticmethod def _changed_entries(entries): return [ @@ -620,7 +622,9 @@ def _run(self): raise StageCmdFailedError(self) @rwlocked(read=["deps"], write=["outs"]) - def run(self, dry=False, no_commit=False, force=False): + def run( + self, dry=False, no_commit=False, force=False, ignore_build_cache=False + ): if (self.cmd or self.is_import) and not self.locked and not dry: self.remove_outs(ignore_remove=False, force=False) @@ -653,16 +657,20 @@ def run(self, dry=False, no_commit=False, force=False): self.check_missing_outputs() else: - logger.info("Running command:\n\t{}".format(self.cmd)) if not dry: + if not force and not ignore_build_cache: + self.repo.stage_cache.restore(self) + if ( not force and not self.is_callback and not self.always_changed and self._already_cached() ): + logger.info("Stage is cached, skipping.") self.checkout() else: + logger.info("Running command:\n\t{}".format(self.cmd)) self._run() if not dry: diff --git a/dvc/stage/cache.py b/dvc/stage/cache.py new file mode 100644 index 0000000000..fcb0006898 --- /dev/null +++ b/dvc/stage/cache.py @@ -0,0 +1,134 @@ +import os +import yaml +import logging +import hashlib + +from voluptuous import Schema, Required, Invalid + +from dvc.utils.fs import makedirs +from dvc.utils import relpath + +logger = logging.getLogger(__name__) + +SCHEMA = Schema( + { + Required("cmd"): str, + Required("deps"): {str: str}, + Required("outs"): {str: str}, + } +) + + +def _sha256(string): + return hashlib.sha256(string.encode()).hexdigest() + + +def _get_cache_hash(cache, key=False): + string = _sha256(cache["cmd"]) + + for path, checksum in cache["deps"].items(): + string += _sha256(path) + string += _sha256(checksum) + + for path, checksum in cache["outs"].items(): + string += _sha256(path) + if not key: + string += _sha256(checksum) + + return _sha256(string) + + +def _get_stage_hash(stage): + if not stage.cmd or not stage.deps or not stage.outs: + return None + + for dep in stage.deps: + if dep.scheme != "local" or not dep.def_path or not dep.get_checksum(): + return None + + for out in stage.outs: + if out.scheme != "local" or not out.def_path or out.persist: + return None + + return _get_cache_hash(_create_cache(stage), key=True) + + +def _create_cache(stage): + return { + "cmd": stage.cmd, + "deps": {dep.def_path: dep.get_checksum() for dep in stage.deps}, + "outs": {out.def_path: out.get_checksum() for out in stage.outs}, + } + + +class StageCache: + def __init__(self, cache_dir): + self.cache_dir = os.path.join(cache_dir, "stages") + + def _get_cache_dir(self, key): + return os.path.join(self.cache_dir, key[:2], key) + + def _get_cache_path(self, key, value): + return os.path.join(self._get_cache_dir(key), value) + + def _load_cache(self, key, value): + path = self._get_cache_path(key, value) + + try: + with open(path, "r") as fobj: + return SCHEMA(yaml.load(fobj)) + except FileNotFoundError: + return None + except (yaml.error.YAMLError, Invalid): + logger.warning("corrupted cache file '%s'.", relpath(path)) + os.unlink(path) + return None + + def _load(self, stage): + key = _get_stage_hash(stage) + if not key: + return None + + cache_dir = self._get_cache_dir(key) + if not os.path.exists(cache_dir): + return None + + for value in os.listdir(cache_dir): + cache = self._load_cache(key, value) + if cache: + return cache + + return None + + def save(self, stage): + cache_key = _get_stage_hash(stage) + if not cache_key: + return + + cache = _create_cache(stage) + cache_value = _get_cache_hash(cache) + + if self._load_cache(cache_key, cache_value): + return + + # sanity check + SCHEMA(cache) + + path = self._get_cache_path(cache_key, cache_value) + dpath = os.path.dirname(path) + makedirs(dpath, exist_ok=True) + with open(path, "w+") as fobj: + yaml.dump(cache, fobj) + + def restore(self, stage): + cache = self._load(stage) + if not cache: + return + + deps = {dep.def_path: dep for dep in stage.deps} + for def_path, checksum in cache["deps"].items(): + deps[def_path].checksum = checksum + + outs = {out.def_path: out for out in stage.outs} + for def_path, checksum in cache["outs"].items(): + outs[def_path].checksum = checksum diff --git a/tests/func/__pycache__/tmpvpl_3i8b b/tests/func/__pycache__/tmpvpl_3i8b new file mode 100644 index 0000000000..7bc8bc6420 Binary files /dev/null and b/tests/func/__pycache__/tmpvpl_3i8b differ diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 186c0a8290..2db5467a37 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -1295,7 +1295,9 @@ def test(self): ["repro", self._get_stage_target(self.stage), "--no-commit"] ) self.assertEqual(ret, 0) - self.assertFalse(os.path.exists(self.dvc.cache.local.cache_dir)) + self.assertEqual( + os.listdir(self.dvc.cache.local.cache_dir), ["stages"] + ) class TestReproAlreadyCached(TestRepro): diff --git a/tests/unit/test_stage.py b/tests/unit/test_stage.py index 02924029e6..3d90d116c2 100644 --- a/tests/unit/test_stage.py +++ b/tests/unit/test_stage.py @@ -1,3 +1,4 @@ +import os import signal import subprocess import threading @@ -103,3 +104,39 @@ def test_always_changed(dvc): with dvc.lock: assert stage.changed() assert stage.status()["path"] == ["always changed"] + + +def test_stage_cache(tmp_dir, dvc, mocker): + tmp_dir.gen("dep", "dep") + stage = dvc.run(deps=["dep"], outs=["out"], cmd="echo content > out",) + + with dvc.lock, dvc.state: + stage.remove(remove_outs=True, force=True) + + assert not (tmp_dir / "out").exists() + assert not (tmp_dir / "out.dvc").exists() + + cache_dir = os.path.join( + dvc.stage_cache.cache_dir, + "dc", + "dc512ad947c7fd4df1037dc9c46efd83d8d5f88297a1c71baad81081cd216c34", + ) + cache_file = os.path.join( + cache_dir, + "fcbbdb34bfa75a1f821931b4714baf50e88d272a35c866597200bb2aac79621b", + ) + + assert os.path.isdir(cache_dir) + assert os.listdir(cache_dir) == [os.path.basename(cache_file)] + assert os.path.isfile(cache_file) + + run_spy = mocker.spy(stage, "_run") + checkout_spy = mocker.spy(stage, "checkout") + with dvc.lock, dvc.state: + stage.run() + + assert not run_spy.called + assert checkout_spy.call_count == 1 + + assert (tmp_dir / "out").exists() + assert (tmp_dir / "out").read_text() == "content\n"