@@ -1112,15 +1112,34 @@ def test_set_columns_metadata_schema(self):
1112
1112
table .set_columns (** table2 .asdict ())
1113
1113
assert table .metadata_schema == ms
1114
1114
1115
+ def verify_metadata_vector (self , table , key , dtype , default_value ):
1116
+ md_vec = table .metadata_vector (key , default_value = default_value , dtype = dtype )
1117
+ assert isinstance (md_vec , np .ndarray )
1118
+ if dtype is not None :
1119
+ assert md_vec .dtype == np .dtype (dtype )
1120
+ assert len (md_vec ) == table .num_rows
1121
+ if not isinstance (key , list ):
1122
+ key = [key ]
1123
+ for x , row in zip (md_vec , table ):
1124
+ md = row .metadata
1125
+ for k in key :
1126
+ if k in md :
1127
+ md = md [k ]
1128
+ else :
1129
+ md = default_value
1130
+ break
1131
+ assert np .all (np .cast [dtype ](md ) == x )
1132
+
1115
1133
def test_metadata_vector (self ):
1116
1134
table = self .table_class ()
1117
1135
ms = tskit .MetadataSchema ({"codec" : "json" })
1118
1136
table .metadata_schema = ms
1119
1137
metadata_list = [
1120
- {"a" : 4 },
1121
- {"a" : 10 },
1122
- {"a" : - 3 , "b" : {"c" : 1 }},
1123
- {"b" : {"c" : 3.2 }},
1138
+ {"a" : 4 , "u" : [1 , 2 ]},
1139
+ {"a" : 10 , "u" : [3 , 4 ]},
1140
+ {"a" : - 3 , "b" : {"c" : 1 }, "u" : [5 , 6 ]},
1141
+ {"b" : {"c" : 3.2 }, "u" : [7 , 8 ]},
1142
+ {"b" : {"x" : 8.2 }},
1124
1143
{},
1125
1144
]
1126
1145
for md in metadata_list :
@@ -1130,12 +1149,18 @@ def test_metadata_vector(self):
1130
1149
"metadata" : md ,
1131
1150
}
1132
1151
)
1133
- default_value = - 1
1134
- for key in ["a" , ["b" , "c" ]]:
1135
- assert np .equal (
1136
- [md .get (key , default_value ) for md in metadata_list ],
1137
- table .metadata_vector (key , default_value = default_value ),
1152
+ for dtype in [None , "int" , "float" , "object" ]:
1153
+ self .verify_metadata_vector (table , key = "a" , dtype = dtype , default_value = - 1 )
1154
+ self .verify_metadata_vector (table , key = ["a" ], dtype = dtype , default_value = - 1 )
1155
+ self .verify_metadata_vector (table , key = ["x" ], dtype = dtype , default_value = - 1 )
1156
+ self .verify_metadata_vector (
1157
+ table , key = ["b" , "c" ], dtype = dtype , default_value = - 1
1138
1158
)
1159
+ self .verify_metadata_vector (table , key = ["b" ], dtype = "object" , default_value = - 1 )
1160
+ self .verify_metadata_vector (table , key = ["u" ], dtype = "int" , default_value = [0 , 0 ])
1161
+ md_vec = table .metadata_vector ("u" , default_value = [0 , 0 ], dtype = "int" )
1162
+ # and finally we should get rectangular arrays when it makes sense
1163
+ assert md_vec .shape == (table .num_rows , 2 )
1139
1164
1140
1165
1141
1166
class AssertEqualsMixin :
0 commit comments