Skip to content

Commit 6a89b6d

Browse files
authored
Merge branch 'main' into raft_proto
2 parents aa66031 + a57e45c commit 6a89b6d

File tree

2 files changed

+57
-41
lines changed

2 files changed

+57
-41
lines changed

test/test_prototype_models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,27 +217,27 @@ def test_smoke():
217217
# With this filter, every unexpected warning will be turned into an error
218218
@pytest.mark.filterwarnings("error")
219219
class TestHandleLegacyInterface:
220-
class TestWeights(WeightsEnum):
220+
class ModelWeights(WeightsEnum):
221221
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
222222

223223
@pytest.mark.parametrize(
224224
"kwargs",
225225
[
226226
pytest.param(dict(), id="empty"),
227227
pytest.param(dict(weights=None), id="None"),
228-
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"),
228+
pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
229229
],
230230
)
231231
def test_no_warn(self, kwargs):
232-
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
232+
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
233233
def builder(*, weights=None):
234234
pass
235235

236236
builder(**kwargs)
237237

238238
@pytest.mark.parametrize("pretrained", (True, False))
239239
def test_pretrained_pos(self, pretrained):
240-
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
240+
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
241241
def builder(*, weights=None):
242242
pass
243243

@@ -246,7 +246,7 @@ def builder(*, weights=None):
246246

247247
@pytest.mark.parametrize("pretrained", (True, False))
248248
def test_pretrained_kw(self, pretrained):
249-
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
249+
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
250250
def builder(*, weights=None):
251251
pass
252252

@@ -256,12 +256,12 @@ def builder(*, weights=None):
256256
@pytest.mark.parametrize("pretrained", (True, False))
257257
@pytest.mark.parametrize("positional", (True, False))
258258
def test_equivalent_behavior_weights(self, pretrained, positional):
259-
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
259+
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
260260
def builder(*, weights=None):
261261
pass
262262

263263
args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
264-
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"):
264+
with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
265265
builder(*args, **kwargs)
266266

267267
def test_multi_params(self):
@@ -270,7 +270,7 @@ def test_multi_params(self):
270270

271271
@handle_legacy_interface(
272272
**{
273-
weights_param: (pretrained_param, self.TestWeights.Sentinel)
273+
weights_param: (pretrained_param, self.ModelWeights.Sentinel)
274274
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
275275
}
276276
)
@@ -285,7 +285,7 @@ def test_default_callable(self):
285285
@handle_legacy_interface(
286286
weights=(
287287
"pretrained",
288-
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None,
288+
lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
289289
)
290290
)
291291
def builder(*, weights=None, flag):

torchvision/datasets/_optical_flow.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525

2626
class FlowDataset(ABC, VisionDataset):
27-
# Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid
28-
# For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow),
29-
# and it's up to whatever consumes the dataset to decide what `valid` should be.
27+
# Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
28+
# For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
29+
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
3030
_has_builtin_flow_mask = False
3131

3232
def __init__(self, root, transforms=None):
@@ -38,11 +38,14 @@ def __init__(self, root, transforms=None):
3838
self._image_list = []
3939

4040
def _read_img(self, file_name):
41-
return Image.open(file_name)
41+
img = Image.open(file_name)
42+
if img.mode != "RGB":
43+
img = img.convert("RGB")
44+
return img
4245

4346
@abstractmethod
4447
def _read_flow(self, file_name):
45-
# Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True
48+
# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
4649
pass
4750

4851
def __getitem__(self, index):
@@ -53,23 +56,27 @@ def __getitem__(self, index):
5356
if self._flow_list: # it will be empty for some dataset when split="test"
5457
flow = self._read_flow(self._flow_list[index])
5558
if self._has_builtin_flow_mask:
56-
flow, valid = flow
59+
flow, valid_flow_mask = flow
5760
else:
58-
valid = None
61+
valid_flow_mask = None
5962
else:
60-
flow = valid = None
63+
flow = valid_flow_mask = None
6164

6265
if self.transforms is not None:
63-
img1, img2, flow, valid = self.transforms(img1, img2, flow, valid)
66+
img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
6467

65-
if self._has_builtin_flow_mask:
66-
return img1, img2, flow, valid
68+
if self._has_builtin_flow_mask or valid_flow_mask is not None:
69+
# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
70+
return img1, img2, flow, valid_flow_mask
6771
else:
6872
return img1, img2, flow
6973

7074
def __len__(self):
7175
return len(self._image_list)
7276

77+
def __rmul__(self, v):
78+
return torch.utils.data.ConcatDataset([self] * v)
79+
7380

