Skip to content

repro: support glob/foreach-group to run at once through CLI #4976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dvc/command/repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _repro_kwargs(self):
"recursive": self.args.recursive,
"force_downstream": self.args.force_downstream,
"pull": self.args.pull,
"glob": self.args.glob,
}


Expand Down Expand Up @@ -175,6 +176,12 @@ def add_arguments(repro_parser):
"from the run-cache."
),
)
repro_parser.add_argument(
"--glob",
action="store_true",
default=False,
help="Allows targets containing shell-style wildcards.",
)


def add_parser(subparsers, parent_parser):
Expand Down
132 changes: 76 additions & 56 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import os
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from functools import wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List

from funcy import cached_property, cat
from git import InvalidGitRepositoryError

from dvc.config import Config
from dvc.dvcfile import PIPELINE_FILE, Dvcfile, is_valid_filename
from dvc.dvcfile import PIPELINE_FILE, is_valid_filename
from dvc.exceptions import FileMissingError
from dvc.exceptions import IsADirectoryError as DvcIsADirectoryError
from dvc.exceptions import (
Expand All @@ -17,6 +17,7 @@
OutputNotFoundError,
)
from dvc.path_info import PathInfo
from dvc.repo.stage import StageLoad
from dvc.scm import Base
from dvc.scm.base import SCMError
from dvc.tree.repo import RepoTree
Expand All @@ -28,6 +29,9 @@
from .trie import build_outs_trie

if TYPE_CHECKING:
from networkx import DiGraph

from dvc.stage import Stage
from dvc.tree.base import BaseTree


Expand Down Expand Up @@ -165,6 +169,7 @@ def __init__(

self.cache = Cache(self)
self.cloud = DataCloud(self)
self.stage = StageLoad(self)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for better naming would be appreciated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StageResolver? Also, I think that naming the object stage while we have a class of this name will get confusing and some point. Maybe stage_load or stage_resolver will be ok, depending on what name we choose in the end.

Copy link
Collaborator Author

@skshetry skshetry Nov 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pared, will make it StageLoader, and change the other one to MultiStageLoader.

Maybe stage_load or stage_resolver will be ok, depending on what name we choose in the end.

stage_load.load_one() looks a bit odd to me. :) But, I understand your concern.


if scm or not self.dvc_dir:
self.lock = LockNoop()
Expand Down Expand Up @@ -270,25 +275,6 @@ def _ignore(self):

self.scm.ignore_list(flist)

def get_stage(self, path=None, name=None):
if not path:
path = PIPELINE_FILE
logger.debug("Assuming '%s' to be a stage inside '%s'", name, path)

dvcfile = Dvcfile(self, path)
return dvcfile.stages[name]

def get_stages(self, path=None, name=None):
if not path:
path = PIPELINE_FILE
logger.debug("Assuming '%s' to be a stage inside '%s'", name, path)

if name:
return [self.get_stage(path, name)]

dvcfile = Dvcfile(self, path)
return list(dvcfile.stages.values())

def check_modified_graph(self, new_stages):
"""Generate graph including the new stage to check for errors"""
# Building graph might be costly for the ones with many DVC-files,
Expand All @@ -306,79 +292,105 @@ def check_modified_graph(self, new_stages):
if not getattr(self, "_skip_graph_checks", False):
build_graph(self.stages + new_stages)

def _collect_inside(self, path, graph):
@staticmethod
def _collect_inside(path: str, graph: "DiGraph"):
import networkx as nx

stages = nx.dfs_postorder_nodes(graph)
return [stage for stage in stages if path_isin(stage.path, path)]

def collect(
self, target=None, with_deps=False, recursive=False, graph=None
self,
target: str = None,
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
glob: bool = False,
):
if not target:
return list(graph) if graph else self.stages

if recursive and os.path.isdir(target):
if recursive and self.tree.isdir(target):
return self._collect_inside(
os.path.abspath(target), graph or self.graph
)

path, name = parse_target(target)
stages = self.get_stages(path, name)
stages = self.stage.from_target(
target, accept_group=accept_group, glob=glob
)
if not with_deps:
return stages

return self._collect_stages_with_deps(stages, graph=graph)

def _collect_stages_with_deps(
self, stages: List["Stage"], graph: "DiGraph" = None
):
res = set()
for stage in stages:
res.update(self._collect_pipeline(stage, graph=graph))
return res

def _collect_pipeline(self, stage, graph=None):
def _collect_pipeline(self, stage: "Stage", graph: "DiGraph" = None):
import networkx as nx

pipeline = get_pipeline(get_pipelines(graph or self.graph), stage)
return nx.dfs_postorder_nodes(pipeline, stage)

def _collect_from_default_dvcfile(self, target):
dvcfile = Dvcfile(self, PIPELINE_FILE)
if dvcfile.exists():
return dvcfile.stages.get(target)

