Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
2dd7d6c
initial commit of widerface dataset
jgbradley1 Oct 21, 2020
edc3a9a
Merge branch 'master' into add-widerface-dataset
jgbradley1 Oct 21, 2020
f9e31c9
comment out old code
jgbradley1 Oct 21, 2020
e4ee45f
improve parsing of annotation files
jgbradley1 Oct 22, 2020
9eae696
code cleanup and fix docstring comments
jgbradley1 Oct 22, 2020
1fbd0b7
speed up check for quota exceeded
jgbradley1 Oct 22, 2020
75c620d
cleanup print statements
jgbradley1 Oct 22, 2020
9637659
Merge branch 'master' into add-widerface-dataset
jgbradley1 Oct 22, 2020
a82d7b5
reformat code and remove print statements
jgbradley1 Oct 23, 2020
bba0db2
minor code cleanup and reformatting
jgbradley1 Oct 23, 2020
0c33f5f
add more comments
jgbradley1 Oct 23, 2020
a7c0b30
reuse variable
jgbradley1 Oct 23, 2020
40cde34
reverse formatting changes
jgbradley1 Oct 23, 2020
bb50718
Merge branch 'master' into add-widerface-dataset
jgbradley1 Oct 23, 2020
48a620f
fix flake8 errors
jgbradley1 Oct 26, 2020
bc8c35b
add type annotations
jgbradley1 Oct 26, 2020
e0b8664
fix mypy errors
Oct 26, 2020
40e9823
Merge branch 'master' into add-widerface-dataset
Oct 26, 2020
2e73130
add a base_folder to root directory
jgbradley1 Oct 26, 2020
4de06aa
some formatting fixes
jgbradley1 Oct 27, 2020
c28966f
Merge branch 'master' into add-widerface-dataset
jgbradley1 Oct 27, 2020
70dc752
GDrive threshold does not throw 403 error
jgbradley1 Oct 29, 2020
4d2506f
testing new download logic
jgbradley1 Oct 29, 2020
6f76fd7
cleanup logic for download and integrity check
jgbradley1 Oct 29, 2020
5a55195
Merge branch 'master' into add-widerface-dataset
jgbradley1 Oct 29, 2020
9c6d02c
use a better variable name
jgbradley1 Oct 29, 2020
57f3777
Merge branch 'add-widerface-dataset' of github.com:jgbradley1/vision …
jgbradley1 Oct 29, 2020
2f76d94
format fix
jgbradley1 Oct 29, 2020
515edd4
Merge branch 'master' into add-widerface-dataset
jgbradley1 Oct 31, 2020
a7f021c
reorder list in docstring
jgbradley1 Nov 1, 2020
35b6834
initial widerface unit test - fails on MD5 check
jgbradley1 Nov 1, 2020
f0f47c1
use list of dictionaries to store dataset
jgbradley1 Nov 1, 2020
463bde0
fix docstring formatting
jgbradley1 Nov 1, 2020
6ef5379
remove unnecessary error checking
jgbradley1 Nov 1, 2020
e844078
fix type checker error
jgbradley1 Nov 1, 2020
7a36e89
Merge branch 'master' into add-widerface-dataset
jgbradley1 Nov 9, 2020
da96b84
revert typo fix
jgbradley1 Nov 9, 2020
9d3cac7
rename var constants, use file context manager, verify str args
jgbradley1 Nov 10, 2020
fb846a2
fix flake8 error
jgbradley1 Nov 10, 2020
c11858f
fix checking target_type argument values
jgbradley1 Nov 10, 2020
7a2a2e7
Merge branch 'add-widerface-dataset' into widerface-unittest
jgbradley1 Nov 10, 2020
2e45680
create uncompressed dataset folders
jgbradley1 Nov 10, 2020
c8f3f37
cleanup unit tests for widerface
jgbradley1 Nov 10, 2020
ea09dab
use correct os function
jgbradley1 Nov 10, 2020
1f0223c
add more info to docstring
jgbradley1 Nov 10, 2020
2813d4e
disable unittests for windows
jgbradley1 Nov 11, 2020
9984146
Merge branch 'master' into add-widerface-dataset
jgbradley1 Nov 11, 2020
f5981ed
fix _check_integrity logic
jgbradley1 Nov 11, 2020
a4d3051
update docstring
jgbradley1 Nov 12, 2020
7f6c327
Merge branch 'master' into add-widerface-dataset
jgbradley1 Nov 21, 2020
6513f7f
Merge branch 'master' into add-widerface-dataset
jgbradley1 Dec 1, 2020
871088d
Merge branch 'master' into add-widerface-dataset
jgbradley1 Dec 27, 2020
1beba85
Merge branch 'master' into add-widerface-dataset
jgbradley1 Jan 3, 2021
7845d45
Merge branch 'add-widerface-dataset' of github.com:jgbradley1/vision …
jgbradley1 Jan 3, 2021
2dcd8c8
remove citation
jgbradley1 Jan 3, 2021
95d6708
remove target_type option
jgbradley1 Jan 4, 2021
00448e9
fix formatting issue
jgbradley1 Jan 4, 2021
752ed0d
remove comment and add more info to docstring
jgbradley1 Jan 4, 2021
31b0122
update type annotations
jgbradley1 Jan 7, 2021
f8ef3d3
Merge branch 'master' into add-widerface-dataset
jgbradley1 Jan 7, 2021
02ae27c
restart CI jobs
jgbradley1 Jan 7, 2021
6c4a1e8
Merge branch 'master' into add-widerface-dataset
jgbradley1 Jan 8, 2021
00f24a8
Merge branch 'master' into add-widerface-dataset
vfdev-5 Jan 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,73 @@ def _make_devkit_archive(root):
yield root


