Skip to content

params, plots, metrics: unify collection #4603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions dvc/repo/collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
import os
from typing import Iterable

from dvc.path_info import PathInfo
from dvc.repo import Repo
from dvc.tree.repo import RepoTree

logger = logging.getLogger(__name__)


def _collect_outs(
repo: Repo, output_filter: callable = None, deps: bool = False
):
outs = {
out
for stage in repo.stages
for out in (stage.deps if deps else stage.outs)
}
return set(filter(output_filter, outs)) if output_filter else outs


def _collect_paths(
repo: Repo, targets: Iterable, recursive: bool = False, rev: str = None
):
path_infos = {PathInfo(os.path.abspath(target)) for target in targets}
tree = RepoTree(repo)

target_infos = set()
for path_info in path_infos:

if recursive and tree.isdir(path_info):
target_infos.update(set(tree.walk_files(path_info)))

if not tree.isfile(path_info):
if not recursive:
logger.warning(
"'%s' was not found at: '%s'.", path_info, rev,
)
continue
target_infos.add(path_info)
return target_infos


def _filter_duplicates(outs: Iterable, path_infos: Iterable):
res_outs = set()
res_infos = set(path_infos)

for out in outs:
if out.path_info in path_infos:
res_outs.add(out)
res_infos.remove(out.path_info)

return res_outs, res_infos


def collect(
repo: Repo,
deps: bool = False,
targets: Iterable = None,
output_filter: callable = None,
rev: str = None,
recursive: bool = False,
):
assert targets or output_filter

outs = _collect_outs(repo, output_filter=output_filter, deps=deps)

if not targets:
return outs, set()

target_infos = _collect_paths(repo, targets, recursive=recursive, rev=rev)

return _filter_duplicates(outs, target_infos)
2 changes: 1 addition & 1 deletion dvc/repo/experiments/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _collect_experiment(repo, rev, stash=False, sha_only=True):
commit = _resolve_commit(repo, rev)
res["timestamp"] = datetime.fromtimestamp(commit.committed_date)

configs = _collect_configs(repo)
configs = _collect_configs(repo, rev=rev)
params = _read_params(repo, configs, rev)
if params:
res["params"] = params
Expand Down
36 changes: 9 additions & 27 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,23 @@
import logging
import os

from dvc.exceptions import NoMetricsError
from dvc.path_info import PathInfo
from dvc.repo import locked
from dvc.repo.collect import collect
from dvc.tree.repo import RepoTree
from dvc.utils.serialize import YAMLFileCorruptedError, load_yaml

logger = logging.getLogger(__name__)


def _collect_metrics(repo, targets, recursive):
def _is_metric(out):
return bool(out.metric)


if targets:
target_infos = [
PathInfo(os.path.abspath(target)) for target in targets
]
tree = RepoTree(repo)

rec_files = []
if recursive:
for target_info in target_infos:
if tree.isdir(target_info):
rec_files.extend(list(tree.walk_files(target_info)))

result = [t for t in target_infos if tree.isfile(t)]
result.extend(rec_files)

return result

metrics = set()
for stage in repo.stages:
for out in stage.outs:
if not out.metric:
continue
metrics.add(out.path_info)
return list(metrics)
def _collect_metrics(repo, targets, recursive):
metrics, path_infos = collect(
repo, targets=targets, output_filter=_is_metric, recursive=recursive
)
return [m.path_info for m in metrics] + list(path_infos)


def _extract_metrics(metrics, path, rev):
Expand Down
18 changes: 9 additions & 9 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
from dvc.repo import locked
from dvc.repo.collect import collect
from dvc.utils.serialize import LOADERS, ParseError

logger = logging.getLogger(__name__)
Expand All @@ -13,15 +14,14 @@ class NoParamsError(DvcException):
pass


def _collect_configs(repo):
configs = set()
configs.add(PathInfo(repo.root_dir) / ParamsDependency.DEFAULT_PARAMS_FILE)
for stage in repo.stages:
for dep in stage.deps:
if not isinstance(dep, ParamsDependency):
continue
def _is_params(dep):
return isinstance(dep, ParamsDependency)


configs.add(dep.path_info)
def _collect_configs(repo, rev):
params, _ = collect(repo, deps=True, output_filter=_is_params, rev=rev)
configs = {p.path_info for p in params}
configs.add(PathInfo(repo.root_dir) / ParamsDependency.DEFAULT_PARAMS_FILE)
return list(configs)


Expand Down Expand Up @@ -49,7 +49,7 @@ def show(repo, revs=None):
res = {}

for branch in repo.brancher(revs=revs):
configs = _collect_configs(repo)
configs = _collect_configs(repo, branch)
params = _read_params(repo, configs, branch)

if params:
Expand Down
37 changes: 9 additions & 28 deletions dvc/repo/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import os

from funcy import cached_property, first, project

from dvc.exceptions import DvcException, NoPlotsError
from dvc.path_info import PathInfo
from dvc.repo.collect import collect
from dvc.schema import PLOT_PROPS
from dvc.tree.repo import RepoTree
from dvc.utils import relpath
Expand Down Expand Up @@ -149,34 +148,16 @@ def templates(self):
return PlotTemplates(self.repo.dvc_dir)


def _collect_plots(repo, targets=None, rev=None):
plots = {out for stage in repo.stages for out in stage.outs if out.plot}

def to_result(plots):
return {plot.path_info: _plot_props(plot) for plot in plots}

if not targets:
return to_result(plots)

target_infos = {PathInfo(os.path.abspath(target)) for target in targets}
def _is_plot(out):
return bool(out.plot)

target_plots = set()
for p in plots:
if p.path_info in target_infos:
target_plots.add(p)
target_infos.remove(p.path_info)

tree = RepoTree(repo)
result = to_result(target_plots)

for t in target_infos:
if tree.isfile(t):
result[t] = {}
else:
logger.warning(
"'%s' was not found at: '%s'. It will not be plotted.", t, rev,
)

def _collect_plots(repo, targets=None, rev=None):
plots, path_infos = collect(
repo, output_filter=_is_plot, targets=targets, rev=rev
)
result = {plot.path_info: _plot_props(plot) for plot in plots}
result.update({path_info: {} for path_info in path_infos})
return result


Expand Down
5 changes: 1 addition & 4 deletions tests/func/plots/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,7 @@ def test_plot_even_if_metric_missing(
caplog.clear()
with caplog.at_level(logging.WARNING, "dvc"):
plots = dvc.plots.show(revs=["v1", "v2"], targets=["metric.json"])
assert (
"'metric.json' was not found at: 'v1'. "
"It will not be plotted." in caplog.text
)
assert "'metric.json' was not found at: 'v1'." in caplog.text

plot_content = json.loads(plots["metric.json"])
assert plot_content["data"]["values"] == [
Expand Down