-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add pretrained weights on Chairs and Things for raft_large #5060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dce22b3
228a17c
d244401
0406c83
f186973
57aff36
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Optical flow reference training scripts | ||
|
||
This folder contains reference training scripts for optical flow. | ||
They serve as a log of how to train specific models, so as to provide baseline | ||
training and evaluation scripts to quickly bootstrap research. | ||
|
||
|
||
### RAFT Large | ||
|
||
The RAFT large model was trained on Flying Chairs and then on Flying Things. | ||
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The | ||
rest of the hyper-parameters are exactly the same as the original RAFT training | ||
recipe from https://github.com/princeton-vl/RAFT. | ||
|
||
``` | ||
torchrun --nproc_per_node 8 --nnodes 1 train.py \ | ||
--dataset-root $dataset_root \ | ||
--name $name_chairs \ | ||
--model raft_large \ | ||
--train-dataset chairs \ | ||
--batch-size 2 \ | ||
--lr 0.0004 \ | ||
--weight-decay 0.0001 \ | ||
--num-steps 100000 \ | ||
--output-dir $chairs_dir | ||
``` | ||
|
||
``` | ||
torchrun --nproc_per_node 8 --nnodes 1 train.py \ | ||
--dataset-root $dataset_root \ | ||
--name $name_things \ | ||
--model raft_large \ | ||
--train-dataset things \ | ||
--batch-size 2 \ | ||
--lr 0.000125 \ | ||
--weight-decay 0.0001 \ | ||
--num-steps 100000 \ | ||
--freeze-batch-norm \ | ||
--output-dir $things_dir\ | ||
--resume $chairs_dir/$name_chairs.pth | ||
``` | ||
|
||
|
||
### Evaluation | ||
|
||
``` | ||
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained | ||
``` | ||
|
||
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the | ||
final pass of Sintel. Results may vary slightly depending on the batch size and | ||
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`: | ||
|
||
``` | ||
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248 | ||
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964 | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from torch.nn.modules.instancenorm import InstanceNorm2d | ||
from torchvision.ops import ConvNormActivation | ||
|
||
from ..._internally_replaced_utils import load_state_dict_from_url | ||
from ...utils import _log_api_usage_once | ||
from ._utils import grid_sample, make_coords_grid, upsample_flow | ||
|
||
|
@@ -19,6 +20,9 @@ | |
) | ||
|
||
|
||
_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once PR is merged I will upload this to manifold There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: all current models use |
||
|
||
|
||
class ResidualBlock(nn.Module): | ||
"""Slightly modified Residual block with extra relu and biases.""" | ||
|
||
|
@@ -474,8 +478,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12): | |
hidden_state = torch.tanh(hidden_state) | ||
context = F.relu(context) | ||
|
||
coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda() | ||
coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda() | ||
coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) | ||
coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) | ||
|
||
flow_predictions = [] | ||
for _ in range(num_flow_updates): | ||
|
@@ -496,6 +500,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12): | |
|
||
def _raft( | ||
*, | ||
arch=None, | ||
pretrained=False, | ||
progress=False, | ||
# Feature encoder | ||
feature_encoder_layers, | ||
feature_encoder_block, | ||
|
@@ -560,14 +567,19 @@ def _raft( | |
multiplier=0.25, # See comment in MaskPredictor about this | ||
) | ||
|
||
return RAFT( | ||
model = RAFT( | ||
feature_encoder=feature_encoder, | ||
context_encoder=context_encoder, | ||
corr_block=corr_block, | ||
update_block=update_block, | ||
mask_predictor=mask_predictor, | ||
**kwargs, # not really needed, all params should be consumed by now | ||
) | ||
if pretrained: | ||
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) | ||
model.load_state_dict(state_dict) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return model | ||
|
||
|
||
def raft_large(*, pretrained=False, progress=True, **kwargs): | ||
|
@@ -584,10 +596,10 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): | |
nn.Module: The model. | ||
""" | ||
|
||
if pretrained: | ||
raise ValueError("No checkpoint is available for raft_large") | ||
|
||
return _raft( | ||
arch="raft_large", | ||
pretrained=pretrained, | ||
progress=progress, | ||
# Feature encoder | ||
feature_encoder_layers=(64, 64, 96, 128, 256), | ||
feature_encoder_block=ResidualBlock, | ||
|
@@ -629,11 +641,13 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): | |
nn.Module: The model. | ||
|
||
""" | ||
|
||
if pretrained: | ||
raise ValueError("No checkpoint is available for raft_small") | ||
|
||
return _raft( | ||
arch="raft_small", | ||
pretrained=pretrained, | ||
progress=progress, | ||
# Feature encoder | ||
feature_encoder_layers=(32, 32, 64, 96, 128), | ||
feature_encoder_block=BottleneckBlock, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,11 @@ | |
from torch.nn.modules.instancenorm import InstanceNorm2d | ||
from torchvision.models.optical_flow import RAFT | ||
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock | ||
|
||
# from torchvision.prototype.transforms import RaftEval | ||
from torchvision.prototype.transforms import RaftEval | ||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
from .._api import WeightsEnum | ||
|
||
# from .._api import Weights | ||
from .._api import Weights | ||
from .._utils import handle_legacy_interface | ||
|
||
|
||
|
@@ -22,17 +21,33 @@ | |
) | ||
|
||
|
||
_COMMON_META = {"interpolation": InterpolationMode.BILINEAR} | ||
|
||
|
||
class Raft_Large_Weights(WeightsEnum): | ||
pass | ||
# C_T_V1 = Weights( | ||
# # Chairs + Things | ||
# url="", | ||
# transforms=RaftEval, | ||
# meta={ | ||
# "recipe": "", | ||
# "epe": -1234, | ||
# }, | ||
# ) | ||
C_T_V1 = Weights( | ||
# Chairs + Things, ported from original paper repo (raft-things.pth) | ||
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", | ||
transforms=RaftEval, | ||
meta={ | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**_COMMON_META, | ||
"recipe": "https://github.com/princeton-vl/RAFT", | ||
"sintel_train_cleanpass_epe": 1.4411, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to rename one of them as the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Unfortunately no, because the rest of the weights will be trained on sintel, so reporting the epe on the trainset would not be relevant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm happy to have a dict or something else to properly keep track of the other metrics though - ultimately I think it would make sense to also have 1px, 3px etc. I think we'll have a better idea of what it should look like once the rest of the weights are available There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, no strong opinions. You could dump all the metrics in an |
||
"sintel_train_finalpass_epe": 2.7894, | ||
}, | ||
) | ||
|
||
C_T_V2 = Weights( | ||
# Chairs + Things | ||
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", | ||
transforms=RaftEval, | ||
meta={ | ||
**_COMMON_META, | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", | ||
"sintel_train_cleanpass_epe": 1.3822, | ||
"sintel_train_finalpass_epe": 2.7161, | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}, | ||
) | ||
|
||
# C_T_SKHT_V1 = Weights( | ||
# # Chairs + Things + Sintel fine-tuning, i.e.: | ||
|
@@ -59,7 +74,7 @@ class Raft_Large_Weights(WeightsEnum): | |
# }, | ||
# ) | ||
|
||
# default = C_T_V1 | ||
default = C_T_V2 | ||
|
||
|
||
class Raft_Small_Weights(WeightsEnum): | ||
|
@@ -75,13 +90,13 @@ class Raft_Small_Weights(WeightsEnum): | |
# default = C_T_V1 | ||
|
||
|
||
@handle_legacy_interface(weights=("pretrained", None)) | ||
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) | ||
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): | ||
"""RAFT model from | ||
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. | ||
|
||
Args: | ||
weights(Raft_Large_weights, optinal): TODO not implemented yet | ||
weights(Raft_Large_weights, optional): pretrained weights to use. | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class | ||
to override any default. | ||
|
@@ -92,7 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * | |
|
||
weights = Raft_Large_Weights.verify(weights) | ||
|
||
return _raft( | ||
model = _raft( | ||
# Feature encoder | ||
feature_encoder_layers=(64, 64, 96, 128, 256), | ||
feature_encoder_block=ResidualBlock, | ||
|
@@ -119,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * | |
**kwargs, | ||
) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.get_state_dict(progress=progress)) | ||
|
||
return model | ||
|
||
|
||
@handle_legacy_interface(weights=("pretrained", None)) | ||
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): | ||
|
@@ -138,7 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, * | |
|
||
weights = Raft_Small_Weights.verify(weights) | ||
|
||
return _raft( | ||
model = _raft( | ||
# Feature encoder | ||
feature_encoder_layers=(32, 32, 64, 96, 128), | ||
feature_encoder_block=BottleneckBlock, | ||
|
@@ -164,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, * | |
use_mask_predictor=False, | ||
**kwargs, | ||
) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.get_state_dict(progress=progress)) | ||
return model |
Uh oh!
There was an error while loading. Please reload this page.