@contextlib.contextmanager
def widerface_root():
"""
Generates a dataset with the following folder structure and returns the path root:
<root>
└── widerface
├── wider_face_split
├── WIDER_train
├── WIDER_val
└── WIDER_test

The dataset consist of
1 image for each dataset split (train, val, test) and annotation files
for each split
"""

def _make_image(file):
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)

def _make_train_archive(root):
extracted_dir = os.path.join(root, 'WIDER_train', 'images', '0--Parade')
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_1.jpg'))

def _make_val_archive(root):
extracted_dir = os.path.join(root, 'WIDER_val', 'images', '0--Parade')
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_2.jpg'))

def _make_test_archive(root):
extracted_dir = os.path.join(root, 'WIDER_test', 'images', '0--Parade')
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_3.jpg'))

def _make_annotations_archive(root):
train_bbox_contents = '0--Parade/0_Parade_marchingband_1_1.jpg\n1\n449 330 122 149 0 0 0 0 0 0\n'
val_bbox_contents = '0--Parade/0_Parade_marchingband_1_2.jpg\n1\n501 160 285 443 0 0 0 0 0 0\n'
test_filelist_contents = '0--Parade/0_Parade_marchingband_1_3.jpg\n'
extracted_dir = os.path.join(root, 'wider_face_split')
os.mkdir(extracted_dir)

# bbox training file
bbox_file = os.path.join(extracted_dir, "wider_face_train_bbx_gt.txt")
with open(bbox_file, "w") as txt_file:
txt_file.write(train_bbox_contents)

# bbox validation file
bbox_file = os.path.join(extracted_dir, "wider_face_val_bbx_gt.txt")
with open(bbox_file, "w") as txt_file:
txt_file.write(val_bbox_contents)

# test filelist file
filelist_file = os.path.join(extracted_dir, "wider_face_test_filelist.txt")
with open(filelist_file, "w") as txt_file:
txt_file.write(test_filelist_contents)

with get_tmp_dir() as root:
root_base = os.path.join(root, "widerface")
os.mkdir(root_base)
_make_train_archive(root_base)
_make_val_archive(root_base)
_make_test_archive(root_base)
_make_annotations_archive(root_base)

yield root


@contextlib.contextmanager
def cityscapes_root():

Expand Down
22 changes: 21 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torchvision
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand Down Expand Up @@ -139,6 +139,26 @@ def test_imagenet(self, mock_verify):
dataset = torchvision.datasets.ImageNet(root, split='val')
self.generic_classification_dataset_test(dataset)

@mock.patch('torchvision.datasets.WIDERFace._check_integrity')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_widerface(self, mock_check_integrity):
mock_check_integrity.return_value = True
with widerface_root() as root:
dataset = torchvision.datasets.WIDERFace(root, split='train')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))

dataset = torchvision.datasets.WIDERFace(root, split='val')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))

dataset = torchvision.datasets.WIDERFace(root, split='test')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))

@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar10(self, mock_ext_check, mock_int_check):
Expand Down
6 changes: 4 additions & 2 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .imagenet import ImageNet
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .widerface import WIDERFace
from .sbd import SBDataset
from .vision import VisionDataset
from .usps import USPS
Expand All @@ -31,5 +32,6 @@
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS', 'Kinetics400', 'HMDB51', 'UCF101', 'Places365')
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101',
'Places365')
183 changes: 183 additions & 0 deletions torchvision/datasets/widerface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from PIL import Image
import os
from os.path import abspath, expanduser
import torch
from typing import Any, Callable, List, Dict, Optional, Tuple, Union
from .utils import check_integrity, download_file_from_google_drive, \
download_and_extract_archive, extract_archive, verify_str_arg
from .vision import VisionDataset