7481
class Sintel(FlowDataset):
7582
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
@@ -107,8 +114,8 @@ class Sintel(FlowDataset):
107114
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
108115
details on the different passes.
109116
transforms (callable, optional): A function/transform that takes in
110-
``img1, img2, flow, valid`` and returns a transformed version.
111-
``valid`` is expected for consistency with other datasets which
117+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
118+
``valid_flow_mask`` is expected for consistency with other datasets which
112119
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
113120
"""
114121

@@ -140,9 +147,11 @@ def __getitem__(self, index):
140147
index(int): The index of the example to retrieve
141148
142149
Returns:
143-
tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
144-
The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
145-
3-tuple with ``(img1, img2, None)`` is returned.
150+
tuple: A 3-tuple with ``(img1, img2, flow)``.
151+
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
152+
``flow`` is None if ``split="test"``.
153+
If a valid flow mask is generated within the ``transforms`` parameter,
154+
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
146155
"""
147156
return super().__getitem__(index)
148157

@@ -167,7 +176,7 @@ class KittiFlow(FlowDataset):
167176
root (string): Root directory of the KittiFlow Dataset.
168177
split (string, optional): The dataset split, either "train" (default) or "test"
169178
transforms (callable, optional): A function/transform that takes in
170-
``img1, img2, flow, valid`` and returns a transformed version.
179+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
171180
"""
172181

173182
_has_builtin_flow_mask = True
@@ -199,11 +208,11 @@ def __getitem__(self, index):
199208
index(int): The index of the example to retrieve
200209
201210
Returns:
202-
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
203-
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
211+
tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
212+
where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
204213
indicating which flow values are valid. The flow is a numpy array of
205-
shape (2, H, W) and the images are PIL images. If `split="test"`, a
206-
4-tuple with ``(img1, img2, None, None)`` is returned.
214+
shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
215+
``split="test"``.
207216
"""
208217
return super().__getitem__(index)
209218

@@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset):
232241
root (string): Root directory of the FlyingChairs Dataset.
233242
split (string, optional): The dataset split, either "train" (default) or "val"
234243
transforms (callable, optional): A function/transform that takes in
235-
``img1, img2, flow, valid`` and returns a transformed version.
236-
``valid`` is expected for consistency with other datasets which
244+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
245+
``valid_flow_mask`` is expected for consistency with other datasets which
237246
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
238247
"""
239248

@@ -269,6 +278,9 @@ def __getitem__(self, index):
269278
Returns:
270279
tuple: A 3-tuple with ``(img1, img2, flow)``.
271280
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
281+
``flow`` is None if ``split="val"``.
282+
If a valid flow mask is generated within the ``transforms`` parameter,
283+
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
272284
"""
273285
return super().__getitem__(index)
274286

@@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset):
300312
details on the different passes.
301313
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
302314
transforms (callable, optional): A function/transform that takes in
303-
``img1, img2, flow, valid`` and returns a transformed version.
304-
``valid`` is expected for consistency with other datasets which
315+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
316+
``valid_flow_mask`` is expected for consistency with other datasets which
305317
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
306318
"""
307319

@@ -357,6 +369,9 @@ def __getitem__(self, index):
357369
Returns:
358370
tuple: A 3-tuple with ``(img1, img2, flow)``.
359371
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
372+
``flow`` is None if ``split="test"``.
373+
If a valid flow mask is generated within the ``transforms`` parameter,
374+
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
360375
"""
361376
return super().__getitem__(index)
362377

@@ -382,7 +397,7 @@ class HD1K(FlowDataset):
382397
root (string): Root directory of the HD1K Dataset.
383398
split (string, optional): The dataset split, either "train" (default) or "test"
384399
transforms (callable, optional): A function/transform that takes in
385-
``img1, img2, flow, valid`` and returns a transformed version.
400+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
386401
"""
387402

388403
_has_builtin_flow_mask = True
@@ -422,11 +437,11 @@ def __getitem__(self, index):
422437
index(int): The index of the example to retrieve
423438
424439
Returns:
425-
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
426-
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
440+
tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
441+
is a numpy boolean mask of shape (H, W)
427442
indicating which flow values are valid. The flow is a numpy array of
428-
shape (2, H, W) and the images are PIL images. If `split="test"`, a
429-
4-tuple with ``(img1, img2, None, None)`` is returned.
443+
shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
444+
``split="test"``.
430445
"""
431446
return super().__getitem__(index)
432447

@@ -451,11 +466,12 @@ def _read_flo(file_name):
451466
def _read_16bits_png_with_flow_and_valid_mask(file_name):
452467

453468
flow_and_valid = _read_png_16(file_name).to(torch.float32)
454-
flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
469+
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
455470
flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive
471+
valid_flow_mask = valid_flow_mask.bool()
456472

457473
# For consistency with other datasets, we convert to numpy
458-
return flow.numpy(), valid.numpy()
474+
return flow.numpy(), valid_flow_mask.numpy()
459475

460476

461477
def _read_pfm(file_name):

0 commit comments

Comments
 (0)