diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 1530f16487..5c941205fa 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -14,6 +14,7 @@ from dvc.command.repro import CmdRepro from dvc.command.repro import add_arguments as add_repro_arguments from dvc.exceptions import DvcException, InvalidArgumentError +from dvc.repo.experiments import Experiments from dvc.utils.flatten import flatten logger = logging.getLogger(__name__) @@ -443,14 +444,6 @@ def run(self): elif not self.args.targets: self.args.targets = self.default_targets - if ( - self.args.checkpoint_reset - and self.args.checkpoint_continue is not None - ): - raise InvalidArgumentError( - "--continue and --reset cannot be used together" - ) - ret = 0 for target in self.args.targets: try: @@ -460,13 +453,7 @@ def run(self): run_all=self.args.run_all, jobs=self.args.jobs, params=self.args.params, - checkpoint=( - self.args.checkpoint - or self.args.checkpoint_continue is not None - or self.args.checkpoint_reset - ), - checkpoint_continue=self.args.checkpoint_continue, - checkpoint_reset=self.args.checkpoint_reset, + checkpoint_resume=self.args.checkpoint_resume, **self._repro_kwargs, ) except DvcException: @@ -738,65 +725,38 @@ def add_parser(subparsers, parent_parser): help=EXPERIMENTS_RUN_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) - # inherit arguments from `dvc repro` - add_repro_arguments(experiments_run_parser) - experiments_run_parser.add_argument( - "--params", - action="append", - default=[], - help="Use the specified param values when reproducing pipelines.", - metavar="[:]", - ) - experiments_run_parser.add_argument( - "--queue", - action="store_true", - default=False, - help="Stage this experiment in the run queue for future execution.", - ) + _add_run_common(experiments_run_parser) experiments_run_parser.add_argument( - "--run-all", - action="store_true", - default=False, - help="Execute all experiments in the run queue.", - ) - experiments_run_parser.add_argument( - "-j", - "--jobs", - type=int, - help="Run the specified number of experiments at a time in parallel.", - metavar="", + "--checkpoint-resume", type=str, default=None, help=argparse.SUPPRESS, ) - experiments_run_parser.add_argument( - "--checkpoint", - action="store_true", - default=False, - help="Reproduce pipelines as a checkpoint experiment.", + experiments_run_parser.set_defaults(func=CmdExperimentsRun) + + EXPERIMENTS_RESUME_HELP = "Resume checkpoint experiments." + experiments_resume_parser = experiments_subparsers.add_parser( + "resume", + parents=[parent_parser], + aliases=["res"], + description=append_doc_link( + EXPERIMENTS_RESUME_HELP, "experiments/resume" + ), + help=EXPERIMENTS_RESUME_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, ) - experiments_run_parser.add_argument( - "--continue", + _add_run_common(experiments_resume_parser) + experiments_resume_parser.add_argument( + "-r", + "--rev", type=str, - nargs="?", - default=None, - const=":last", - dest="checkpoint_continue", + default=Experiments.LAST_CHECKPOINT, + dest="checkpoint_resume", help=( - "Continue from the specified checkpoint experiment " - "(implies --checkpoint). If no experiment revision is provided, " + "Continue the specified checkpoint experiment. " + "If no experiment revision is provided, " "the most recently run checkpoint experiment will be used." ), metavar="", ) - experiments_run_parser.add_argument( - "--reset", - action="store_true", - default=False, - dest="checkpoint_reset", - help=( - "Reset checkpoint experiment if it already exists " - "(implies --checkpoint)." - ), - ) - experiments_run_parser.set_defaults(func=CmdExperimentsRun) + experiments_resume_parser.set_defaults(func=CmdExperimentsRun) EXPERIMENTS_GC_HELP = "Garbage collect unneeded experiments." EXPERIMENTS_GC_DESCRIPTION = ( @@ -856,3 +816,35 @@ def add_parser(subparsers, parent_parser): help="Force garbage collection - automatically agree to all prompts.", ) experiments_gc_parser.set_defaults(func=CmdExperimentsGC) + + +def _add_run_common(parser): + """Add common args for 'exp run' and 'exp resume'.""" + # inherit arguments from `dvc repro` + add_repro_arguments(parser) + parser.add_argument( + "--params", + action="append", + default=[], + help="Use the specified param values when reproducing pipelines.", + metavar="[:]", + ) + parser.add_argument( + "--queue", + action="store_true", + default=False, + help="Stage this experiment in the run queue for future execution.", + ) + parser.add_argument( + "--run-all", + action="store_true", + default=False, + help="Execute all experiments in the run queue.", + ) + parser.add_argument( + "-j", + "--jobs", + type=int, + help="Run the specified number of experiments at a time in parallel.", + metavar="", + ) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index bd3002aa3a..fa7e70c229 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -2,13 +2,10 @@ import os import re import tempfile +import threading from collections import namedtuple from collections.abc import Mapping -from concurrent.futures import ( - ProcessPoolExecutor, - ThreadPoolExecutor, - as_completed, -) +from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from functools import partial, wraps from typing import Iterable, Optional @@ -21,6 +18,7 @@ from dvc.repo.experiments.executor import ExperimentExecutor, LocalExecutor from dvc.scm.git import Git from dvc.stage import PipelineStage +from dvc.stage.run import CheckpointKilledError from dvc.stage.serialize import to_lockfile from dvc.tree.repo import RepoTree from dvc.utils import dict_sha256, env2bool, relpath @@ -71,9 +69,9 @@ def __init__(self, rev, continue_rev): msg = ( f"Checkpoint experiment containing '{rev[:7]}' already exists." " To restart the experiment run:\n\n" - "\tdvc exp run --reset ...\n\n" + "\tdvc exp run -f ...\n\n" "To resume the experiment, run:\n\n" - f"\tdvc exp run --continue {continue_rev[:7]}\n" + f"\tdvc exp resume {continue_rev[:7]}\n" ) super().__init__(msg) self.rev = rev @@ -106,6 +104,7 @@ class Experiments: r"^(?P[a-f0-9]{7})-(?P[a-f0-9]+)" r"(?P-checkpoint)?$" ) + LAST_CHECKPOINT = ":last" StashEntry = namedtuple("StashEntry", ["index", "baseline_rev", "branch"]) @@ -200,7 +199,7 @@ def _scm_checkout(self, rev): self.scm.repo.git.clean(force=True) if self.scm.repo.head.is_detached: self._checkout_default_branch() - if not Git.is_sha(rev) or not self.scm.has_rev(rev): + if not self.scm.has_rev(rev): self.scm.pull() logger.debug("Checking out experiment commit '%s'", rev) self.scm.checkout(rev) @@ -240,7 +239,7 @@ def _stash_exp( *args, params: Optional[dict] = None, branch: Optional[str] = None, - allow_unchanged: Optional[bool] = False, + allow_unchanged: Optional[bool] = True, apply_workspace: Optional[bool] = True, **kwargs, ): @@ -391,16 +390,13 @@ def _reset_checkpoint_branch(self, branch, rev, branch_tip, reset): def reproduce_one(self, queue=False, **kwargs): """Reproduce and checkout a single experiment.""" - checkpoint = kwargs.get("checkpoint", False) stash_rev = self.new(**kwargs) if queue: logger.info( "Queued experiment '%s' for future execution.", stash_rev[:7] ) return [stash_rev] - results = self.reproduce( - [stash_rev], keep_stash=False, checkpoint=checkpoint - ) + results = self.reproduce([stash_rev], keep_stash=False) exp_rev = first(results) if exp_rev is not None: self.checkout_exp(exp_rev) @@ -420,49 +416,15 @@ def reproduce_queued(self, **kwargs): @scm_locked def new( - self, - *args, - checkpoint: Optional[bool] = False, - checkpoint_continue: Optional[str] = None, - checkpoint_reset: Optional[bool] = False, - branch: Optional[str] = None, - **kwargs, + self, *args, branch: Optional[str] = None, **kwargs, ): """Create a new experiment. Experiment will be reproduced and checked out into the user's workspace. """ - if checkpoint_continue: - branch = None - if checkpoint_continue == ":last": - # Continue from most recently committed checkpoint - for head in sorted( - self.scm.repo.heads, - key=lambda h: h.commit.committed_date, - reverse=True, - ): - exp_branch = head.name - m = self.BRANCH_RE.match(exp_branch) - if m and m.group("checkpoint"): - branch = exp_branch - break - if not branch: - raise DvcException( - "No existing checkpoint experiment to continue" - ) - else: - rev = self.scm.resolve_rev(checkpoint_continue) - branch = self._get_branch_containing(rev) - if not branch: - raise DvcException( - "Could not find checkpoint experiment " - f"'{checkpoint_continue}'" - ) - logger.debug( - "Continuing checkpoint experiment '%s'", checkpoint_continue - ) - kwargs["apply_workspace"] = False + if kwargs.get("checkpoint_resume", None) is not None: + return self._resume_checkpoint(*args, **kwargs) if branch: rev = self.scm.resolve_rev(branch) @@ -473,13 +435,10 @@ def new( rev = self.repo.scm.get_rev() self._scm_checkout(rev) + force = kwargs.get("force", False) try: stash_rev = self._stash_exp( - *args, - branch=branch, - allow_unchanged=checkpoint, - checkpoint_reset=checkpoint_reset, - **kwargs, + *args, branch=branch, checkpoint_reset=force, **kwargs, ) except UnchangedExperimentError as exc: logger.info("Reproducing existing experiment '%s'.", rev[:7]) @@ -489,12 +448,63 @@ def new( ) return stash_rev + def _resume_checkpoint( + self, *args, checkpoint_resume: Optional[str] = None, **kwargs, + ): + """Resume an existing (checkpoint) experiment. + + Experiment will be reproduced and checked out into the user's + workspace. + """ + assert checkpoint_resume + + branch = None + if checkpoint_resume == self.LAST_CHECKPOINT: + # Continue from most recently committed checkpoint + for head in sorted( + self.scm.repo.heads, + key=lambda h: h.commit.committed_date, + reverse=True, + ): + exp_branch = head.name + m = self.BRANCH_RE.match(exp_branch) + if m and m.group("checkpoint"): + branch = exp_branch + break + if not branch: + raise DvcException( + "No existing checkpoint experiment to continue" + ) + else: + rev = self.scm.resolve_rev(checkpoint_resume) + branch = self._get_branch_containing(rev) + if not branch: + raise DvcException( + "Could not find checkpoint experiment " + f"'{checkpoint_resume}'" + ) + logger.debug( + "Continuing checkpoint experiment '%s'", checkpoint_resume + ) + kwargs["apply_workspace"] = False + + rev = self.scm.resolve_rev(branch) + logger.debug( + "Using '%s' (tip of branch '%s') as baseline", rev, branch + ) + self._scm_checkout(rev) + + stash_rev = self._stash_exp(*args, branch=branch, **kwargs) + logger.debug( + "Stashed experiment '%s' for future execution.", stash_rev[:7] + ) + return stash_rev + @scm_locked def reproduce( self, revs: Optional[Iterable] = None, keep_stash: Optional[bool] = True, - checkpoint: Optional[bool] = False, **kwargs, ): """Reproduce the specified experiments. @@ -547,10 +557,7 @@ def reproduce( self._collect_input(executor) executors[rev] = executor - if checkpoint: - exec_results = self._reproduce_checkpoint(executors) - else: - exec_results = self._reproduce(executors, **kwargs) + exec_results = self._reproduce(executors, **kwargs) if keep_stash: # only drop successfully run stashed experiments @@ -583,86 +590,107 @@ def _reproduce(self, executors: dict, jobs: Optional[int] = 1) -> dict: """ result = {} - with ProcessPoolExecutor(max_workers=jobs) as workers: + collect_lock = threading.Lock() + + with ThreadPoolExecutor(max_workers=jobs) as workers: futures = {} for rev, executor in executors.items(): + checkpoint_func = partial( + self._checkpoint_callback, + result, + collect_lock, + rev, + executor, + ) future = workers.submit( executor.reproduce, executor.dvc_dir, cwd=executor.dvc.root_dir, + checkpoint_func=checkpoint_func, **executor.repro_kwargs, ) futures[future] = (rev, executor) + for future in as_completed(futures): rev, executor = futures[future] exc = future.exception() - if exc is None: - exp_hash = future.result() - if executor.branch: - self._scm_checkout(executor.branch) + + try: + if exc is None: + stages = future.result() + self._collect_executor( + rev, executor, stages, result, collect_lock + ) else: - self._scm_checkout(executor.baseline_rev) - exp_rev = self._collect_and_commit(rev, executor, exp_hash) - if exp_rev: - logger.info("Reproduced experiment '%s'.", exp_rev[:7]) - result[rev] = {exp_rev: exp_hash} - else: - logger.exception( - "Failed to reproduce experiment '%s'", rev[:7] - ) - executor.cleanup() + # Checkpoint errors have already been logged + if not isinstance(exc, CheckpointKilledError): + logger.exception( + "Failed to reproduce experiment '%s'", + rev[:7], + exc_info=exc, + ) + finally: + executor.cleanup() return result - def _reproduce_checkpoint(self, executors): - result = {} - for rev, executor in executors.items(): - logger.debug("Reproducing checkpoint experiment '%s'", rev[:7]) + def _collect_executor(self, rev, executor, stages, result, lock): + exp_hash = hash_exp(stages) + checkpoint = any(stage.is_checkpoint for stage in stages) + + lock.acquire() + try: + # NOTE: GitPython Repo instances cannot be re-used + # after process has received SIGINT or SIGTERM, so we + # need this hack to re-instantiate git instances after + # checkpoint runs. See: + # https://github.com/gitpython-developers/GitPython/issues/427 + del self.repo.scm + del self.scm if executor.branch: self._scm_checkout(executor.branch) else: self._scm_checkout(executor.baseline_rev) - - def _checkpoint_callback(rev, executor, unchanged, stages): - exp_hash = hash_exp(stages + unchanged) - exp_rev = self._collect_and_commit( - rev, executor, exp_hash, checkpoint=True - ) - if exp_rev: - if not executor.branch: - branch = self._get_branch_containing(exp_rev) - executor.branch = branch - logger.info( - "Checkpoint experiment iteration '%s'.", exp_rev[:7] - ) - result[rev] = {exp_rev: exp_hash} - - checkpoint_func = partial(_checkpoint_callback, rev, executor) - - exp_hash = executor.reproduce( - executor.dvc_dir, - cwd=executor.dvc.root_dir, - checkpoint=True, - checkpoint_func=checkpoint_func, - **executor.repro_kwargs, + exp_rev = self._collect_and_commit( + rev, executor, exp_hash, checkpoint=checkpoint ) + if exp_rev: + logger.info("Reproduced experiment '%s'.", exp_rev[:7]) + result[rev] = {exp_rev: exp_hash} + finally: + lock.release() - # NOTE: GitPython Repo instances cannot be re-used after - # process has received SIGINT or SIGTERM, so we need this hack - # to re-instantiate git instances after checkpoint runs. See: - # https://github.com/gitpython-developers/GitPython/issues/427 - del self.repo.scm - del self.scm + def _checkpoint_callback( + self, + result: Mapping, + lock: threading.Lock, + rev: str, + executor: LocalExecutor, + unchanged: Iterable, + stages: Iterable, + ): + lock.acquire() + try: + if executor.branch: + self._scm_checkout(executor.branch) + else: + self._scm_checkout(executor.baseline_rev) - # Create final checkpoint commit if needed + exp_hash = hash_exp(stages + unchanged) exp_rev = self._collect_and_commit( rev, executor, exp_hash, checkpoint=True ) - if exp_rev not in result[rev]: + if exp_rev: + if not executor.branch: + branch = self._get_branch_containing(exp_rev) + executor.branch = branch + logger.info( + "Checkpoint experiment iteration '%s'.", exp_rev[:7] + ) result[rev] = {exp_rev: exp_hash} - - return result + finally: + lock.release() def _collect_and_commit(self, rev, executor, exp_hash, **kwargs): try: diff --git a/dvc/repo/experiments/executor.py b/dvc/repo/experiments/executor.py index 5cdacd54c6..3869926605 100644 --- a/dvc/repo/experiments/executor.py +++ b/dvc/repo/experiments/executor.py @@ -110,6 +110,7 @@ def __init__( # to run repro self._tree.CACHE_MODE = 0o644 self.checkpoint_reset = checkpoint_reset + self.checkpoint = False def _config(self, cache_dir): local_config = os.path.join(self.dvc_dir, "config.local") @@ -136,7 +137,6 @@ def tree(self): def reproduce(dvc_dir, cwd=None, **kwargs): """Run dvc repro and return the result.""" from dvc.repo import Repo - from dvc.repo.experiments import hash_exp unchanged = [] @@ -167,15 +167,8 @@ def filter_pipeline(stages): # be removed/does not yet exist) so that our executor workspace # is not polluted with the (persistent) out from an unrelated # experiment run - checkpoint = kwargs.pop("checkpoint", False) - dvc.checkout( - allow_missing=checkpoint, force=checkpoint, quiet=checkpoint - ) - stages = dvc.reproduce( - on_unchanged=filter_pipeline, - allow_missing=checkpoint, - **kwargs, - ) + dvc.checkout(force=True, quiet=True) + stages = dvc.reproduce(on_unchanged=filter_pipeline, **kwargs) finally: if old_cwd is not None: os.chdir(old_cwd) @@ -183,7 +176,7 @@ def filter_pipeline(stages): # ideally we would return stages here like a normal repro() call, but # stages is not currently picklable and cannot be returned across # multiprocessing calls - return hash_exp(stages + unchanged) + return stages + unchanged def collect_output(self) -> Iterable["PathInfo"]: repo_tree = RepoTree(self.dvc) diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index e7a61f6127..bac335d397 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -3,6 +3,7 @@ from dvc.exceptions import InvalidArgumentError, ReproductionError from dvc.repo.scm_context import scm_context +from dvc.stage.run import CheckpointKilledError from . import locked from .graph import get_pipeline, get_pipelines @@ -13,7 +14,6 @@ def _reproduce_stage(stage, **kwargs): def _run_callback(repro_callback): _dump_stage(stage) - logger.debug(f"{repro_callback} ([{stage}])") repro_callback([stage]) checkpoint_func = kwargs.pop("checkpoint_func", None) @@ -154,33 +154,7 @@ def _reproduce_stages( The derived evaluation of _downstream_ B would be: [B, D, E] """ - import networkx as nx - - if single_item: - all_pipelines = stages - else: - all_pipelines = [] - for stage in stages: - if downstream: - # NOTE (py3 only): - # Python's `deepcopy` defaults to pickle/unpickle the object. - # Stages are complex objects (with references to `repo`, - # `outs`, and `deps`) that cause struggles when you try - # to serialize them. We need to create a copy of the graph - # itself, and then reverse it, instead of using - # graph.reverse() directly because it calls `deepcopy` - # underneath -- unless copy=False is specified. - nodes = nx.dfs_postorder_nodes( - G.copy().reverse(copy=False), stage - ) - all_pipelines += reversed(list(nodes)) - else: - all_pipelines += nx.dfs_postorder_nodes(G, stage) - - pipeline = [] - for stage in all_pipelines: - if stage not in pipeline: - pipeline.append(stage) + pipeline = _get_pipeline(G, stages, downstream, single_item) force_downstream = kwargs.pop("force_downstream", False) result = [] @@ -212,6 +186,8 @@ def _reproduce_stages( if ret: result.extend(ret) + except CheckpointKilledError: + raise except Exception as exc: raise ReproductionError(stage.relpath) from exc @@ -220,5 +196,37 @@ def _reproduce_stages( return result +def _get_pipeline(G, stages, downstream, single_item): + import networkx as nx + + if single_item: + all_pipelines = stages + else: + all_pipelines = [] + for stage in stages: + if downstream: + # NOTE (py3 only): + # Python's `deepcopy` defaults to pickle/unpickle the object. + # Stages are complex objects (with references to `repo`, + # `outs`, and `deps`) that cause struggles when you try + # to serialize them. We need to create a copy of the graph + # itself, and then reverse it, instead of using + # graph.reverse() directly because it calls `deepcopy` + # underneath -- unless copy=False is specified. + nodes = nx.dfs_postorder_nodes( + G.copy().reverse(copy=False), stage + ) + all_pipelines += reversed(list(nodes)) + else: + all_pipelines += nx.dfs_postorder_nodes(G, stage) + + pipeline = [] + for stage in all_pipelines: + if stage not in pipeline: + pipeline.append(stage) + + return pipeline + + def _repro_callback(experiments_callback, unchanged, stages): experiments_callback(unchanged, stages) diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index 5ac54a5a46..abbdb61336 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -485,6 +485,7 @@ def _func(o): @rwlocked(write=["outs"]) def checkout(self, **kwargs): stats = defaultdict(list) + kwargs["allow_missing"] = self.is_checkpoint for out in self.filter_outs(kwargs.get("filter_info")): key, outs = self._checkout(out, **kwargs) if key: diff --git a/dvc/stage/run.py b/dvc/stage/run.py index 9c5cfb88c3..591350b647 100644 --- a/dvc/stage/run.py +++ b/dvc/stage/run.py @@ -16,6 +16,10 @@ CHECKPOINT_SIGNAL_FILE = "DVC_CHECKPOINT" +class CheckpointKilledError(StageCmdFailedError): + pass + + def _nix_cmd(executable, cmd): opts = {"zsh": ["--no-rcs"], "bash": ["--noprofile", "--norc"]} name = os.path.basename(executable).lower() @@ -79,7 +83,8 @@ def cmd_run(stage, *args, checkpoint_func=None, **kwargs): if main_thread: old_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) - with checkpoint_monitor(stage, checkpoint_func, p): + killed = threading.Event() + with checkpoint_monitor(stage, checkpoint_func, p, killed): p.communicate() finally: if old_handler: @@ -87,6 +92,8 @@ def cmd_run(stage, *args, checkpoint_func=None, **kwargs): retcode = None if not p else p.returncode if retcode != 0: + if killed.is_set(): + raise CheckpointKilledError(stage.cmd, retcode) raise StageCmdFailedError(stage.cmd, retcode) @@ -111,44 +118,32 @@ def run_stage(stage, dry=False, force=False, checkpoint_func=None, **kwargs): cmd_run(stage, checkpoint_func=checkpoint_func) -class CheckpointCond: - def __init__(self): - self.done = False - self.cond = threading.Condition() - - def notify(self): - with self.cond: - self.done = True - self.cond.notify() - - def wait(self, timeout=None): - with self.cond: - return self.cond.wait(timeout) or self.done - - @contextmanager -def checkpoint_monitor(stage, callback_func, proc): +def checkpoint_monitor(stage, callback_func, proc, killed): if not callback_func: yield None return logger.debug( - "Monitoring checkpoint stage '%s' with cmd process '%s'", stage, proc + "Monitoring checkpoint stage '%s' with cmd process '%d'", + stage, + proc.pid, ) - done_cond = CheckpointCond() + done = threading.Event() monitor_thread = threading.Thread( - target=_checkpoint_run, args=(stage, callback_func, done_cond, proc), + target=_checkpoint_run, + args=(stage, callback_func, done, proc, killed), ) try: monitor_thread.start() yield monitor_thread finally: - done_cond.notify() + done.set() monitor_thread.join() -def _checkpoint_run(stage, callback_func, done_cond, proc): +def _checkpoint_run(stage, callback_func, done, proc, killed): """Run callback_func whenever checkpoint signal file is present.""" signal_path = os.path.join(stage.repo.tmp_dir, CHECKPOINT_SIGNAL_FILE) while True: @@ -159,17 +154,31 @@ def _checkpoint_run(stage, callback_func, done_cond, proc): logger.exception( "Error generating checkpoint, %s will be aborted", stage ) - proc.terminate() + _kill(proc) + killed.set() finally: logger.debug("Remove checkpoint signal file") os.remove(signal_path) - if done_cond.wait(1): + if done.wait(1): return +def _kill(proc): + if os.name == "nt": + return _kill_nt(proc) + proc.terminate() + proc.wait() + + +def _kill_nt(proc): + # windows stages are spawned with shell=True, proc is the shell process and + # not the actual stage process - we have to kill the entire tree + subprocess.call(["taskkill", "/F", "/T", "/PID", str(proc.pid)]) + + @relock_repo def _run_callback(stage, callback_func): stage.save(allow_missing=True) - # TODO: do we need commit() (and check for --no-commit) here + stage.commit(allow_missing=True) logger.debug("Running checkpoint callback for stage '%s'", stage) callback_func() diff --git a/tests/func/experiments/test_checkpoints.py b/tests/func/experiments/test_checkpoints.py new file mode 100644 index 0000000000..3d1c2d33cf --- /dev/null +++ b/tests/func/experiments/test_checkpoints.py @@ -0,0 +1,119 @@ +import logging +from textwrap import dedent + +import pytest +from funcy import first + +from dvc.exceptions import DvcException +from dvc.repo.experiments import Experiments + +CHECKPOINT_SCRIPT_FORMAT = dedent( + """\ + import os + import sys + import shutil + from time import sleep + + from dvc.api import make_checkpoint + + checkpoint_file = {} + checkpoint_iterations = int({}) + if os.path.exists(checkpoint_file): + with open(checkpoint_file) as fobj: + try: + value = int(fobj.read()) + except ValueError: + value = 0 + else: + with open(checkpoint_file, "w"): + pass + value = 0 + + shutil.copyfile({}, {}) + + if os.getenv("DVC_CHECKPOINT"): + for _ in range(checkpoint_iterations): + value += 1 + with open(checkpoint_file, "w") as fobj: + fobj.write(str(value)) + make_checkpoint() +""" +) +CHECKPOINT_SCRIPT = CHECKPOINT_SCRIPT_FORMAT.format( + "sys.argv[1]", "sys.argv[2]", "sys.argv[3]", "sys.argv[4]" +) + + +@pytest.fixture +def checkpoint_stage(tmp_dir, scm, dvc): + tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT) + tmp_dir.gen("params.yaml", "foo: 1") + stage = dvc.run( + cmd="python checkpoint.py foo 5 params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + checkpoints=["foo"], + no_exec=True, + name="checkpoint-file", + ) + scm.add(["dvc.yaml", "checkpoint.py", "params.yaml"]) + scm.commit("init") + return stage + + +def test_new_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, mocker): + new_mock = mocker.spy(dvc.experiments, "new") + dvc.experiments.run(checkpoint_stage.addressing, params=["foo=2"]) + + new_mock.assert_called_once() + assert (tmp_dir / "foo").read_text() == "5" + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text().strip() == "foo: 2" + + +@pytest.mark.parametrize("last", [True, False]) +def test_resume_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, mocker, last): + with pytest.raises(DvcException): + if last: + dvc.experiments.run( + checkpoint_stage.addressing, + checkpoint_resume=Experiments.LAST_CHECKPOINT, + ) + else: + dvc.experiments.run( + checkpoint_stage.addressing, checkpoint_resume="foo" + ) + + results = dvc.experiments.run( + checkpoint_stage.addressing, params=["foo=2"] + ) + if last: + exp_rev = Experiments.LAST_CHECKPOINT + else: + exp_rev = first(results) + + dvc.experiments.run(checkpoint_stage.addressing, checkpoint_resume=exp_rev) + + assert (tmp_dir / "foo").read_text() == "10" + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text().strip() == "foo: 2" + + +def test_reset_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, mocker, caplog): + dvc.experiments.run(checkpoint_stage.addressing) + scm.repo.git.reset(hard=True) + scm.repo.git.clean(force=True) + + with caplog.at_level(logging.ERROR): + results = dvc.experiments.run(checkpoint_stage.addressing) + assert len(results) == 0 + assert "already exists" in caplog.text + + dvc.experiments.run(checkpoint_stage.addressing, force=True) + + assert (tmp_dir / "foo").read_text() == "5" + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text().strip() == "foo: 1" diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 9beed5ea18..49578a09d6 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -1,4 +1,4 @@ -from textwrap import dedent +import logging import pytest from funcy import first @@ -6,44 +6,9 @@ from dvc.utils.serialize import PythonFileCorruptedError from tests.func.test_repro_multistage import COPY_SCRIPT -CHECKPOINT_SCRIPT_FORMAT = dedent( - """\ - import os - import sys - import shutil - from time import sleep - - from dvc.api import make_checkpoint - - checkpoint_file = {} - checkpoint_iterations = int({}) - if os.path.exists(checkpoint_file): - with open(checkpoint_file) as fobj: - try: - value = int(fobj.read()) - except ValueError: - value = 0 - else: - with open(checkpoint_file, "w"): - pass - value = 0 - - shutil.copyfile({}, {}) - - if os.getenv("DVC_CHECKPOINT"): - for _ in range(checkpoint_iterations): - value += 1 - with open(checkpoint_file, "w") as fobj: - fobj.write(str(value)) - make_checkpoint() -""" -) -CHECKPOINT_SCRIPT = CHECKPOINT_SCRIPT_FORMAT.format( - "sys.argv[1]", "sys.argv[2]", "sys.argv[3]", "sys.argv[4]" -) - -def test_new_simple(tmp_dir, scm, dvc, mocker): +@pytest.fixture +def exp_stage(tmp_dir, scm, dvc): tmp_dir.gen("copy.py", COPY_SCRIPT) tmp_dir.gen("params.yaml", "foo: 1") stage = dvc.run( @@ -54,11 +19,14 @@ def test_new_simple(tmp_dir, scm, dvc, mocker): ) scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) scm.commit("init") + return stage + +def test_new_simple(tmp_dir, scm, dvc, exp_stage, mocker): tmp_dir.gen("params.yaml", "foo: 2") new_mock = mocker.spy(dvc.experiments, "new") - dvc.experiments.run(stage.addressing) + dvc.experiments.run(exp_stage.addressing) new_mock.assert_called_once() assert ( @@ -66,27 +34,31 @@ def test_new_simple(tmp_dir, scm, dvc, mocker): ).read_text() == "foo: 2" -def test_update_with_pull(tmp_dir, scm, dvc, mocker): - tmp_dir.gen("copy.py", COPY_SCRIPT) - tmp_dir.gen("params.yaml", "foo: 1") - stage = dvc.run( - cmd="python copy.py params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo"], - name="copy-file", +def test_failed_exp(tmp_dir, scm, dvc, exp_stage, mocker, caplog): + from dvc.stage.exceptions import StageCmdFailedError + + tmp_dir.gen("params.yaml", "foo: 2") + + mocker.patch( + "dvc.stage.run.cmd_run", + side_effect=StageCmdFailedError(exp_stage.cmd, -1), ) - scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) - scm.commit("init") + with caplog.at_level(logging.ERROR): + dvc.experiments.run(exp_stage.addressing) + assert "Failed to reproduce experiment" in caplog.text + + +def test_update_with_pull(tmp_dir, scm, dvc, exp_stage, mocker): expected_revs = [scm.get_rev()] tmp_dir.gen("params.yaml", "foo: 2") - dvc.experiments.run(stage.addressing) + dvc.experiments.run(exp_stage.addressing) scm.add(["dvc.yaml", "dvc.lock", "params.yaml", "metrics.yaml"]) scm.commit("promote experiment") expected_revs.append(scm.get_rev()) tmp_dir.gen("params.yaml", "foo: 3") - dvc.experiments.run(stage.addressing) + dvc.experiments.run(exp_stage.addressing) exp_scm = dvc.experiments.scm for rev in expected_revs: @@ -122,22 +94,11 @@ def test_modify_list_param(tmp_dir, scm, dvc, mocker, change, expected): ).read_text().strip() == expected -def test_checkout(tmp_dir, scm, dvc): - tmp_dir.gen("copy.py", COPY_SCRIPT) - tmp_dir.gen("params.yaml", "foo: 1") - stage = dvc.run( - cmd="python copy.py params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo"], - name="copy-file", - ) - scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) - scm.commit("init") - - results = dvc.experiments.run(stage.addressing, params=["foo=2"]) +def test_checkout(tmp_dir, scm, dvc, exp_stage): + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) exp_a = first(results) - results = dvc.experiments.run(stage.addressing, params=["foo=3"]) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) exp_b = first(results) dvc.experiments.checkout(exp_a) @@ -149,25 +110,15 @@ def test_checkout(tmp_dir, scm, dvc): assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 3" -def test_get_baseline(tmp_dir, scm, dvc): - tmp_dir.gen("copy.py", COPY_SCRIPT) - tmp_dir.gen("params.yaml", "foo: 1") - stage = dvc.run( - cmd="python copy.py params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo"], - name="copy-file", - ) - scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) - scm.commit("init") +def test_get_baseline(tmp_dir, scm, dvc, exp_stage): init_rev = scm.get_rev() assert dvc.experiments.get_baseline(init_rev) is None - results = dvc.experiments.run(stage.addressing, params=["foo=2"]) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) exp_rev = first(results) assert dvc.experiments.get_baseline(exp_rev) == init_rev - dvc.experiments.run(stage.addressing, params=["foo=3"], queue=True) + dvc.experiments.run(exp_stage.addressing, params=["foo=3"], queue=True) assert dvc.experiments.get_baseline("stash@{0}") == init_rev dvc.experiments.checkout(exp_rev) @@ -175,12 +126,12 @@ def test_get_baseline(tmp_dir, scm, dvc): scm.commit("promote exp") promote_rev = scm.get_rev() - results = dvc.experiments.run(stage.addressing, params=["foo=4"]) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"]) exp_rev = first(results) assert dvc.experiments.get_baseline(promote_rev) is None assert dvc.experiments.get_baseline(exp_rev) == promote_rev - dvc.experiments.run(stage.addressing, params=["foo=5"], queue=True) + dvc.experiments.run(exp_stage.addressing, params=["foo=5"], queue=True) assert dvc.experiments.get_baseline("stash@{0}") == promote_rev assert dvc.experiments.get_baseline("stash@{1}") == init_rev @@ -249,24 +200,13 @@ def test_update_py_params(tmp_dir, scm, dvc): dvc.experiments.run(stage.addressing, params=["params.py:INT=2a"]) -def test_extend_branch(tmp_dir, scm, dvc): - tmp_dir.gen("copy.py", COPY_SCRIPT) - tmp_dir.gen("params.yaml", "foo: 1") - stage = dvc.run( - cmd="python copy.py params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo"], - name="copy-file", - ) - scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) - scm.commit("init") - - results = dvc.experiments.run(stage.addressing, params=["foo=2"]) +def test_extend_branch(tmp_dir, scm, dvc, exp_stage): + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) exp_a = first(results) exp_branch = dvc.experiments._get_branch_containing(exp_a) results = dvc.experiments.run( - stage.addressing, + exp_stage.addressing, params=["foo=3"], branch=exp_branch, apply_workspace=False, @@ -285,119 +225,19 @@ def test_extend_branch(tmp_dir, scm, dvc): assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 3" -def test_detached_parent(tmp_dir, scm, dvc, mocker): - tmp_dir.gen("copy.py", COPY_SCRIPT) - tmp_dir.gen("params.yaml", "foo: 1") - stage = dvc.run( - cmd="python copy.py params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo"], - name="copy-file", - ) - scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) - scm.commit("v1") +def test_detached_parent(tmp_dir, scm, dvc, exp_stage, mocker): detached_rev = scm.get_rev() tmp_dir.gen("params.yaml", "foo: 2") - dvc.reproduce(stage.addressing) + dvc.reproduce(exp_stage.addressing) scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) scm.commit("v2") scm.checkout(detached_rev) assert scm.repo.head.is_detached - results = dvc.experiments.run(stage.addressing, params=["foo=3"]) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) exp_rev = first(results) assert dvc.experiments.get_baseline(exp_rev) == detached_rev assert (tmp_dir / "params.yaml").read_text().strip() == "foo: 3" assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 3" - - -def test_new_checkpoint(tmp_dir, scm, dvc, mocker): - tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT) - tmp_dir.gen("params.yaml", "foo: 1") - stage = dvc.run( - cmd="python checkpoint.py foo 5 params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo"], - checkpoints=["foo"], - no_exec=True, - name="checkpoint-file", - ) - scm.add(["dvc.yaml", "checkpoint.py", "params.yaml"]) - scm.commit("init") - - new_mock = mocker.spy(dvc.experiments, "new") - dvc.experiments.run(stage.addressing, checkpoint=True, params=["foo=2"]) - - new_mock.assert_called_once() - assert (tmp_dir / "foo").read_text() == "5" - assert ( - tmp_dir / ".dvc" / "experiments" / "metrics.yaml" - ).read_text().strip() == "foo: 2" - - -@pytest.mark.parametrize("last", [True, False]) -def test_continue_checkpoint(tmp_dir, scm, dvc, mocker, last): - tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT) - tmp_dir.gen("params.yaml", "foo: 1") - stage = dvc.run( - cmd="python checkpoint.py foo 5 params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo"], - checkpoints=["foo"], - no_exec=True, - name="checkpoint-file", - ) - scm.add(["dvc.yaml", "checkpoint.py", "params.yaml"]) - scm.commit("init") - - results = dvc.experiments.run( - stage.addressing, checkpoint=True, params=["foo=2"] - ) - if last: - exp_rev = ":last" - else: - exp_rev = first(results) - - dvc.experiments.run( - stage.addressing, checkpoint=True, checkpoint_continue=exp_rev, - ) - - assert (tmp_dir / "foo").read_text() == "10" - assert ( - tmp_dir / ".dvc" / "experiments" / "metrics.yaml" - ).read_text().strip() == "foo: 2" - - -def test_reset_checkpoint(tmp_dir, scm, dvc, mocker): - from dvc.exceptions import ReproductionError - - tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT) - tmp_dir.gen("params.yaml", "foo: 1") - stage = dvc.run( - cmd="python checkpoint.py foo 5 params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo"], - checkpoints=["foo"], - no_exec=True, - name="checkpoint-file", - ) - scm.add(["dvc.yaml", "checkpoint.py", "params.yaml"]) - scm.commit("init") - - dvc.experiments.run(stage.addressing, checkpoint=True) - scm.repo.git.reset(hard=True) - scm.repo.git.clean(force=True) - - with pytest.raises(ReproductionError): - dvc.experiments.run(stage.addressing, checkpoint=True) - - dvc.experiments.run( - stage.addressing, checkpoint=True, checkpoint_reset=True - ) - - assert (tmp_dir / "foo").read_text() == "5" - assert ( - tmp_dir / ".dvc" / "experiments" / "metrics.yaml" - ).read_text().strip() == "foo: 1" diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 22fb6a0ef5..c48cc48f11 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -67,20 +67,21 @@ def test_experiments_show(dvc, mocker): ) -def test_experiments_run(dvc, mocker): +@pytest.mark.parametrize( + "args, resume", [(["exp", "run"], None), (["exp", "resume"], ":last")] +) +def test_experiments_run(dvc, mocker, args, resume): default_arguments = { "params": [], "queue": False, "run_all": False, "jobs": None, - "checkpoint": False, - "checkpoint_continue": None, - "checkpoint_reset": False, + "checkpoint_resume": resume, } default_arguments.update(repro_arguments) - cmd = CmdExperimentsRun(parse_args(["exp", "run"])) + cmd = CmdExperimentsRun(parse_args(args)) mocker.patch.object(cmd.repo, "reproduce") mocker.patch.object(cmd.repo.experiments, "run") cmd.run()