Skip to content

Commit 57aff36

Browse files
committed
Address comments
1 parent f186973 commit 57aff36

File tree

3 files changed

+28
-75
lines changed

3 files changed

+28
-75
lines changed

references/optical_flow/README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ recipe from https://github.com/princeton-vl/RAFT.
1616
torchrun --nproc_per_node 8 --nnodes 1 train.py \
1717
--dataset-root $dataset_root \
1818
--name $name_chairs \
19+
--model raft_large \
1920
--train-dataset chairs \
2021
--batch-size 2 \
2122
--lr 0.0004 \
@@ -28,6 +29,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
2829
torchrun --nproc_per_node 8 --nnodes 1 train.py \
2930
--dataset-root $dataset_root \
3031
--name $name_things \
32+
--model raft_large \
3133
--train-dataset things \
3234
--batch-size 2 \
3335
--lr 0.000125 \
@@ -42,14 +44,14 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
4244
### Evaluation
4345

4446
```
45-
torchrun --nproc_per_node 8 --nnodes 1 train.py --val-dataset sintel --batch-size 10 --dataset-root $dataset_root --model raft_large --pretrained
47+
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained
4648
```
4749

48-
This should give an epe of about 1.3825 on the clean pass and 2.7148 on the
50+
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
4951
final pass of Sintel. Results may vary slightly depending on the batch size and
50-
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`.
52+
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`:
5153

5254
```
53-
Sintel val clean epe: 1.3825 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3782 f1: 4.0234
54-
Sintel val final epe: 2.7148 1px: 0.8526 3px: 0.9203 5px: 0.9392 per_image_epe: 2.7199 f1: 7.6100
55+
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
56+
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
5557
```

test/test_prototype_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def test_naming_conventions(model_fn):
9191
+ TM.get_models_from_module(models.detection)
9292
+ TM.get_models_from_module(models.quantization)
9393
+ TM.get_models_from_module(models.segmentation)
94-
+ TM.get_models_from_module(models.video),
94+
+ TM.get_models_from_module(models.video)
95+
+ TM.get_models_from_module(models.optical_flow),
9596
)
9697
def test_schema_meta_validation(model_fn):
9798
classification_fields = ["size", "categories", "acc@1", "acc@5"]
@@ -102,6 +103,7 @@ def test_schema_meta_validation(model_fn):
102103
"quantization": classification_fields + ["backend", "quantization", "unquantized"],
103104
"segmentation": ["categories", "mIoU", "acc"],
104105
"video": classification_fields,
106+
"optical_flow": [],
105107
}
106108
module_name = model_fn.__module__.split(".")[-2]
107109
fields = set(defaults["all"] + defaults[module_name])

torchvision/prototype/models/optical_flow/raft.py

Lines changed: 18 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torchvision.models.optical_flow import RAFT
66
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
77
from torchvision.prototype.transforms import RaftEval
8+
from torchvision.transforms.functional import InterpolationMode
89

910
from .._api import WeightsEnum
1011
from .._api import Weights
@@ -20,12 +21,16 @@
2021
)
2122

2223

