diff --git a/dvc/command/dag.py b/dvc/command/dag.py index 2cf38e9304..e408be644a 100644 --- a/dvc/command/dag.py +++ b/dvc/command/dag.py @@ -30,7 +30,7 @@ def _show_dot(G): return dot_file.getvalue() -def _build(G, target=None, full=False): +def _build(G, target=None, full=False, outs=False): import networkx as nx from dvc.repo.graph import get_pipeline, get_pipelines @@ -44,8 +44,25 @@ def _build(G, target=None, full=False): else: H = G - def _relabel(stage): - return stage.addressing + if outs: + G = nx.DiGraph() + for stage in H.nodes: + G.add_nodes_from(stage.outs) + + for from_stage, to_stage in nx.edge_dfs(H): + G.add_edges_from( + [ + (from_out, to_out) + for from_out in from_stage.outs + for to_out in to_stage.outs + ] + ) + H = G + + def _relabel(node): + from dvc.stage import Stage + + return node.addressing if isinstance(node, Stage) else str(node) return nx.relabel_nodes(H, _relabel, copy=False) @@ -64,7 +81,12 @@ def run(self): return 1 target = stages[0] - G = _build(self.repo.graph, target=target, full=self.args.full,) + G = _build( + self.repo.graph, + target=target, + full=self.args.full, + outs=self.args.outs, + ) if self.args.dot: logger.info(_show_dot(G)) @@ -108,6 +130,13 @@ def add_parser(subparsers, parent_parser): "showing DAG consisting only of ancestors." ), ) + dag_parser.add_argument( + "-o", + "--outs", + action="store_true", + default=False, + help="Print output files instead of stages.", + ) dag_parser.add_argument( "target", nargs="?", diff --git a/tests/unit/command/test_dag.py b/tests/unit/command/test_dag.py index d854fabe57..5c30825fc0 100644 --- a/tests/unit/command/test_dag.py +++ b/tests/unit/command/test_dag.py @@ -58,6 +58,20 @@ def test_build_target(graph): assert set(G.edges()) == {("3", "a.dvc"), ("3", "b.dvc")} +def test_build_target_with_outs(graph): + (stage,) = filter( + lambda s: hasattr(s, "name") and s.name == "3", graph.nodes() + ) + G = _build(graph, target=stage, outs=True) + assert set(G.nodes()) == {"a", "b", "h", "i"} + assert set(G.edges()) == { + ("h", "a"), + ("h", "b"), + ("i", "a"), + ("i", "b"), + } + + def test_build_full(graph): (stage,) = filter( lambda s: hasattr(s, "name") and s.name == "3", graph.nodes()