Skip to content

Commit 6274080

Browse files
authored
Document Keypoint RCNN separately (#5933)
* Document Keypoint RCNN separately * Move Keypoint detection into its own section * ufmt
1 parent a5c86ff commit 6274080

File tree

4 files changed

+73
-7
lines changed

4 files changed

+73
-7
lines changed

docs/source/conf.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,15 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
348348
lines.append("")
349349

350350

351-
def generate_weights_table(module, table_name, metrics):
351+
def generate_weights_table(module, table_name, metrics, include_pattern=None, exclude_pattern=None):
352352
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")]
353353
weights = [w for weight_enum in weight_enums for w in weight_enum]
354354

355+
if include_pattern is not None:
356+
weights = [w for w in weights if include_pattern in str(w)]
357+
if exclude_pattern is not None:
358+
weights = [w for w in weights if exclude_pattern not in str(w)]
359+
355360
metrics_keys, metrics_names = zip(*metrics)
356361
column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"]
357362
column_names = [f"**{name}**" for name in column_names] # Add bold
@@ -377,7 +382,15 @@ def generate_weights_table(module, table_name, metrics):
377382

378383

379384
generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
380-
generate_weights_table(module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")])
385+
generate_weights_table(
386+
module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")], exclude_pattern="Keypoint"
387+
)
388+
generate_weights_table(
389+
module=M.detection,
390+
table_name="detection_keypoint",
391+
metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")],
392+
include_pattern="Keypoint",
393+
)
381394
generate_weights_table(
382395
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
383396
)

docs/source/models/keypoint_rcnn.rst

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
Keypoint R-CNN
2+
==============
3+
4+
.. currentmodule:: torchvision.models.detection
5+
6+
The Keypoint R-CNN model is based on the `Mask R-CNN
7+
<https://arxiv.org/abs/1703.06870>`__ paper.
8+
9+
10+
Model builders
11+
--------------
12+
13+
The following model builders can be used to instantiate a Keypoint R-CNN model,
14+
with or without pre-trained weights. All the model builders internally rely on
15+
the ``torchvision.models.detection.KeypointRCNN`` base class. Please refer to the `source
16+
code
17+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/keypoint_rcnn.py>`__
18+
for more details about this class.
19+
20+
.. autosummary::
21+
:toctree: generated/
22+
:template: function.rst
23+
24+
keypointrcnn_resnet50_fpn

docs/source/models_new.rst

+23-2
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ All models are evaluated on COCO val2017:
8989

9090

9191

92-
Object Detection, Instance Segmentation and Person Keypoint Detection
93-
=====================================================================
92+
Object Detection
93+
================
9494

9595
.. currentmodule:: torchvision.models.detection
9696

@@ -114,6 +114,27 @@ Box MAPs are reported on COCO
114114
.. include:: generated/detection_table.rst
115115

116116

117+
Keypoint detection
118+
==================
119+
120+
.. currentmodule:: torchvision.models.detection
121+
122+
The following keypoint detection models are available, with or without
123+
pre-trained weights:
124+
125+
.. toctree::
126+
:maxdepth: 1
127+
128+
models/keypoint_rcnn
129+
130+
Table of all available Keypoint detection weights
131+
-------------------------------------------------
132+
133+
Box and Keypoint MAPs are reported on COCO:
134+
135+
.. include:: generated/detection_keypoint_table.rst
136+
137+
117138
Video Classification
118139
====================
119140

torchvision/models/detection/keypoint_rcnn.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def keypointrcnn_resnet50_fpn(
366366
"""
367367
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
368368
369-
Reference: `"Mask R-CNN" <https://arxiv.org/abs/1703.06870>`_.
369+
Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
370370
371371
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
372372
image, and should be in ``0-1`` range. Different images can have different sizes.
@@ -410,14 +410,22 @@ def keypointrcnn_resnet50_fpn(
410410
>>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
411411
412412
Args:
413-
weights (KeypointRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model
413+
weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
414+
pretrained weights to use. See
415+
:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
416+
below for more details, and possible values. By default, no
417+
pre-trained weights are used.
414418
progress (bool): If True, displays a progress bar of the download to stderr
415419
num_classes (int, optional): number of output classes of the model (including the background)
416420
num_keypoints (int, optional): number of keypoints
417-
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
421+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
422+
pretrained weights for the backbone.
418423
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
419424
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
420425
passed (the default) this value is set to 3.
426+
427+
.. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
428+
:members:
421429
"""
422430
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
423431
weights_backbone = ResNet50_Weights.verify(weights_backbone)

0 commit comments

Comments
 (0)