diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 1757ce3794..9f543894d5 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -9,6 +9,10 @@ **Features** +- Add fancy indexing to tables. E.g. ``table[6:86]`` returns a new table with the + specified rows. Supports slices, index arrays and boolean masks + (:user:`benjeffery`, :issue:`1221`, :pr:`1348`, :pr:`1342`). + - Add ``Table.append`` method for adding rows from classes such as ``SiteTableRow`` and ``Site`` (:user:`benjeffery`, :issue:`1111`, :pr:`1254`). diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 5a266eeda5..2eff9d0517 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -115,6 +115,31 @@ def make_transposed_input_data(self, num_rows): for j in range(num_rows) ] + @pytest.fixture + def test_rows(self, scope="session"): + test_rows = self.make_transposed_input_data(10) + # Annoyingly we have to tweak some types as once added to a row and then put in + # an error message things come out differently + for n in range(10): + for col in test_rows[n].keys(): + if col in ["timestamp", "record", "ancestral_state", "derived_state"]: + test_rows[n][col] = bytes(test_rows[n][col]).decode("ascii") + return test_rows + + @pytest.fixture + def table(self, test_rows): + table = self.table_class() + for row in test_rows: + table.add_row(**row) + return table + + @pytest.fixture + def table_5row(self, test_rows): + table_5row = self.table_class() + for row in test_rows[:5]: + table_5row.add_row(**row) + return table_5row + def test_max_rows_increment(self): for bad_value in [-1, -(2 ** 10)]: with pytest.raises(ValueError): @@ -1031,52 +1056,34 @@ def test_set_columns_metadata_schema(self): class AssertEqualsMixin: - @pytest.fixture - def test_rows(self, scope="session"): - test_rows = self.make_transposed_input_data(10) - # Annoyingly we have to tweak some types as once added to a row and then put in - # an error message things come out differently - for n in range(10): - for col in test_rows[n].keys(): - if col in ["timestamp", "record", "ancestral_state", "derived_state"]: - test_rows[n][col] = bytes(test_rows[n][col]).decode("ascii") - return test_rows - - @pytest.fixture - def table1(self, test_rows): - table1 = self.table_class() - for row in test_rows[:5]: - table1.add_row(**row) - return table1 - - def test_equal(self, table1, test_rows): + def test_equal(self, table_5row, test_rows): table2 = self.table_class() for row in test_rows[:5]: table2.add_row(**row) - table1.assert_equals(table2) + table_5row.assert_equals(table2) - def test_type(self, table1): + def test_type(self, table_5row): with pytest.raises( AssertionError, - match=f"Types differ: self={type(table1)} other=", + match=f"Types differ: self={type(table_5row)} other=", ): - table1.assert_equals(42) + table_5row.assert_equals(42) - def test_metadata_schema(self, table1): - if hasattr(table1, "metadata_schema"): - assert table1.metadata_schema == tskit.MetadataSchema(None) - table2 = table1.copy() + def test_metadata_schema(self, table_5row): + if hasattr(table_5row, "metadata_schema"): + assert table_5row.metadata_schema == tskit.MetadataSchema(None) + table2 = table_5row.copy() table2.metadata_schema = tskit.MetadataSchema({"codec": "json"}) with pytest.raises( AssertionError, - match=f"{type(table1).__name__} metadata schemas differ: self=None " + match=f"{type(table_5row).__name__} metadata schemas differ: self=None " f"other=OrderedDict([('codec', " "'json')])", ): - table1.assert_equals(table2) - table1.assert_equals(table2, ignore_metadata=True) + table_5row.assert_equals(table2) + table_5row.assert_equals(table2, ignore_metadata=True) - def test_row_changes(self, table1, test_rows): + def test_row_changes(self, table_5row, test_rows): for column_name in test_rows[0].keys(): table2 = self.table_class() for row in test_rows[:4]: @@ -1089,16 +1096,16 @@ def test_row_changes(self, table1, test_rows): with pytest.raises( AssertionError, match=re.escape( - f"{type(table1).__name__} row 4 differs:\n" + f"{type(table_5row).__name__} row 4 differs:\n" f"self.{column_name}={test_rows[4][column_name]} " f"other.{column_name}={test_rows[5][column_name]}" ), ): - table1.assert_equals(table2) + table_5row.assert_equals(table2) if column_name == "metadata": - table1.assert_equals(table2, ignore_metadata=True) + table_5row.assert_equals(table2, ignore_metadata=True) if column_name == "timestamp": - table1.assert_equals(table2, ignore_timestamps=True) + table_5row.assert_equals(table2, ignore_timestamps=True) # Two columns differ, as we don't know the order in the error message # test for both independently @@ -1123,7 +1130,7 @@ def test_row_changes(self, table1, test_rows): f"other.{column_name}={test_rows[5][column_name]}" ), ): - table1.assert_equals(table2) + table_5row.assert_equals(table2) with pytest.raises( AssertionError, match=re.escape( @@ -1131,19 +1138,19 @@ def test_row_changes(self, table1, test_rows): f"other.{column_name2}={test_rows[5][column_name2]}" ), ): - table1.assert_equals(table2) + table_5row.assert_equals(table2) - def test_num_rows(self, table1, test_rows): + def test_num_rows(self, table_5row, test_rows): table2 = self.table_class() for row in test_rows[:4]: table2.add_row(**row) with pytest.raises( AssertionError, - match=f"{type(table1).__name__} number of rows differ: self=5 other=4", + match=f"{type(table_5row).__name__} number of rows differ: self=5 other=4", ): - table1.assert_equals(table2) + table_5row.assert_equals(table2) - def test_metadata(self, table1, test_rows): + def test_metadata(self, table_5row, test_rows): if "metadata" in test_rows[0].keys(): table2 = self.table_class() for row in test_rows[:4]: @@ -1156,15 +1163,15 @@ def test_metadata(self, table1, test_rows): with pytest.raises( AssertionError, match=re.escape( - f"{type(table1).__name__} row 4 differs:\n" + f"{type(table_5row).__name__} row 4 differs:\n" f"self.metadata={test_rows[4]['metadata']} " f"other.metadata={test_rows[5]['metadata']}" ), ): - table1.assert_equals(table2) - table1.assert_equals(table2, ignore_metadata=True) + table_5row.assert_equals(table2) + table_5row.assert_equals(table2, ignore_metadata=True) - def test_timestamp(self, table1, test_rows): + def test_timestamp(self, table_5row, test_rows): if "timestamp" in test_rows[0].keys(): table2 = self.table_class() for row in test_rows[:4]: @@ -1177,16 +1184,185 @@ def test_timestamp(self, table1, test_rows): with pytest.raises( AssertionError, match=re.escape( - f"{type(table1).__name__} row 4 differs:\n" + f"{type(table_5row).__name__} row 4 differs:\n" f"self.timestamp={test_rows[4]['timestamp']} " f"other.timestamp={test_rows[5]['timestamp']}" ), ): - table1.assert_equals(table2) - table1.assert_equals(table2, ignore_timestamps=True) + table_5row.assert_equals(table2) + table_5row.assert_equals(table2, ignore_timestamps=True) + + +class FancyIndexingMixin: + @pytest.mark.parametrize( + "slic", + [ + slice(None, None), + slice(None, 3), + slice(2, None), + slice(1, 4), + slice(4, 1), + slice(1, 4, 2), + slice(4, 1, 2), + slice(4, 1, -1), + slice(1, 4, -1), + slice(3, None, -1), + slice(None, 3, -1), + slice(None, None, -2), + ], + ) + def test_slice(self, table, test_rows, slic): + assert table.num_rows >= 5 + expected = table.copy() + expected.truncate(0) + for row in test_rows[slic]: + expected.add_row(**row) + table[slic].assert_equals(expected) + @pytest.mark.parametrize( + "mask", + [ + [False] * 5, + [True] * 5, + [True] + [False] * 4, + [False, True, False, True, True], + ], + ) + def test_boolean_array(self, table_5row, test_rows, mask): + assert table_5row.num_rows >= 5 + expected = table_5row.copy() + expected.truncate(0) + for flag, row in zip(mask, test_rows[:5]): + if flag: + expected.add_row(**row) + table_5row[mask].assert_equals(expected) + + @pytest.mark.parametrize( + "index_array", + [ + [], + [0], + [4], + random.choices(range(5), k=100), + np.array([0, 0, 0, 2], dtype=np.uint64), + np.array([2, 4, 4, 0], dtype=np.int64), + np.array([0, 0, 0, 2], dtype=np.uint32), + np.array([2, 4, 4, 0], dtype=np.int32), + np.array([4, 3, 4, 1], dtype=np.uint8), + np.array([4, 3, 4, 1], dtype=np.int8), + ], + ) + def test_index_array(self, table_5row, index_array): + assert table_5row.num_rows >= 5 + expected = table_5row.copy() + expected.truncate(0) + for index in index_array: + expected.append(table_5row[index]) + table_5row[index_array].assert_equals(expected) + table_5row[tuple(index_array)].assert_equals(expected) + + def test_index_range(self, table_5row): + expected = table_5row.copy() + expected.truncate(0) + for index in range(2, 4): + expected.append(table_5row[index]) + table_5row[range(2, 4)].assert_equals(expected) + + @pytest.mark.parametrize( + "dtype", + [ + np.float32, + np.float64, + object, + str, + ], + ) + def test_bad_dtypes(self, table, dtype): + with pytest.raises(TypeError): + table[np.zeros((10,), dtype=np.float32)] -class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): + @pytest.mark.parametrize( + "dtype", + [ + np.uint32, + np.int64, + np.uint64, + ], + ) + def test_bad_casts(self, table, dtype): + with pytest.raises(OverflowError, match="Cannot convert safely"): + table[np.asarray([np.iinfo(np.int32).max + 1], dtype=dtype)] + + def test_extrema(self, table): + max_ = np.iinfo(np.int32).max + with pytest.raises(OverflowError, match="Cannot convert safely"): + table[[max_ + 1]] + + # Slice gets clipped to valid range + copy = table.copy() + copy.clear() + table[max_ + 1 : max_ + 2].assert_equals(copy) + + with pytest.raises(OverflowError, match="Cannot convert safely"): + table[range(max_ + 1, max_ + 2)] + + @pytest.mark.parametrize( + "bad_shape", + [ + [[0]], + [[1, 2], [3, 4]], + ], + ) + def test_bad_shapes(self, table, bad_shape): + with pytest.raises(ValueError, match="object too deep"): + table[bad_shape] + + def test_bad_bool_length(self, table): + with pytest.raises( + IndexError, match="Boolean index must be same length as table" + ): + table[[False] * (len(table) + 1)] + with pytest.raises( + IndexError, match="Boolean index must be same length as table" + ): + table[[False]] + + def test_bad_indexes(self, table): + with pytest.raises(_tskit.LibraryError, match="out of bounds"): + table[[-1]] + with pytest.raises(_tskit.LibraryError, match="out of bounds"): + table[range(-5, 0)] + with pytest.raises(_tskit.LibraryError, match="out of bounds"): + table[[len(table)]] + with pytest.raises(TypeError, match="Cannot cast"): + table[[5.5]] + with pytest.raises(TypeError, match="Cannot convert"): + table[[None]] + with pytest.raises(TypeError, match="not supported between instances"): + table[["foobar"]] + with pytest.raises(TypeError, match="Index must be integer, slice or iterable"): + table[5.5] + with pytest.raises(TypeError, match="Cannot convert to a rectangular array"): + table[None] + with pytest.raises(TypeError, match="not supported between instances"): + table["foobar"] + + def test_not_writable(self, table): + with pytest.raises(TypeError, match="object does not support item assignment"): + table[5] = 5 + with pytest.raises(TypeError, match="object does not support item assignment"): + table[[5]] = 5 + + +common_tests = [ + CommonTestsMixin, + MetadataTestsMixin, + AssertEqualsMixin, + FancyIndexingMixin, +] + + +class TestIndividualTable(*common_tests): columns = [UInt32Column("flags")] ragged_list_columns = [ (DoubleColumn("location"), UInt32Column("location_offset")), @@ -1315,7 +1491,7 @@ def test_various_not_equals(self): assert a == b -class TestNodeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): +class TestNodeTable(*common_tests): columns = [ UInt32Column("flags"), @@ -1396,7 +1572,7 @@ def test_add_row_bad_data(self): t.add_row(metadata=123) -class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): +class TestEdgeTable(*common_tests): columns = [ DoubleColumn("left"), @@ -1444,7 +1620,7 @@ def test_add_row_bad_data(self): t.add_row(0, 0, 0, 0, metadata=123) -class TestSiteTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): +class TestSiteTable(*common_tests): columns = [DoubleColumn("position")] ragged_list_columns = [ (CharColumn("ancestral_state"), UInt32Column("ancestral_state_offset")), @@ -1499,7 +1675,7 @@ def test_packset_ancestral_state(self): assert np.array_equal(table.ancestral_state_offset, ancestral_state_offset) -class TestMutationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): +class TestMutationTable(*common_tests): columns = [ Int32Column("site"), Int32Column("node"), @@ -1571,7 +1747,7 @@ def test_packset_derived_state(self): assert np.array_equal(table.derived_state_offset, derived_state_offset) -class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): +class TestMigrationTable(*common_tests): columns = [ DoubleColumn("left"), DoubleColumn("right"), @@ -1673,7 +1849,7 @@ def test_packset_record(self): assert t[1].record == "BBBB" -class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): +class TestPopulationTable(*common_tests): metadata_mandatory = True columns = [] ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))] diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 2c82fb5577..32ffb68d8e 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -27,6 +27,7 @@ import datetime import itertools import json +import numbers import sys import warnings from dataclasses import dataclass @@ -335,16 +336,42 @@ def __setattr__(self, name, value): def __getitem__(self, index): """ - Return the specified row of this table, decoding metadata if it is present. - Supports negative indexing, e.g. ``table[-5]``. + If passed an integer, return the specified row of this table, decoding metadata + if it is present. Supports negative indexing, e.g. ``table[-5]``. + If passed a slice, iterable or array return a new table containing the specified + rows. Similar to numpy fancy indexing, if the array or iterables contains + booleans then the index acts as a mask, returning those rows for which the mask + is True. + + :param index: the zero-index of a desired row, a slice of the desired rows, an + iterable or array of the desired row numbers, or a boolean array to use as + a mask. + """ + + if isinstance(index, numbers.Integral): + # Single row by integer + if index < 0: + index += len(self) + if index < 0 or index >= len(self): + raise IndexError("Index out of bounds") + return self.row_class(*self.ll_table.get_row(index)) + elif isinstance(index, numbers.Number): + raise TypeError("Index must be integer, slice or iterable") + elif isinstance(index, slice): + index = range(*index.indices(len(self))) + else: + index = np.asarray(index) + if index.dtype == np.bool_: + if len(index) != len(self): + raise IndexError("Boolean index must be same length as table") + index = np.flatnonzero(index) + index = util.safe_np_int_cast(index, np.int32) - :param int index: the zero-index of the desired row - """ - if index < 0: - index += len(self) - if index < 0 or index >= len(self): - raise IndexError("Index out of bounds") - return self.row_class(*self.ll_table.get_row(index)) + ret = self.__class__() + ret.metadata_schema = self.metadata_schema + ret.ll_table.extend(self.ll_table, row_indexes=index) + + return ret def append(self, row): """