Skip to content

Commit 0c6aabb

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] New schema for metrics in weights meta-data (#6047)
Summary: * Classif models * Detection * Segmentation * quantization * Video * optical flow * tests * Fix docs * Fix Video dataset * Consistency for RAFT dataset names * use ImageNet-1K * Use COCO-val2017-VOC-labels for segmentation * formatting Reviewed By: NicolasHug Differential Revision: D36760921 fbshipit-source-id: 6efc259c6b8b922510a35be0ed1ca071e2a53a6f
1 parent 1301c87 commit 0c6aabb

36 files changed

+792
-483
lines changed

docs/source/conf.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -334,25 +334,22 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
334334
lines.append("")
335335

336336
for field in obj:
337-
lines += [f"**{str(field)}**:", ""]
338-
339-
table = []
340-
341-
# the `meta` dict contains another embedded `metrics` dict. To
342-
# simplify the table generation below, we create the
343-
# `meta_with_metrics` dict, where the metrics dict has been "flattened"
344337
meta = copy(field.meta)
345-
metrics = meta.pop("metrics", {})
346-
meta_with_metrics = dict(meta, **metrics)
347338

348-
lines += [meta_with_metrics.pop("_docs")]
339+
lines += [f"**{str(field)}**:", ""]
340+
lines += [meta.pop("_docs")]
349341

350342
if field == obj.DEFAULT:
351343
lines += [f"Also available as ``{obj.__name__}.DEFAULT``."]
352-
353344
lines += [""]
354345

355-
for k, v in meta_with_metrics.items():
346+
table = []
347+
metrics = meta.pop("_metrics")
348+
for dataset, dataset_metrics in metrics.items():
349+
for metric_name, metric_value in dataset_metrics.items():
350+
table.append((f"{metric_name} (on {dataset})", str(metric_value)))
351+
352+
for k, v in meta.items():
356353
if k in {"recipe", "license"}:
357354
v = f"`link <{v}>`__"
358355
elif k == "min_size":
@@ -374,7 +371,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
374371
lines.append("")
375372

376373

377-
def generate_weights_table(module, table_name, metrics, include_patterns=None, exclude_patterns=None):
374+
def generate_weights_table(module, table_name, metrics, dataset, include_patterns=None, exclude_patterns=None):
378375
weights_endswith = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
379376
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith(weights_endswith)]
380377
weights = [w for weight_enum in weight_enums for w in weight_enum]
@@ -391,7 +388,7 @@ def generate_weights_table(module, table_name, metrics, include_patterns=None, e
391388
content = [
392389
(
393390
f":class:`{w} <{type(w).__name__}>`",
394-
*(w.meta["metrics"][metric] for metric in metrics_keys),
391+
*(w.meta["_metrics"][dataset][metric] for metric in metrics_keys),
395392
f"{w.meta['num_params']/1e6:.1f}M",
396393
f"`link <{w.meta['recipe']}>`__",
397394
)
@@ -408,29 +405,45 @@ def generate_weights_table(module, table_name, metrics, include_patterns=None, e
408405
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")
409406

410407

411-
generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
412408
generate_weights_table(
413-
module=M.quantization, table_name="classification_quant", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")]
409+
module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")], dataset="ImageNet-1K"
410+
)
411+
generate_weights_table(
412+
module=M.quantization,
413+
table_name="classification_quant",
414+
metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")],
415+
dataset="ImageNet-1K",
414416
)
415417
generate_weights_table(
416-
module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")], exclude_patterns=["Mask", "Keypoint"]
418+
module=M.detection,
419+
table_name="detection",
420+
metrics=[("box_map", "Box MAP")],
421+
exclude_patterns=["Mask", "Keypoint"],
422+
dataset="COCO-val2017",
417423
)
418424
generate_weights_table(
419425
module=M.detection,
420426
table_name="instance_segmentation",
421427
metrics=[("box_map", "Box MAP"), ("mask_map", "Mask MAP")],
428+
dataset="COCO-val2017",
422429
include_patterns=["Mask"],
423430
)
424431
generate_weights_table(
425432
module=M.detection,
426433
table_name="detection_keypoint",
427434
metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")],
435+
dataset="COCO-val2017",
428436
include_patterns=["Keypoint"],
429437
)
430438
generate_weights_table(
431-
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
439+
module=M.segmentation,
440+
table_name="segmentation",
441+
metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")],
442+
dataset="COCO-val2017-VOC-labels",
443+
)
444+
generate_weights_table(
445+
module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")], dataset="Kinetics-400"
432446
)
433-
generate_weights_table(module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
434447

435448

436449
def setup(app):

test/test_extended_models.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,27 +85,31 @@ def test_schema_meta_validation(model_fn):
8585
"categories",
8686
"keypoint_names",
8787
"license",
88-
"metrics",
88+
"_metrics",
8989
"min_size",
9090
"num_params",
9191
"recipe",
9292
"unquantized",
9393
"_docs",
9494
}
9595
# mandatory fields for each computer vision task
96-
classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")}
96+
classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
9797
defaults = {
98-
"all": {"metrics", "min_size", "num_params", "recipe", "_docs"},
98+
"all": {"_metrics", "min_size", "num_params", "recipe", "_docs"},
9999
"models": classification_fields,
100-
"detection": {"categories", ("metrics", "box_map")},
100+
"detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
101101
"quantization": classification_fields | {"backend", "unquantized"},
102-
"segmentation": {"categories", ("metrics", "miou"), ("metrics", "pixel_acc")},
103-
"video": classification_fields,
102+
"segmentation": {
103+
"categories",
104+
("_metrics", "COCO-val2017-VOC-labels", "miou"),
105+
("_metrics", "COCO-val2017-VOC-labels", "pixel_acc"),
106+
},
107+
"video": {"categories", ("_metrics", "Kinetics-400", "acc@1"), ("_metrics", "Kinetics-400", "acc@5")},
104108
"optical_flow": set(),
105109
}
106110
model_name = model_fn.__name__
107111
module_name = model_fn.__module__.split(".")[-2]
108-
fields = defaults["all"] | defaults[module_name]
112+
expected_fields = defaults["all"] | defaults[module_name]
109113

110114
weights_enum = _get_model_weights(model_fn)
111115
if len(weights_enum) == 0:
@@ -115,7 +119,13 @@ def test_schema_meta_validation(model_fn):
115119
incorrect_params = []
116120
bad_names = []
117121
for w in weights_enum:
118-
missing_fields = fields - (set(w.meta.keys()) | set(("metrics", x) for x in w.meta.get("metrics", {}).keys()))
122+
actual_fields = set(w.meta.keys())
123+
actual_fields |= set(
124+
("_metrics", dataset, metric_key)
125+
for dataset in w.meta.get("_metrics", {}).keys()
126+
for metric_key in w.meta.get("_metrics", {}).get(dataset, {}).keys()
127+
)
128+
missing_fields = expected_fields - actual_fields
119129
unsupported_fields = set(w.meta.keys()) - permitted_fields
120130
if missing_fields or unsupported_fields:
121131
problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}

torchvision/models/alexnet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ class AlexNet_Weights(WeightsEnum):
6161
"min_size": (63, 63),
6262
"categories": _IMAGENET_CATEGORIES,
6363
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
64-
"metrics": {
65-
"acc@1": 56.522,
66-
"acc@5": 79.066,
64+
"_metrics": {
65+
"ImageNet-1K": {
66+
"acc@1": 56.522,
67+
"acc@5": 79.066,
68+
}
6769
},
6870
"_docs": """
6971
These weights reproduce closely the results of the paper using a simplified training recipe.

torchvision/models/convnext.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,11 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
222222
meta={
223223
**_COMMON_META,
224224
"num_params": 28589128,
225-
"metrics": {
226-
"acc@1": 82.520,
227-
"acc@5": 96.146,
225+
"_metrics": {
226+
"ImageNet-1K": {
227+
"acc@1": 82.520,
228+
"acc@5": 96.146,
229+
}
228230
},
229231
},
230232
)
@@ -238,9 +240,11 @@ class ConvNeXt_Small_Weights(WeightsEnum):
238240
meta={
239241
**_COMMON_META,
240242
"num_params": 50223688,
241-
"metrics": {
242-
"acc@1": 83.616,
243-
"acc@5": 96.650,
243+
"_metrics": {
244+
"ImageNet-1K": {
245+
"acc@1": 83.616,
246+
"acc@5": 96.650,
247+
}
244248
},
245249
},
246250
)
@@ -254,9 +258,11 @@ class ConvNeXt_Base_Weights(WeightsEnum):
254258
meta={
255259
**_COMMON_META,
256260
"num_params": 88591464,
257-
"metrics": {
258-
"acc@1": 84.062,
259-
"acc@5": 96.870,
261+
"_metrics": {
262+
"ImageNet-1K": {
263+
"acc@1": 84.062,
264+
"acc@5": 96.870,
265+
}
260266
},
261267
},
262268
)
@@ -270,9 +276,11 @@ class ConvNeXt_Large_Weights(WeightsEnum):
270276
meta={
271277
**_COMMON_META,
272278
"num_params": 197767336,
273-
"metrics": {
274-
"acc@1": 84.414,
275-
"acc@5": 96.976,
279+
"_metrics": {
280+
"ImageNet-1K": {
281+
"acc@1": 84.414,
282+
"acc@5": 96.976,
283+
}
276284
},
277285
},
278286
)

torchvision/models/densenet.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,11 @@ class DenseNet121_Weights(WeightsEnum):
272272
meta={
273273
**_COMMON_META,
274274
"num_params": 7978856,
275-
"metrics": {
276-
"acc@1": 74.434,
277-
"acc@5": 91.972,
275+
"_metrics": {
276+
"ImageNet-1K": {
277+
"acc@1": 74.434,
278+
"acc@5": 91.972,
279+
}
278280
},
279281
},
280282
)
@@ -288,9 +290,11 @@ class DenseNet161_Weights(WeightsEnum):
288290
meta={
289291
**_COMMON_META,
290292
"num_params": 28681000,
291-
"metrics": {
292-
"acc@1": 77.138,
293-
"acc@5": 93.560,
293+
"_metrics": {
294+
"ImageNet-1K": {
295+
"acc@1": 77.138,
296+
"acc@5": 93.560,
297+
}
294298
},
295299
},
296300
)
@@ -304,9 +308,11 @@ class DenseNet169_Weights(WeightsEnum):
304308
meta={
305309
**_COMMON_META,
306310
"num_params": 14149480,
307-
"metrics": {
308-
"acc@1": 75.600,
309-
"acc@5": 92.806,
311+
"_metrics": {
312+
"ImageNet-1K": {
313+
"acc@1": 75.600,
314+
"acc@5": 92.806,
315+
}
310316
},
311317
},
312318
)
@@ -320,9 +326,11 @@ class DenseNet201_Weights(WeightsEnum):
320326
meta={
321327
**_COMMON_META,
322328
"num_params": 20013928,
323-
"metrics": {
324-
"acc@1": 76.896,
325-
"acc@5": 93.370,
329+
"_metrics": {
330+
"ImageNet-1K": {
331+
"acc@1": 76.896,
332+
"acc@5": 93.370,
333+
}
326334
},
327335
},
328336
)

torchvision/models/detection/faster_rcnn.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,10 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
383383
**_COMMON_META,
384384
"num_params": 41755286,
385385
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
386-
"metrics": {
387-
"box_map": 37.0,
386+
"_metrics": {
387+
"COCO-val2017": {
388+
"box_map": 37.0,
389+
}
388390
},
389391
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
390392
},
@@ -400,8 +402,10 @@ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
400402
**_COMMON_META,
401403
"num_params": 43712278,
402404
"recipe": "https://github.com/pytorch/vision/pull/5763",
403-
"metrics": {
404-
"box_map": 46.7,
405+
"_metrics": {
406+
"COCO-val2017": {
407+
"box_map": 46.7,
408+
}
405409
},
406410
"_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
407411
},
@@ -417,8 +421,10 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
417421
**_COMMON_META,
418422
"num_params": 19386354,
419423
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
420-
"metrics": {
421-
"box_map": 32.8,
424+
"_metrics": {
425+
"COCO-val2017": {
426+
"box_map": 32.8,
427+
}
422428
},
423429
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
424430
},
@@ -434,8 +440,10 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
434440
**_COMMON_META,
435441
"num_params": 19386354,
436442
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
437-
"metrics": {
438-
"box_map": 22.8,
443+
"_metrics": {
444+
"COCO-val2017": {
445+
"box_map": 22.8,
446+
}
439447
},
440448
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
441449
},

torchvision/models/detection/fcos.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,10 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
658658
"categories": _COCO_CATEGORIES,
659659
"min_size": (1, 1),
660660
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
661-
"metrics": {
662-
"box_map": 39.2,
661+
"_metrics": {
662+
"COCO-val2017": {
663+
"box_map": 39.2,
664+
}
663665
},
664666
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
665667
},

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,11 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
322322
**_COMMON_META,
323323
"num_params": 59137258,
324324
"recipe": "https://github.com/pytorch/vision/issues/1606",
325-
"metrics": {
326-
"box_map": 50.6,
327-
"kp_map": 61.1,
325+
"_metrics": {
326+
"COCO-val2017": {
327+
"box_map": 50.6,
328+
"kp_map": 61.1,
329+
}
328330
},
329331
"_docs": """
330332
These weights were produced by following a similar training recipe as on the paper but use a checkpoint
@@ -339,9 +341,11 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
339341
**_COMMON_META,
340342
"num_params": 59137258,
341343
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
342-
"metrics": {
343-
"box_map": 54.6,
344-
"kp_map": 65.0,
344+
"_metrics": {
345+
"COCO-val2017": {
346+
"box_map": 54.6,
347+
"kp_map": 65.0,
348+
}
345349
},
346350
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
347351
},

0 commit comments

Comments
 (0)