@@ -334,25 +334,22 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
334
334
lines .append ("" )
335
335
336
336
for field in obj :
337
- lines += [f"**{ str (field )} **:" , "" ]
338
-
339
- table = []
340
-
341
- # the `meta` dict contains another embedded `metrics` dict. To
342
- # simplify the table generation below, we create the
343
- # `meta_with_metrics` dict, where the metrics dict has been "flattened"
344
337
meta = copy (field .meta )
345
- metrics = meta .pop ("metrics" , {})
346
- meta_with_metrics = dict (meta , ** metrics )
347
338
348
- lines += [meta_with_metrics .pop ("_docs" )]
339
+ lines += [f"**{ str (field )} **:" , "" ]
340
+ lines += [meta .pop ("_docs" )]
349
341
350
342
if field == obj .DEFAULT :
351
343
lines += [f"Also available as ``{ obj .__name__ } .DEFAULT``." ]
352
-
353
344
lines += ["" ]
354
345
355
- for k , v in meta_with_metrics .items ():
346
+ table = []
347
+ metrics = meta .pop ("_metrics" )
348
+ for dataset , dataset_metrics in metrics .items ():
349
+ for metric_name , metric_value in dataset_metrics .items ():
350
+ table .append ((f"{ metric_name } (on { dataset } )" , str (metric_value )))
351
+
352
+ for k , v in meta .items ():
356
353
if k in {"recipe" , "license" }:
357
354
v = f"`link <{ v } >`__"
358
355
elif k == "min_size" :
@@ -374,7 +371,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
374
371
lines .append ("" )
375
372
376
373
377
- def generate_weights_table (module , table_name , metrics , include_patterns = None , exclude_patterns = None ):
374
+ def generate_weights_table (module , table_name , metrics , dataset , include_patterns = None , exclude_patterns = None ):
378
375
weights_endswith = "_QuantizedWeights" if module .__name__ .split ("." )[- 1 ] == "quantization" else "_Weights"
379
376
weight_enums = [getattr (module , name ) for name in dir (module ) if name .endswith (weights_endswith )]
380
377
weights = [w for weight_enum in weight_enums for w in weight_enum ]
@@ -391,7 +388,7 @@ def generate_weights_table(module, table_name, metrics, include_patterns=None, e
391
388
content = [
392
389
(
393
390
f":class:`{ w } <{ type (w ).__name__ } >`" ,
394
- * (w .meta ["metrics" ][metric ] for metric in metrics_keys ),
391
+ * (w .meta ["_metrics" ][ dataset ][metric ] for metric in metrics_keys ),
395
392
f"{ w .meta ['num_params' ]/ 1e6 :.1f} M" ,
396
393
f"`link <{ w .meta ['recipe' ]} >`__" ,
397
394
)
@@ -408,29 +405,45 @@ def generate_weights_table(module, table_name, metrics, include_patterns=None, e
408
405
table_file .write (f"{ textwrap .indent (table , ' ' * 4 )} \n \n " )
409
406
410
407
411
- generate_weights_table (module = M , table_name = "classification" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )])
412
408
generate_weights_table (
413
- module = M .quantization , table_name = "classification_quant" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )]
409
+ module = M , table_name = "classification" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )], dataset = "ImageNet-1K"
410
+ )
411
+ generate_weights_table (
412
+ module = M .quantization ,
413
+ table_name = "classification_quant" ,
414
+ metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )],
415
+ dataset = "ImageNet-1K" ,
414
416
)
415
417
generate_weights_table (
416
- module = M .detection , table_name = "detection" , metrics = [("box_map" , "Box MAP" )], exclude_patterns = ["Mask" , "Keypoint" ]
418
+ module = M .detection ,
419
+ table_name = "detection" ,
420
+ metrics = [("box_map" , "Box MAP" )],
421
+ exclude_patterns = ["Mask" , "Keypoint" ],
422
+ dataset = "COCO-val2017" ,
417
423
)
418
424
generate_weights_table (
419
425
module = M .detection ,
420
426
table_name = "instance_segmentation" ,
421
427
metrics = [("box_map" , "Box MAP" ), ("mask_map" , "Mask MAP" )],
428
+ dataset = "COCO-val2017" ,
422
429
include_patterns = ["Mask" ],
423
430
)
424
431
generate_weights_table (
425
432
module = M .detection ,
426
433
table_name = "detection_keypoint" ,
427
434
metrics = [("box_map" , "Box MAP" ), ("kp_map" , "Keypoint MAP" )],
435
+ dataset = "COCO-val2017" ,
428
436
include_patterns = ["Keypoint" ],
429
437
)
430
438
generate_weights_table (
431
- module = M .segmentation , table_name = "segmentation" , metrics = [("miou" , "Mean IoU" ), ("pixel_acc" , "pixelwise Acc" )]
439
+ module = M .segmentation ,
440
+ table_name = "segmentation" ,
441
+ metrics = [("miou" , "Mean IoU" ), ("pixel_acc" , "pixelwise Acc" )],
442
+ dataset = "COCO-val2017-VOC-labels" ,
443
+ )
444
+ generate_weights_table (
445
+ module = M .video , table_name = "video" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )], dataset = "Kinetics-400"
432
446
)
433
- generate_weights_table (module = M .video , table_name = "video" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )])
434
447
435
448
436
449
def setup (app ):
0 commit comments