Skip to content

Commit 01ffb3a

Browse files
authored
Add RAFT model for optical flow (#5022)
1 parent 9b57de6 commit 01ffb3a

File tree

8 files changed

+752
-4
lines changed

8 files changed

+752
-4
lines changed

docs/source/models.rst

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Models and pre-trained weights
77
The ``torchvision.models`` subpackage contains definitions of models for addressing
88
different tasks, including: image classification, pixelwise semantic
99
segmentation, object detection, instance segmentation, person
10-
keypoint detection and video classification.
10+
keypoint detection, video classification, and optical flow.
1111

1212
.. note ::
1313
Backward compatibility is guaranteed for loading a serialized
@@ -798,3 +798,16 @@ ResNet (2+1)D
798798
:template: function.rst
799799

800800
torchvision.models.video.r2plus1d_18
801+
802+
Optical flow
803+
============
804+
805+
Raft
806+
----
807+
808+
.. autosummary::
809+
:toctree: generated/
810+
:template: function.rst
811+
812+
torchvision.models.optical_flow.raft_large
813+
torchvision.models.optical_flow.raft_small
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _get_expected_file(name=None):
9393
return expected_file
9494

9595

96-
def _assert_expected(output, name, prec):
96+
def _assert_expected(output, name, prec=None, atol=None, rtol=None):
9797
"""Test that a python value matches the recorded contents of a file
9898
based on a "check" name. The value must be
9999
pickable with `torch.save`. This file
@@ -110,10 +110,11 @@ def _assert_expected(output, name, prec):
110110
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
111111
binary_size = os.path.getsize(expected_file)
112112
if binary_size > MAX_PICKLE_SIZE:
113-
raise RuntimeError(f"The output for {filename}, is larger than 50kb")
113+
raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
114114
else:
115115
expected = torch.load(expected_file)
116-
rtol = atol = prec
116+
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
117+
atol = atol or prec
117118
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
118119

119120

@@ -818,5 +819,33 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load
818819
assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]
819820

820821

822+
@needs_cuda
823+
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small))
824+
@pytest.mark.parametrize("scripted", (False, True))
825+
def test_raft(model_builder, scripted):
826+
827+
torch.manual_seed(0)
828+
829+
# We need very small images, otherwise the pickle size would exceed the 50KB
830+
# As a resut we need to override the correlation pyramid to not downsample
831+
# too much, otherwise we would get nan values (effective H and W would be
832+
# reduced to 1)
833+
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)
834+
835+
model = model_builder(corr_block=corr_block).eval().to("cuda")
836+
if scripted:
837+
model = torch.jit.script(model)
838+
839+
bs = 1
840+
img1 = torch.rand(bs, 3, 80, 72).cuda()
841+
img2 = torch.rand(bs, 3, 80, 72).cuda()
842+
843+
preds = model(img1, img2)
844+
flow_pred = preds[-1]
845+
# Tolerance is fairly high, but there are 2 * H * W outputs to check
846+
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
847+
_assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1)
848+
849+
821850
if __name__ == "__main__":
822851
pytest.main([__file__])

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .regnet import *
1313
from . import detection
1414
from . import feature_extraction
15+
from . import optical_flow
1516
from . import quantization
1617
from . import segmentation
1718
from . import video
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .raft import RAFT, raft_large, raft_small
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import Tensor
6+
7+
8+
def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None):
9+
"""Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates."""
10+
h, w = img.shape[-2:]
11+
12+
xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
13+
xgrid = 2 * xgrid / (w - 1) - 1
14+
ygrid = 2 * ygrid / (h - 1) - 1
15+
normalized_grid = torch.cat([xgrid, ygrid], dim=-1)
16+
17+
return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)
18+
19+
20+
def make_coords_grid(batch_size: int, h: int, w: int):
21+
coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
22+
coords = torch.stack(coords[::-1], dim=0).float()
23+
return coords[None].repeat(batch_size, 1, 1, 1)
24+
25+
26+
def upsample_flow(flow, up_mask: Optional[Tensor] = None):
27+
"""Upsample flow by a factor of 8.
28+
29+
If up_mask is None we just interpolate.
30+
If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
31+
Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
32+
"""
33+
batch_size, _, h, w = flow.shape
34+
new_h, new_w = h * 8, w * 8
35+
36+
if up_mask is None:
37+
return 8 * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)
38+
39+
up_mask = up_mask.view(batch_size, 1, 9, 8, 8, h, w)
40+
up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1
41+
42+
upsampled_flow = F.unfold(8 * flow, kernel_size=3, padding=1).view(batch_size, 2, 9, 1, 1, h, w)
43+
upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)
44+
45+
return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, 2, new_h, new_w)

0 commit comments

Comments
 (0)