class WIDERFace(VisionDataset):
"""`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.

Args:
root (string): Root directory where images and annotations are downloaded to.
Expects the following folder structure if download=False:
<root>
└── widerface
├── wider_face_split ('wider_face_split.zip' if compressed)
├── WIDER_train ('WIDER_train.zip' if compressed)
├── WIDER_val ('WIDER_val.zip' if compressed)
└── WIDER_test ('WIDER_test.zip' if compressed)
split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
Defaults to ``train``.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

BASE_FOLDER = "widerface"
FILE_LIST = [
# File ID MD5 Hash Filename
("0B6eKvaijfFUDQUUwd21EckhUbWs", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
("0B6eKvaijfFUDd3dIRmpvSk8tLUk", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip")
]
ANNOTATIONS_FILE = (
"http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip",
"0e3767bcf0e326556d407bf5bff5d27c",
"wider_face_split.zip"
)

def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(WIDERFace, self).__init__(root=os.path.join(root, self.BASE_FOLDER),
transform=transform,
target_transform=target_transform)
# check arguments
self.split = verify_str_arg(split, "split", ("train", "val", "test"))

if download:
self.download()

if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. " +
"You can use download=True to download and prepare it")

self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
if self.split in ("train", "val"):
self.parse_train_val_annotations_file()
else:
self.parse_test_annotations_file()

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is a dict of annotations for all faces in the image.
target=None for the test split.
"""

# stay consistent with other datasets and return a PIL Image
img = Image.open(self.img_info[index]["img_path"])

if self.transform is not None:
img = self.transform(img)

target = None if self.split == "test" else self.img_info[index]["annotations"]
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self) -> int:
return len(self.img_info)

def extra_repr(self) -> str:
lines = ["Split: {split}"]
return '\n'.join(lines).format(**self.__dict__)

def parse_train_val_annotations_file(self) -> None:
filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
filepath = os.path.join(self.root, "wider_face_split", filename)

with open(filepath, "r") as f:
lines = f.readlines()
file_name_line, num_boxes_line, box_annotation_line = True, False, False
num_boxes, box_counter = 0, 0
labels = []
for line in lines:
line = line.rstrip()
if file_name_line:
img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
img_path = abspath(expanduser(img_path))
file_name_line = False
num_boxes_line = True
elif num_boxes_line:
num_boxes = int(line)
num_boxes_line = False
box_annotation_line = True
elif box_annotation_line:
box_counter += 1
line_split = line.split(" ")
line_values = [int(x) for x in line_split]
labels.append(line_values)
if box_counter >= num_boxes:
box_annotation_line = False
file_name_line = True
labels_tensor = torch.tensor(labels)
self.img_info.append({
"img_path": img_path,
"annotations": {"bbox": labels_tensor[:, 0:4], # x, y, width, height
"blur": labels_tensor[:, 4],
"expression": labels_tensor[:, 5],
"illumination": labels_tensor[:, 6],
"occlusion": labels_tensor[:, 7],
"pose": labels_tensor[:, 8],
"invalid": labels_tensor[:, 9]}
})
box_counter = 0
labels.clear()
else:
raise RuntimeError("Error parsing annotation file {}".format(filepath))

def parse_test_annotations_file(self) -> None:
filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
filepath = abspath(expanduser(filepath))
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
line = line.rstrip()
img_path = os.path.join(self.root, "WIDER_test", "images", line)
img_path = abspath(expanduser(img_path))
self.img_info.append({"img_path": img_path})

def _check_integrity(self) -> bool:
# Allow original archive to be deleted (zip). Only need the extracted images
all_files = self.FILE_LIST.copy()
all_files.append(self.ANNOTATIONS_FILE)
for (_, md5, filename) in all_files:
file, ext = os.path.splitext(filename)
extracted_dir = os.path.join(self.root, file)
if not os.path.exists(extracted_dir):
return False
return True

def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return

# download and extract image data
for (file_id, md5, filename) in self.FILE_LIST:
download_file_from_google_drive(file_id, self.root, filename, md5)
filepath = os.path.join(self.root, filename)
extract_archive(filepath)

# download and extract annotation files
download_and_extract_archive(url=self.ANNOTATIONS_FILE[0],
download_root=self.root,
md5=self.ANNOTATIONS_FILE[1])