diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 01e56c7a108..65c85922e2a 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -349,17 +349,22 @@ def __init__(self, in_channels, dim_reduced, num_classes): # nn.init.constant_(param, 0) +_COMMON_META = { + "task": "image_object_detection", + "architecture": "MaskRCNN", + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", transforms=ObjectDetection, meta={ - "task": "image_object_detection", - "architecture": "MaskRCNN", + **_COMMON_META, "publication_year": 2017, "num_params": 44401393, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", "map": 37.9, "map_mask": 34.6, @@ -369,7 +374,19 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): - pass + COCO_V1 = Weights( + url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "publication_year": 2021, + "num_params": 46359409, + "recipe": "https://github.com/pytorch/vision/pull/5773", + "map": 47.4, + "map_mask": 41.8, + }, + ) + DEFAULT = COCO_V1 @handle_legacy_interface(