@@ -79,20 +79,32 @@ def test_naming_conventions(model_fn):
79
79
)
80
80
@run_if_test_with_extended
81
81
def test_schema_meta_validation (model_fn ):
82
- # TODO: add list of permitted fields
83
- classification_fields = ["categories" , "acc@1" , "acc@5" ]
82
+ # list of all possible supported high-level fields for weights meta-data
83
+ permitted_fields = {
84
+ "backend" ,
85
+ "categories" ,
86
+ "keypoint_names" ,
87
+ "license" ,
88
+ "metrics" ,
89
+ "min_size" ,
90
+ "num_params" ,
91
+ "recipe" ,
92
+ "unquantized" ,
93
+ }
94
+ # mandatory fields for each computer vision task
95
+ classification_fields = {"categories" , ("metrics" , "acc@1" ), ("metrics" , "acc@5" )}
84
96
defaults = {
85
- "all" : [ "recipe " , "num_params" , "min_size" ] ,
97
+ "all" : { "metrics " , "min_size" , " num_params" , "recipe" } ,
86
98
"models" : classification_fields ,
87
- "detection" : [ "categories" , "map" ] ,
88
- "quantization" : classification_fields + [ "backend" , "unquantized" ] ,
89
- "segmentation" : [ "categories" , "mIoU " , "acc" ] ,
99
+ "detection" : { "categories" , ( "metrics" , "box_map" )} ,
100
+ "quantization" : classification_fields | { "backend" , "unquantized" } ,
101
+ "segmentation" : { "categories" , ( "metrics " , "miou" ), ( "metrics" , "pixel_acc" )} ,
90
102
"video" : classification_fields ,
91
- "optical_flow" : [] ,
103
+ "optical_flow" : set () ,
92
104
}
93
105
model_name = model_fn .__name__
94
106
module_name = model_fn .__module__ .split ("." )[- 2 ]
95
- fields = set ( defaults ["all" ] + defaults [module_name ])
107
+ fields = defaults ["all" ] | defaults [module_name ]
96
108
97
109
weights_enum = _get_model_weights (model_fn )
98
110
if len (weights_enum ) == 0 :
@@ -102,9 +114,10 @@ def test_schema_meta_validation(model_fn):
102
114
incorrect_params = []
103
115
bad_names = []
104
116
for w in weights_enum :
105
- missing_fields = fields - set (w .meta .keys ())
106
- if missing_fields :
107
- problematic_weights [w ] = missing_fields
117
+ missing_fields = fields - (set (w .meta .keys ()) | set (("metrics" , x ) for x in w .meta .get ("metrics" , {}).keys ()))
118
+ unsupported_fields = set (w .meta .keys ()) - permitted_fields
119
+ if missing_fields or unsupported_fields :
120
+ problematic_weights [w ] = {"missing" : missing_fields , "unsupported" : unsupported_fields }
108
121
if w == weights_enum .DEFAULT :
109
122
if module_name == "quantization" :
110
123
# parameters() count doesn't work well with quantization, so we check against the non-quantized
0 commit comments