Skip to content

Commit d45ba7a

Browse files
committed
Merge branch 'master' into kinetics400-test
2 parents 855e122 + 2e8c124 commit d45ba7a

File tree

2 files changed

+93
-72
lines changed

2 files changed

+93
-72
lines changed

test/datasets_utils.py

Lines changed: 82 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import itertools
77
import os
88
import pathlib
9+
import random
10+
import string
911
import unittest
1012
import unittest.mock
1113
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
@@ -32,6 +34,7 @@
3234
"create_image_folder",
3335
"create_video_file",
3436
"create_video_folder",
37+
"create_random_string",
3538
]
3639

3740

@@ -93,14 +96,6 @@ def inner_wrapper(*args, **kwargs):
9396
return outer_wrapper
9497

9598

96-
# As of Python 3.7 this is provided by contextlib
97-
# https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
98-
# TODO: If the minimum Python requirement is >= 3.7, replace this
99-
@contextlib.contextmanager
100-
def nullcontext(enter_result=None):
101-
yield enter_result
102-
103-
10499
def test_all_configs(test):
105100
"""Decorator to run test against all configurations.
106101
@@ -116,7 +111,7 @@ def test_foo(self, config):
116111

117112
@functools.wraps(test)
118113
def wrapper(self):
119-
for config in self.CONFIGS:
114+
for config in self.CONFIGS or (self._DEFAULT_CONFIG,):
120115
with self.subTest(**config):
121116
test(self, config)
122117

@@ -207,6 +202,8 @@ def test_baz(self):
207202
CONFIGS = None
208203
REQUIRED_PACKAGES = None
209204

205+
_DEFAULT_CONFIG = None
206+
210207
_TRANSFORM_KWARGS = {
211208
"transform",
212209
"target_transform",
@@ -268,7 +265,7 @@ def create_dataset(
268265
self,
269266
config: Optional[Dict[str, Any]] = None,
270267
inject_fake_data: bool = True,
271-
disable_download_extract: Optional[bool] = None,
268+
patch_checks: Optional[bool] = None,
272269
**kwargs: Any,
273270
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
274271
r"""Create the dataset in a temporary directory.
@@ -278,8 +275,8 @@ def create_dataset(
278275
default configuration is used.
279276
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
280277
creating the dataset.
281-
disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
282-
the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
278+
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
279+
omitted defaults to the same value as ``inject_fake_data``.
283280
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
284281
overlap with ``config``.
285282
@@ -288,43 +285,28 @@ def create_dataset(
288285
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
289286
for details.
290287
"""
291-
if config is None:
292-
config = self.CONFIGS[0].copy()
288+
default_config = self._DEFAULT_CONFIG.copy()
289+
if config is not None:
290+
default_config.update(config)
291+
config = default_config
292+
293+
if patch_checks is None:
294+
patch_checks = inject_fake_data
293295

294296
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
297+
if "download" in self._HAS_SPECIAL_KWARG:
298+
special_kwargs["download"] = False
295299
config.update(other_kwargs)
296300

297-
if disable_download_extract is None:
298-
disable_download_extract = inject_fake_data
301+
patchers = self._patch_download_extract()
302+
if patch_checks:
303+
patchers.update(self._patch_checks())
299304

300305
with get_tmp_dir() as tmpdir:
301306
args = self.dataset_args(tmpdir, config)
307+
info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None
302308

303-
if inject_fake_data:
304-
info = self.inject_fake_data(tmpdir, config)
305-
if info is None:
306-
raise UsageError(
307-
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
308-
"examples for the current configuration."
309-
)
310-
elif isinstance(info, int):
311-
info = dict(num_examples=info)
312-
elif not isinstance(info, dict):
313-
raise UsageError(
314-
f"The additional information returned by the method 'inject_fake_data' must be either an "
315-
f"integer indicating the number of examples for the current configuration or a dictionary with "
316-
f"the same content. Got {type(info)} instead."
317-
)
318-
elif "num_examples" not in info:
319-
raise UsageError(
320-
"The information dictionary returned by the method 'inject_fake_data' must contain a "
321-
"'num_examples' field that holds the number of examples for the current configuration."
322-
)
323-
else:
324-
info = None
325-
326-
cm = self._disable_download_extract if disable_download_extract else nullcontext
327-
with cm(special_kwargs), disable_console_output():
309+
with self._maybe_apply_patches(patchers), disable_console_output():
328310
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)
329311

330312
yield dataset, info
@@ -352,19 +334,17 @@ def _verify_required_public_class_attributes(cls):
352334
@classmethod
353335
def _populate_private_class_attributes(cls):
354336
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
337+
338+
cls._DEFAULT_CONFIG = {
339+
kwarg: default
340+
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
341+
if kwarg not in cls._SPECIAL_KWARGS
342+
}
343+
355344
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}
356345

357346
@classmethod
358347
def _process_optional_public_class_attributes(cls):
359-
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
360-
if cls.CONFIGS is None:
361-
config = {
362-
kwarg: default
363-
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
364-
if kwarg not in cls._SPECIAL_KWARGS
365-
}
366-
cls.CONFIGS = (config,)
367-
368348
if cls.REQUIRED_PACKAGES is not None:
369349
try:
370350
for pkg in cls.REQUIRED_PACKAGES:
@@ -380,28 +360,44 @@ def _split_kwargs(self, kwargs):
380360
other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
381361
return special_kwargs, other_kwargs
382362

383-
@contextlib.contextmanager
384-
def _disable_download_extract(self, special_kwargs):
385-
inject_download_kwarg = "download" in self._HAS_SPECIAL_KWARG and "download" not in special_kwargs
386-
if inject_download_kwarg:
387-
special_kwargs["download"] = False
363+
def _inject_fake_data(self, tmpdir, config):
364+
info = self.inject_fake_data(tmpdir, config)
365+
if info is None:
366+
raise UsageError(
367+
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
368+
"examples for the current configuration."
369+
)
370+
elif isinstance(info, int):
371+
info = dict(num_examples=info)
372+
elif not isinstance(info, dict):
373+
raise UsageError(
374+
f"The additional information returned by the method 'inject_fake_data' must be either an "
375+
f"integer indicating the number of examples for the current configuration or a dictionary with "
376+
f"the same content. Got {type(info)} instead."
377+
)
378+
elif "num_examples" not in info:
379+
raise UsageError(
380+
"The information dictionary returned by the method 'inject_fake_data' must contain a "
381+
"'num_examples' field that holds the number of examples for the current configuration."
382+
)
383+
return info
384+
385+
def _patch_download_extract(self):
386+
module = inspect.getmodule(self.DATASET_CLASS).__name__
387+
return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}
388388

389+
def _patch_checks(self):
389390
module = inspect.getmodule(self.DATASET_CLASS).__name__
391+
return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}
392+
393+
@contextlib.contextmanager
394+
def _maybe_apply_patches(self, patchers):
390395
with contextlib.ExitStack() as stack:
391396
mocks = {}
392-
for function, kwargs in itertools.chain(
393-
zip(self._CHECK_FUNCTIONS, [dict(return_value=True)] * len(self._CHECK_FUNCTIONS)),
394-
zip(self._DOWNLOAD_EXTRACT_FUNCTIONS, [dict()] * len(self._DOWNLOAD_EXTRACT_FUNCTIONS)),
395-
):
397+
for patcher in patchers:
396398
with contextlib.suppress(AttributeError):
397-
patcher = unittest.mock.patch(f"{module}.{function}", **kwargs)
398-
mocks[function] = stack.enter_context(patcher)
399-
400-
try:
401-
yield mocks
402-
finally:
403-
if inject_download_kwarg:
404-
del special_kwargs["download"]
399+
mocks[patcher.target] = stack.enter_context(patcher)
400+
yield mocks
405401

406402
def test_not_found_or_corrupted(self):
407403
with self.assertRaises((FileNotFoundError, RuntimeError)):
@@ -469,13 +465,13 @@ def create_dataset(
469465
self,
470466
config: Optional[Dict[str, Any]] = None,
471467
inject_fake_data: bool = True,
472-
disable_download_extract: Optional[bool] = None,
468+
patch_checks: Optional[bool] = None,
473469
**kwargs: Any,
474470
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
475471
with super().create_dataset(
476472
config=config,
477473
inject_fake_data=inject_fake_data,
478-
disable_download_extract=disable_download_extract,
474+
patch_checks=patch_checks,
479475
**kwargs,
480476
) as (dataset, info):
481477
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
@@ -572,7 +568,7 @@ def create_image_file(
572568

573569
image = create_image_or_video_tensor(size)
574570
file = pathlib.Path(root) / name
575-
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file)
571+
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file, **kwargs)
576572
return file
577573

578574

@@ -708,6 +704,21 @@ def size(idx):
708704
os.makedirs(root)
709705

710706
return [
711-
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size)
707+
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
712708
for idx in range(num_examples)
713709
]
710+
711+
712+
def create_random_string(length: int, *digits: str) -> str:
713+
"""Create a random string.
714+
715+
Args:
716+
length (int): Number of characters in the generated string.
717+
*characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
718+
"""
719+
if not digits:
720+
digits = string.ascii_lowercase
721+
else:
722+
digits = "".join(itertools.chain(*digits))
723+
724+
return "".join(random.choice(digits) for _ in range(length))

torchvision/datasets/phototour.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,17 @@
1010

1111

1212
class PhotoTour(VisionDataset):
13-
"""`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
13+
"""`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.
14+
15+
.. note::
16+
17+
We only provide the newer version of the dataset, since the authors state that it
18+
19+
is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
20+
patches are centred on real interest point detections, rather than being projections of 3D points as is the
21+
case in the old dataset.
22+
23+
The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
1424
1525
1626
Args:

0 commit comments

Comments
 (0)