|
5 | 5 | from itertools import chain
|
6 | 6 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
7 | 7 |
|
8 |
| -from dvc.repo.experiments.queue.base import QueueDoneResult |
9 | 8 | from dvc.repo.metrics.show import _gather_metrics
|
10 | 9 | from dvc.repo.params.show import _gather_params
|
11 | 10 | from dvc.scm import iter_revs
|
@@ -76,7 +75,7 @@ def _collect_experiment_commit(
|
76 | 75 |
|
77 | 76 | res["status"] = status.name
|
78 | 77 | if status == ExpStatus.Running:
|
79 |
| - res["executor"] = running[exp_rev].get("location") |
| 78 | + res["executor"] = running.get(exp_rev, {}).get("location", None) |
80 | 79 | else:
|
81 | 80 | res["executor"] = None
|
82 | 81 |
|
@@ -180,6 +179,81 @@ def get_names(repo: "Repo", result: Dict[str, Dict[str, Any]]):
|
180 | 179 |
|
181 | 180 |
|
182 | 181 | # flake8: noqa: C901
|
| 182 | +def _collect_active_experiment( |
| 183 | + repo: "Repo", |
| 184 | + found_revs: Dict[str, List[str]], |
| 185 | + running: Dict[str, Any], |
| 186 | + **kwargs, |
| 187 | +) -> Dict[str, Dict[str, Any]]: |
| 188 | + result: Dict[str, Dict] = defaultdict(OrderedDict) |
| 189 | + for entry in chain( |
| 190 | + repo.experiments.tempdir_queue.iter_active(), |
| 191 | + repo.experiments.celery_queue.iter_active(), |
| 192 | + ): |
| 193 | + stash_rev = entry.stash_rev |
| 194 | + if entry.baseline_rev in found_revs and ( |
| 195 | + stash_rev not in running or not running[stash_rev].get("last") |
| 196 | + ): |
| 197 | + result[entry.baseline_rev][stash_rev] = _collect_experiment_commit( |
| 198 | + repo, |
| 199 | + stash_rev, |
| 200 | + status=ExpStatus.Running, |
| 201 | + running=running, |
| 202 | + **kwargs, |
| 203 | + ) |
| 204 | + return result |
| 205 | + |
| 206 | + |
| 207 | +def _collect_queued_experiment( |
| 208 | + repo: "Repo", |
| 209 | + found_revs: Dict[str, List[str]], |
| 210 | + running: Dict[str, Any], |
| 211 | + **kwargs, |
| 212 | +) -> Dict[str, Dict[str, Any]]: |
| 213 | + result: Dict[str, Dict] = defaultdict(OrderedDict) |
| 214 | + for entry in repo.experiments.celery_queue.iter_queued(): |
| 215 | + stash_rev = entry.stash_rev |
| 216 | + if entry.baseline_rev in found_revs: |
| 217 | + result[entry.baseline_rev][stash_rev] = _collect_experiment_commit( |
| 218 | + repo, |
| 219 | + stash_rev, |
| 220 | + status=ExpStatus.Queued, |
| 221 | + running=running, |
| 222 | + **kwargs, |
| 223 | + ) |
| 224 | + return result |
| 225 | + |
| 226 | + |
| 227 | +def _collect_failed_experiment( |
| 228 | + repo: "Repo", |
| 229 | + found_revs: Dict[str, List[str]], |
| 230 | + running: Dict[str, Any], |
| 231 | + **kwargs, |
| 232 | +) -> Dict[str, Dict[str, Any]]: |
| 233 | + result: Dict[str, Dict] = defaultdict(OrderedDict) |
| 234 | + for queue_done_result in repo.experiments.celery_queue.iter_failed(): |
| 235 | + entry = queue_done_result.entry |
| 236 | + stash_rev = entry.stash_rev |
| 237 | + if entry.baseline_rev in found_revs: |
| 238 | + experiment = _collect_experiment_commit( |
| 239 | + repo, |
| 240 | + stash_rev, |
| 241 | + status=ExpStatus.Failed, |
| 242 | + running=running, |
| 243 | + **kwargs, |
| 244 | + ) |
| 245 | + result[entry.baseline_rev][stash_rev] = experiment |
| 246 | + return result |
| 247 | + |
| 248 | + |
| 249 | +def update_new( |
| 250 | + to_dict: Dict[str, Dict[str, Any]], from_dict: Dict[str, Dict[str, Any]] |
| 251 | +): |
| 252 | + for baseline, experiments in from_dict.items(): |
| 253 | + for rev, experiment in experiments.items(): |
| 254 | + to_dict[baseline][rev] = to_dict[baseline].get(rev, experiment) |
| 255 | + |
| 256 | + |
183 | 257 | def show(
|
184 | 258 | repo: "Repo",
|
185 | 259 | all_branches=False,
|
@@ -215,6 +289,39 @@ def show(
|
215 | 289 |
|
216 | 290 | running = repo.experiments.get_running_exps(fetch_refs=fetch_running)
|
217 | 291 |
|
| 292 | + queued_experiment = ( |
| 293 | + _collect_queued_experiment( |
| 294 | + repo, |
| 295 | + found_revs, |
| 296 | + running, |
| 297 | + param_deps=param_deps, |
| 298 | + onerror=onerror, |
| 299 | + ) |
| 300 | + if not hide_queued |
| 301 | + else {} |
| 302 | + ) |
| 303 | + |
| 304 | + active_experiment = _collect_active_experiment( |
| 305 | + repo, |
| 306 | + found_revs, |
| 307 | + running, |
| 308 | + param_deps=param_deps, |
| 309 | + onerror=onerror, |
| 310 | + ) |
| 311 | + |
| 312 | + failed_experiments = ( |
| 313 | + _collect_failed_experiment( |
| 314 | + repo, |
| 315 | + found_revs, |
| 316 | + running, |
| 317 | + sha_only=sha_only, |
| 318 | + param_deps=param_deps, |
| 319 | + onerror=onerror, |
| 320 | + ) |
| 321 | + if not hide_failed |
| 322 | + else {} |
| 323 | + ) |
| 324 | + |
218 | 325 | for rev in found_revs:
|
219 | 326 | status = ExpStatus.Running if rev in running else ExpStatus.Success
|
220 | 327 | res[rev]["baseline"] = _collect_experiment_commit(
|
@@ -249,38 +356,13 @@ def show(
|
249 | 356 | onerror=onerror,
|
250 | 357 | )
|
251 | 358 |
|
252 |
| - # collect standalone & celery experiments |
253 |
| - for entry in chain( |
254 |
| - repo.experiments.tempdir_queue.iter_active(), |
255 |
| - repo.experiments.celery_queue.iter_active(), |
256 |
| - repo.experiments.celery_queue.iter_queued(), |
257 |
| - repo.experiments.celery_queue.iter_failed(), |
258 |
| - ): |
259 |
| - if isinstance(entry, QueueDoneResult): |
260 |
| - entry = entry.entry |
261 |
| - if hide_failed: |
262 |
| - continue |
263 |
| - status = ExpStatus.Failed |
264 |
| - elif entry.stash_rev in running: |
265 |
| - status = ExpStatus.Running |
266 |
| - else: |
267 |
| - if hide_queued: |
268 |
| - continue |
269 |
| - status = ExpStatus.Queued |
270 |
| - stash_rev = entry.stash_rev |
271 |
| - if entry.baseline_rev in found_revs: |
272 |
| - if stash_rev not in running or not running[stash_rev].get( |
273 |
| - "last" |
274 |
| - ): |
275 |
| - experiment = _collect_experiment_commit( |
276 |
| - repo, |
277 |
| - stash_rev, |
278 |
| - status=status, |
279 |
| - param_deps=param_deps, |
280 |
| - running=running, |
281 |
| - onerror=onerror, |
282 |
| - ) |
283 |
| - res[entry.baseline_rev][stash_rev] = experiment |
| 359 | + if not hide_failed: |
| 360 | + update_new(res, failed_experiments) |
| 361 | + |
| 362 | + update_new(res, active_experiment) |
| 363 | + |
| 364 | + if not hide_queued: |
| 365 | + update_new(res, queued_experiment) |
284 | 366 |
|
285 | 367 | if not sha_only:
|
286 | 368 | get_names(repo, res)
|
|
0 commit comments