5
5
from torchvision .models .optical_flow import RAFT
6
6
from torchvision .models .optical_flow .raft import _raft , BottleneckBlock , ResidualBlock
7
7
from torchvision .prototype .transforms import RaftEval
8
+ from torchvision .transforms .functional import InterpolationMode
8
9
9
10
from .._api import WeightsEnum
10
11
from .._api import Weights
20
21
)
21
22
22
23
24
+ _COMMON_META = {"interpolation" : InterpolationMode .BILINEAR }
25
+
26
+
23
27
class Raft_Large_Weights (WeightsEnum ):
24
28
C_T_V1 = Weights (
25
29
# Chairs + Things, ported from original paper repo (raft-things.pth)
26
30
url = "https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth" ,
27
31
transforms = RaftEval ,
28
32
meta = {
33
+ ** _COMMON_META ,
29
34
"recipe" : "https://github.com/princeton-vl/RAFT" ,
30
35
"sintel_train_cleanpass_epe" : 1.4411 ,
31
36
"sintel_train_finalpass_epe" : 2.7894 ,
@@ -37,7 +42,8 @@ class Raft_Large_Weights(WeightsEnum):
37
42
url = "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth" ,
38
43
transforms = RaftEval ,
39
44
meta = {
40
- "recipe" : "" , # TODO
45
+ ** _COMMON_META ,
46
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/optical_flow" ,
41
47
"sintel_train_cleanpass_epe" : 1.3822 ,
42
48
"sintel_train_finalpass_epe" : 2.7161 ,
43
49
},
@@ -84,68 +90,6 @@ class Raft_Small_Weights(WeightsEnum):
84
90
# default = C_T_V1
85
91
86
92
87
- def _raft_builder (
88
- * ,
89
- weights ,
90
- progress ,
91
- # Feature encoder
92
- feature_encoder_layers ,
93
- feature_encoder_block ,
94
- feature_encoder_norm_layer ,
95
- # Context encoder
96
- context_encoder_layers ,
97
- context_encoder_block ,
98
- context_encoder_norm_layer ,
99
- # Correlation block
100
- corr_block_num_levels ,
101
- corr_block_radius ,
102
- # Motion encoder
103
- motion_encoder_corr_layers ,
104
- motion_encoder_flow_layers ,
105
- motion_encoder_out_channels ,
106
- # Recurrent block
107
- recurrent_block_hidden_state_size ,
108
- recurrent_block_kernel_size ,
109
- recurrent_block_padding ,
110
- # Flow Head
111
- flow_head_hidden_size ,
112
- # Mask predictor
113
- use_mask_predictor ,
114
- ** kwargs ,
115
- ):
116
- model = _raft (
117
- # Feature encoder
118
- feature_encoder_layers = feature_encoder_layers ,
119
- feature_encoder_block = feature_encoder_block ,
120
- feature_encoder_norm_layer = feature_encoder_norm_layer ,
121
- # Context encoder
122
- context_encoder_layers = context_encoder_layers ,
123
- context_encoder_block = context_encoder_block ,
124
- context_encoder_norm_layer = context_encoder_norm_layer ,
125
- # Correlation block
126
- corr_block_num_levels = corr_block_num_levels ,
127
- corr_block_radius = corr_block_radius ,
128
- # Motion encoder
129
- motion_encoder_corr_layers = motion_encoder_corr_layers ,
130
- motion_encoder_flow_layers = motion_encoder_flow_layers ,
131
- motion_encoder_out_channels = motion_encoder_out_channels ,
132
- # Recurrent block
133
- recurrent_block_hidden_state_size = recurrent_block_hidden_state_size ,
134
- recurrent_block_kernel_size = recurrent_block_kernel_size ,
135
- recurrent_block_padding = recurrent_block_padding ,
136
- # Flow head
137
- flow_head_hidden_size = flow_head_hidden_size ,
138
- # Mask predictor
139
- use_mask_predictor = use_mask_predictor ,
140
- ** kwargs ,
141
- )
142
-
143
- if weights is not None :
144
- model .load_state_dict (weights .get_state_dict (progress = progress ))
145
-
146
- return model
147
-
148
-
149
93
@handle_legacy_interface (weights = ("pretrained" , Raft_Large_Weights .C_T_V2 ))
150
94
def raft_large (* , weights : Optional [Raft_Large_Weights ] = None , progress = True , ** kwargs ):
151
95
"""RAFT model from
@@ -163,9 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
163
107
164
108
weights = Raft_Large_Weights .verify (weights )
165
109
166
- return _raft_builder (
167
- weights = weights ,
168
- progress = progress ,
110
+ model = _raft (
169
111
# Feature encoder
170
112
feature_encoder_layers = (64 , 64 , 96 , 128 , 256 ),
171
113
feature_encoder_block = ResidualBlock ,
@@ -192,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
192
134
** kwargs ,
193
135
)
194
136
137
+ if weights is not None :
138
+ model .load_state_dict (weights .get_state_dict (progress = progress ))
139
+
140
+ return model
141
+
195
142
196
143
@handle_legacy_interface (weights = ("pretrained" , None ))
197
144
def raft_small (* , weights : Optional [Raft_Small_Weights ] = None , progress = True , ** kwargs ):
@@ -211,9 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
211
158
212
159
weights = Raft_Small_Weights .verify (weights )
213
160
214
- return _raft_builder (
215
- weights = weights ,
216
- progress = progress ,
161
+ model = _raft (
217
162
# Feature encoder
218
163
feature_encoder_layers = (32 , 32 , 64 , 96 , 128 ),
219
164
feature_encoder_block = BottleneckBlock ,
@@ -239,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
239
184
use_mask_predictor = False ,
240
185
** kwargs ,
241
186
)
187
+
188
+ if weights is not None :
189
+ model .load_state_dict (weights .get_state_dict (progress = progress ))
190
+ return model
0 commit comments