Skip to content

[Cherry-pick for 0.15] add docstring for dataset wrapper (#7333) #7344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,12 @@ Base classes for custom datasets
DatasetFolder
ImageFolder
VisionDataset

Transforms v2
-------------

.. autosummary::
:toctree: generated/
:template: function.rst

wrap_dataset_for_transforms_v2
65 changes: 64 additions & 1 deletion torchvision/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,71 @@
__all__ = ["wrap_dataset_for_transforms_v2"]


# TODO: naming!
def wrap_dataset_for_transforms_v2(dataset):
"""[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.

.. v2betastatus:: wrap_dataset_for_transforms_v2 function

Example:
>>> dataset = torchvision.datasets.CocoDetection(...)
>>> dataset = wrap_dataset_for_transforms_v2(dataset)

.. note::

For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset
configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you
to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so.

The dataset samples are wrapped according to the description below.

Special cases:

* :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``.
The original keys are preserved.
* :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
preserved.
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict
of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
in the corresponding ``torchvision.datapoints``. The original keys are preserved.
* :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
:class:`~torchvision.datapoints.Mask` datapoint.
* :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
:class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and
``"labels"``.
* :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.

Image classification datasets

This wrapper is a no-op for image classification datasets, since they were already fully supported by
:mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`.

Segmentation datasets

Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of
:class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item).

Video classification datasets

Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
:class:`~torchvision.datapoints.Video` while leaving the other items as is.

.. note::

Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative
``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`.

Args:
dataset: the dataset instance to wrap for compatibility with transforms v2.
"""
return VisionDatasetDatapointWrapper(dataset)


Expand Down