diff --git a/docs/source/conf.py b/docs/source/conf.py index db7b2ef14a2..e4db34c3889 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -345,9 +345,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): metrics = meta.pop("metrics", {}) meta_with_metrics = dict(meta, **metrics) - custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs - if custom_docs is not None: - lines += [custom_docs] + lines += [meta_with_metrics.pop("_docs")] if field == obj.DEFAULT: lines += [f"Also available as ``{obj.__name__}.DEFAULT``."] diff --git a/test/test_extended_models.py b/test/test_extended_models.py index e3f79e28af4..7acdd1c0ca5 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -95,8 +95,8 @@ def test_schema_meta_validation(model_fn): # mandatory fields for each computer vision task classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")} defaults = { - "all": {"metrics", "min_size", "num_params", "recipe"}, - "models": classification_fields | {"_docs"}, + "all": {"metrics", "min_size", "num_params", "recipe", "_docs"}, + "models": classification_fields, "detection": {"categories", ("metrics", "box_map")}, "quantization": classification_fields | {"backend", "unquantized"}, "segmentation": {"categories", ("metrics", "miou"), ("metrics", "pixel_acc")}, diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index afe66ead646..f768089666a 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -386,6 +386,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): "metrics": { "box_map": 37.0, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 @@ -402,6 +403,7 @@ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): "metrics": { "box_map": 46.7, }, + "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""", }, ) DEFAULT = COCO_V1 @@ -418,6 +420,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): "metrics": { "box_map": 32.8, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 @@ -434,6 +437,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): "metrics": { "box_map": 22.8, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 @@ -454,7 +458,7 @@ def fasterrcnn_resnet50_fpn( ) -> FasterRCNN: """ Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object - Detection with Region Proposal Networks `__ + Detection with Region Proposal Networks `__ paper. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 3544ea3117e..4780a93f731 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -661,6 +661,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): "metrics": { "box_map": 39.2, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 410c53d60b7..4932e21b474 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -326,6 +326,10 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 50.6, "kp_map": 61.1, }, + "_docs": """ + These weights were produced by following a similar training recipe as on the paper but use a checkpoint + from an early epoch. + """, }, ) COCO_V1 = Weights( @@ -339,6 +343,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 54.6, "kp_map": 65.0, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 6b1ba04a195..71450e287e4 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -368,6 +368,7 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 37.9, "mask_map": 34.6, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 @@ -385,6 +386,7 @@ class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): "box_map": 47.4, "mask_map": 41.8, }, + "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""", }, ) DEFAULT = COCO_V1 diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 0cb4979d332..0f44a482cde 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -690,6 +690,7 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): "metrics": { "box_map": 36.4, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 @@ -706,6 +707,7 @@ class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum): "metrics": { "box_map": 41.5, }, + "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""", }, ) DEFAULT = COCO_V1 diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 7e5625329be..f0bd01a5879 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -37,6 +37,7 @@ class SSD300_VGG16_Weights(WeightsEnum): "metrics": { "box_map": 25.1, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index f94758cb166..7fdcb70b673 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -196,6 +196,7 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): "metrics": { "box_map": 21.3, }, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1 diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 869477f0d81..6c194d6e6ab 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -518,7 +518,7 @@ def forward(self, image1, image2, num_flow_updates: int = 12): class Raft_Large_Weights(WeightsEnum): C_T_V1 = Weights( - # Chairs + Things, ported from original paper repo (raft-things.pth) + # Weights ported from https://github.com/princeton-vl/RAFT url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", transforms=OpticalFlow, meta={ @@ -531,11 +531,11 @@ class Raft_Large_Weights(WeightsEnum): "kitti_train_per_image_epe": 5.0172, "kitti_train_fl_all": 17.4506, }, + "_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""", }, ) C_T_V2 = Weights( - # Chairs + Things url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", transforms=OpticalFlow, meta={ @@ -548,11 +548,12 @@ class Raft_Large_Weights(WeightsEnum): "kitti_train_per_image_epe": 4.5118, "kitti_train_fl_all": 16.0679, }, + "_docs": """These weights were trained from scratch on Chairs + Things.""", }, ) C_T_SKHT_V1 = Weights( - # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) + # Weights ported from https://github.com/princeton-vl/RAFT url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", transforms=OpticalFlow, meta={ @@ -563,13 +564,14 @@ class Raft_Large_Weights(WeightsEnum): "sintel_test_cleanpass_epe": 1.94, "sintel_test_finalpass_epe": 3.18, }, + "_docs": """ + These weights were ported from the original paper. They are trained on Chairs + Things and fine-tuned on + Sintel (C+T+S+K+H). + """, }, ) C_T_SKHT_V2 = Weights( - # Chairs + Things + Sintel fine-tuning, i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", transforms=OpticalFlow, meta={ @@ -580,11 +582,14 @@ class Raft_Large_Weights(WeightsEnum): "sintel_test_cleanpass_epe": 1.819, "sintel_test_finalpass_epe": 3.067, }, + "_docs": """ + These weights were trained from scratch on Chairs + Things and fine-tuned on Sintel (C+T+S+K+H). + """, }, ) C_T_SKHT_K_V1 = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) + # Weights ported from https://github.com/princeton-vl/RAFT url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", transforms=OpticalFlow, meta={ @@ -594,14 +599,14 @@ class Raft_Large_Weights(WeightsEnum): "metrics": { "kitti_test_fl_all": 5.10, }, + "_docs": """ + These weights were ported from the original paper. They are trained on Chairs + Things, fine-tuned on + Sintel and then on Kitti. + """, }, ) C_T_SKHT_K_V2 = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti - # Same as CT_SKHT with extra fine-tuning on Kitti - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", transforms=OpticalFlow, meta={ @@ -611,6 +616,9 @@ class Raft_Large_Weights(WeightsEnum): "metrics": { "kitti_test_fl_all": 5.19, }, + "_docs": """ + These weights were trained from scratch on Chairs + Things, fine-tuned on Sintel and then on Kitti. + """, }, ) @@ -619,7 +627,7 @@ class Raft_Large_Weights(WeightsEnum): class Raft_Small_Weights(WeightsEnum): C_T_V1 = Weights( - # Chairs + Things, ported from original paper repo (raft-small.pth) + # Weights ported from https://github.com/princeton-vl/RAFT url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", transforms=OpticalFlow, meta={ @@ -632,10 +640,10 @@ class Raft_Small_Weights(WeightsEnum): "kitti_train_per_image_epe": 7.6557, "kitti_train_fl_all": 25.2801, }, + "_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""", }, ) C_T_V2 = Weights( - # Chairs + Things url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", transforms=OpticalFlow, meta={ @@ -648,6 +656,7 @@ class Raft_Small_Weights(WeightsEnum): "kitti_train_per_image_epe": 7.5978, "kitti_train_fl_all": 25.2369, }, + "_docs": """These weights were trained from scratch on Chairs + Things.""", }, ) diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 5df391044ff..c95e5ec0a9b 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -121,6 +121,10 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): "acc@1": 69.826, "acc@5": 89.404, }, + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, }, ) DEFAULT = IMAGENET1K_FBGEMM_V1 diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 1fbfb00fe75..e535c32e3d8 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -187,6 +187,10 @@ class Inception_V3_QuantizedWeights(WeightsEnum): "acc@1": 77.176, "acc@5": 93.354, }, + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, }, ) DEFAULT = IMAGENET1K_FBGEMM_V1 diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 0d2e35c8566..5169609aeba 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -79,6 +79,10 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): "acc@1": 71.658, "acc@5": 90.150, }, + "_docs": """ + These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized + weights listed below. + """, }, ) DEFAULT = IMAGENET1K_QNNPACK_V1 diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 804e0c77bc9..1f3edb05f91 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -173,6 +173,10 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): "acc@1": 73.004, "acc@5": 90.858, }, + "_docs": """ + These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized + weights listed below. + """, }, ) DEFAULT = IMAGENET1K_QNNPACK_V1 diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index dc3ee4c35c5..70dd5d92f0e 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -154,6 +154,10 @@ def _resnet( "categories": _IMAGENET_CATEGORIES, "backend": "fbgemm", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, } diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 5672d850cf2..523f8739b2e 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -118,6 +118,10 @@ def _shufflenetv2( "categories": _IMAGENET_CATEGORIES, "backend": "fbgemm", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, } diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 57ccc377c2d..fc7cdb36cb0 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -131,6 +131,10 @@ def _deeplabv3_resnet( _COMMON_META = { "categories": _VOC_CATEGORIES, "min_size": (1, 1), + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC + dataset. + """, } diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index fe8c82d8d4d..27c931bfc18 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -50,6 +50,10 @@ def __init__(self, in_channels: int, channels: int) -> None: _COMMON_META = { "categories": _VOC_CATEGORIES, "min_size": (1, 1), + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC + dataset. + """, } diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index bcbba7f14fe..ed36d881ee9 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -106,6 +106,10 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): "miou": 57.9, "pixel_acc": 91.2, }, + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the + Pascal VOC dataset. + """, }, ) DEFAULT = COCO_WITH_VOC_LABELS_V1 diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 320df6576ac..ec45092f532 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -312,6 +312,7 @@ def _video_resnet( "min_size": (1, 1), "categories": _KINETICS400_CATEGORIES, "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", + "_docs": """These weights reproduce closely the accuracy of the paper for 16-frame clip inputs.""", }