24
24
25
25
26
26
class FlowDataset (ABC , VisionDataset ):
27
- # Some datasets like Kitti have a built-in valid mask , indicating which flow values are valid
28
- # For those we return (img1, img2, flow, valid ), and for the rest we return (img1, img2, flow),
29
- # and it's up to whatever consumes the dataset to decide what `valid` should be.
27
+ # Some datasets like Kitti have a built-in valid_flow_mask , indicating which flow values are valid
28
+ # For those we return (img1, img2, flow, valid_flow_mask ), and for the rest we return (img1, img2, flow),
29
+ # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
30
30
_has_builtin_flow_mask = False
31
31
32
32
def __init__ (self , root , transforms = None ):
@@ -38,11 +38,14 @@ def __init__(self, root, transforms=None):
38
38
self ._image_list = []
39
39
40
40
def _read_img (self , file_name ):
41
- return Image .open (file_name )
41
+ img = Image .open (file_name )
42
+ if img .mode != "RGB" :
43
+ img = img .convert ("RGB" )
44
+ return img
42
45
43
46
@abstractmethod
44
47
def _read_flow (self , file_name ):
45
- # Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True
48
+ # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
46
49
pass
47
50
48
51
def __getitem__ (self , index ):
@@ -53,23 +56,27 @@ def __getitem__(self, index):
53
56
if self ._flow_list : # it will be empty for some dataset when split="test"
54
57
flow = self ._read_flow (self ._flow_list [index ])
55
58
if self ._has_builtin_flow_mask :
56
- flow , valid = flow
59
+ flow , valid_flow_mask = flow
57
60
else :
58
- valid = None
61
+ valid_flow_mask = None
59
62
else :
60
- flow = valid = None
63
+ flow = valid_flow_mask = None
61
64
62
65
if self .transforms is not None :
63
- img1 , img2 , flow , valid = self .transforms (img1 , img2 , flow , valid )
66
+ img1 , img2 , flow , valid_flow_mask = self .transforms (img1 , img2 , flow , valid_flow_mask )
64
67
65
- if self ._has_builtin_flow_mask :
66
- return img1 , img2 , flow , valid
68
+ if self ._has_builtin_flow_mask or valid_flow_mask is not None :
69
+ # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
70
+ return img1 , img2 , flow , valid_flow_mask
67
71
else :
68
72
return img1 , img2 , flow
69
73
70
74
def __len__ (self ):
71
75
return len (self ._image_list )
72
76
77
+ def __rmul__ (self , v ):
78
+ return torch .utils .data .ConcatDataset ([self ] * v )
79
+
73
80
74
81
class Sintel (FlowDataset ):
75
82
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
@@ -107,8 +114,8 @@ class Sintel(FlowDataset):
107
114
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
108
115
details on the different passes.
109
116
transforms (callable, optional): A function/transform that takes in
110
- ``img1, img2, flow, valid `` and returns a transformed version.
111
- ``valid `` is expected for consistency with other datasets which
117
+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
118
+ ``valid_flow_mask `` is expected for consistency with other datasets which
112
119
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
113
120
"""
114
121
@@ -140,9 +147,11 @@ def __getitem__(self, index):
140
147
index(int): The index of the example to retrieve
141
148
142
149
Returns:
143
- tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
144
- The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
145
- 3-tuple with ``(img1, img2, None)`` is returned.
150
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
151
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
152
+ ``flow`` is None if ``split="test"``.
153
+ If a valid flow mask is generated within the ``transforms`` parameter,
154
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
146
155
"""
147
156
return super ().__getitem__ (index )
148
157
@@ -167,7 +176,7 @@ class KittiFlow(FlowDataset):
167
176
root (string): Root directory of the KittiFlow Dataset.
168
177
split (string, optional): The dataset split, either "train" (default) or "test"
169
178
transforms (callable, optional): A function/transform that takes in
170
- ``img1, img2, flow, valid `` and returns a transformed version.
179
+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
171
180
"""
172
181
173
182
_has_builtin_flow_mask = True
@@ -199,11 +208,11 @@ def __getitem__(self, index):
199
208
index(int): The index of the example to retrieve
200
209
201
210
Returns:
202
- tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
203
- valid)`` where ``valid `` is a numpy boolean mask of shape (H, W)
211
+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
212
+ where ``valid_flow_mask `` is a numpy boolean mask of shape (H, W)
204
213
indicating which flow values are valid. The flow is a numpy array of
205
- shape (2, H, W) and the images are PIL images. If `split="test"`, a
206
- 4-tuple with ``(img1, img2, None, None)`` is returned .
214
+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
215
+ ``split="test"`` .
207
216
"""
208
217
return super ().__getitem__ (index )
209
218
@@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset):
232
241
root (string): Root directory of the FlyingChairs Dataset.
233
242
split (string, optional): The dataset split, either "train" (default) or "val"
234
243
transforms (callable, optional): A function/transform that takes in
235
- ``img1, img2, flow, valid `` and returns a transformed version.
236
- ``valid `` is expected for consistency with other datasets which
244
+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
245
+ ``valid_flow_mask `` is expected for consistency with other datasets which
237
246
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
238
247
"""
239
248
@@ -269,6 +278,9 @@ def __getitem__(self, index):
269
278
Returns:
270
279
tuple: A 3-tuple with ``(img1, img2, flow)``.
271
280
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
281
+ ``flow`` is None if ``split="val"``.
282
+ If a valid flow mask is generated within the ``transforms`` parameter,
283
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
272
284
"""
273
285
return super ().__getitem__ (index )
274
286
@@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset):
300
312
details on the different passes.
301
313
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
302
314
transforms (callable, optional): A function/transform that takes in
303
- ``img1, img2, flow, valid `` and returns a transformed version.
304
- ``valid `` is expected for consistency with other datasets which
315
+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
316
+ ``valid_flow_mask `` is expected for consistency with other datasets which
305
317
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
306
318
"""
307
319
@@ -357,6 +369,9 @@ def __getitem__(self, index):
357
369
Returns:
358
370
tuple: A 3-tuple with ``(img1, img2, flow)``.
359
371
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
372
+ ``flow`` is None if ``split="test"``.
373
+ If a valid flow mask is generated within the ``transforms`` parameter,
374
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
360
375
"""
361
376
return super ().__getitem__ (index )
362
377
@@ -382,7 +397,7 @@ class HD1K(FlowDataset):
382
397
root (string): Root directory of the HD1K Dataset.
383
398
split (string, optional): The dataset split, either "train" (default) or "test"
384
399
transforms (callable, optional): A function/transform that takes in
385
- ``img1, img2, flow, valid `` and returns a transformed version.
400
+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
386
401
"""
387
402
388
403
_has_builtin_flow_mask = True
@@ -422,11 +437,11 @@ def __getitem__(self, index):
422
437
index(int): The index of the example to retrieve
423
438
424
439
Returns:
425
- tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
426
- valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
440
+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
441
+ is a numpy boolean mask of shape (H, W)
427
442
indicating which flow values are valid. The flow is a numpy array of
428
- shape (2, H, W) and the images are PIL images. If `split="test"`, a
429
- 4-tuple with ``(img1, img2, None, None)`` is returned .
443
+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
444
+ ``split="test"`` .
430
445
"""
431
446
return super ().__getitem__ (index )
432
447
@@ -451,11 +466,12 @@ def _read_flo(file_name):
451
466
def _read_16bits_png_with_flow_and_valid_mask (file_name ):
452
467
453
468
flow_and_valid = _read_png_16 (file_name ).to (torch .float32 )
454
- flow , valid = flow_and_valid [:2 , :, :], flow_and_valid [2 , :, :]
469
+ flow , valid_flow_mask = flow_and_valid [:2 , :, :], flow_and_valid [2 , :, :]
455
470
flow = (flow - 2 ** 15 ) / 64 # This conversion is explained somewhere on the kitti archive
471
+ valid_flow_mask = valid_flow_mask .bool ()
456
472
457
473
# For consistency with other datasets, we convert to numpy
458
- return flow .numpy (), valid .numpy ()
474
+ return flow .numpy (), valid_flow_mask .numpy ()
459
475
460
476
461
477
def _read_pfm (file_name ):
0 commit comments