Skip to content

Fix some celery queue related ci failure. #8404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from typing import Dict, Iterable, Optional

from funcy import cached_property, first
from funcy import cached_property, chain, first

from dvc.exceptions import DvcException
from dvc.ui import ui
Expand Down Expand Up @@ -171,7 +171,18 @@ def reproduce_celery(
) -> Dict[str, str]:
results: Dict[str, str] = {}
if entries is None:
entries = list(self.celery_queue.iter_queued())
entries = list(
chain(
self.celery_queue.iter_active(),
self.celery_queue.iter_queued(),
)
)

logger.debug(
"reproduce all these entries '%s'",
entries,
)

if not entries:
return results

Expand All @@ -189,7 +200,10 @@ def reproduce_celery(
time.sleep(1)
self.celery_queue.follow(entry)
# wait for task collection to complete
result = self.celery_queue.get_result(entry)
try:
result = self.celery_queue.get_result(entry)
except FileNotFoundError:
result = None
if result is None or result.exp_hash is None:
name = entry.name or entry.stash_rev[:7]
failed.append(name)
Expand Down
11 changes: 2 additions & 9 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,29 +602,22 @@ def _repro_dvc(
logger.debug("Running repro in '%s'", os.getcwd())
yield dvc
info.status = TaskStatus.SUCCESS
if infofile is not None:
info.dump_json(infofile)

except CheckpointKilledError:
info.status = TaskStatus.FAILED
if infofile is not None:
info.dump_json(infofile)
raise
except DvcException:
if log_errors:
logger.exception("")
info.status = TaskStatus.FAILED
if infofile is not None:
info.dump_json(infofile)
raise
except Exception:
if log_errors:
logger.exception("unexpected error")
info.status = TaskStatus.FAILED
if infofile is not None:
info.dump_json(infofile)
raise
finally:
if infofile is not None:
info.dump_json(infofile)
dvc.close()
os.chdir(old_cwd)

Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/experiments/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def scm(self):
return SCM(self.root_dir)

def cleanup(self, infofile: str):
super().cleanup(infofile)
self.scm.close()
del self.scm
super().cleanup(infofile)

def collect_cache(
self, repo: "Repo", exp_ref: "ExpRefInfo", run_cache: bool = True
Expand Down
64 changes: 32 additions & 32 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@
from dvc.ui import ui

from ..exceptions import UnresolvedQueueExpNamesError
from ..executor.base import (
EXEC_TMP_DIR,
ExecutorInfo,
ExecutorResult,
TaskStatus,
)
from ..executor.base import EXEC_TMP_DIR, ExecutorInfo, ExecutorResult
from .base import BaseStashQueue, QueueDoneResult, QueueEntry, QueueGetResult
from .tasks import run_exp
from .utils import fetch_running_exp_from_temp_dir
Expand Down Expand Up @@ -187,6 +182,7 @@ def _iter_queued(self) -> Generator[_MessageEntry, None, None]:
continue
args, kwargs, _embed = msg.decode()
entry_dict = kwargs.get("entry_dict", args[0])
logger.debug("Found queued task %s", entry_dict["stash_rev"])
yield _MessageEntry(msg, QueueEntry.from_dict(entry_dict))

def _iter_processed(self) -> Generator[_MessageEntry, None, None]:
Expand All @@ -203,6 +199,7 @@ def _iter_active_tasks(self) -> Generator[_TaskEntry, None, None]:
task_id = msg.headers["id"]
result: AsyncResult = AsyncResult(task_id)
if not result.ready():
logger.debug("Found active task %s", entry.stash_rev)
yield _TaskEntry(result, entry)

def _iter_done_tasks(self) -> Generator[_TaskEntry, None, None]:
Expand All @@ -211,6 +208,7 @@ def _iter_done_tasks(self) -> Generator[_TaskEntry, None, None]:
task_id = msg.headers["id"]
result: AsyncResult = AsyncResult(task_id)
if result.ready():
logger.debug("Found done task %s", entry.stash_rev)
yield _TaskEntry(result, entry)

def iter_active(self) -> Generator[QueueEntry, None, None]:
Expand Down Expand Up @@ -243,48 +241,50 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
def reproduce(self) -> Mapping[str, Mapping[str, str]]:
raise NotImplementedError

def get_result(
def _load_info(self, rev: str) -> ExecutorInfo:
infofile = self.get_infofile_path(rev)
return ExecutorInfo.load_json(infofile)

def _get_done_result(
self, entry: QueueEntry, timeout: Optional[float] = None
) -> Optional[ExecutorResult]:
) -> Optional["ExecutorResult"]:
from celery.exceptions import TimeoutError as _CeleryTimeout

def _load_info(rev: str) -> ExecutorInfo:
infofile = self.get_infofile_path(rev)
return ExecutorInfo.load_json(infofile)

def _load_collected(rev: str) -> Optional[ExecutorResult]:
executor_info = _load_info(rev)
if executor_info.status > TaskStatus.SUCCESS:
for msg, processed_entry in self._iter_processed():
if entry.stash_rev == processed_entry.stash_rev:
task_id = msg.headers["id"]
result: AsyncResult = AsyncResult(task_id)
if not result.ready():
logger.debug(
"Waiting for exp task '%s' to complete", result.id
)
try:
result.get(timeout=timeout)
except _CeleryTimeout as exc:
raise DvcException(
"Timed out waiting for exp to finish."
) from exc
executor_info = self._load_info(entry.stash_rev)
return executor_info.result
raise FileNotFoundError
raise FileNotFoundError

def get_result(
self, entry: QueueEntry, timeout: Optional[float] = None
) -> Optional["ExecutorResult"]:

try:
return _load_collected(entry.stash_rev)
return self._get_done_result(entry, timeout)
except FileNotFoundError:
# Infofile will not be created until execution begins
pass

for queue_entry in self.iter_queued():
if entry.stash_rev == queue_entry.stash_rev:
raise DvcException("Experiment has not been started.")
for result, active_entry in self._iter_active_tasks():
if entry.stash_rev == active_entry.stash_rev:
logger.debug(
"Waiting for exp task '%s' to complete", result.id
)
try:
result.get(timeout=timeout)
except _CeleryTimeout as exc:
raise DvcException(
"Timed out waiting for exp to finish."
) from exc
executor_info = _load_info(entry.stash_rev)
return executor_info.result

# NOTE: It's possible for an exp to complete while iterating through
# other queued and active tasks, in which case the exp will get moved
# out of the active task list, and needs to be loaded here.
return _load_collected(entry.stash_rev)
return self._get_done_result(entry, timeout)

def kill(self, revs: Collection[str]) -> None:
to_kill: Set[QueueEntry] = set()
Expand Down
3 changes: 1 addition & 2 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def run(
of `repro` for that experiment.
"""
if run_all:
entries = list(repo.experiments.celery_queue.iter_queued())
return repo.experiments.reproduce_celery(entries, jobs=jobs)
return repo.experiments.reproduce_celery(jobs=jobs)

hydra_sweep = None
if params:
Expand Down
1 change: 0 additions & 1 deletion dvc/repo/experiments/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,6 @@ def show(
repo,
found_revs,
running,
sha_only=sha_only,
param_deps=param_deps,
onerror=onerror,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def test_run_celery(tmp_dir, scm, dvc, exp_stage, mocker):
repro_spy = mocker.spy(dvc.experiments, "reproduce_celery")
results = dvc.experiments.run(run_all=True)
assert len(results) == 2
repro_spy.assert_called_once_with(entries=mocker.ANY, jobs=1)
repro_spy.assert_called_once_with(jobs=1)

expected = {"foo: 2", "foo: 3"}
metrics = set()
Expand Down
96 changes: 40 additions & 56 deletions tests/func/experiments/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,92 +11,76 @@ def to_dict(tasks):
return status_dict


@pytest.fixture
def success_tasks(tmp_dir, dvc, scm, test_queue, exp_stage):
queue_length = 3
name_list = []
for i in range(queue_length):
name = f"success{i}"
name_list.append(name)
dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], queue=True, name=name
)
dvc.experiments.run(run_all=True)
return name_list


@pytest.fixture
def failed_tasks(tmp_dir, dvc, scm, test_queue, failed_exp_stage):
queue_length = 3
name_list = []
for i in range(queue_length):
name = f"failed{i}"
name_list.append(name)
dvc.experiments.run(
failed_exp_stage.addressing,
params=[f"foo={i+queue_length}"],
queue=True,
name=name,
)
dvc.experiments.run(run_all=True)
return name_list


@pytest.mark.xfail(strict=False, reason="pytest-celery flaky")
@pytest.mark.parametrize("follow", [True, False])
def test_celery_logs(
tmp_dir,
scm,
dvc,
failed_exp_stage,
test_queue,
follow,
capsys,
):
celery_queue = dvc.experiments.celery_queue
dvc.experiments.run(failed_exp_stage.addressing, queue=True)
dvc.experiments.run(run_all=True)

queue = dvc.experiments.celery_queue
done_result = first(queue.iter_done())
done_result = first(celery_queue.iter_done())

name = done_result.entry.stash_rev
captured = capsys.readouterr()
queue.logs(name, follow=follow)
celery_queue.logs(name, follow=follow)
captured = capsys.readouterr()
assert "failed to reproduce 'failed-copy-file'" in captured.out


@pytest.mark.xfail(strict=False, reason="pytest-celery flaky")
def test_queue_remove_done(dvc, failed_tasks, success_tasks):
assert len(dvc.experiments.celery_queue.failed_stash) == 3
status = to_dict(dvc.experiments.celery_queue.status())
def test_queue_remove_done(
dvc,
exp_stage,
failed_exp_stage,
):
queue_length = 3
success_tasks = []
failed_tasks = []
celery_queue = dvc.experiments.celery_queue
for i in range(queue_length):
name = f"success{i}"
success_tasks.append(name)
dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], queue=True, name=name
)
name_fail = f"failed{i}"
failed_tasks.append(name_fail)
dvc.experiments.run(
failed_exp_stage.addressing,
params=[f"foo={i+queue_length}"],
queue=True,
name=name_fail,
)
dvc.experiments.run(run_all=True)
assert len(celery_queue.failed_stash) == 3
status = to_dict(celery_queue.status())
assert len(status) == 6
for name in failed_tasks:
assert status[name] == "Failed"
for name in success_tasks:
assert status[name] == "Success"

with pytest.raises(InvalidArgumentError):
dvc.experiments.celery_queue.remove(failed_tasks[:2] + ["non-exist"])
assert len(dvc.experiments.celery_queue.status()) == 6
celery_queue.remove(failed_tasks[:2] + ["non-exist"])
assert len(celery_queue.status()) == 6

to_remove = [failed_tasks[0], success_tasks[2]]
assert set(dvc.experiments.celery_queue.remove(to_remove)) == set(
to_remove
)
assert set(celery_queue.remove(to_remove)) == set(to_remove)

assert len(dvc.experiments.celery_queue.failed_stash) == 2
status = to_dict(dvc.experiments.celery_queue.status())
assert len(celery_queue.failed_stash) == 2
status = to_dict(celery_queue.status())
assert set(status) == set(failed_tasks[1:] + success_tasks[:2])

assert dvc.experiments.celery_queue.clear(failed=True) == failed_tasks[1:]
assert set(celery_queue.clear(failed=True)) == set(failed_tasks[1:])

assert len(dvc.experiments.celery_queue.failed_stash) == 0
assert set(to_dict(dvc.experiments.celery_queue.status())) == set(
success_tasks[:2]
)
assert len(celery_queue.failed_stash) == 0
assert set(to_dict(celery_queue.status())) == set(success_tasks[:2])

assert (
dvc.experiments.celery_queue.clear(success=True) == success_tasks[:2]
)
assert set(celery_queue.clear(success=True)) == set(success_tasks[:2])

assert dvc.experiments.celery_queue.status() == []
assert celery_queue.status() == []
4 changes: 3 additions & 1 deletion tests/func/experiments/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def test_show_queued(tmp_dir, scm, dvc, exp_stage):


@pytest.mark.vscode
@pytest.mark.xfail(strict=False, reason="pytest-celery flaky")
def test_show_failed_experiment(tmp_dir, scm, dvc, failed_exp_stage):
baseline_rev = scm.get_rev()
timestamp = datetime.fromtimestamp(
Expand Down Expand Up @@ -647,6 +646,9 @@ def _get_rev_isotimestamp(rev):
result1 = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
rev1 = first(result1)
ref_info1 = first(exp_refs_by_rev(scm, rev1))

# at least 1 second gap between these experiments to make sure
# the previous experiment to be regarded as branch_base
time.sleep(1)
result2 = dvc.experiments.run(exp_stage.addressing, params=["foo=3"])
rev2 = first(result2)
Expand Down
Loading