diff --git a/.gitignore b/.gitignore
index 11689d25a98..c02a6ab80e3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,4 +4,5 @@ torchvision.egg-info/
*/**/__pycache__
*/**/*.pyc
*/**/*~
-*~
\ No newline at end of file
+*~
+docs/build
\ No newline at end of file
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 00000000000..2ca4b0d71a2
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,27 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line.
+SPHINXOPTS =
+SPHINXBUILD = sphinx-build
+SPHINXPROJ = torchvision
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+docset: html
+ doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/vision/ --force $(BUILDDIR)/html/
+
+ # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution.
+ cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png
+ convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png
+
+.PHONY: help Makefile docset
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 00000000000..6429a151515
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,36 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+set SPHINXPROJ=torchvision
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
+
+:end
+popd
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 00000000000..09a5dd7ae4b
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,3 @@
+sphinx
+sphinxcontrib-googleanalytics
+-e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme
diff --git a/docs/source/_static/css/pytorch_theme.css b/docs/source/_static/css/pytorch_theme.css
new file mode 100644
index 00000000000..0e54497643c
--- /dev/null
+++ b/docs/source/_static/css/pytorch_theme.css
@@ -0,0 +1,118 @@
+body {
+ font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
+}
+
+/* Default header fonts are ugly */
+h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
+ font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
+}
+
+/* Use white for docs background */
+.wy-side-nav-search {
+ background-color: #fff;
+}
+
+.wy-nav-content-wrap, .wy-menu li.current > a {
+ background-color: #fff;
+}
+
+@media screen and (min-width: 1400px) {
+ .wy-nav-content-wrap {
+ background-color: rgba(0, 0, 0, 0.0470588);
+ }
+
+ .wy-nav-content {
+ background-color: #fff;
+ }
+}
+
+/* Fixes for mobile */
+.wy-nav-top {
+ background-color: #fff;
+ background-image: url('../img/pytorch-logo-dark.svg');
+ background-repeat: no-repeat;
+ background-position: center;
+ padding: 0;
+ margin: 0.4045em 0.809em;
+ color: #333;
+}
+
+.wy-nav-top > a {
+ display: none;
+}
+
+@media screen and (max-width: 768px) {
+ .wy-side-nav-search>a img.logo {
+ height: 60px;
+ }
+}
+
+/* This is needed to ensure that logo above search scales properly */
+.wy-side-nav-search a {
+ display: block;
+}
+
+/* This ensures that multiple constructors will remain in separate lines. */
+.rst-content dl:not(.docutils) dt {
+ display: table;
+}
+
+/* Use our red for literals (it's very similar to the original color) */
+.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal {
+ color: #F05732;
+}
+
+.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref,
+.rst-content code.xref, a .rst-content tt, a .rst-content code {
+ color: #404040;
+}
+
+/* Change link colors (except for the menu) */
+
+a {
+ color: #F05732;
+}
+
+a:hover {
+ color: #F05732;
+}
+
+
+a:visited {
+ color: #D44D2C;
+}
+
+.wy-menu a {
+ color: #b3b3b3;
+}
+
+.wy-menu a:hover {
+ color: #b3b3b3;
+}
+
+/* Default footer text is quite big */
+footer {
+ font-size: 80%;
+}
+
+footer .rst-footer-buttons {
+ font-size: 125%; /* revert footer settings - 1/80% = 125% */
+}
+
+footer p {
+ font-size: 100%;
+}
+
+/* For hidden headers that appear in TOC tree */
+/* see http://stackoverflow.com/a/32363545/3343043 */
+.rst-content .hidden-section {
+ display: none;
+}
+
+nav .hidden-section {
+ display: inherit;
+}
+
+.wy-side-nav-search>div.version {
+ color: #000;
+}
diff --git a/docs/source/_static/img/pytorch-logo-dark.png b/docs/source/_static/img/pytorch-logo-dark.png
new file mode 100644
index 00000000000..0288a564e22
Binary files /dev/null and b/docs/source/_static/img/pytorch-logo-dark.png differ
diff --git a/docs/source/_static/img/pytorch-logo-dark.svg b/docs/source/_static/img/pytorch-logo-dark.svg
new file mode 100644
index 00000000000..717a3ce942f
--- /dev/null
+++ b/docs/source/_static/img/pytorch-logo-dark.svg
@@ -0,0 +1,24 @@
+
+
+
diff --git a/docs/source/_static/img/pytorch-logo-flame.png b/docs/source/_static/img/pytorch-logo-flame.png
new file mode 100644
index 00000000000..370633f2ec2
Binary files /dev/null and b/docs/source/_static/img/pytorch-logo-flame.png differ
diff --git a/docs/source/_static/img/pytorch-logo-flame.svg b/docs/source/_static/img/pytorch-logo-flame.svg
new file mode 100644
index 00000000000..22d7228b4fa
--- /dev/null
+++ b/docs/source/_static/img/pytorch-logo-flame.svg
@@ -0,0 +1,33 @@
+
+
\ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 00000000000..3ca7882296a
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,250 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# PyTorch documentation build configuration file, created by
+# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
+#
+# This file is execfile()d with the current directory set to its
+# containing dir.
+#
+# Note that not all possible configuration values are present in this
+# autogenerated file.
+#
+# All configuration values have a default; values that are commented out
+# serve to show the default.
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# import sys
+# sys.path.insert(0, os.path.abspath('.'))
+import torch
+import torchvision
+import sphinx_rtd_theme
+
+
+# -- General configuration ------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+#
+# needs_sphinx = '1.0'
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.autosummary',
+ 'sphinx.ext.doctest',
+ 'sphinx.ext.intersphinx',
+ 'sphinx.ext.todo',
+ 'sphinx.ext.coverage',
+ 'sphinx.ext.mathjax',
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.viewcode',
+ 'sphinxcontrib.googleanalytics',
+]
+
+napoleon_use_ivar = True
+
+googleanalytics_id = 'UA-90545585-1'
+googleanalytics_enabled = True
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+# source_suffix = ['.rst', '.md']
+source_suffix = '.rst'
+
+# The master toctree document.
+master_doc = 'index'
+
+# General information about the project.
+project = 'Torchvision'
+copyright = '2017, Torch Contributors'
+author = 'Torch Contributors'
+
+# The version info for the project you're documenting, acts as replacement for
+# |version| and |release|, also used in various other places throughout the
+# built documents.
+#
+# The short X.Y version.
+# TODO: change to [:2] at v1.0
+version = 'master (' + torchvision.__version__ + ' )'
+# The full version, including alpha/beta/rc tags.
+# TODO: verify this works as expected
+release = 'master'
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This patterns also effect to html_static_path and html_extra_path
+exclude_patterns = []
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = 'sphinx'
+
+# If true, `todo` and `todoList` produce output, else they produce nothing.
+todo_include_todos = True
+
+
+# -- Options for HTML output ----------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_rtd_theme'
+html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
+
+# Theme options are theme-specific and customize the look and feel of a theme
+# further. For a list of options available for each theme, see the
+# documentation.
+#
+html_theme_options = {
+ 'collapse_navigation': False,
+ 'display_version': True,
+ 'logo_only': True,
+}
+
+html_logo = '_static/img/pytorch-logo-dark.svg'
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+
+# html_style_path = 'css/pytorch_theme.css'
+html_context = {
+ 'css_files': [
+ 'https://fonts.googleapis.com/css?family=Lato',
+ '_static/css/pytorch_theme.css'
+ ],
+}
+
+
+# -- Options for HTMLHelp output ------------------------------------------
+
+# Output file base name for HTML help builder.
+htmlhelp_basename = 'PyTorchdoc'
+
+
+# -- Options for LaTeX output ---------------------------------------------
+
+latex_elements = {
+ # The paper size ('letterpaper' or 'a4paper').
+ #
+ # 'papersize': 'letterpaper',
+
+ # The font size ('10pt', '11pt' or '12pt').
+ #
+ # 'pointsize': '10pt',
+
+ # Additional stuff for the LaTeX preamble.
+ #
+ # 'preamble': '',
+
+ # Latex figure (float) alignment
+ #
+ # 'figure_align': 'htbp',
+}
+
+# Grouping the document tree into LaTeX files. List of tuples
+# (source start file, target name, title,
+# author, documentclass [howto, manual, or own class]).
+latex_documents = [
+ (master_doc, 'pytorch.tex', 'torchvision Documentation',
+ 'Torch Contributors', 'manual'),
+]
+
+
+# -- Options for manual page output ---------------------------------------
+
+# One entry per manual page. List of tuples
+# (source start file, name, description, authors, manual section).
+man_pages = [
+ (master_doc, 'torchvision', 'torchvision Documentation',
+ [author], 1)
+]
+
+
+# -- Options for Texinfo output -------------------------------------------
+
+# Grouping the document tree into Texinfo files. List of tuples
+# (source start file, target name, title, author,
+# dir menu entry, description, category)
+texinfo_documents = [
+ (master_doc, 'torchvision', 'torchvision Documentation',
+ author, 'torchvision', 'One line description of project.',
+ 'Miscellaneous'),
+]
+
+
+# Example configuration for intersphinx: refer to the Python standard library.
+intersphinx_mapping = {
+ 'python': ('https://docs.python.org/', None),
+ 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
+}
+
+# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
+# See http://stackoverflow.com/a/41184353/3343043
+
+from docutils import nodes
+from sphinx.util.docfields import TypedField
+from sphinx import addnodes
+
+
+def patched_make_field(self, types, domain, items, **kw):
+ # `kw` catches `env=None` needed for newer sphinx while maintaining
+ # backwards compatibility when passed along further down!
+
+ # type: (List, unicode, Tuple) -> nodes.field
+ def handle_item(fieldarg, content):
+ par = nodes.paragraph()
+ par += addnodes.literal_strong('', fieldarg) # Patch: this line added
+ # par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
+ # addnodes.literal_strong))
+ if fieldarg in types:
+ par += nodes.Text(' (')
+ # NOTE: using .pop() here to prevent a single type node to be
+ # inserted twice into the doctree, which leads to
+ # inconsistencies later when references are resolved
+ fieldtype = types.pop(fieldarg)
+ if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
+ typename = u''.join(n.astext() for n in fieldtype)
+ typename = typename.replace('int', 'python:int')
+ typename = typename.replace('long', 'python:long')
+ typename = typename.replace('float', 'python:float')
+ typename = typename.replace('type', 'python:type')
+ par.extend(self.make_xrefs(self.typerolename, domain, typename,
+ addnodes.literal_emphasis, **kw))
+ else:
+ par += fieldtype
+ par += nodes.Text(')')
+ par += nodes.Text(' -- ')
+ par += content
+ return par
+
+ fieldname = nodes.field_name('', self.label)
+ if len(items) == 1 and self.can_collapse:
+ fieldarg, content = items[0]
+ bodynode = handle_item(fieldarg, content)
+ else:
+ bodynode = self.list_type()
+ for fieldarg, content in items:
+ bodynode += nodes.list_item('', handle_item(fieldarg, content))
+ fieldbody = nodes.field_body('', bodynode)
+ return nodes.field('', fieldname, fieldbody)
+
+
+TypedField.make_field = patched_make_field
diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst
new file mode 100644
index 00000000000..3152a82365c
--- /dev/null
+++ b/docs/source/datasets.rst
@@ -0,0 +1,112 @@
+torchvision.datasets
+====================
+
+All datasets are subclasses of :class:`torch.utils.data.Dataset`
+i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
+Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
+which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
+For example: ::
+
+ imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
+ data_loader = torch.utils.data.DataLoader(imagenet_data,
+ batch_size=4,
+ shuffle=True,
+ num_workers=args.nThreads)
+
+The following datasets are available:
+
+.. contents:: Datasets
+ :local:
+
+All the datasets have almost similar API. They all have two common arguments:
+``transform`` and ``target_transform`` to transform the input and target respectively.
+
+
+.. currentmodule:: torchvision.datasets
+
+
+MNIST
+~~~~~
+
+.. autoclass:: MNIST
+
+COCO
+~~~~
+
+.. note ::
+ These require the `COCO API to be installed`_
+
+.. _COCO API to be installed: https://github.com/pdollar/coco/tree/master/PythonAPI
+
+
+Captions
+^^^^^^^^
+
+.. autoclass:: CocoCaptions
+ :members: __getitem__
+ :special-members:
+
+
+Detection
+^^^^^^^^^
+
+.. autoclass:: CocoDetection
+ :members: __getitem__
+ :special-members:
+
+LSUN
+~~~~
+
+.. autoclass:: LSUN
+ :members: __getitem__
+ :special-members:
+
+ImageFolder
+~~~~~~~~~~~
+
+.. autoclass:: ImageFolder
+ :members: __getitem__
+ :special-members:
+
+
+Imagenet-12
+~~~~~~~~~~~
+
+This should simply be implemented with an ``ImageFolder`` dataset.
+The data is preprocessed `as described
+here `__
+
+`Here is an
+example `__.
+
+CIFAR
+~~~~~
+
+.. autoclass:: CIFAR10
+ :members: __getitem__
+ :special-members:
+
+STL10
+~~~~~
+
+
+.. autoclass:: STL10
+ :members: __getitem__
+ :special-members:
+
+SVHN
+~~~~~
+
+
+.. autoclass:: SVHN
+ :members: __getitem__
+ :special-members:
+
+PhotoTour
+~~~~~~~~~
+
+
+.. autoclass:: PhotoTour
+ :members: __getitem__
+ :special-members:
+
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 00000000000..f8f89f92629
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,17 @@
+torchvision
+===========
+
+The :mod:`torchvision` package consists of popular datasets, model
+architectures, and common image transformations for computer vision.
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Package Reference
+
+ datasets
+ models
+ transforms
+ utils
+
+.. automodule:: torchvision
+ :members:
diff --git a/docs/source/models.rst b/docs/source/models.rst
new file mode 100644
index 00000000000..fd5471561b8
--- /dev/null
+++ b/docs/source/models.rst
@@ -0,0 +1,135 @@
+torchvision.models
+==================
+
+The models subpackage contains definitions for the following model
+architectures:
+
+- `AlexNet`_
+- `VGG`_
+- `ResNet`_
+- `SqueezeNet`_
+- `DenseNet`_
+- `Inception`_ v3
+
+You can construct a model with random weights by calling its constructor:
+
+.. code:: python
+
+ import torchvision.models as models
+ resnet18 = models.resnet18()
+ alexnet = models.alexnet()
+ vgg16 = models.vgg16()
+ squeezenet = models.squeezenet1_0()
+ densenet = models.densenet_161()
+ inception = models.inception_v3()
+
+We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
+These can be constructed by passing ``pretrained=True``:
+
+.. code:: python
+
+ import torchvision.models as models
+ resnet18 = models.resnet18(pretrained=True)
+ alexnet = models.alexnet(pretrained=True)
+ squeezenet = models.squeezenet1_0(pretrained=True)
+ vgg16 = models.vgg16(pretrained=True)
+ densenet = models.densenet_161(pretrained=True)
+ inception = models.inception_v3(pretrained=True)
+
+All pre-trained models expect input images normalized in the same way,
+i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
+where H and W are expected to be at least 224.
+The images have to be loaded in to a range of [0, 1] and then normalized
+using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
+You can use the following transform to normalize::
+
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+
+An example of such normalization can be found in the imagenet example
+`here `_
+
+ImageNet 1-crop error rates (224x224)
+
+================================ ============= =============
+Network Top-1 error Top-5 error
+================================ ============= =============
+AlexNet 43.45 20.91
+VGG-11 30.98 11.37
+VGG-13 30.07 10.75
+VGG-16 28.41 9.62
+VGG-19 27.62 9.12
+VGG-11 with batch normalization 29.62 10.19
+VGG-13 with batch normalization 28.45 9.63
+VGG-16 with batch normalization 26.63 8.50
+VGG-19 with batch normalization 25.76 8.15
+ResNet-18 30.24 10.92
+ResNet-34 26.70 8.58
+ResNet-50 23.85 7.13
+ResNet-101 22.63 6.44
+ResNet-152 21.69 5.94
+SqueezeNet 1.0 41.90 19.58
+SqueezeNet 1.1 41.81 19.38
+Densenet-121 25.35 7.83
+Densenet-169 24.00 7.00
+Densenet-201 22.80 6.43
+Densenet-161 22.35 6.20
+Inception v3 22.55 6.44
+================================ ============= =============
+
+
+.. _AlexNet: https://arxiv.org/abs/1404.5997
+.. _VGG: https://arxiv.org/abs/1409.1556
+.. _ResNet: https://arxiv.org/abs/1512.03385
+.. _SqueezeNet: https://arxiv.org/abs/1602.07360
+.. _DenseNet: https://arxiv.org/abs/1608.06993
+.. _Inception: https://arxiv.org/abs/1512.00567
+
+.. currentmodule:: torchvision.models
+
+Alexnet
+-------
+
+.. autofunction:: alexnet
+
+VGG
+---
+
+.. autofunction:: vgg11
+.. autofunction:: vgg11_bn
+.. autofunction:: vgg13
+.. autofunction:: vgg13_bn
+.. autofunction:: vgg16
+.. autofunction:: vgg16_bn
+.. autofunction:: vgg19
+.. autofunction:: vgg19_bn
+
+
+ResNet
+------
+
+.. autofunction:: resnet18
+.. autofunction:: resnet34
+.. autofunction:: resnet50
+.. autofunction:: resnet101
+.. autofunction:: resnet152
+
+SqueezeNet
+----------
+
+.. autofunction:: squeezenet1_0
+.. autofunction:: squeezenet1_1
+
+DensetNet
+---------
+
+.. autofunction:: densenet121
+.. autofunction:: densenet169
+.. autofunction:: densenet161
+.. autofunction:: densenet201
+
+Inception v3
+------------
+
+.. autofunction:: inception_v3
+
diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst
new file mode 100644
index 00000000000..baf36b9b0e2
--- /dev/null
+++ b/docs/source/transforms.rst
@@ -0,0 +1,60 @@
+torchvision.transforms
+======================
+
+.. currentmodule:: torchvision.transforms
+
+Transforms are common image transforms. They can be chained together using :class:`Compose`
+
+.. autoclass:: Compose
+
+Transforms on PIL Image
+-----------------------
+
+.. autoclass:: Resize
+
+.. autoclass:: Scale
+
+.. autoclass:: CenterCrop
+
+.. autoclass:: RandomCrop
+
+.. autoclass:: RandomHorizontalFlip
+
+.. autoclass:: RandomVerticalFlip
+
+.. autoclass:: RandomResizedCrop
+
+.. autoclass:: RandomSizedCrop
+
+.. autoclass:: FiveCrop
+
+.. autoclass:: TenCrop
+
+.. autoclass:: Pad
+
+.. autoclass:: ColorJitter
+
+Transforms on torch.\*Tensor
+----------------------------
+
+.. autoclass:: Normalize
+ :members: __call__
+ :special-members:
+
+
+Conversion Transforms
+---------------------
+
+.. autoclass:: ToTensor
+ :members: __call__
+ :special-members:
+
+.. autoclass:: ToPILImage
+ :members: __call__
+ :special-members:
+
+Generic Transforms
+------------------
+
+.. autoclass:: Lambda
+
diff --git a/docs/source/utils.rst b/docs/source/utils.rst
new file mode 100644
index 00000000000..ad2fc91c897
--- /dev/null
+++ b/docs/source/utils.rst
@@ -0,0 +1,9 @@
+torchvision.utils
+=================
+
+.. currentmodule:: torchvision.utils
+
+.. autofunction:: make_grid
+
+.. autofunction:: save_image
+
diff --git a/setup.py b/setup.py
index 42aca1ff095..0f46586deec 100644
--- a/setup.py
+++ b/setup.py
@@ -1,13 +1,32 @@
#!/usr/bin/env python
import os
+import io
+import re
import shutil
import sys
from setuptools import setup, find_packages
+def read(*names, **kwargs):
+ with io.open(
+ os.path.join(os.path.dirname(__file__), *names),
+ encoding=kwargs.get("encoding", "utf8")
+ ) as fp:
+ return fp.read()
+
+
+def find_version(*file_paths):
+ version_file = read(*file_paths)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
+ version_file, re.M)
+ if version_match:
+ return version_match.group(1)
+ raise RuntimeError("Unable to find version string.")
+
+
readme = open('README.rst').read()
-VERSION = '0.1.9'
+VERSION = find_version('torchvision', '__init__.py')
requirements = [
'numpy',
diff --git a/torchvision/__init__.py b/torchvision/__init__.py
index 50d3dcf4fa9..2133197e1af 100644
--- a/torchvision/__init__.py
+++ b/torchvision/__init__.py
@@ -3,6 +3,7 @@
from torchvision import transforms
from torchvision import utils
+__version__ = '0.1.9'
_image_backend = 'PIL'
diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py
index 492714aaebc..079992e0269 100644
--- a/torchvision/models/__init__.py
+++ b/torchvision/models/__init__.py
@@ -1,88 +1,3 @@
-"""The models subpackage contains definitions for the following model
-architectures:
-
-- `AlexNet`_
-- `VGG`_
-- `ResNet`_
-- `SqueezeNet`_
-- `DenseNet`_
-- `Inception`_ v3
-
-You can construct a model with random weights by calling its constructor:
-
-.. code:: python
-
- import torchvision.models as models
- resnet18 = models.resnet18()
- alexnet = models.alexnet()
- vgg16 = models.vgg16()
- squeezenet = models.squeezenet1_0()
- densenet = models.densenet_161()
- inception = models.inception_v3()
-
-We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
-These can be constructed by passing ``pretrained=True``:
-
-.. code:: python
-
- import torchvision.models as models
- resnet18 = models.resnet18(pretrained=True)
- alexnet = models.alexnet(pretrained=True)
- squeezenet = models.squeezenet1_0(pretrained=True)
- vgg16 = models.vgg16(pretrained=True)
- densenet = models.densenet_161(pretrained=True)
- inception = models.inception_v3(pretrained=True)
-
-All pre-trained models expect input images normalized in the same way,
-i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
-where H and W are expected to be at least 224.
-The images have to be loaded in to a range of [0, 1] and then normalized
-using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
-You can use the following transform to normalize::
-
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])
-
-An example of such normalization can be found in the imagenet example
-`here `_
-
-ImageNet 1-crop error rates (224x224)
-
-================================ ============= =============
-Network Top-1 error Top-5 error
-================================ ============= =============
-ResNet-18 30.24 10.92
-ResNet-34 26.70 8.58
-ResNet-50 23.85 7.13
-ResNet-101 22.63 6.44
-ResNet-152 21.69 5.94
-Inception v3 22.55 6.44
-AlexNet 43.45 20.91
-VGG-11 30.98 11.37
-VGG-13 30.07 10.75
-VGG-16 28.41 9.62
-VGG-19 27.62 9.12
-VGG-11 with batch normalization 29.62 10.19
-VGG-13 with batch normalization 28.45 9.63
-VGG-16 with batch normalization 26.63 8.50
-VGG-19 with batch normalization 25.76 8.15
-SqueezeNet 1.0 41.90 19.58
-SqueezeNet 1.1 41.81 19.38
-Densenet-121 25.35 7.83
-Densenet-169 24.00 7.00
-Densenet-201 22.80 6.43
-Densenet-161 22.35 6.20
-================================ ============= =============
-
-
-.. _AlexNet: https://arxiv.org/abs/1404.5997
-.. _VGG: https://arxiv.org/abs/1409.1556
-.. _ResNet: https://arxiv.org/abs/1512.03385
-.. _SqueezeNet: https://arxiv.org/abs/1602.07360
-.. _DenseNet: https://arxiv.org/abs/1608.06993
-.. _Inception: https://arxiv.org/abs/1512.00567
-"""
-
from .alexnet import *
from .resnet import *
from .vgg import *
diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py
index 59a951a65b1..dfdff67aee6 100644
--- a/torchvision/models/densenet.py
+++ b/torchvision/models/densenet.py
@@ -17,7 +17,7 @@
def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from
- `"Densely Connected Convolutional Networks" `
+ `"Densely Connected Convolutional Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
@@ -31,7 +31,7 @@ def densenet121(pretrained=False, **kwargs):
def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from
- `"Densely Connected Convolutional Networks" `
+ `"Densely Connected Convolutional Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
@@ -45,7 +45,7 @@ def densenet169(pretrained=False, **kwargs):
def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from
- `"Densely Connected Convolutional Networks" `
+ `"Densely Connected Convolutional Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
@@ -59,7 +59,7 @@ def densenet201(pretrained=False, **kwargs):
def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from
- `"Densely Connected Convolutional Networks" `
+ `"Densely Connected Convolutional Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
@@ -111,7 +111,7 @@ def __init__(self, num_input_features, num_output_features):
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
- `"Densely Connected Convolutional Networks" `
+ `"Densely Connected Convolutional Networks" `_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
diff --git a/torchvision/transforms.py b/torchvision/transforms.py
index 69b1d6fe3a8..81aebb2cc97 100644
--- a/torchvision/transforms.py
+++ b/torchvision/transforms.py
@@ -30,12 +30,12 @@ def _is_numpy_image(img):
def to_tensor(pic):
- """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
See ``ToTensor`` for more details.
Args:
- pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
@@ -84,10 +84,10 @@ def to_pil_image(pic):
See ``ToPIlImage`` for more details.
Args:
- pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
+ pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns:
- PIL.Image: Image converted to PIL.Image.
+ PIL Image: Image converted to PIL Image.
"""
if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
@@ -143,10 +143,10 @@ def normalize(tensor, mean, std):
def resize(img, size, interpolation=Image.BILINEAR):
- """Resize the input PIL.Image to the given size.
+ """Resize the input PIL Image to the given size.
Args:
- img (PIL.Image): Image to be resized.
+ img (PIL Image): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaing
@@ -156,7 +156,7 @@ def resize(img, size, interpolation=Image.BILINEAR):
``PIL.Image.BILINEAR``
Returns:
- PIL.Image: Resized image.
+ PIL Image: Resized image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
@@ -186,10 +186,10 @@ def scale(*args, **kwargs):
def pad(img, padding, fill=0):
- """Pad the given PIL.Image on all sides with the given "pad" value.
+ """Pad the given PIL Image on all sides with the given "pad" value.
Args:
- img (PIL.Image): Image to be padded.
+ img (PIL Image): Image to be padded.
padding (int or tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
@@ -199,7 +199,7 @@ def pad(img, padding, fill=0):
length 3, it is used to fill R, G, B channels respectively.
Returns:
- PIL.Image: Padded image.
+ PIL Image: Padded image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
@@ -217,17 +217,17 @@ def pad(img, padding, fill=0):
def crop(img, i, j, h, w):
- """Crop the given PIL.Image.
+ """Crop the given PIL Image.
Args:
- img (PIL.Image): Image to be cropped.
+ img (PIL Image): Image to be cropped.
i: Upper pixel coordinate.
j: Left pixel coordinate.
h: Height of the cropped image.
w: Width of the cropped image.
Returns:
- PIL.Image: Cropped image.
+ PIL Image: Cropped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
@@ -236,12 +236,12 @@ def crop(img, i, j, h, w):
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
- """Crop the given PIL.Image and resize it to desired size.
+ """Crop the given PIL Image and resize it to desired size.
Notably used in RandomResizedCrop.
Args:
- img (PIL.Image): Image to be cropped.
+ img (PIL Image): Image to be cropped.
i: Upper pixel coordinate.
j: Left pixel coordinate.
h: Height of the cropped image.
@@ -250,7 +250,7 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``.
Returns:
- PIL.Image: Cropped image.
+ PIL Image: Cropped image.
"""
assert _is_pil_image(img), 'img should be PIL Image'
img = crop(img, i, j, h, w)
@@ -259,13 +259,13 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
def hflip(img):
- """Horizontally flip the given PIL.Image.
+ """Horizontally flip the given PIL Image.
Args:
- img (PIL.Image): Image to be flipped.
+ img (PIL Image): Image to be flipped.
Returns:
- PIL.Image: Horizontall flipped image.
+ PIL Image: Horizontall flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
@@ -274,13 +274,13 @@ def hflip(img):
def vflip(img):
- """Vertically flip the given PIL.Image.
+ """Vertically flip the given PIL Image.
Args:
- img (PIL.Image): Image to be flipped.
+ img (PIL Image): Image to be flipped.
Returns:
- PIL.Image: Vertically flipped image.
+ PIL Image: Vertically flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
@@ -289,10 +289,11 @@ def vflip(img):
def five_crop(img, size):
- """Crop the given PIL.Image into four corners and the central crop.
+ """Crop the given PIL Image into four corners and the central crop.
- Note: this transform returns a tuple of images and there may be a mismatch in the number of
- inputs and targets your `Dataset` returns.
+ .. Note::
+ This transform returns a tuple of images and there may be a
+ mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
@@ -321,11 +322,12 @@ def five_crop(img, size):
def ten_crop(img, size, vertical_flip=False):
- """Crop the given PIL.Image into four corners and the central crop plus the
+ """Crop the given PIL Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
- Note: this transform returns a tuple of images and there may be a mismatch in the number of
- inputs and targets your `Dataset` returns.
+ .. Note::
+ This transform returns a tuple of images and there may be a
+ mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
@@ -359,13 +361,13 @@ def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an Image.
Args:
- img (PIL.Image): PIL Image to be adjusted.
+ img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
- PIL.Image: Brightness adjusted image.
+ PIL Image: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
@@ -379,13 +381,13 @@ def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.
Args:
- img (PIL.Image): PIL Image to be adjusted.
+ img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
- PIL.Image: Contrast adjusted image.
+ PIL Image: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
@@ -399,13 +401,13 @@ def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.
Args:
- img (PIL.Image): PIL Image to be adjusted.
+ img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
- PIL.Image: Saturation adjusted image.
+ PIL Image: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
@@ -428,7 +430,7 @@ def adjust_hue(img, hue_factor):
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
- img (PIL.Image): PIL Image to be adjusted.
+ img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
@@ -436,7 +438,7 @@ def adjust_hue(img, hue_factor):
with complementary colors while 0 gives the original image.
Returns:
- PIL.Image: Hue adjusted image.
+ PIL Image: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
@@ -471,7 +473,7 @@ def adjust_gamma(img, gamma, gain=1):
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
Args:
- img (PIL.Image): PIL Image to be adjusted.
+ img (PIL Image): PIL Image to be adjusted.
gamma (float): Non negative real number. gamma larger than 1 make the
shadows darker, while gamma smaller than 1 make dark regions
lighter.
@@ -517,16 +519,16 @@ def __call__(self, img):
class ToTensor(object):
- """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
- Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
"""
Args:
- pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
@@ -538,16 +540,16 @@ class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
- H x W x C to a PIL.Image while preserving the value range.
+ H x W x C to a PIL Image while preserving the value range.
"""
def __call__(self, pic):
"""
Args:
- pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
+ pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns:
- PIL.Image: Image converted to PIL.Image.
+ PIL Image: Image converted to PIL Image.
"""
return to_pil_image(pic)
@@ -582,7 +584,7 @@ def __call__(self, tensor):
class Resize(object):
- """Resize the input PIL.Image to the given size.
+ """Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
@@ -602,15 +604,18 @@ def __init__(self, size, interpolation=Image.BILINEAR):
def __call__(self, img):
"""
Args:
- img (PIL.Image): Image to be scaled.
+ img (PIL Image): Image to be scaled.
Returns:
- PIL.Image: Rescaled image.
+ PIL Image: Rescaled image.
"""
return resize(img, self.size, self.interpolation)
class Scale(Resize):
+ """
+ Note: This transform is deprecated in favor of Resize.
+ """
def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.")
@@ -618,7 +623,7 @@ def __init__(self, *args, **kwargs):
class CenterCrop(object):
- """Crops the given PIL.Image at the center.
+ """Crops the given PIL Image at the center.
Args:
size (sequence or int): Desired output size of the crop. If size is an
@@ -637,7 +642,7 @@ def get_params(img, output_size):
"""Get parameters for ``crop`` for center crop.
Args:
- img (PIL.Image): Image to be cropped.
+ img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
@@ -652,17 +657,17 @@ def get_params(img, output_size):
def __call__(self, img):
"""
Args:
- img (PIL.Image): Image to be cropped.
+ img (PIL Image): Image to be cropped.
Returns:
- PIL.Image: Cropped image.
+ PIL Image: Cropped image.
"""
i, j, h, w = self.get_params(img, self.size)
return crop(img, i, j, h, w)
class Pad(object):
- """Pad the given PIL.Image on all sides with the given "pad" value.
+ """Pad the given PIL Image on all sides with the given "pad" value.
Args:
padding (int or tuple): Padding on each border. If a single int is provided this
@@ -687,10 +692,10 @@ def __init__(self, padding, fill=0):
def __call__(self, img):
"""
Args:
- img (PIL.Image): Image to be padded.
+ img (PIL Image): Image to be padded.
Returns:
- PIL.Image: Padded image.
+ PIL Image: Padded image.
"""
return pad(img, self.padding, self.fill)
@@ -711,7 +716,7 @@ def __call__(self, img):
class RandomCrop(object):
- """Crop the given PIL.Image at a random location.
+ """Crop the given PIL Image at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
@@ -735,7 +740,7 @@ def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
- img (PIL.Image): Image to be cropped.
+ img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
@@ -753,10 +758,10 @@ def get_params(img, output_size):
def __call__(self, img):
"""
Args:
- img (PIL.Image): Image to be cropped.
+ img (PIL Image): Image to be cropped.
Returns:
- PIL.Image: Cropped image.
+ PIL Image: Cropped image.
"""
if self.padding > 0:
img = pad(img, self.padding)
@@ -767,15 +772,15 @@ def __call__(self, img):
class RandomHorizontalFlip(object):
- """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
+ """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
def __call__(self, img):
"""
Args:
- img (PIL.Image): Image to be flipped.
+ img (PIL Image): Image to be flipped.
Returns:
- PIL.Image: Randomly flipped image.
+ PIL Image: Randomly flipped image.
"""
if random.random() < 0.5:
return hflip(img)
@@ -783,15 +788,15 @@ def __call__(self, img):
class RandomVerticalFlip(object):
- """Vertically flip the given PIL.Image randomly with a probability of 0.5."""
+ """Vertically flip the given PIL Image randomly with a probability of 0.5."""
def __call__(self, img):
"""
Args:
- img (PIL.Image): Image to be flipped.
+ img (PIL Image): Image to be flipped.
Returns:
- PIL.Image: Randomly flipped image.
+ PIL Image: Randomly flipped image.
"""
if random.random() < 0.5:
return vflip(img)
@@ -799,7 +804,7 @@ def __call__(self, img):
class RandomResizedCrop(object):
- """Crop the given PIL.Image to random size and aspect ratio.
+ """Crop the given PIL Image to random size and aspect ratio.
A crop of random size of (0.08 to 1.0) of the original size and a random
aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
@@ -820,7 +825,7 @@ def get_params(img):
"""Get parameters for ``crop`` for a random sized crop.
Args:
- img (PIL.Image): Image to be cropped.
+ img (PIL Image): Image to be cropped.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
@@ -851,16 +856,19 @@ def get_params(img):
def __call__(self, img):
"""
Args:
- img (PIL.Image): Image to be flipped.
+ img (PIL Image): Image to be flipped.
Returns:
- PIL.Image: Randomly cropped and resize image.
+ PIL Image: Randomly cropped and resize image.
"""
i, j, h, w = self.get_params(img)
return resized_crop(img, i, j, h, w, self.size, self.interpolation)
class RandomSizedCrop(RandomResizedCrop):
+ """
+ Note: This transform is deprecated in favor of RandomResizedCrop.
+ """
def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
"please use transforms.RandomResizedCrop instead.")
@@ -868,7 +876,7 @@ def __init__(self, *args, **kwargs):
class FiveCrop(object):
- """Crop the given PIL.Image into four corners and the central crop.abs
+ """Crop the given PIL Image into four corners and the central crop.abs
Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns.
@@ -892,7 +900,7 @@ def __call__(self, img):
class TenCrop(object):
- """Crop the given PIL.Image into four corners and the central crop plus the
+ """Crop the given PIL Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default)
Note: this transform returns a tuple of images and there may be a mismatch in the number of
@@ -972,10 +980,10 @@ def get_params(brightness, contrast, saturation, hue):
def __call__(self, img):
"""
Args:
- img (PIL.Image): Input image.
+ img (PIL Image): Input image.
Returns:
- PIL.Image: Color jittered image.
+ PIL Image: Color jittered image.
"""
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)