@@ -1112,6 +1112,150 @@ def test_set_columns_metadata_schema(self):
1112
1112
table .set_columns (** table2 .asdict ())
1113
1113
assert table .metadata_schema == ms
1114
1114
1115
+ def verify_metadata_vector (self , table , key , dtype , default_value = 9999 ):
1116
+ # this is just a hack for testing; the actual method
1117
+ # does this more elegantly
1118
+ has_default = default_value != 9999
1119
+ if has_default :
1120
+ md_vec = table .metadata_vector (
1121
+ key , default_value = default_value , dtype = dtype
1122
+ )
1123
+ else :
1124
+ md_vec = table .metadata_vector (key , dtype = dtype )
1125
+ assert isinstance (md_vec , np .ndarray )
1126
+ if dtype is not None :
1127
+ assert md_vec .dtype == np .dtype (dtype )
1128
+ assert len (md_vec ) == table .num_rows
1129
+ if not isinstance (key , list ):
1130
+ key = [key ]
1131
+ for x , row in zip (md_vec , table ):
1132
+ md = row .metadata
1133
+ for k in key :
1134
+ if k in md or not has_default :
1135
+ md = md [k ]
1136
+ else :
1137
+ md = default_value
1138
+ break
1139
+ assert np .all (np .cast [dtype ](md ) == x )
1140
+
1141
+ def test_metadata_vector_errors (self ):
1142
+ table = self .table_class ()
1143
+ ms = tskit .MetadataSchema ({"codec" : "json" })
1144
+ table .metadata_schema = ms
1145
+ table .add_row (
1146
+ ** {
1147
+ ** self .input_data_for_add_row (),
1148
+ "metadata" : None ,
1149
+ }
1150
+ )
1151
+ with pytest .raises (KeyError ):
1152
+ _ = table .metadata_vector ("x" )
1153
+ metadata_list = [
1154
+ {"a" : 4 , "u" : [1 , 2 ]},
1155
+ {},
1156
+ ]
1157
+ for md in metadata_list :
1158
+ table .add_row (
1159
+ ** {
1160
+ ** self .input_data_for_add_row (),
1161
+ "metadata" : md ,
1162
+ }
1163
+ )
1164
+ with pytest .raises (KeyError ):
1165
+ _ = table .metadata_vector ("x" )
1166
+
1167
+ table .clear ()
1168
+ metadata_list = [
1169
+ {"a" : {"c" : 5 }, "u" : [1 , 2 ]},
1170
+ {"a" : {"b" : 6 }},
1171
+ ]
1172
+ for md in metadata_list :
1173
+ table .add_row (
1174
+ ** {
1175
+ ** self .input_data_for_add_row (),
1176
+ "metadata" : md ,
1177
+ }
1178
+ )
1179
+ with pytest .raises (KeyError ):
1180
+ _ = table .metadata_vector (["a" , "x" ])
1181
+
1182
+ def test_metadata_vector_nodefault (self ):
1183
+ table = self .table_class ()
1184
+ ms = tskit .MetadataSchema ({"codec" : "json" })
1185
+ table .metadata_schema = ms
1186
+ metadata_list = [
1187
+ {"abc" : 4 , "u" : [1 , 2 ]},
1188
+ {"abc" : 10 , "u" : [3 , 4 ]},
1189
+ {"abc" : - 3 , "b" : {"c" : 1 }, "u" : [5 , 6 ]},
1190
+ {"abc" : 1 },
1191
+ ]
1192
+ for md in metadata_list :
1193
+ table .add_row (
1194
+ ** {
1195
+ ** self .input_data_for_add_row (),
1196
+ "metadata" : md ,
1197
+ }
1198
+ )
1199
+ # first the totally obvious test
1200
+ md_vec = table .metadata_vector ("abc" )
1201
+ assert np .all (np .equal (md_vec , [d ["abc" ] for d in metadata_list ]))
1202
+ # now automated ones
1203
+ for dtype in [None , "int" , "float" , "object" ]:
1204
+ self .verify_metadata_vector (
1205
+ table , key = "abc" , dtype = dtype , default_value = 9999
1206
+ )
1207
+ self .verify_metadata_vector (
1208
+ table , key = ["abc" ], dtype = dtype , default_value = 9999
1209
+ )
1210
+
1211
+ def test_metadata_vector (self ):
1212
+ table = self .table_class ()
1213
+ ms = tskit .MetadataSchema ({"codec" : "json" })
1214
+ table .metadata_schema = ms
1215
+ metadata_list = [
1216
+ {"abc" : 4 , "u" : [1 , 2 ]},
1217
+ {"abc" : 10 , "u" : [3 , 4 ]},
1218
+ {"abc" : - 3 , "b" : {"c" : 1 }, "u" : [5 , 6 ]},
1219
+ {"b" : {"c" : 3.2 }, "u" : [7 , 8 ]},
1220
+ {"b" : {"x" : 8.2 }},
1221
+ {},
1222
+ None ,
1223
+ ]
1224
+ for md in metadata_list :
1225
+ table .add_row (
1226
+ ** {
1227
+ ** self .input_data_for_add_row (),
1228
+ "metadata" : md ,
1229
+ }
1230
+ )
1231
+ # first the totally obvious test
1232
+ md_vec = table .metadata_vector ("abc" , default_value = 0 )
1233
+ assert np .all (
1234
+ np .equal (
1235
+ md_vec ,
1236
+ [
1237
+ d ["abc" ] if (d is not None and "abc" in d ) else 0
1238
+ for d in metadata_list
1239
+ ],
1240
+ )
1241
+ )
1242
+
1243
+ # now some automated ones
1244
+ for dtype in [None , "int" , "float" , "object" ]:
1245
+ self .verify_metadata_vector (table , key = "abc" , dtype = dtype , default_value = - 1 )
1246
+ self .verify_metadata_vector (
1247
+ table , key = ["abc" ], dtype = dtype , default_value = - 1
1248
+ )
1249
+ self .verify_metadata_vector (table , key = ["x" ], dtype = dtype , default_value = - 1 )
1250
+ self .verify_metadata_vector (
1251
+ table , key = ["b" , "c" ], dtype = dtype , default_value = - 1
1252
+ )
1253
+ self .verify_metadata_vector (table , key = ["b" ], dtype = "object" , default_value = - 1 )
1254
+ self .verify_metadata_vector (table , key = ["u" ], dtype = "int" , default_value = [0 , 0 ])
1255
+ # and finally we should get rectangular arrays when it makes sense
1256
+ md_vec = table .metadata_vector ("u" , default_value = [0 , 0 ])
1257
+ assert md_vec .shape == (table .num_rows , 2 )
1258
+
1115
1259
1116
1260
class AssertEqualsMixin :
1117
1261
def test_equal (self , table_5row , test_rows ):
0 commit comments