Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,11 @@ def _fetch_running_exp(
info = ExecutorInfo.from_dict(load_json(infofile))
except OSError:
return result
if info.result is None:
if info.status < TaskStatus.FAILED:
if rev == "workspace":
# If we are appending to a checkpoint branch in a workspace
# run, show the latest checkpoint as running.
if info.status > TaskStatus.RUNNING:
if info.status == TaskStatus.SUCCESS:
return result
last_rev = self.scm.get_ref(EXEC_BRANCH)
if last_rev:
Expand All @@ -481,7 +481,11 @@ def _fetch_running_exp(
result[rev] = info.asdict()
else:
result[rev] = info.asdict()
if info.git_url and fetch_refs:
if (
info.git_url
and fetch_refs
and info.status > TaskStatus.PREPARING
):

def on_diverged(_ref: str, _checkpoint: bool):
return False
Expand Down
28 changes: 17 additions & 11 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
Union,
)

from scmrepo.exceptions import SCMError

from dvc.env import DVC_EXP_AUTO_PUSH, DVC_EXP_GIT_REMOTE
from dvc.exceptions import DvcException
from dvc.stage.serialize import to_lockfile
Expand Down Expand Up @@ -361,21 +363,25 @@ def on_diverged_ref(orig_ref: str, new_rev: str):
return False

# fetch experiments
dest_scm.fetch_refspecs(
self.git_url,
[f"{ref}:{ref}" for ref in refs],
on_diverged=on_diverged_ref,
force=force,
**kwargs,
)
# update last run checkpoint (if it exists)
if has_checkpoint:
try:
dest_scm.fetch_refspecs(
self.git_url,
[f"{EXEC_CHECKPOINT}:{EXEC_CHECKPOINT}"],
force=True,
[f"{ref}:{ref}" for ref in refs],
on_diverged=on_diverged_ref,
force=force,
**kwargs,
)
# update last run checkpoint (if it exists)
if has_checkpoint:
dest_scm.fetch_refspecs(
self.git_url,
[f"{EXEC_CHECKPOINT}:{EXEC_CHECKPOINT}"],
force=True,
**kwargs,
)
except SCMError:
pass

return refs

@classmethod
Expand Down
150 changes: 116 additions & 34 deletions dvc/repo/experiments/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from dvc.repo.experiments.queue.base import QueueDoneResult
from dvc.repo.metrics.show import _gather_metrics
from dvc.repo.params.show import _gather_params
from dvc.scm import iter_revs
Expand Down Expand Up @@ -76,7 +75,7 @@ def _collect_experiment_commit(

res["status"] = status.name
if status == ExpStatus.Running:
res["executor"] = running[exp_rev].get("location")
res["executor"] = running.get(exp_rev, {}).get("location", None)
else:
res["executor"] = None

Expand Down Expand Up @@ -180,6 +179,81 @@ def get_names(repo: "Repo", result: Dict[str, Dict[str, Any]]):


# flake8: noqa: C901
def _collect_active_experiment(
repo: "Repo",
found_revs: Dict[str, List[str]],
running: Dict[str, Any],
**kwargs,
) -> Dict[str, Dict[str, Any]]:
result: Dict[str, Dict] = defaultdict(OrderedDict)
for entry in chain(
repo.experiments.tempdir_queue.iter_active(),
repo.experiments.celery_queue.iter_active(),
):
stash_rev = entry.stash_rev
if entry.baseline_rev in found_revs and (
stash_rev not in running or not running[stash_rev].get("last")
):
result[entry.baseline_rev][stash_rev] = _collect_experiment_commit(
repo,
stash_rev,
status=ExpStatus.Running,
running=running,
**kwargs,
)
return result


def _collect_queued_experiment(
repo: "Repo",
found_revs: Dict[str, List[str]],
running: Dict[str, Any],
**kwargs,
) -> Dict[str, Dict[str, Any]]:
result: Dict[str, Dict] = defaultdict(OrderedDict)
for entry in repo.experiments.celery_queue.iter_queued():
stash_rev = entry.stash_rev
if entry.baseline_rev in found_revs:
result[entry.baseline_rev][stash_rev] = _collect_experiment_commit(
repo,
stash_rev,
status=ExpStatus.Queued,
running=running,
**kwargs,
)
return result


def _collect_failed_experiment(
repo: "Repo",
found_revs: Dict[str, List[str]],
running: Dict[str, Any],
**kwargs,
) -> Dict[str, Dict[str, Any]]:
result: Dict[str, Dict] = defaultdict(OrderedDict)
for queue_done_result in repo.experiments.celery_queue.iter_failed():
entry = queue_done_result.entry
stash_rev = entry.stash_rev
if entry.baseline_rev in found_revs:
experiment = _collect_experiment_commit(
repo,
stash_rev,
status=ExpStatus.Failed,
running=running,
**kwargs,
)
result[entry.baseline_rev][stash_rev] = experiment
return result


def update_new(
to_dict: Dict[str, Dict[str, Any]], from_dict: Dict[str, Dict[str, Any]]
):
for baseline, experiments in from_dict.items():
for rev, experiment in experiments.items():
to_dict[baseline][rev] = to_dict[baseline].get(rev, experiment)


def show(
repo: "Repo",
all_branches=False,
Expand Down Expand Up @@ -215,6 +289,39 @@ def show(

running = repo.experiments.get_running_exps(fetch_refs=fetch_running)

queued_experiment = (
_collect_queued_experiment(
repo,
found_revs,
running,
param_deps=param_deps,
onerror=onerror,
)
if not hide_queued
else {}
)

active_experiment = _collect_active_experiment(
repo,
found_revs,
running,
param_deps=param_deps,
onerror=onerror,
)

failed_experiments = (
_collect_failed_experiment(
repo,
found_revs,
running,
sha_only=sha_only,
param_deps=param_deps,
onerror=onerror,
)
if not hide_failed
else {}
)

for rev in found_revs:
status = ExpStatus.Running if rev in running else ExpStatus.Success
res[rev]["baseline"] = _collect_experiment_commit(
Expand Down Expand Up @@ -249,38 +356,13 @@ def show(
onerror=onerror,
)

# collect standalone & celery experiments
for entry in chain(
repo.experiments.tempdir_queue.iter_active(),
repo.experiments.celery_queue.iter_active(),
repo.experiments.celery_queue.iter_queued(),
repo.experiments.celery_queue.iter_failed(),
):
if isinstance(entry, QueueDoneResult):
entry = entry.entry
if hide_failed:
continue
status = ExpStatus.Failed
elif entry.stash_rev in running:
status = ExpStatus.Running
else:
if hide_queued:
continue
status = ExpStatus.Queued
stash_rev = entry.stash_rev
if entry.baseline_rev in found_revs:
if stash_rev not in running or not running[stash_rev].get(
"last"
):
experiment = _collect_experiment_commit(
repo,
stash_rev,
status=status,
param_deps=param_deps,
running=running,
onerror=onerror,
)
res[entry.baseline_rev][stash_rev] = experiment
if not hide_failed:
update_new(res, failed_experiments)

update_new(res, active_experiment)

if not hide_queued:
update_new(res, queued_experiment)

if not sha_only:
get_names(repo, res)
Expand Down
4 changes: 2 additions & 2 deletions tests/func/experiments/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def test_show_queued(tmp_dir, scm, dvc, exp_stage):


@pytest.mark.vscode
@pytest.mark.xfail(strict=False, reason="pytest-celery flaky")
def test_show_failed_experiment(tmp_dir, scm, dvc, failed_exp_stage):
baseline_rev = scm.get_rev()
timestamp = datetime.fromtimestamp(
Expand Down Expand Up @@ -423,8 +424,6 @@ def test_show_running_workspace(tmp_dir, scm, dvc, exp_stage, capsys, status):
makedirs(os.path.dirname(pidfile), True)
(tmp_dir / pidfile).dump_json(info.asdict())

print(dvc.experiments.show().get("workspace"))

assert dvc.experiments.show().get("workspace") == {
"baseline": {
"data": {
Expand Down Expand Up @@ -558,6 +557,7 @@ def test_show_running_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, mocker):
git_url="foo.git",
baseline_rev=baseline_rev,
location=TempDirExecutor.DEFAULT_LOCATION,
status=TaskStatus.RUNNING,
)
makedirs(os.path.dirname(pidfile), True)
(tmp_dir / pidfile).dump_json(info.asdict())
Expand Down