From 3d07ea33a6cf2cd24aa1fd0d1f0c536d17bbc6b1 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 2 Jul 2020 14:15:38 +0900 Subject: [PATCH 1/3] s3,gs: don't append "/" to prefixed path when listing objects --- dvc/remote/gs.py | 2 +- dvc/remote/s3.py | 2 +- tests/unit/remote/test_remote_tree.py | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 0775f0af51..d39ceddcd4 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -144,7 +144,7 @@ def _list_paths(self, path_info, max_items=None): yield blob.name def walk_files(self, path_info, **kwargs): - for fname in self._list_paths(path_info / "", **kwargs): + for fname in self._list_paths(path_info, **kwargs): # skip nested empty directories if fname.endswith("/"): continue diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 9a1911da95..71c5ff0852 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -187,7 +187,7 @@ def _list_paths(self, path_info, max_items=None): ) def walk_files(self, path_info, **kwargs): - for fname in self._list_paths(path_info / "", **kwargs): + for fname in self._list_paths(path_info, **kwargs): if fname.endswith("/"): continue diff --git a/tests/unit/remote/test_remote_tree.py b/tests/unit/remote/test_remote_tree.py index 0c01ab89b6..a5b3b1dc6f 100644 --- a/tests/unit/remote/test_remote_tree.py +++ b/tests/unit/remote/test_remote_tree.py @@ -81,7 +81,9 @@ def test_walk_files(remote): remote.path_info / "data/subdir/empty_file", ] - assert list(remote.tree.walk_files(remote.path_info / "data")) == files + assert ( + list(remote.tree.walk_files(remote.path_info / "data" / "")) == files + ) @pytest.mark.parametrize("remote", [pytest.lazy_fixture("s3")], indirect=True) @@ -139,7 +141,7 @@ def test_isfile(remote): def test_download_dir(remote, tmpdir): path = str(tmpdir / "data") to_info = PathInfo(path) - remote.tree.download(remote.path_info / "data", to_info) + remote.tree.download(remote.path_info / "data" / "", to_info) assert os.path.isdir(path) data_dir = tmpdir / "data" assert len(list(walk_files(path))) == 7 From 6b329fb0a9842102aee77e2ff0d10d25a98bd0cd Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Fri, 3 Jul 2020 01:16:10 +0900 Subject: [PATCH 2/3] remote: use optional prefix kwarg for walk_files * if prefix is True, path will be walked as a prefix * if prefix is False, path will be treated as a directory (trailing slash will be appended for S3-like remotes) --- dvc/remote/azure.py | 2 ++ dvc/remote/base.py | 13 ++++++++++--- dvc/remote/gdrive.py | 2 +- dvc/remote/gs.py | 2 ++ dvc/remote/oss.py | 2 ++ dvc/remote/s3.py | 2 ++ tests/unit/remote/test_remote_tree.py | 6 ++---- 7 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 817fcc6ed0..77b831fc19 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -121,6 +121,8 @@ def _list_paths(self, bucket, prefix): next_marker = blobs.next_marker def walk_files(self, path_info, **kwargs): + if not kwargs.pop("prefix", False): + path_info = path_info / "" for fname in self._list_paths( path_info.bucket, path_info.path, **kwargs ): diff --git a/dvc/remote/base.py b/dvc/remote/base.py index b7d5cafa8e..91b116e700 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -213,7 +213,12 @@ def iscopy(self, path_info): return False # We can't be sure by default def walk_files(self, path_info, **kwargs): - """Return a generator with `PathInfo`s to all the files""" + """Return a generator with `PathInfo`s to all the files. + + Optional kwargs: + prefix (bool): If true `path_info` will be treated as a prefix + rather than directory path. + """ raise NotImplementedError def is_empty(self, path_info): @@ -522,14 +527,16 @@ def list_paths(self, prefix=None, progress_callback=None): path_info = self.path_info / prefix[:2] / prefix[2:] else: path_info = self.path_info / prefix[:2] + prefix = True else: path_info = self.path_info + prefix = False if progress_callback: - for file_info in self.walk_files(path_info): + for file_info in self.walk_files(path_info, prefix=prefix): progress_callback() yield file_info.path else: - yield from self.walk_files(path_info) + yield from self.walk_files(path_info, prefix=prefix) def list_hashes(self, prefix=None, progress_callback=None): """Iterate over hashes in this tree. diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 874622f404..c6df9a886c 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -553,7 +553,7 @@ def _list_paths(self, prefix=None): ) def walk_files(self, path_info, **kwargs): - if path_info == self.path_info: + if path_info == self.path_info or not kwargs.pop("prefix", False): prefix = None else: prefix = path_info.path diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index d39ceddcd4..fc703231d2 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -144,6 +144,8 @@ def _list_paths(self, path_info, max_items=None): yield blob.name def walk_files(self, path_info, **kwargs): + if not kwargs.pop("prefix", False): + path_info = path_info / "" for fname in self._list_paths(path_info, **kwargs): # skip nested empty directories if fname.endswith("/"): diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 31feaed6ad..5471169afb 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -100,6 +100,8 @@ def _list_paths(self, path_info): yield blob.key def walk_files(self, path_info, **kwargs): + if not kwargs.pop("prefix", False): + path_info = path_info / "" for fname in self._list_paths(path_info): if fname.endswith("/"): continue diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 71c5ff0852..2c01fd800e 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -187,6 +187,8 @@ def _list_paths(self, path_info, max_items=None): ) def walk_files(self, path_info, **kwargs): + if not kwargs.pop("prefix", False): + path_info = path_info / "" for fname in self._list_paths(path_info, **kwargs): if fname.endswith("/"): continue diff --git a/tests/unit/remote/test_remote_tree.py b/tests/unit/remote/test_remote_tree.py index a5b3b1dc6f..0c01ab89b6 100644 --- a/tests/unit/remote/test_remote_tree.py +++ b/tests/unit/remote/test_remote_tree.py @@ -81,9 +81,7 @@ def test_walk_files(remote): remote.path_info / "data/subdir/empty_file", ] - assert ( - list(remote.tree.walk_files(remote.path_info / "data" / "")) == files - ) + assert list(remote.tree.walk_files(remote.path_info / "data")) == files @pytest.mark.parametrize("remote", [pytest.lazy_fixture("s3")], indirect=True) @@ -141,7 +139,7 @@ def test_isfile(remote): def test_download_dir(remote, tmpdir): path = str(tmpdir / "data") to_info = PathInfo(path) - remote.tree.download(remote.path_info / "data" / "", to_info) + remote.tree.download(remote.path_info / "data", to_info) assert os.path.isdir(path) data_dir = tmpdir / "data" assert len(list(walk_files(path))) == 7 From e4c3a901da2015e84abbd1b831f6b6a98538c273 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Fri, 3 Jul 2020 01:17:31 +0900 Subject: [PATCH 3/3] tests: add unit test for BaseRemoteTree.list_paths --- tests/unit/remote/test_base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index c2025f1c24..016d71271a 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -142,6 +142,20 @@ def test_list_hashes(dvc): assert hashes == ["123456"] +def test_list_paths(dvc): + tree = BaseRemoteTree(dvc, {}) + tree.path_info = PathInfo("foo") + + with mock.patch.object(tree, "walk_files", return_value=[]) as walk_mock: + for _ in tree.list_paths(): + pass + walk_mock.assert_called_with(tree.path_info, prefix=False) + + for _ in tree.list_paths(prefix="000"): + pass + walk_mock.assert_called_with(tree.path_info / "00" / "0", prefix=True) + + @pytest.mark.parametrize( "hash_, result", [(None, False), ("", False), ("3456.dir", True), ("3456", False)],