Skip to content

Commit 5b98c64

Browse files
committed
fix resource loading
1 parent afd8bc1 commit 5b98c64

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

torchvision/prototype/datasets/_builtin/ucf101.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5151
sha256="5c0d1a53b8ed364a2ac830a73f405e51bece7d98ce1254fd19ed4a36b224bd27",
5252
)
5353

54+
# The SSL certificate of the server is currently invalid, but downloading "unsafe" data is not supported yet
5455
videos = HttpResource(
5556
f"{url_root}/UCF101.rar",
5657
sha256="ca8dfadb4c891cb11316f94d52b6b0ac2a11994e67a0cae227180cd160bd8e55",
57-
extract=True,
5858
)
5959
videos._preprocess = self._extract_videos_archive
6060

torchvision/prototype/datasets/utils/_resource.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
extract_archive,
2121
_decompress,
2222
download_file_from_google_drive,
23+
tqdm,
2324
)
2425

2526

@@ -86,20 +87,30 @@ def load(
8687
root = pathlib.Path(root)
8788
path = root / self.file_name
8889
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
89-
# with no suffixes at all. Thus, we look for all paths that share the same name without suffixes as the raw
90-
# file.
91-
path_candidates = {file for file in path.parent.glob(path.name.replace("".join(path.suffixes), "") + "*")}
92-
# If we don't find anything, we try to download the raw file.
93-
if not path_candidates:
94-
path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
90+
# with no suffixes at all.
91+
stem = path.name.replace("".join(path.suffixes), "")
92+
93+
# In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since
94+
# extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive
95+
# is always extracted in a folder with the corresponding file name.
96+
folder_candidate = path.parent / stem
97+
if folder_candidate.exists() and folder_candidate.is_dir():
98+
return self._loader(path)
99+
100+
# If there is no folder, we look for all files that share the same stem as the raw file, but might have a
101+
# different suffix.
102+
file_candidates = {file for file in path.parent.glob(stem + ".*")}
103+
# If we don't find anything, we download the raw file.
104+
if not file_candidates:
105+
file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
95106
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
96-
if path_candidates == {path}:
107+
if file_candidates == {path}:
97108
if self._preprocess is not None:
98109
path = self._preprocess(path)
99-
# Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority
100-
# that we want.
110+
# Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we
111+
# want for the best I/O performance.
101112
else:
102-
path = min(path_candidates, key=lambda path: len(path.suffixes))
113+
path = min(file_candidates, key=lambda path: len(path.suffixes))
103114
return self._loader(path)
104115

105116
@abc.abstractmethod
@@ -117,7 +128,7 @@ def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool
117128
def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None:
118129
hash = hashlib.sha256()
119130
with open(path, "rb") as file:
120-
for chunk in iter(lambda: file.read(chunk_size), b""):
131+
for chunk in tqdm(iter(lambda: file.read(chunk_size), b"")):
121132
hash.update(chunk)
122133
sha256 = hash.hexdigest()
123134
if sha256 != self.sha256:

0 commit comments

Comments
 (0)