Skip to content

Commit e489abc

Browse files
bpinayafmassa
authored andcommitted
VOCSegmentation, VOCDetection, linting passing, examples. (#663)
* VOC Dataset, linted, flak8 passing, samples on gist. * Double backtick on values. * Apply suggestions from code review Add suggestions from @ellisbrown, using dict of dicts instead of array index. Co-Authored-By: bpinaya <[email protected]> * Fixed errors with the new comments. * Added documentation on RST * Removed getBB, added parse_voc_xml, variable naming change. * Removed unused variable, removed VOC_CLASSES, two new gists for test.
1 parent 4cc6f45 commit e489abc

File tree

3 files changed

+259
-1
lines changed

3 files changed

+259
-1
lines changed

docs/source/datasets.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,15 @@ Flickr
149149
.. autoclass:: Flickr30k
150150
:members: __getitem__
151151
:special-members:
152+
153+
VOC
154+
~~~~~~
155+
156+
157+
.. autoclass:: VOCSegmentation
158+
:members: __getitem__
159+
:special-members:
160+
161+
.. autoclass:: VOCDetection
162+
:members: __getitem__
163+
:special-members:

torchvision/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from .omniglot import Omniglot
1212
from .sbu import SBU
1313
from .flickr import Flickr8k, Flickr30k
14+
from .voc import VOCSegmentation, VOCDetection
1415

1516
__all__ = ('LSUN', 'LSUNClass',
1617
'ImageFolder', 'DatasetFolder', 'FakeData',
1718
'CocoCaptions', 'CocoDetection',
1819
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
1920
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
20-
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k')
21+
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
22+
'VOCSegmentation', 'VOCDetection')

torchvision/datasets/voc.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import os
2+
import sys
3+
import tarfile
4+
import collections
5+
import torch.utils.data as data
6+
if sys.version_info[0] == 2:
7+
import xml.etree.cElementTree as ET
8+
else:
9+
import xml.etree.ElementTree as ET
10+
11+
from PIL import Image
12+
from .utils import download_url, check_integrity
13+
14+
DATASET_YEAR_DICT = {
15+
'2012': {
16+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
17+
'filename': 'VOCtrainval_11-May-2012.tar',
18+
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
19+
'base_dir': 'VOCdevkit/VOC2012'
20+
},
21+
'2011': {
22+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
23+
'filename': 'VOCtrainval_25-May-2011.tar',
24+
'md5': '6c3384ef61512963050cb5d687e5bf1e',
25+
'base_dir': 'TrainVal/VOCdevkit/VOC2011'
26+
},
27+
'2010': {
28+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
29+
'filename': 'VOCtrainval_03-May-2010.tar',
30+
'md5': 'da459979d0c395079b5c75ee67908abb',
31+
'base_dir': 'VOCdevkit/VOC2010'
32+
},
33+
'2009': {
34+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
35+
'filename': 'VOCtrainval_11-May-2009.tar',
36+
'md5': '59065e4b188729180974ef6572f6a212',
37+
'base_dir': 'VOCdevkit/VOC2009'
38+
},
39+
'2008': {
40+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
41+
'filename': 'VOCtrainval_11-May-2012.tar',
42+
'md5': '2629fa636546599198acfcfbfcf1904a',
43+
'base_dir': 'VOCdevkit/VOC2008'
44+
},
45+
'2007': {
46+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
47+
'filename': 'VOCtrainval_06-Nov-2007.tar',
48+
'md5': 'c52e279531787c972589f7e41ab4ae64',
49+
'base_dir': 'VOCdevkit/VOC2007'
50+
}
51+
}
52+
53+
54+
class VOCSegmentation(data.Dataset):
55+
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
56+
57+
Args:
58+
root (string): Root directory of the VOC Dataset.
59+
year (string, optional): The dataset year, supports years 2007 to 2012.
60+
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
61+
download (bool, optional): If true, downloads the dataset from the internet and
62+
puts it in root directory. If dataset is already downloaded, it is not
63+
downloaded again.
64+
transform (callable, optional): A function/transform that takes in an PIL image
65+
and returns a transformed version. E.g, ``transforms.RandomCrop``
66+
target_transform (callable, optional): A function/transform that takes in the
67+
target and transforms it.
68+
"""
69+
70+
def __init__(self,
71+
root,
72+
year='2012',
73+
image_set='train',
74+
download=False,
75+
transform=None,
76+
target_transform=None):
77+
self.root = root
78+
self.year = year
79+
self.url = DATASET_YEAR_DICT[year]['url']
80+
self.filename = DATASET_YEAR_DICT[year]['filename']
81+
self.md5 = DATASET_YEAR_DICT[year]['md5']
82+
self.transform = transform
83+
self.target_transform = target_transform
84+
self.image_set = image_set
85+
base_dir = DATASET_YEAR_DICT[year]['base_dir']
86+
voc_root = os.path.join(self.root, base_dir)
87+
image_dir = os.path.join(voc_root, 'JPEGImages')
88+
mask_dir = os.path.join(voc_root, 'SegmentationClass')
89+
90+
if download:
91+
download_extract(self.url, self.root, self.filename, self.md5)
92+
93+
if not os.path.isdir(voc_root):
94+
raise RuntimeError('Dataset not found or corrupted.' +
95+
' You can use download=True to download it')
96+
97+
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
98+
99+
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
100+
101+
if not os.path.exists(split_f):
102+
raise ValueError(
103+
'Wrong image_set entered! Please use image_set="train" '
104+
'or image_set="trainval" or image_set="val"')
105+
106+
with open(os.path.join(split_f), "r") as f:
107+
file_names = [x.strip() for x in f.readlines()]
108+
109+
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
110+
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
111+
assert (len(self.images) == len(self.masks))
112+
113+
def __getitem__(self, index):
114+
"""
115+
Args:
116+
index (int): Index
117+
118+
Returns:
119+
tuple: (image, target) where target is the image segmentation.
120+
"""
121+
img = Image.open(self.images[index]).convert('RGB')
122+
target = Image.open(self.masks[index])
123+
124+
if self.transform is not None:
125+
img = self.transform(img)
126+
127+
if self.target_transform is not None:
128+
target = self.target_transform(target)
129+
130+
return img, target
131+
132+
def __len__(self):
133+
return len(self.images)
134+
135+
136+
class VOCDetection(data.Dataset):
137+
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
138+
139+
Args:
140+
root (string): Root directory of the VOC Dataset.
141+
year (string, optional): The dataset year, supports years 2007 to 2012.
142+
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
143+
download (bool, optional): If true, downloads the dataset from the internet and
144+
puts it in root directory. If dataset is already downloaded, it is not
145+
downloaded again.
146+
(default: alphabetic indexing of VOC's 20 classes).
147+
transform (callable, optional): A function/transform that takes in an PIL image
148+
and returns a transformed version. E.g, ``transforms.RandomCrop``
149+
target_transform (callable, required): A function/transform that takes in the
150+
target and transforms it.
151+
"""
152+
153+
def __init__(self,
154+
root,
155+
year='2012',
156+
image_set='train',
157+
download=False,
158+
transform=None,
159+
target_transform=None):
160+
self.root = root
161+
self.year = year
162+
self.url = DATASET_YEAR_DICT[year]['url']
163+
self.filename = DATASET_YEAR_DICT[year]['filename']
164+
self.md5 = DATASET_YEAR_DICT[year]['md5']
165+
self.transform = transform
166+
self.target_transform = target_transform
167+
self.image_set = image_set
168+
169+
base_dir = DATASET_YEAR_DICT[year]['base_dir']
170+
voc_root = os.path.join(self.root, base_dir)
171+
image_dir = os.path.join(voc_root, 'JPEGImages')
172+
annotation_dir = os.path.join(voc_root, 'Annotations')
173+
174+
if download:
175+
download_extract(self.url, self.root, self.filename, self.md5)
176+
177+
if not os.path.isdir(voc_root):
178+
raise RuntimeError('Dataset not found or corrupted.' +
179+
' You can use download=True to download it')
180+
181+
splits_dir = os.path.join(voc_root, 'ImageSets/Main')
182+
183+
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
184+
185+
if not os.path.exists(split_f):
186+
raise ValueError(
187+
'Wrong image_set entered! Please use image_set="train" '
188+
'or image_set="trainval" or image_set="val" or a valid'
189+
'image_set from the VOC ImageSets/Main folder.')
190+
191+
with open(os.path.join(split_f), "r") as f:
192+
file_names = [x.strip() for x in f.readlines()]
193+
194+
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
195+
self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
196+
assert (len(self.images) == len(self.annotations))
197+
198+
def __getitem__(self, index):
199+
"""
200+
Args:
201+
index (int): Index
202+
203+
Returns:
204+
tuple: (image, target) where target is a dictionary of the XML tree.
205+
"""
206+
img = Image.open(self.images[index]).convert('RGB')
207+
target = self.parse_voc_xml(
208+
ET.parse(self.annotations[index]).getroot())
209+
210+
if self.transform is not None:
211+
img = self.transform(img)
212+
213+
if self.target_transform is not None:
214+
target = self.target_transform(target)
215+
216+
return img, target
217+
218+
def __len__(self):
219+
return len(self.images)
220+
221+
def parse_voc_xml(self, node):
222+
voc_dict = {}
223+
children = list(node)
224+
if children:
225+
def_dic = collections.defaultdict(list)
226+
for dc in map(self.parse_voc_xml, children):
227+
for ind, v in dc.items():
228+
def_dic[ind].append(v)
229+
voc_dict = {
230+
node.tag:
231+
{ind: v[0] if len(v) == 1 else v
232+
for ind, v in def_dic.items()}
233+
}
234+
if node.text:
235+
text = node.text.strip()
236+
if not children:
237+
voc_dict[node.tag] = text
238+
return voc_dict
239+
240+
241+
def download_extract(url, root, filename, md5):
242+
download_url(url, root, filename, md5)
243+
with tarfile.open(os.path.join(root, filename), "r") as tar:
244+
tar.extractall(path=root)

0 commit comments

Comments
 (0)