diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 44cdb06f0c..5553f3710e 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -20,7 +20,15 @@ STATUS_DELETED, STATUS_MISSING, ) -from dvc.utils import remove, move, copyfile, dict_md5, to_chunks, tmp_fname +from dvc.utils import ( + remove, + move, + copyfile, + dict_md5, + to_chunks, + tmp_fname, + walk_files, +) from dvc.utils import LARGE_DIR_SIZE from dvc.config import Config from dvc.exceptions import DvcException @@ -353,11 +361,7 @@ def already_cached(self, path_info): return not self.changed_cache(current_md5) def _discard_working_directory_changes(self, path, dir_info, force=False): - working_dir_files = set( - os.path.join(root, file) - for root, _, files in os.walk(str(path)) - for file in files - ) + working_dir_files = set(path for path in walk_files(path)) cached_files = set( os.path.join(path, file["relpath"]) for file in dir_info diff --git a/dvc/repo/unprotect.py b/dvc/repo/unprotect.py index c34dd1c0cf..5220e471f8 100644 --- a/dvc/repo/unprotect.py +++ b/dvc/repo/unprotect.py @@ -6,7 +6,7 @@ import dvc.logger as logger from dvc.system import System -from dvc.utils import copyfile, remove +from dvc.utils import copyfile, remove, walk_files from dvc.exceptions import DvcException @@ -33,10 +33,8 @@ def _unprotect_file(path): def _unprotect_dir(path): - for root, dirs, files in os.walk(str(path)): - for f in files: - path = os.path.join(root, f) - _unprotect_file(path) + for path in walk_files(path): + _unprotect_file(path) def unprotect(path): diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index d45b13cbd5..58ae4cd5b1 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -242,6 +242,6 @@ def load_stage_file(path): def walk_files(directory): - for root, _, files in os.walk(directory): + for root, _, files in os.walk(str(directory)): for f in files: yield os.path.join(root, f) diff --git a/tests/test_checkout.py b/tests/test_checkout.py index d1a2b5dc78..a10e5cc3be 100644 --- a/tests/test_checkout.py +++ b/tests/test_checkout.py @@ -8,6 +8,7 @@ from dvc.main import main from dvc.repo import Repo as DvcRepo from dvc.system import System +from dvc.utils import walk_files from tests.basic_env import TestDvc from tests.test_repro import TestRepro from dvc.stage import Stage @@ -121,10 +122,7 @@ def outs_info(self, stage): FileInfo = collections.namedtuple("FileInfo", "path inode") paths = [ - os.path.join(root, file) - for output in stage.outs - for root, _, files in os.walk(output.path) - for file in files + path for output in stage.outs for path in walk_files(output.path) ] return [