@@ -1112,6 +1112,135 @@ def test_set_columns_metadata_schema(self):
1112
1112
table .set_columns (** table2 .asdict ())
1113
1113
assert table .metadata_schema == ms
1114
1114
1115
+ def verify_metadata_vector (self , table , key , dtype , default_value = 9999 ):
1116
+ # this is just a hack for testing; the actual method
1117
+ # does this more elegantly
1118
+ has_default = default_value != 9999
1119
+ if has_default :
1120
+ md_vec = table .metadata_vector (
1121
+ key , default_value = default_value , dtype = dtype
1122
+ )
1123
+ else :
1124
+ md_vec = table .metadata_vector (key , dtype = dtype )
1125
+ assert isinstance (md_vec , np .ndarray )
1126
+ if dtype is not None :
1127
+ assert md_vec .dtype == np .dtype (dtype )
1128
+ assert len (md_vec ) == table .num_rows
1129
+ if not isinstance (key , list ):
1130
+ key = [key ]
1131
+ for x , row in zip (md_vec , table ):
1132
+ md = row .metadata
1133
+ for k in key :
1134
+ if k in md or not has_default :
1135
+ md = md [k ]
1136
+ else :
1137
+ md = default_value
1138
+ break
1139
+ assert np .all (np .cast [dtype ](md ) == x )
1140
+
1141
+ def test_metadata_vector_errors (self ):
1142
+ table = self .table_class ()
1143
+ ms = tskit .MetadataSchema ({"codec" : "json" })
1144
+ table .metadata_schema = ms
1145
+ table .add_row (
1146
+ ** {
1147
+ ** self .input_data_for_add_row (),
1148
+ "metadata" : None ,
1149
+ }
1150
+ )
1151
+ with pytest .raises (KeyError ):
1152
+ _ = table .metadata_vector ("x" )
1153
+ metadata_list = [
1154
+ {"a" : 4 , "u" : [1 , 2 ]},
1155
+ {},
1156
+ ]
1157
+ for md in metadata_list :
1158
+ table .add_row (
1159
+ ** {
1160
+ ** self .input_data_for_add_row (),
1161
+ "metadata" : md ,
1162
+ }
1163
+ )
1164
+ with pytest .raises (KeyError ):
1165
+ _ = table .metadata_vector ("x" )
1166
+
1167
+ def test_metadata_vector_nodefault (self ):
1168
+ table = self .table_class ()
1169
+ ms = tskit .MetadataSchema ({"codec" : "json" })
1170
+ table .metadata_schema = ms
1171
+ metadata_list = [
1172
+ {"abc" : 4 , "u" : [1 , 2 ]},
1173
+ {"abc" : 10 , "u" : [3 , 4 ]},
1174
+ {"abc" : - 3 , "b" : {"c" : 1 }, "u" : [5 , 6 ]},
1175
+ {"abc" : 1 },
1176
+ ]
1177
+ for md in metadata_list :
1178
+ table .add_row (
1179
+ ** {
1180
+ ** self .input_data_for_add_row (),
1181
+ "metadata" : md ,
1182
+ }
1183
+ )
1184
+ # first the totally obvious test
1185
+ md_vec = table .metadata_vector ("abc" )
1186
+ assert np .all (np .equal (md_vec , [d ["abc" ] for d in metadata_list ]))
1187
+ # now automated ones
1188
+ for dtype in [None , "int" , "float" , "object" ]:
1189
+ self .verify_metadata_vector (
1190
+ table , key = "abc" , dtype = dtype , default_value = 9999
1191
+ )
1192
+ self .verify_metadata_vector (
1193
+ table , key = ["abc" ], dtype = dtype , default_value = 9999
1194
+ )
1195
+
1196
+ def test_metadata_vector (self ):
1197
+ table = self .table_class ()
1198
+ ms = tskit .MetadataSchema ({"codec" : "json" })
1199
+ table .metadata_schema = ms
1200
+ metadata_list = [
1201
+ {"abc" : 4 , "u" : [1 , 2 ]},
1202
+ {"abc" : 10 , "u" : [3 , 4 ]},
1203
+ {"abc" : - 3 , "b" : {"c" : 1 }, "u" : [5 , 6 ]},
1204
+ {"b" : {"c" : 3.2 }, "u" : [7 , 8 ]},
1205
+ {"b" : {"x" : 8.2 }},
1206
+ {},
1207
+ None ,
1208
+ ]
1209
+ for md in metadata_list :
1210
+ table .add_row (
1211
+ ** {
1212
+ ** self .input_data_for_add_row (),
1213
+ "metadata" : md ,
1214
+ }
1215
+ )
1216
+ # first the totally obvious test
1217
+ md_vec = table .metadata_vector ("abc" , default_value = 0 )
1218
+ assert np .all (
1219
+ np .equal (
1220
+ md_vec ,
1221
+ [
1222
+ d ["abc" ] if (d is not None and "abc" in d ) else 0
1223
+ for d in metadata_list
1224
+ ],
1225
+ )
1226
+ )
1227
+
1228
+ # now some automated ones
1229
+ for dtype in [None , "int" , "float" , "object" ]:
1230
+ self .verify_metadata_vector (table , key = "abc" , dtype = dtype , default_value = - 1 )
1231
+ self .verify_metadata_vector (
1232
+ table , key = ["abc" ], dtype = dtype , default_value = - 1
1233
+ )
1234
+ self .verify_metadata_vector (table , key = ["x" ], dtype = dtype , default_value = - 1 )
1235
+ self .verify_metadata_vector (
1236
+ table , key = ["b" , "c" ], dtype = dtype , default_value = - 1
1237
+ )
1238
+ self .verify_metadata_vector (table , key = ["b" ], dtype = "object" , default_value = - 1 )
1239
+ self .verify_metadata_vector (table , key = ["u" ], dtype = "int" , default_value = [0 , 0 ])
1240
+ # and finally we should get rectangular arrays when it makes sense
1241
+ md_vec = table .metadata_vector ("u" , default_value = [0 , 0 ])
1242
+ assert md_vec .shape == (table .num_rows , 2 )
1243
+
1115
1244
1116
1245
class AssertEqualsMixin :
1117
1246
def test_equal (self , table_5row , test_rows ):
0 commit comments