Skip to content

Commit 25257fa

Browse files
committed
Add initial structures folder for bounding boxes
1 parent 627dcfd commit 25257fa

File tree

4 files changed

+194
-6
lines changed

4 files changed

+194
-6
lines changed

torchvision/layers/nms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from torchvision import _C
33

44
nms = _C.nms
5-
nms.__doc__ = """
6-
This function performs Non-maximum suppresion"""
5+
# nms.__doc__ = """
6+
# This function performs Non-maximum suppresion"""

torchvision/structures/__init__.py

Whitespace-only changes.
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import torch
2+
3+
4+
# transpose
5+
FLIP_LEFT_RIGHT = 0
6+
FLIP_TOP_BOTTOM = 1
7+
8+
9+
class BBox(object):
10+
"""
11+
This class represents a set of bounding boxes.
12+
The bounding boxes are represented as a Nx4 Tensor.
13+
In order ot uniquely determine the bounding boxes with respect
14+
to an image, we also store the corresponding image dimensions.
15+
They can contain extra information that is specific to each bounding box, such as
16+
labels.
17+
"""
18+
def __init__(self, bbox, image_size, mode='xyxy'):
19+
device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device('cpu')
20+
bbox = torch.tensor(bbox, dtype=torch.float32, device=device)
21+
if bbox.ndimension() != 2:
22+
raise ValueError(
23+
"bbox should have 2 dimensions, got {}".format(bbox.ndimension()))
24+
if bbox.size(-1) != 4:
25+
raise ValueError(
26+
"last dimenion of bbox should have a "
27+
"size of 4, got {}".format(bbox.size(-1)))
28+
if mode not in ('xyxy', 'xywh'):
29+
raise ValueError(
30+
"mode should be 'xyxy' or 'xywh'")
31+
32+
self.bbox = bbox
33+
self.size = image_size # (image_width, image_height)
34+
self.mode = mode
35+
self.extra_fields = {}
36+
37+
def add_field(self, field, field_data):
38+
self.extra_fields[field] = field_data
39+
40+
def get_field(self, field):
41+
return self.extra_fields[field]
42+
43+
def fields(self):
44+
return list(self.extra_fields.keys())
45+
46+
def _copy_extra_fields(self, bbox):
47+
for k, v in bbox.extra_fields.items():
48+
self.extra_fields[k] = v
49+
50+
def convert(self, mode):
51+
if mode not in ('xyxy', 'xywh'):
52+
raise ValueError(
53+
"mode should be 'xyxy' or 'xywh'")
54+
if mode == self.mode:
55+
return self
56+
# we only have two modes, so don't need to check
57+
# self.mode
58+
xmin, ymin, xmax, ymax = self._split()
59+
if mode == 'xyxy':
60+
bbox = torch.cat(
61+
(xmin, ymin, xmax, ymax), dim=-1)
62+
bbox = BBox(bbox, self.size, mode=mode)
63+
else:
64+
bbox = torch.cat(
65+
(xmin, ymin, xmax - xmin, ymax - ymin), dim=-1)
66+
bbox = BBox(bbox, self.size, mode=mode)
67+
bbox._copy_extra_fields(self)
68+
return bbox
69+
70+
def _split(self):
71+
if self.mode == 'xyxy':
72+
xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
73+
return xmin, ymin, xmax, ymax
74+
elif self.mode == 'xywh':
75+
xmin, ymin, w, h = self.bbox.split(1, dim=-1)
76+
return xmin, ymin, xmin + w, ymin + h
77+
else:
78+
raise RuntimeError('Should not be here')
79+
80+
def resize(self, size, *args, **kwargs):
81+
"""
82+
Returns a resized copy of this bounding box
83+
84+
:param size: The requested size in pixels, as a 2-tuple:
85+
(width, height).
86+
"""
87+
88+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
89+
if ratios[0] == ratios[1]:
90+
ratio = ratios[0]
91+
scaled_box = self.bbox * ratio
92+
bbox = BBox(scaled_box, size, mode=self.mode)
93+
bbox._copy_extra_fields(self)
94+
return bbox
95+
96+
ratio_width, ratio_height = ratios
97+
xmin, ymin, xmax, ymax = self._split()
98+
scaled_xmin = xmin * ratio_width
99+
scaled_xmax = xmax * ratio_width
100+
scaled_ymin = ymin * ratio_height
101+
scaled_ymax = ymax * ratio_height
102+
scaled_box = torch.cat(
103+
(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1)
104+
bbox = BBox(scaled_box, size, mode='xyxy')
105+
bbox._copy_extra_fields(self)
106+
return bbox.convert(self.mode)
107+
108+
def transpose(self, method):
109+
"""
110+
Transpose bounding box (flip or rotate in 90 degree steps)
111+
:param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
112+
:py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
113+
:py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
114+
:py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.
115+
"""
116+
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
117+
raise NotImplementedError(
118+
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented")
119+
image_width, image_height = self.size
120+
xmin, ymin, xmax, ymax = self._split()
121+
if method == FLIP_LEFT_RIGHT:
122+
transposed_xmin = image_width - xmax
123+
transposed_xmax = image_width - xmin
124+
transposed_ymin = ymin
125+
transposed_ymax = ymax
126+
elif method == FLIP_TOP_BOTTOM:
127+
transposed_xmin = xmin
128+
transposed_xmax = xmax
129+
transposed_ymin = image_height - ymax
130+
transposed_ymax = image_height - ymin
131+
132+
transposed_boxes = torch.cat(
133+
(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1)
134+
bbox = BBox(transposed_boxes, self.size, mode='xyxy')
135+
bbox._copy_extra_fields(self)
136+
return bbox.convert(self.mode)
137+
138+
def crop(self, box):
139+
"""
140+
Cropss a rectangular region from this bounding box. The box is a
141+
4-tuple defining the left, upper, right, and lower pixel
142+
coordinate.
143+
"""
144+
xmin, ymin, xmax, ymax = self._split()
145+
w, h = box[2] - box[0], box[3] - box[1]
146+
cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
147+
cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
148+
cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
149+
cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)
150+
151+
# TODO should I filter empty boxes here?
152+
if False:
153+
is_empty = (cropped_xmin == cropped_xmax) | (cropped_ymin == cropped_ymax)
154+
155+
cropped_box = torch.cat(
156+
(cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1)
157+
bbox = BBox(cropped_box, (w, h), mode='xyxy')
158+
bbox._copy_extra_fields(self)
159+
return bbox.convert(self.mode)
160+
161+
def __repr__(self):
162+
s = self.__class__.__name__ + '('
163+
s += 'num_boxes={}, '.format(self.bbox.size(0))
164+
s += 'image_width={}, '.format(self.size[0])
165+
s += 'image_height={}, '.format(self.size[1])
166+
s += 'mode={})'.format(self.mode)
167+
return s
168+
169+
170+
if __name__ == '__main__':
171+
bbox = BBox([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10))
172+
s_bbox = bbox.resize((5, 5))
173+
print(s_bbox)
174+
print(s_bbox.bbox)
175+
176+
t_bbox = bbox.transpose(0)
177+
print(t_bbox)
178+
print(t_bbox.bbox)

torchvision/transforms/functional.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import collections
1414
import warnings
1515

16+
from ..structures.bounding_box import BBox
17+
1618

1719
def _is_pil_image(img):
1820
if accimage is not None:
@@ -167,7 +169,7 @@ def normalize(tensor, mean, std):
167169
return tensor
168170

169171

170-
def resize(img, size, interpolation=Image.BILINEAR):
172+
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
171173
"""Resize the input PIL Image to the given size.
172174
173175
Args:
@@ -183,15 +185,23 @@ def resize(img, size, interpolation=Image.BILINEAR):
183185
Returns:
184186
PIL Image: Resized image.
185187
"""
186-
if not _is_pil_image(img):
188+
if not (_is_pil_image(img) or isinstance(img, BBox)):
187189
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
188190
if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
189191
raise TypeError('Got inappropriate size arg: {}'.format(size))
190192

191193
if isinstance(size, int):
192194
w, h = img.size
195+
196+
if max_size is not None:
197+
min_original_size = float(min((w, h)))
198+
max_original_size = float(max((w, h)))
199+
if max_original_size / min_original_size * size > max_size:
200+
size = int(round(max_size * min_original_size / max_original_size))
201+
193202
if (w <= h and w == size) or (h <= w and h == size):
194203
return img
204+
195205
if w < h:
196206
ow = size
197207
oh = int(size * h / w)
@@ -291,7 +301,7 @@ def crop(img, i, j, h, w):
291301
Returns:
292302
PIL Image: Cropped image.
293303
"""
294-
if not _is_pil_image(img):
304+
if not (_is_pil_image(img) or isinstance(img, BBox)):
295305
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
296306

297307
return img.crop((j, i, j + w, i + h))
@@ -339,7 +349,7 @@ def hflip(img):
339349
Returns:
340350
PIL Image: Horizontall flipped image.
341351
"""
342-
if not _is_pil_image(img):
352+
if not (_is_pil_image(img) or isinstance(img, BBox)):
343353
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
344354

345355
return img.transpose(Image.FLIP_LEFT_RIGHT)

0 commit comments

Comments
 (0)