def collect_granular(
self, target=None, with_deps=False, recursive=False, graph=None
def _collect_specific_target(
self, target: str, with_deps: bool, recursive: bool, accept_group: bool
):
"""
Priority is in the order of following in case of ambiguity:
- .dvc file or .yaml file
- dir if recursive and directory exists
- stage_name
- output file
"""
if not target:
return [(stage, None) for stage in self.stages]

# Optimization: do not collect the graph for a specific target
file, name = parse_target(target)
stages = []

# Optimization: do not collect the graph for a specific target
if not file:
# parsing is ambiguous when it does not have a colon
# or if it's not a dvcfile, as it can be a stage name
# in `dvc.yaml` or, an output in a stage.
logger.debug(
"Checking if stage '%s' is in '%s'", target, PIPELINE_FILE
)
if not (recursive and os.path.isdir(target)):
Copy link
Collaborator Author

@skshetry skshetry Nov 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.path.isdir might have been a bug. Though, we don't collect one single file when using brancher.

stage = self._collect_from_default_dvcfile(target)
if stage:
stages = (
self._collect_pipeline(stage) if with_deps else [stage]
if not (
recursive and self.tree.isdir(target)
) and self.tree.exists(PIPELINE_FILE):
with suppress(StageNotFound):
stages = self.stage.load_all(
PIPELINE_FILE, target, accept_group=accept_group
)
if with_deps:
stages = self._collect_stages_with_deps(stages)

elif not with_deps and is_valid_filename(file):
stages = self.get_stages(file, name)
stages = self.stage.load_all(
file, name, accept_group=accept_group,
)

return stages, file, name

def collect_granular(
self,
target: str = None,
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
):
"""
Priority is in the order of following in case of ambiguity:
- .dvc file or .yaml file
- dir if recursive and directory exists
- stage_name
- output file
"""
if not target:
return [(stage, None) for stage in self.stages]

stages, file, _ = self._collect_specific_target(
target, with_deps, recursive, accept_group
)
if not stages:
if not (recursive and os.path.isdir(target)):
if not (recursive and self.tree.isdir(target)):
try:
(out,) = self.find_outs_by_path(target, strict=False)
filter_info = PathInfo(os.path.abspath(target))
Expand All @@ -387,7 +399,13 @@ def collect_granular(
pass

try:
stages = self.collect(target, with_deps, recursive, graph)
stages = self.collect(
target,
with_deps,
recursive,
graph,
accept_group=accept_group,
)
except StageFileDoesNotExistError as exc:
# collect() might try to use `target` as a stage name
# and throw error that dvc.yaml does not exist, whereas it
Expand Down Expand Up @@ -498,7 +516,9 @@ def _collect_stages(self):

for root, dirs, files in self.tree.walk(self.root_dir):
for file_name in filter(is_valid_filename, files):
new_stages = self.get_stages(os.path.join(root, file_name))
new_stages = self.stage.load_file(
os.path.join(root, file_name)
)
stages.extend(new_stages)
outs.update(
out.fspath
Expand Down
12 changes: 7 additions & 5 deletions dvc/repo/freeze.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import typing

from . import locked

if typing.TYPE_CHECKING:
from . import Repo

@locked
def _set(repo, target, frozen):
from dvc.utils import parse_target

path, name = parse_target(target)
stage = repo.get_stage(path, name)
@locked
def _set(repo: "Repo", target, frozen):
stage = repo.stage.get_target(target)
stage.frozen = frozen
stage.dvcfile.dump(stage, update_lock=False)

Expand Down
10 changes: 6 additions & 4 deletions dvc/repo/remove.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
import typing

from ..utils import parse_target
from . import locked

if typing.TYPE_CHECKING:
from dvc.repo import Repo

logger = logging.getLogger(__name__)


@locked
def remove(self, target, outs=False):
path, name = parse_target(target)
stages = self.get_stages(path, name)
def remove(self: "Repo", target: str, outs: bool = False):
stages = self.stage.from_target(target)

for stage in stages:
stage.remove(remove_outs=outs, force=outs)
Expand Down
21 changes: 15 additions & 6 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import typing
from functools import partial

from dvc.exceptions import InvalidArgumentError, ReproductionError
Expand All @@ -8,6 +9,9 @@
from . import locked
from .graph import get_pipeline, get_pipelines

if typing.TYPE_CHECKING:
from . import Repo

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -75,15 +79,15 @@ def _get_active_graph(G):
@locked
@scm_context
def reproduce(
self,
self: "Repo",
target=None,
recursive=False,
pipeline=False,
all_pipelines=False,
**kwargs,
):
from dvc.utils import parse_target

glob = kwargs.pop("glob", False)
accept_group = not glob
assert target is None or isinstance(target, str)
if not target and not all_pipelines:
raise InvalidArgumentError(
Expand All @@ -97,12 +101,11 @@ def reproduce(
active_graph = _get_active_graph(self.graph)
active_pipelines = get_pipelines(active_graph)

path, name = parse_target(target)
if pipeline or all_pipelines:
if all_pipelines:
pipelines = active_pipelines
else:
stage = self.get_stage(path, name)
stage = self.stage.get_target(target)
pipelines = [get_pipeline(active_pipelines, stage)]

targets = []
Expand All @@ -111,7 +114,13 @@ def reproduce(
if pipeline.in_degree(stage) == 0:
targets.append(stage)
else:
targets = self.collect(target, recursive=recursive, graph=active_graph)
targets = self.collect(
target,
recursive=recursive,
graph=active_graph,
accept_group=accept_group,
glob=glob,
)

return _reproduce_stages(active_graph, targets, **kwargs)

Expand Down
Loading