Skip to content

Commit cc2aa8c

Browse files
authored
Merge branch 'main' into start_places365
2 parents e7ec167 + 97eddc5 commit cc2aa8c

File tree

2 files changed

+15
-10
lines changed
  • torchvision

2 files changed

+15
-10
lines changed

torchvision/models/optical_flow/raft.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
_MODELS_URLS = {
24-
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
24+
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
2525
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
2626
}
2727

@@ -587,10 +587,16 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
587587
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
588588
589589
Args:
590-
pretrained (bool): Whether to use pretrained weights.
591-
progress (bool): If True, displays a progress bar of the download to stderr
592-
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
593-
to override any default.
590+
pretrained (bool): Whether to use weights that have been pre-trained on
591+
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`
592+
with two fine-tuning steps:
593+
594+
- one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D`
595+
- one on :class:`~torchvsion.datasets.KittiFlow`.
596+
597+
This corresponds to the ``C+T+S/K`` strategy in the paper.
598+
599+
progress (bool): If True, displays a progress bar of the download to stderr.
594600
595601
Returns:
596602
nn.Module: The model.
@@ -632,10 +638,9 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
632638
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
633639
634640
Args:
635-
pretrained (bool): Whether to use pretrained weights.
641+
pretrained (bool): Whether to use weights that have been pre-trained on
642+
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`.
636643
progress (bool): If True, displays a progress bar of the download to stderr
637-
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
638-
to override any default.
639644
640645
Returns:
641646
nn.Module: The model.

torchvision/prototype/models/optical_flow/raft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class Raft_Large_Weights(WeightsEnum):
115115
},
116116
)
117117

118-
DEFAULT = C_T_V2
118+
DEFAULT = C_T_SKHT_V2
119119

120120

121121
class Raft_Small_Weights(WeightsEnum):
@@ -151,7 +151,7 @@ class Raft_Small_Weights(WeightsEnum):
151151
DEFAULT = C_T_V2
152152

153153

154-
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
154+
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
155155
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
156156
"""RAFT model from
157157
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

0 commit comments

Comments
 (0)