24+
_COMMON_META = {"interpolation": InterpolationMode.BILINEAR}
25+
26+
2327
class Raft_Large_Weights(WeightsEnum):
2428
C_T_V1 = Weights(
2529
# Chairs + Things, ported from original paper repo (raft-things.pth)
2630
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
2731
transforms=RaftEval,
2832
meta={
33+
**_COMMON_META,
2934
"recipe": "https://github.com/princeton-vl/RAFT",
3035
"sintel_train_cleanpass_epe": 1.4411,
3136
"sintel_train_finalpass_epe": 2.7894,
@@ -37,7 +42,8 @@ class Raft_Large_Weights(WeightsEnum):
3742
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
3843
transforms=RaftEval,
3944
meta={
40-
"recipe": "", # TODO
45+
**_COMMON_META,
46+
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
4147
"sintel_train_cleanpass_epe": 1.3822,
4248
"sintel_train_finalpass_epe": 2.7161,
4349
},
@@ -84,68 +90,6 @@ class Raft_Small_Weights(WeightsEnum):
8490
# default = C_T_V1
8591

8692

87-
def _raft_builder(
88-
*,
89-
weights,
90-
progress,
91-
# Feature encoder
92-
feature_encoder_layers,
93-
feature_encoder_block,
94-
feature_encoder_norm_layer,
95-
# Context encoder
96-
context_encoder_layers,
97-
context_encoder_block,
98-
context_encoder_norm_layer,
99-
# Correlation block
100-
corr_block_num_levels,
101-
corr_block_radius,
102-
# Motion encoder
103-
motion_encoder_corr_layers,
104-
motion_encoder_flow_layers,
105-
motion_encoder_out_channels,
106-
# Recurrent block
107-
recurrent_block_hidden_state_size,
108-
recurrent_block_kernel_size,
109-
recurrent_block_padding,
110-
# Flow Head
111-
flow_head_hidden_size,
112-
# Mask predictor
113-
use_mask_predictor,
114-
**kwargs,
115-
):
116-
model = _raft(
117-
# Feature encoder
118-
feature_encoder_layers=feature_encoder_layers,
119-
feature_encoder_block=feature_encoder_block,
120-
feature_encoder_norm_layer=feature_encoder_norm_layer,
121-
# Context encoder
122-
context_encoder_layers=context_encoder_layers,
123-
context_encoder_block=context_encoder_block,
124-
context_encoder_norm_layer=context_encoder_norm_layer,
125-
# Correlation block
126-
corr_block_num_levels=corr_block_num_levels,
127-
corr_block_radius=corr_block_radius,
128-
# Motion encoder
129-
motion_encoder_corr_layers=motion_encoder_corr_layers,
130-
motion_encoder_flow_layers=motion_encoder_flow_layers,
131-
motion_encoder_out_channels=motion_encoder_out_channels,
132-
# Recurrent block
133-
recurrent_block_hidden_state_size=recurrent_block_hidden_state_size,
134-
recurrent_block_kernel_size=recurrent_block_kernel_size,
135-
recurrent_block_padding=recurrent_block_padding,
136-
# Flow head
137-
flow_head_hidden_size=flow_head_hidden_size,
138-
# Mask predictor
139-
use_mask_predictor=use_mask_predictor,
140-
**kwargs,
141-
)
142-
143-
if weights is not None:
144-
model.load_state_dict(weights.get_state_dict(progress=progress))
145-
146-
return model
147-
148-
14993
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
15094
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
15195
"""RAFT model from
@@ -163,9 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
163107

164108
weights = Raft_Large_Weights.verify(weights)
165109

166-
return _raft_builder(
167-
weights=weights,
168-
progress=progress,
110+
model = _raft(
169111
# Feature encoder
170112
feature_encoder_layers=(64, 64, 96, 128, 256),
171113
feature_encoder_block=ResidualBlock,
@@ -192,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
192134
**kwargs,
193135
)
194136

137+
if weights is not None:
138+
model.load_state_dict(weights.get_state_dict(progress=progress))
139+
140+
return model
141+
195142

196143
@handle_legacy_interface(weights=("pretrained", None))
197144
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
@@ -211,9 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
211158

212159
weights = Raft_Small_Weights.verify(weights)
213160

214-
return _raft_builder(
215-
weights=weights,
216-
progress=progress,
161+
model = _raft(
217162
# Feature encoder
218163
feature_encoder_layers=(32, 32, 64, 96, 128),
219164
feature_encoder_block=BottleneckBlock,
@@ -239,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
239184
use_mask_predictor=False,
240185
**kwargs,
241186
)
187+
188+
if weights is not None:
189+
model.load_state_dict(weights.get_state_dict(progress=progress))
190+
return model

0 commit comments

Comments
 (0)