diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index bebf87626b..287490fcbd 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -4,7 +4,7 @@ import time from typing import Dict, Iterable, Optional -from funcy import cached_property, first +from funcy import cached_property, chain, first from dvc.exceptions import DvcException from dvc.ui import ui @@ -171,7 +171,18 @@ def reproduce_celery( ) -> Dict[str, str]: results: Dict[str, str] = {} if entries is None: - entries = list(self.celery_queue.iter_queued()) + entries = list( + chain( + self.celery_queue.iter_active(), + self.celery_queue.iter_queued(), + ) + ) + + logger.debug( + "reproduce all these entries '%s'", + entries, + ) + if not entries: return results @@ -189,7 +200,10 @@ def reproduce_celery( time.sleep(1) self.celery_queue.follow(entry) # wait for task collection to complete - result = self.celery_queue.get_result(entry) + try: + result = self.celery_queue.get_result(entry) + except FileNotFoundError: + result = None if result is None or result.exp_hash is None: name = entry.name or entry.stash_rev[:7] failed.append(name) diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index c93b047914..a8b9134b04 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -602,29 +602,22 @@ def _repro_dvc( logger.debug("Running repro in '%s'", os.getcwd()) yield dvc info.status = TaskStatus.SUCCESS - if infofile is not None: - info.dump_json(infofile) - except CheckpointKilledError: info.status = TaskStatus.FAILED - if infofile is not None: - info.dump_json(infofile) raise except DvcException: if log_errors: logger.exception("") info.status = TaskStatus.FAILED - if infofile is not None: - info.dump_json(infofile) raise except Exception: if log_errors: logger.exception("unexpected error") info.status = TaskStatus.FAILED - if infofile is not None: - info.dump_json(infofile) raise finally: + if infofile is not None: + info.dump_json(infofile) dvc.close() os.chdir(old_cwd) diff --git a/dvc/repo/experiments/executor/local.py b/dvc/repo/experiments/executor/local.py index e797429ef2..a62d03024a 100644 --- a/dvc/repo/experiments/executor/local.py +++ b/dvc/repo/experiments/executor/local.py @@ -47,9 +47,9 @@ def scm(self): return SCM(self.root_dir) def cleanup(self, infofile: str): - super().cleanup(infofile) self.scm.close() del self.scm + super().cleanup(infofile) def collect_cache( self, repo: "Repo", exp_ref: "ExpRefInfo", run_cache: bool = True diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index c509780885..12cf3f5925 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -23,12 +23,7 @@ from dvc.ui import ui from ..exceptions import UnresolvedQueueExpNamesError -from ..executor.base import ( - EXEC_TMP_DIR, - ExecutorInfo, - ExecutorResult, - TaskStatus, -) +from ..executor.base import EXEC_TMP_DIR, ExecutorInfo, ExecutorResult from .base import BaseStashQueue, QueueDoneResult, QueueEntry, QueueGetResult from .tasks import run_exp from .utils import fetch_running_exp_from_temp_dir @@ -187,6 +182,7 @@ def _iter_queued(self) -> Generator[_MessageEntry, None, None]: continue args, kwargs, _embed = msg.decode() entry_dict = kwargs.get("entry_dict", args[0]) + logger.debug("Found queued task %s", entry_dict["stash_rev"]) yield _MessageEntry(msg, QueueEntry.from_dict(entry_dict)) def _iter_processed(self) -> Generator[_MessageEntry, None, None]: @@ -203,6 +199,7 @@ def _iter_active_tasks(self) -> Generator[_TaskEntry, None, None]: task_id = msg.headers["id"] result: AsyncResult = AsyncResult(task_id) if not result.ready(): + logger.debug("Found active task %s", entry.stash_rev) yield _TaskEntry(result, entry) def _iter_done_tasks(self) -> Generator[_TaskEntry, None, None]: @@ -211,6 +208,7 @@ def _iter_done_tasks(self) -> Generator[_TaskEntry, None, None]: task_id = msg.headers["id"] result: AsyncResult = AsyncResult(task_id) if result.ready(): + logger.debug("Found done task %s", entry.stash_rev) yield _TaskEntry(result, entry) def iter_active(self) -> Generator[QueueEntry, None, None]: @@ -243,48 +241,50 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]: def reproduce(self) -> Mapping[str, Mapping[str, str]]: raise NotImplementedError - def get_result( + def _load_info(self, rev: str) -> ExecutorInfo: + infofile = self.get_infofile_path(rev) + return ExecutorInfo.load_json(infofile) + + def _get_done_result( self, entry: QueueEntry, timeout: Optional[float] = None - ) -> Optional[ExecutorResult]: + ) -> Optional["ExecutorResult"]: from celery.exceptions import TimeoutError as _CeleryTimeout - def _load_info(rev: str) -> ExecutorInfo: - infofile = self.get_infofile_path(rev) - return ExecutorInfo.load_json(infofile) - - def _load_collected(rev: str) -> Optional[ExecutorResult]: - executor_info = _load_info(rev) - if executor_info.status > TaskStatus.SUCCESS: + for msg, processed_entry in self._iter_processed(): + if entry.stash_rev == processed_entry.stash_rev: + task_id = msg.headers["id"] + result: AsyncResult = AsyncResult(task_id) + if not result.ready(): + logger.debug( + "Waiting for exp task '%s' to complete", result.id + ) + try: + result.get(timeout=timeout) + except _CeleryTimeout as exc: + raise DvcException( + "Timed out waiting for exp to finish." + ) from exc + executor_info = self._load_info(entry.stash_rev) return executor_info.result - raise FileNotFoundError + raise FileNotFoundError + + def get_result( + self, entry: QueueEntry, timeout: Optional[float] = None + ) -> Optional["ExecutorResult"]: try: - return _load_collected(entry.stash_rev) + return self._get_done_result(entry, timeout) except FileNotFoundError: - # Infofile will not be created until execution begins pass for queue_entry in self.iter_queued(): if entry.stash_rev == queue_entry.stash_rev: raise DvcException("Experiment has not been started.") - for result, active_entry in self._iter_active_tasks(): - if entry.stash_rev == active_entry.stash_rev: - logger.debug( - "Waiting for exp task '%s' to complete", result.id - ) - try: - result.get(timeout=timeout) - except _CeleryTimeout as exc: - raise DvcException( - "Timed out waiting for exp to finish." - ) from exc - executor_info = _load_info(entry.stash_rev) - return executor_info.result # NOTE: It's possible for an exp to complete while iterating through # other queued and active tasks, in which case the exp will get moved # out of the active task list, and needs to be loaded here. - return _load_collected(entry.stash_rev) + return self._get_done_result(entry, timeout) def kill(self, revs: Collection[str]) -> None: to_kill: Set[QueueEntry] = set() diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index 38361f0637..d03f2ea277 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -29,8 +29,7 @@ def run( of `repro` for that experiment. """ if run_all: - entries = list(repo.experiments.celery_queue.iter_queued()) - return repo.experiments.reproduce_celery(entries, jobs=jobs) + return repo.experiments.reproduce_celery(jobs=jobs) hydra_sweep = None if params: diff --git a/dvc/repo/experiments/show.py b/dvc/repo/experiments/show.py index 09653e6369..4e3a6adaf9 100644 --- a/dvc/repo/experiments/show.py +++ b/dvc/repo/experiments/show.py @@ -398,7 +398,6 @@ def show( repo, found_revs, running, - sha_only=sha_only, param_deps=param_deps, onerror=onerror, ) diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index c0b2665c7b..3eb7d64190 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -536,7 +536,7 @@ def test_run_celery(tmp_dir, scm, dvc, exp_stage, mocker): repro_spy = mocker.spy(dvc.experiments, "reproduce_celery") results = dvc.experiments.run(run_all=True) assert len(results) == 2 - repro_spy.assert_called_once_with(entries=mocker.ANY, jobs=1) + repro_spy.assert_called_once_with(jobs=1) expected = {"foo: 2", "foo: 3"} metrics = set() diff --git a/tests/func/experiments/test_queue.py b/tests/func/experiments/test_queue.py index 7cb4c5c531..0aac40706b 100644 --- a/tests/func/experiments/test_queue.py +++ b/tests/func/experiments/test_queue.py @@ -11,64 +11,54 @@ def to_dict(tasks): return status_dict -@pytest.fixture -def success_tasks(tmp_dir, dvc, scm, test_queue, exp_stage): - queue_length = 3 - name_list = [] - for i in range(queue_length): - name = f"success{i}" - name_list.append(name) - dvc.experiments.run( - exp_stage.addressing, params=[f"foo={i}"], queue=True, name=name - ) - dvc.experiments.run(run_all=True) - return name_list - - -@pytest.fixture -def failed_tasks(tmp_dir, dvc, scm, test_queue, failed_exp_stage): - queue_length = 3 - name_list = [] - for i in range(queue_length): - name = f"failed{i}" - name_list.append(name) - dvc.experiments.run( - failed_exp_stage.addressing, - params=[f"foo={i+queue_length}"], - queue=True, - name=name, - ) - dvc.experiments.run(run_all=True) - return name_list - - -@pytest.mark.xfail(strict=False, reason="pytest-celery flaky") @pytest.mark.parametrize("follow", [True, False]) def test_celery_logs( tmp_dir, scm, dvc, failed_exp_stage, - test_queue, follow, capsys, ): + celery_queue = dvc.experiments.celery_queue dvc.experiments.run(failed_exp_stage.addressing, queue=True) dvc.experiments.run(run_all=True) - queue = dvc.experiments.celery_queue - done_result = first(queue.iter_done()) + done_result = first(celery_queue.iter_done()) + name = done_result.entry.stash_rev captured = capsys.readouterr() - queue.logs(name, follow=follow) + celery_queue.logs(name, follow=follow) captured = capsys.readouterr() assert "failed to reproduce 'failed-copy-file'" in captured.out -@pytest.mark.xfail(strict=False, reason="pytest-celery flaky") -def test_queue_remove_done(dvc, failed_tasks, success_tasks): - assert len(dvc.experiments.celery_queue.failed_stash) == 3 - status = to_dict(dvc.experiments.celery_queue.status()) +def test_queue_remove_done( + dvc, + exp_stage, + failed_exp_stage, +): + queue_length = 3 + success_tasks = [] + failed_tasks = [] + celery_queue = dvc.experiments.celery_queue + for i in range(queue_length): + name = f"success{i}" + success_tasks.append(name) + dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"], queue=True, name=name + ) + name_fail = f"failed{i}" + failed_tasks.append(name_fail) + dvc.experiments.run( + failed_exp_stage.addressing, + params=[f"foo={i+queue_length}"], + queue=True, + name=name_fail, + ) + dvc.experiments.run(run_all=True) + assert len(celery_queue.failed_stash) == 3 + status = to_dict(celery_queue.status()) assert len(status) == 6 for name in failed_tasks: assert status[name] == "Failed" @@ -76,27 +66,21 @@ def test_queue_remove_done(dvc, failed_tasks, success_tasks): assert status[name] == "Success" with pytest.raises(InvalidArgumentError): - dvc.experiments.celery_queue.remove(failed_tasks[:2] + ["non-exist"]) - assert len(dvc.experiments.celery_queue.status()) == 6 + celery_queue.remove(failed_tasks[:2] + ["non-exist"]) + assert len(celery_queue.status()) == 6 to_remove = [failed_tasks[0], success_tasks[2]] - assert set(dvc.experiments.celery_queue.remove(to_remove)) == set( - to_remove - ) + assert set(celery_queue.remove(to_remove)) == set(to_remove) - assert len(dvc.experiments.celery_queue.failed_stash) == 2 - status = to_dict(dvc.experiments.celery_queue.status()) + assert len(celery_queue.failed_stash) == 2 + status = to_dict(celery_queue.status()) assert set(status) == set(failed_tasks[1:] + success_tasks[:2]) - assert dvc.experiments.celery_queue.clear(failed=True) == failed_tasks[1:] + assert set(celery_queue.clear(failed=True)) == set(failed_tasks[1:]) - assert len(dvc.experiments.celery_queue.failed_stash) == 0 - assert set(to_dict(dvc.experiments.celery_queue.status())) == set( - success_tasks[:2] - ) + assert len(celery_queue.failed_stash) == 0 + assert set(to_dict(celery_queue.status())) == set(success_tasks[:2]) - assert ( - dvc.experiments.celery_queue.clear(success=True) == success_tasks[:2] - ) + assert set(celery_queue.clear(success=True)) == set(success_tasks[:2]) - assert dvc.experiments.celery_queue.status() == [] + assert celery_queue.status() == [] diff --git a/tests/func/experiments/test_show.py b/tests/func/experiments/test_show.py index 0171e7b6aa..92f3aabbbc 100644 --- a/tests/func/experiments/test_show.py +++ b/tests/func/experiments/test_show.py @@ -174,7 +174,6 @@ def test_show_queued(tmp_dir, scm, dvc, exp_stage): @pytest.mark.vscode -@pytest.mark.xfail(strict=False, reason="pytest-celery flaky") def test_show_failed_experiment(tmp_dir, scm, dvc, failed_exp_stage): baseline_rev = scm.get_rev() timestamp = datetime.fromtimestamp( @@ -647,6 +646,9 @@ def _get_rev_isotimestamp(rev): result1 = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) rev1 = first(result1) ref_info1 = first(exp_refs_by_rev(scm, rev1)) + + # at least 1 second gap between these experiments to make sure + # the previous experiment to be regarded as branch_base time.sleep(1) result2 = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) rev2 = first(result2) diff --git a/tests/unit/repo/experiments/test_executor_status.py b/tests/unit/repo/experiments/test_executor_status.py index 06af2a925a..9d91d7223e 100644 --- a/tests/unit/repo/experiments/test_executor_status.py +++ b/tests/unit/repo/experiments/test_executor_status.py @@ -82,21 +82,34 @@ def test_workspace_executor_success_status(dvc, scm, exp_stage, queue_type): assert not os.path.exists(infofile) -@pytest.mark.parametrize("queue_type", ["workspace_queue", "tempdir_queue"]) +@pytest.mark.parametrize( + "queue_type", + ["workspace_queue", "tempdir_queue"], +) def test_workspace_executor_failed_status( dvc, scm, failed_exp_stage, queue_type ): - workspace_queue = getattr(dvc.experiments, queue_type) - queue_entry = workspace_queue.put( + queue = getattr(dvc.experiments, queue_type) + queue.put( params={"params.yaml": ["foo=1"]}, targets=failed_exp_stage.addressing, name="failed", ) - name = workspace_queue._EXEC_NAME or queue_entry.stash_rev + entry, executor = queue.get() + name = queue._EXEC_NAME or entry.stash_rev + infofile = queue.get_infofile_path(name) + rev = entry.stash_rev - infofile = workspace_queue.get_infofile_path(name) with pytest.raises(DvcException): - workspace_queue.reproduce() + executor.reproduce( + info=executor.info, + rev=rev, + infofile=infofile, + ) + executor_info = ExecutorInfo.load_json(infofile) + assert executor_info.status == TaskStatus.FAILED + + cleanup_exp.s(executor, infofile)() if queue_type == "workspace_queue": assert not os.path.exists(infofile) else: