|
21 | 21 |
|
22 | 22 |
|
23 | 23 | _MODELS_URLS = {
|
24 |
| - "raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", |
| 24 | + "raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", |
25 | 25 | "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
|
26 | 26 | }
|
27 | 27 |
|
@@ -587,10 +587,16 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
|
587 | 587 | `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
|
588 | 588 |
|
589 | 589 | Args:
|
590 |
| - pretrained (bool): Whether to use pretrained weights. |
591 |
| - progress (bool): If True, displays a progress bar of the download to stderr |
592 |
| - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class |
593 |
| - to override any default. |
| 590 | + pretrained (bool): Whether to use weights that have been pre-trained on |
| 591 | + :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D` |
| 592 | + with two fine-tuning steps: |
| 593 | +
|
| 594 | + - one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D` |
| 595 | + - one on :class:`~torchvsion.datasets.KittiFlow`. |
| 596 | +
|
| 597 | + This corresponds to the ``C+T+S/K`` strategy in the paper. |
| 598 | +
|
| 599 | + progress (bool): If True, displays a progress bar of the download to stderr. |
594 | 600 |
|
595 | 601 | Returns:
|
596 | 602 | nn.Module: The model.
|
@@ -632,10 +638,9 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
|
632 | 638 | `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
|
633 | 639 |
|
634 | 640 | Args:
|
635 |
| - pretrained (bool): Whether to use pretrained weights. |
| 641 | + pretrained (bool): Whether to use weights that have been pre-trained on |
| 642 | + :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`. |
636 | 643 | progress (bool): If True, displays a progress bar of the download to stderr
|
637 |
| - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class |
638 |
| - to override any default. |
639 | 644 |
|
640 | 645 | Returns:
|
641 | 646 | nn.Module: The model.
|
|
0 commit comments