diff --git a/docs/examples.py b/docs/examples.py index b46bec6b8d..dee4a55c58 100644 --- a/docs/examples.py +++ b/docs/examples.py @@ -312,7 +312,7 @@ def write_table(tree): write_table(tree) print(tree.draw_text()) - tree.draw_svg("_static/tree_structure1.svg", tree_height_scale="rank") + tree.draw_svg("_static/tree_structure1.svg", time_scale="rank") edges = """\ left right parent child @@ -327,7 +327,7 @@ def write_table(tree): write_table(tree) print(tree.draw_text()) - tree.draw_svg("_static/tree_structure2.svg", tree_height_scale="rank") + tree.draw_svg("_static/tree_structure2.svg", time_scale="rank") def tree_traversal(): @@ -396,7 +396,7 @@ def finding_nearest_neighbors(): ) tree = ts.first() - tree.draw_svg("_static/different_time_samples.svg", tree_height_scale="rank") + tree.draw_svg("_static/different_time_samples.svg", time_scale="rank") # moving_along_tree_sequence() diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index e943aacd46..1757ce3794 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -34,6 +34,12 @@ - Add ``Table.assert_equals`` and ``TableCollection.assert_equals`` which give an exact report of any differences. (:user:`benjeffery`,:issue:`1076`, :pr:`1328`) +**Changes** + +- In drawing methods ``max_tree_height`` and ``tree_height_scale`` have been deprecated + in favour of ``max_time`` and ``time_scale`` + (:user:`benjeffery`,:issue:`1262`, :pr:`1331`). + **Fixes** - Tree sequences were not properly init'd after unpickling diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index 782dee2f31..a500bdc70c 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -590,9 +590,15 @@ def test_unused_args(self): with pytest.raises(ValueError): t.draw(format=self.drawing_format, node_colours={}) with pytest.raises(ValueError): - t.draw(format=self.drawing_format, max_tree_height=1234) + t.draw(format=self.drawing_format, max_time=1234) with pytest.raises(ValueError): - t.draw(format=self.drawing_format, tree_height_scale="time") + with pytest.warns(FutureWarning): + t.draw(format=self.drawing_format, max_tree_height=1234) + with pytest.raises(ValueError): + t.draw(format=self.drawing_format, time_scale="time") + with pytest.raises(ValueError): + with pytest.warns(FutureWarning): + t.draw(format=self.drawing_format, tree_height_scale="time") class TestDrawUnicode(TestDrawText): @@ -1381,7 +1387,7 @@ def test_tree_sequence_non_minlex(self): ) self.verify_text_rendering(ts.draw_text(order="tree"), drawn_tree) - def test_max_tree_height(self): + def test_max_time(self): ts = self.get_simple_ts() tree = ( " 9 \n" @@ -1399,7 +1405,7 @@ def test_max_tree_height(self): "0 1 2 3\n" ) t = ts.first() - self.verify_text_rendering(t.draw_text(max_tree_height="ts"), tree) + self.verify_text_rendering(t.draw_text(max_time="ts"), tree) tree = ( " 9 \n" @@ -1411,10 +1417,10 @@ def test_max_tree_height(self): "0 1 2 3\n" ) t = ts.first() - self.verify_text_rendering(t.draw_text(max_tree_height="tree"), tree) - for bad_max_tree_height in [1, "sdfr", ""]: + self.verify_text_rendering(t.draw_text(max_time="tree"), tree) + for bad_max_time in [1, "sdfr", ""]: with pytest.raises(ValueError): - t.draw_text(max_tree_height=bad_max_tree_height) + t.draw_text(max_time=bad_max_time) class TestDrawSvg(TestTreeDraw, xmlunittest.XmlTestMixin): @@ -1509,7 +1515,7 @@ def test_bad_tick_spacing(self): def test_no_mixed_yscales(self): ts = self.get_simple_ts() with pytest.raises(ValueError, match="varying yscales"): - ts.draw_svg(y_axis=True, max_tree_height="tree") + ts.draw_svg(y_axis=True, max_time="tree") def test_draw_defaults(self): t = self.get_binary_tree() @@ -1521,7 +1527,7 @@ def test_draw_defaults(self): @pytest.mark.parametrize("y_axis", (True, False)) @pytest.mark.parametrize("y_label", (True, False)) @pytest.mark.parametrize( - "tree_height_scale", + "time_scale", ( "rank", "time", @@ -1530,7 +1536,7 @@ def test_draw_defaults(self): @pytest.mark.parametrize("y_ticks", ([], [0, 1], None)) @pytest.mark.parametrize("y_gridlines", (True, False)) def test_draw_svg_y_axis_parameter_combos( - self, y_axis, y_label, tree_height_scale, y_ticks, y_gridlines + self, y_axis, y_label, time_scale, y_ticks, y_gridlines ): t = self.get_binary_tree() svg = t.draw_svg( @@ -1538,7 +1544,7 @@ def test_draw_svg_y_axis_parameter_combos( y_label=y_label, y_ticks=y_ticks, y_gridlines=y_gridlines, - tree_height_scale=tree_height_scale, + time_scale=time_scale, ) self.verify_basic_svg(svg) ts = self.get_simple_ts() @@ -1547,7 +1553,7 @@ def test_draw_svg_y_axis_parameter_combos( y_label=y_label, y_ticks=y_ticks, y_gridlines=y_gridlines, - tree_height_scale=tree_height_scale, + time_scale=time_scale, ) self.verify_basic_svg(svg, width=200 * ts.num_trees) @@ -1680,46 +1686,57 @@ def test_one_mutation_label_colour(self): self.verify_basic_svg(svg) assert svg.count(f'stroke="{colour}"') == 1 - def test_bad_tree_height_scale(self): + def test_bad_time_scale(self): t = self.get_binary_tree() for bad_scale in ["te", "asdf", "", [], b"23"]: with pytest.raises(ValueError): - t.draw_svg(tree_height_scale=bad_scale) + t.draw_svg(time_scale=bad_scale) + with pytest.raises(ValueError): + with pytest.warns(FutureWarning): + t.draw_svg(tree_height_scale=bad_scale) - def test_bad_max_tree_height(self): + def test_bad_max_time(self): t = self.get_binary_tree() for bad_height in ["te", "asdf", "", [], b"23"]: with pytest.raises(ValueError): - t.draw_svg(max_tree_height=bad_height) + t.draw_svg(max_time=bad_height) + with pytest.raises(ValueError): + with pytest.warns(FutureWarning): + t.draw_svg(max_tree_height=bad_height) - def test_height_scale_time_and_max_tree_height(self): + def test_time_scale_time_and_max_time(self): ts = msprime.simulate(5, recombination_rate=2, random_seed=2) t = ts.first() # The default should be the same as tree. - svg1 = t.draw_svg(max_tree_height="tree") + svg1 = t.draw_svg(max_time="tree") self.verify_basic_svg(svg1) svg2 = t.draw_svg() assert svg1 == svg2 - svg3 = t.draw_svg(max_tree_height="ts") + svg3 = t.draw_svg(max_time="ts") assert svg1 != svg3 - svg4 = t.draw_svg(max_tree_height=max(ts.tables.nodes.time)) + svg4 = t.draw_svg(max_time=max(ts.tables.nodes.time)) assert svg3 == svg4 - - def test_height_scale_rank_and_max_tree_height(self): - # Make sure the rank height scale and max_tree_height interact properly. + with pytest.warns(FutureWarning): + svg5 = t.draw_svg(max_tree_height="tree") + assert svg5 == svg1 + svg6 = t.draw_svg(max_time="tree", max_tree_height="i should be ignored") + assert svg6 == svg1 + + def test_time_scale_rank_and_max_time(self): + # Make sure the rank height scale and max_time interact properly. ts = msprime.simulate(5, recombination_rate=2, random_seed=2) t = ts.first() # The default should be the same as tree. - svg1 = t.draw_svg(max_tree_height="tree", tree_height_scale="rank") + svg1 = t.draw_svg(max_time="tree", time_scale="rank") self.verify_basic_svg(svg1) - svg2 = t.draw_svg(tree_height_scale="rank") + svg2 = t.draw_svg(time_scale="rank") assert svg1 == svg2 - svg3 = t.draw_svg(max_tree_height="ts", tree_height_scale="rank") + svg3 = t.draw_svg(max_time="ts", time_scale="rank") assert svg1 != svg3 self.verify_basic_svg(svg3) - # Numeric max tree height not supported for rank scale. + # Numeric max time not supported for rank scale. with pytest.raises(ValueError): - t.draw_svg(max_tree_height=2, tree_height_scale="rank") + t.draw_svg(max_time=2, time_scale="rank") # # TODO: update the tests below here to check the new SVG based interface. @@ -1826,7 +1843,7 @@ def test_extra_mutations(self, all_muts, x_axis): svg_no_css = svg[svg.find("") :] assert svg_no_css.count("extra") == 1 * extra_mut_copies - def test_max_tree_height(self): + def test_max_time(self): nodes = io.StringIO( """\ id is_sample time @@ -1862,8 +1879,11 @@ def test_max_tree_height(self): snippet2 = svg2[svg2.rfind("edge", 0, str_pos) : str_pos] assert snippet1 != snippet2 - svg1 = ts.at_index(0).draw(max_tree_height="ts") - svg2 = ts.at_index(1).draw(max_tree_height="ts") + svg1 = ts.at_index(0).draw(max_time="ts") + svg2 = ts.at_index(1).draw(max_time="ts") + with pytest.warns(FutureWarning): + svg3 = ts.at_index(1).draw(max_tree_height="ts") + assert svg3 == svg2 # when scaled, node 3 should be at the *same* height in both trees, so the edge # definition should be the same self.verify_basic_svg(svg1) @@ -1913,7 +1933,10 @@ def test_draw_integer_breaks_ts(self): def test_draw_even_height_ts(self): ts = msprime.simulate(5, recombination_rate=1, random_seed=1) - svg = ts.draw_svg(max_tree_height="tree") + svg = ts.draw_svg(max_time="tree") + self.verify_basic_svg(svg, width=200 * ts.num_trees) + with pytest.warns(FutureWarning): + svg = ts.draw_svg(max_tree_height="tree") self.verify_basic_svg(svg, width=200 * ts.num_trees) def test_draw_sized_ts(self): @@ -1921,17 +1944,25 @@ def test_draw_sized_ts(self): svg = ts.draw_svg(size=(600, 400)) self.verify_basic_svg(svg, width=600, height=400) - def test_tree_height_scale(self): + def test_time_scale(self): ts = msprime.simulate(4, random_seed=2) - svg = ts.draw_svg(tree_height_scale="time") + svg = ts.draw_svg(time_scale="time") self.verify_basic_svg(svg) - svg = ts.draw_svg(tree_height_scale="log_time") + svg = ts.draw_svg(time_scale="log_time") self.verify_basic_svg(svg) - svg = ts.draw_svg(tree_height_scale="rank") + with pytest.warns(FutureWarning): + svg2 = ts.draw_svg(tree_height_scale="log_time") + assert svg2 == svg + svg = ts.draw_svg(time_scale="rank") self.verify_basic_svg(svg) + svg3 = ts.draw_svg(time_scale="rank", tree_height_scale="ignore me please") + assert svg3 == svg for bad_scale in [0, "", "NOT A SCALE"]: with pytest.raises(ValueError): - ts.draw_svg(tree_height_scale=bad_scale) + ts.draw_svg(time_scale=bad_scale) + with pytest.raises(ValueError): + with pytest.warns(FutureWarning): + ts.draw_svg(tree_height_scale=bad_scale) def test_x_scale(self): ts = msprime.simulate(4, random_seed=2) @@ -1963,7 +1994,15 @@ def test_y_axis(self): ("log_time", "Time"), ("rank", "Node time"), ]: - svg = tree.draw_svg(y_axis=True, tree_height_scale=hscale) + svg = tree.draw_svg(y_axis=True, time_scale=hscale) + if hscale is not None: + with pytest.warns(FutureWarning): + svg2 = tree.draw_svg(y_axis=True, tree_height_scale=hscale) + assert svg2 == svg + svg3 = tree.draw_svg( + y_axis=True, time_scale=hscale, tree_height_scale="ignore me please" + ) + assert svg3 == svg svg_no_css = svg[svg.find("") :] assert label in svg_no_css assert svg_no_css.count("axes") == 1 @@ -2012,11 +2051,9 @@ def test_no_edges_show_empty(self): tables = full_ts.dump_tables() tables.edges.clear() ts = tables.tree_sequence() - for tree_height_scale in ("time", "log_time", "rank"): + for time_scale in ("time", "log_time", "rank"): # SVG should just be a row of 10 sample nodes - svg = ts.draw_svg( - tree_height_scale=tree_height_scale, x_lim=[0, ts.sequence_length] - ) + svg = ts.draw_svg(time_scale=time_scale, x_lim=[0, ts.sequence_length]) self.verify_basic_svg(svg) assert svg.count("rect") == 10 # Sample nodes are rectangles assert svg.count('path class="edge"') == 0 @@ -2199,7 +2236,7 @@ def test_known_svg_tree_y_axis_rank(self, overwrite_viz, draw_plotbox): y_axis=True, y_label=label, y_gridlines=True, - tree_height_scale="rank", + time_scale="rank", style=".y-axis line.grid {stroke: #CCCCCC}", debug_box=draw_plotbox, ) @@ -2292,13 +2329,9 @@ def test_known_svg_ts_highlighted_mut(self, overwrite_viz, draw_plotbox): def test_known_svg_ts_rank(self, overwrite_viz, draw_plotbox): ts = self.get_simple_ts() - svg1 = ts.draw_svg( - tree_height_scale="rank", y_axis=True, debug_box=draw_plotbox - ) + svg1 = ts.draw_svg(time_scale="rank", y_axis=True, debug_box=draw_plotbox) ts = self.get_simple_ts(use_mutation_times=True) - svg2 = ts.draw_svg( - tree_height_scale="rank", y_axis=True, debug_box=draw_plotbox - ) + svg2 = ts.draw_svg(time_scale="rank", y_axis=True, debug_box=draw_plotbox) assert svg1.count('class="site ') == ts.num_sites assert svg1.count('class="mut ') == ts.num_mutations * 2 assert svg1.replace(" unknown_time", "") == svg2 # Trim the unknown_time class @@ -2309,7 +2342,7 @@ def test_known_svg_ts_rank(self, overwrite_viz, draw_plotbox): @pytest.mark.skip(reason="Fails on CI as OSX gives different random numbers") def test_known_svg_nonbinary_ts(self, overwrite_viz, draw_plotbox): ts = self.get_nonbinary_ts() - svg = ts.draw_svg(tree_height_scale="log_time", debug_box=draw_plotbox) + svg = ts.draw_svg(time_scale="log_time", debug_box=draw_plotbox) assert svg.count('class="site ') == ts.num_sites assert svg.count('class="mut ') == ts.num_mutations * 2 self.verify_known_svg( @@ -2395,7 +2428,7 @@ def test_known_svg_ts_y_axis_log(self, overwrite_viz, draw_plotbox): svg = ts.draw_svg( y_axis=True, y_label="Time (log scale)", - tree_height_scale="log_time", + time_scale="log_time", debug_box=draw_plotbox, ) self.verify_known_svg( @@ -2413,7 +2446,7 @@ def test_known_svg_ts_mutation_times(self, overwrite_viz, draw_plotbox): def test_known_svg_ts_mutation_times_logscale(self, overwrite_viz, draw_plotbox): ts = self.get_simple_ts(use_mutation_times=True) - svg = ts.draw_svg(tree_height_scale="log_time", debug_box=draw_plotbox) + svg = ts.draw_svg(time_scale="log_time", debug_box=draw_plotbox) assert svg.count('class="site ') == ts.num_sites assert svg.count('class="mut ') == ts.num_mutations * 2 self.verify_known_svg( diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 34c9b7e76e..63cf8213a4 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -138,7 +138,7 @@ def insert_discrete_time_mutations(ts, num_times=4, num_sites=10): """ Inserts mutations in the tree sequence at regularly-spaced num_sites positions, at only a discrete set of times (the same for all trees): at - num_times times evenly spaced between 0 and the maximum tree height. + num_times times evenly spaced between 0 and the maximum time. """ tables = ts.tables tables.sites.clear() diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index d0f5aed71b..f6c6d56741 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -29,6 +29,7 @@ import math import numbers import operator +import warnings from dataclasses import dataclass from typing import List from typing import Mapping @@ -62,25 +63,23 @@ def check_orientation(orientation): return orientation -def check_max_tree_height(max_tree_height, allow_numeric=True): - if max_tree_height is None: - max_tree_height = "tree" - is_numeric = isinstance(max_tree_height, numbers.Real) - if max_tree_height not in ["tree", "ts"] and not allow_numeric: - raise ValueError("max_tree_height must be 'tree' or 'ts'") - if max_tree_height not in ["tree", "ts"] and (allow_numeric and not is_numeric): - raise ValueError( - "max_tree_height must be a numeric value or one of 'tree' or 'ts'" - ) - return max_tree_height +def check_max_time(max_time, allow_numeric=True): + if max_time is None: + max_time = "tree" + is_numeric = isinstance(max_time, numbers.Real) + if max_time not in ["tree", "ts"] and not allow_numeric: + raise ValueError("max_time must be 'tree' or 'ts'") + if max_time not in ["tree", "ts"] and (allow_numeric and not is_numeric): + raise ValueError("max_time must be a numeric value or one of 'tree' or 'ts'") + return max_time -def check_tree_height_scale(tree_height_scale): - if tree_height_scale is None: - tree_height_scale = "time" - if tree_height_scale not in ["time", "log_time", "rank"]: - raise ValueError("tree_height_scale must be 'time', 'log_time' or 'rank'") - return tree_height_scale +def check_time_scale(time_scale): + if time_scale is None: + time_scale = "time" + if time_scale not in ["time", "log_time", "rank"]: + raise ValueError("time_scale must be 'time', 'log_time' or 'rank'") + return time_scale def check_format(format): # noqa A002 @@ -263,10 +262,26 @@ def draw_tree( mutation_colours=None, format=None, # noqa A002 edge_colours=None, + time_scale=None, tree_height_scale=None, + max_time=None, max_tree_height=None, order=None, ): + if time_scale is None and tree_height_scale is not None: + time_scale = tree_height_scale + # Deprecated in 0.3.6 + warnings.warn( + "tree_height_scale is deprecated; use time_scale instead", + FutureWarning, + ) + if max_time is None and max_tree_height is not None: + max_time = max_tree_height + # Deprecated in 0.3.6 + warnings.warn( + "max_tree_height is deprecated; use max_time instead", + FutureWarning, + ) # See tree.draw() for documentation on these arguments. fmt = check_format(format) @@ -300,8 +315,8 @@ def remap_style(original_map, new_key, none_value): (width, height), node_labels=node_labels, mutation_labels=mutation_labels, - tree_height_scale=tree_height_scale, - max_tree_height=max_tree_height, + time_scale=time_scale, + max_time=max_time, node_attrs=node_attrs, edge_attrs=edge_attrs, node_label_attrs=node_label_attrs, @@ -323,14 +338,14 @@ def remap_style(original_map, new_key, none_value): raise ValueError("Text trees do not support node_colours") if edge_colours is not None: raise ValueError("Text trees do not support edge_colours") - if tree_height_scale is not None: - raise ValueError("Text trees do not support tree_height_scale") + if time_scale is not None: + raise ValueError("Text trees do not support time_scale") use_ascii = fmt == "ascii" text_tree = VerticalTextTree( tree, node_labels=node_labels, - max_tree_height=max_tree_height, + max_time=max_time, use_ascii=use_ascii, orientation=TOP, order=order, @@ -460,7 +475,7 @@ def __init__( root_svg_attributes, style, svg_class, - tree_height_scale, + time_scale, x_axis=None, y_axis=None, x_label=None, @@ -486,13 +501,13 @@ def __init__( self.root_groups = {} self.debug_box = debug_box self.drawing = dwg - self.tree_height_scale = check_tree_height_scale(tree_height_scale) + self.time_scale = check_time_scale(time_scale) self.y_axis = y_axis self.x_axis = x_axis if x_label is None and x_axis: x_label = "Genome position" if y_label is None and y_axis: - if tree_height_scale == "rank": + if time_scale == "rank": y_label = "Node time" else: y_label = "Time" @@ -731,9 +746,7 @@ def x_transform(self, x): ) def y_transform(self, y): - raise NotImplementedError( - "No transform func defined for tree height -> plot pos" - ) + raise NotImplementedError("No transform func defined for time -> plot pos") class SvgTreeSequence(SvgPlot): @@ -748,7 +761,7 @@ def __init__( ts, size, x_scale, - tree_height_scale, + time_scale, node_labels, mutation_labels, root_svg_attributes, @@ -763,21 +776,37 @@ def __init__( y_ticks, y_gridlines, x_lim=None, - max_tree_height=None, + max_time=None, node_attrs=None, mutation_attrs=None, edge_attrs=None, node_label_attrs=None, mutation_label_attrs=None, + tree_height_scale=None, + max_tree_height=None, **kwargs, ): + if max_time is None and max_tree_height is not None: + max_time = max_tree_height + # Deprecated in 0.3.6 + warnings.warn( + "max_tree_height is deprecated; use max_time instead", + FutureWarning, + ) + if time_scale is None and tree_height_scale is not None: + time_scale = tree_height_scale + # Deprecated in 0.3.6 + warnings.warn( + "tree_height_scale is deprecated; use time_scale instead", + FutureWarning, + ) x_lim = check_x_lim(x_lim, max_x=ts.sequence_length) ts, self.tree_status = clip_ts(ts, x_lim[0], x_lim[1]) num_trees = int(np.sum((self.tree_status & OMIT) != OMIT)) if size is None: size = (200 * num_trees, 200) - if max_tree_height is None: - max_tree_height = "ts" + if max_time is None: + max_time = "ts" # X axis shown by default if x_axis is None: x_axis = True @@ -787,7 +816,7 @@ def __init__( root_svg_attributes, style, svg_class="tree-sequence", - tree_height_scale=tree_height_scale, + time_scale=time_scale, x_axis=x_axis, y_axis=y_axis, x_label=x_label, @@ -809,13 +838,13 @@ def __init__( SvgTree( tree, (self.plotbox.width / num_trees, self.plotbox.height), - tree_height_scale=tree_height_scale, + time_scale=time_scale, node_labels=node_labels, mutation_labels=mutation_labels, order=order, force_root_branch=force_root_branch, symbol_size=symbol_size, - max_tree_height=max_tree_height, + max_time=max_time, node_attrs=node_attrs, mutation_attrs=mutation_attrs, edge_attrs=edge_attrs, @@ -846,7 +875,7 @@ def __init__( y_low = self.y_transform(0) # if poss use zero point for lowest axis value if y_ticks is None: y_ticks = np.unique(ts.tables.nodes.time) - if self.tree_height_scale == "rank": + if self.time_scale == "rank": # Ticks labelled by time not rank y_ticks = {pos: f"{val:.2f}" for pos, val in enumerate(y_ticks)} @@ -940,6 +969,7 @@ def __init__( self, tree, size=None, + max_time=None, max_tree_height=None, node_labels=None, mutation_labels=None, @@ -955,6 +985,7 @@ def __init__( y_ticks=None, y_gridlines=None, all_edge_mutations=None, + time_scale=None, tree_height_scale=None, node_attrs=None, mutation_attrs=None, @@ -963,6 +994,20 @@ def __init__( mutation_label_attrs=None, **kwargs, ): + if max_time is None and max_tree_height is not None: + max_time = max_tree_height + # Deprecated in 0.3.6 + warnings.warn( + "max_tree_height is deprecated; use max_time instead", + FutureWarning, + ) + if time_scale is None and tree_height_scale is not None: + time_scale = tree_height_scale + # Deprecated in 0.3.6 + warnings.warn( + "tree_height_scale is deprecated; use time_scale instead", + FutureWarning, + ) if size is None: size = (200, 200) if symbol_size is None: @@ -975,7 +1020,7 @@ def __init__( root_svg_attributes, style, svg_class=f"tree t{tree.index}", - tree_height_scale=tree_height_scale, + time_scale=time_scale, x_axis=x_axis, y_axis=y_axis, x_label=x_label, @@ -1096,7 +1141,7 @@ def __init__( add_class(self.mutation_label_attrs[m], "lab") self.set_spacing(top=10, left=20, bottom=15, right=20) - self.assign_y_coordinates(max_tree_height, force_root_branch) + self.assign_y_coordinates(max_time, force_root_branch) self.assign_x_coordinates() tick_length_lower = self.default_tick_length # TODO - parameterize tick_length_upper = self.default_tick_length_site # TODO - parameterize @@ -1152,7 +1197,7 @@ def process_mutations_over_node(self, u, low_bound, high_bound, ignore_times=Fal def assign_y_coordinates( self, - max_tree_height, + max_time, force_root_branch, bottom_space=SvgPlot.line_height, top_space=SvgPlot.line_height, @@ -1163,14 +1208,12 @@ def assign_y_coordinates( the plotbox, at the bottom for leaf labels, and (potentially, if no root branches are plotted) above the topmost root node for root labels. """ - max_tree_height = check_max_tree_height( - max_tree_height, self.tree_height_scale != "rank" - ) + max_time = check_max_time(max_time, self.time_scale != "rank") node_time = self.ts.tables.nodes.time mut_time = self.ts.tables.mutations.time root_branch_length = 0 - if self.tree_height_scale == "rank": - if max_tree_height == "tree": + if self.time_scale == "rank": + if max_time == "tree": # We only rank the times within the tree in this case. t = np.zeros_like(node_time) for u in self.tree.nodes(): @@ -1181,10 +1224,10 @@ def assign_y_coordinates( depth = {t: j for j, t in enumerate(times)} if self.mutations_over_roots or force_root_branch: root_branch_length = 1 # Will get scaled later - max_tree_height = max(depth.values()) + root_branch_length + max_time = max(depth.values()) + root_branch_length # In pathological cases, all the roots are at 0 - if max_tree_height == 0: - max_tree_height = 1 + if max_time == 0: + max_time = 1 self.node_height = {u: depth[node_time[u]] for u in self.tree.nodes()} for u in self.node_mutations.keys(): parent = self.tree.parent(u) @@ -1196,9 +1239,9 @@ def assign_y_coordinates( u, self.node_height[u], top, ignore_times=True ) else: - assert self.tree_height_scale in ["time", "log_time"] + assert self.time_scale in ["time", "log_time"] self.node_height = {u: node_time[u] for u in self.tree.nodes()} - if max_tree_height == "tree": + if max_time == "tree": max_node_height = max(self.node_height.values()) max_mut_height = np.nanmax( [0] + [mut.time for m in self.node_mutations.values() for mut in m] @@ -1206,25 +1249,25 @@ def assign_y_coordinates( else: max_node_height = self.ts.max_root_time max_mut_height = np.nanmax(np.append(mut_time, 0)) - max_tree_height = max(max_node_height, max_mut_height) # Reuse variable + max_time = max(max_node_height, max_mut_height) # Reuse variable # In pathological cases, all the roots are at 0 - if max_tree_height == 0: - max_tree_height = 1 + if max_time == 0: + max_time = 1 if self.mutations_over_roots or force_root_branch: # Define a minimum root branch length, after transformation if necessary - if self.tree_height_scale != "log_time": - root_branch_length = max_tree_height * self.root_branch_fraction + if self.time_scale != "log_time": + root_branch_length = max_time * self.root_branch_fraction else: - log_height = np.log(max_tree_height + 1) + log_height = np.log(max_time + 1) root_branch_length = ( np.exp(log_height * (1 + self.root_branch_fraction)) - 1 - - max_tree_height + - max_time ) - # If necessary, allow for this extra branch in max_tree_height - if max_node_height + root_branch_length > max_tree_height: - max_tree_height = max_node_height + root_branch_length + # If necessary, allow for this extra branch in max_time + if max_node_height + root_branch_length > max_time: + max_time = max_node_height + root_branch_length for u in self.node_mutations.keys(): parent = self.tree.parent(u) if parent == NULL: @@ -1234,7 +1277,7 @@ def assign_y_coordinates( top = self.node_height[parent] self.process_mutations_over_node(u, self.node_height[u], top) - assert float(max_tree_height) == max_tree_height + assert float(max_time) == max_time # Add extra space above the top and below the bottom of the tree to keep the # node labels within the plotbox (but top label space not needed if the @@ -1245,20 +1288,20 @@ def assign_y_coordinates( if padding_numerator < 0: raise ValueError("Image size too small to allow space to plot tree") # Transform the y values into plot space (inverted y with 0 at the top of screen) - if self.tree_height_scale == "log_time": + if self.time_scale == "log_time": # add 1 so that don't reach log(0) = -inf error. # just shifts entire timeset by 1 unit so shouldn't affect anything - y_scale = padding_numerator / np.log(max_tree_height + 1) + y_scale = padding_numerator / np.log(max_time + 1) self.y_transform = lambda y: zero_pos - np.log(y + 1) * y_scale else: - y_scale = padding_numerator / max_tree_height + y_scale = padding_numerator / max_time self.y_transform = lambda y: zero_pos - y * y_scale # Calculate default root branch length to use (in plot coords). This is a # minimum, as branches with deep root mutations could be longer self.min_root_branch_plot_length = self.y_transform( - max_tree_height - ) - self.y_transform(max_tree_height + root_branch_length) + max_time + ) - self.y_transform(max_time + root_branch_length) def assign_x_coordinates(self): num_leaves = len(list(self.tree.leaves())) @@ -1326,7 +1369,7 @@ def draw_tree(self): tree = self.tree left_child = get_left_child(tree, self.traversal_order) - # Iterate over nodes, adding groups to reflect the tree heirarchy + # Iterate over nodes, adding groups to reflect the tree hierarchy stack = [] for u in tree.roots: grp = dwg.g( @@ -1461,7 +1504,7 @@ def __init__( trees = [ VerticalTextTree( tree, - max_tree_height="ts", + max_time="ts", node_labels=node_labels, use_ascii=use_ascii, order=order, @@ -1564,7 +1607,7 @@ def get_left_child(tree, traversal_order): return left_child -def node_time_depth(tree, min_branch_length=None, max_tree_height="tree"): +def node_time_depth(tree, min_branch_length=None, max_time="tree"): """ Returns a dictionary mapping nodes in the specified tree to their depth in the specified tree (from the root direction). If min_branch_len is @@ -1578,7 +1621,7 @@ def node_time_depth(tree, min_branch_length=None, max_tree_height="tree"): depth = {} # TODO this is basically the same code for the two cases. Refactor so that # we use the same code. - if max_tree_height == "tree": + if max_time == "tree": for u in tree.nodes(): time_node_map[tree.time(u)].append(u) for t in sorted(time_node_map.keys()): @@ -1591,7 +1634,7 @@ def node_time_depth(tree, min_branch_length=None, max_tree_height="tree"): for root in tree.roots: current_depth = max(current_depth, depth[root] + min_branch_length[root]) else: - assert max_tree_height == "ts" + assert max_time == "ts" ts = tree.tree_sequence for node in ts.nodes(): time_node_map[node.time].append(node.id) @@ -1621,16 +1664,14 @@ def __init__( self, tree, node_labels=None, - max_tree_height=None, + max_time=None, use_ascii=False, orientation=None, order=None, ): self.tree = tree self.traversal_order = check_order(order) - self.max_tree_height = check_max_tree_height( - max_tree_height, allow_numeric=False - ) + self.max_time = check_max_time(max_time, allow_numeric=False) self.use_ascii = use_ascii self.orientation = check_orientation(orientation) self.horizontal_line_char = "━" @@ -1689,9 +1730,7 @@ def _assign_time_positions(self): # TODO when we add mutations to the text tree we'll need to take it into # account here. Presumably we need to get the maximum number of mutations # per branch. - self.time_position, total_depth = node_time_depth( - tree, max_tree_height=self.max_tree_height - ) + self.time_position, total_depth = node_time_depth(tree, max_time=self.max_time) self.height = total_depth - 1 def _assign_traversal_positions(self): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 5d2763f172..1f55bf2444 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1531,7 +1531,9 @@ def draw_svg( path=None, *, size=None, + time_scale=None, tree_height_scale=None, + max_time=None, max_tree_height=None, node_labels=None, mutation_labels=None, @@ -1659,19 +1661,23 @@ def draw_svg( produced SVG drawing in abstract user units (usually interpreted as pixels on initial display). :type size: tuple(int, int) - :param str tree_height_scale: Control how height values for nodes are computed. + :param str time_scale: Control how height values for nodes are computed. If this is equal to ``"time"`` (the default), node heights are proportional to their time values. If this is equal to ``"log_time"``, node heights are proportional to their log(time) values. If it is equal to ``"rank"``, node heights are spaced equally according to their ranked times. - :param str,float max_tree_height: The maximum tree height value in the current - scaling system (see ``tree_height_scale``). Can be either a string or a - numeric value. If equal to ``"tree"`` (the default), the maximum tree height + :param str tree_height_scale: Deprecated alias for time_scale. (Deprecated in + 0.3.6) + :param str,float max_time: The maximum time value in the current + scaling system (see ``time_scale``). Can be either a string or a + numeric value. If equal to ``"tree"`` (the default), the maximum time is set to be that of the oldest root in the tree. If equal to ``"ts"`` the - maximum height is set to be the height of the oldest root in the tree + maximum time is set to be the time of the oldest root in the tree sequence; this is useful when drawing trees from the same tree sequence as it ensures that node heights are consistent. If a numeric value, this is used as - the maximum tree height by which to scale other nodes. + the maximum time by which to scale other nodes. + :param str,float max_time: Deprecated alias for max_tree_height. (Deprecated in + 0.3.6) :param node_labels: If specified, show custom labels for the nodes (specified by ID) that are present in this map; any nodes not present will not have a label. @@ -1701,7 +1707,7 @@ def draw_svg( :param str x_label: Place a label under the plot. If ``None`` (default) and there is an X axis, create and place an appropriate label. :param bool y_axis: Should the plot have an Y axis line, showing time (or - ranked node time if ``tree_height_scale="rank"``). If ``None`` (default) + ranked node time if ``time_scale="rank"``). If ``None`` (default) do not plot a Y axis. :param str y_label: Place a label to the left of the plot. If ``None`` (default) and there is a Y axis, create and place an appropriate label. @@ -1725,7 +1731,9 @@ def draw_svg( draw = drawing.SvgTree( self, size, + time_scale=time_scale, tree_height_scale=tree_height_scale, + max_time=max_time, max_tree_height=max_tree_height, node_labels=node_labels, mutation_labels=mutation_labels, @@ -1760,7 +1768,9 @@ def draw( mutation_colours=None, format=None, # noqa A002 edge_colours=None, + time_scale=None, tree_height_scale=None, + max_time=None, max_tree_height=None, order=None, ): @@ -1834,22 +1844,26 @@ def draw( joining each node in the map to its parent. As for ``node_colours``, unspecified edges take the default colour, and ``None`` values result in the edge being omitted. (Only supported in the SVG format.) - :param str tree_height_scale: Control how height values for nodes are computed. + :param str time_scale: Control how height values for nodes are computed. If this is equal to ``"time"``, node heights are proportional to their time values. If this is equal to ``"log_time"``, node heights are proportional to their log(time) values. If it is equal to ``"rank"``, node heights are spaced equally according to their ranked times. For SVG output the default is 'time'-scale whereas for text output the default is 'rank'-scale. Time scaling is not currently supported for text output. - :param str,float max_tree_height: The maximum tree height value in the current - scaling system (see ``tree_height_scale``). Can be either a string or a - numeric value. If equal to ``"tree"``, the maximum tree height is set to be + :param str tree_height_scale: Deprecated alias for time_scale. (Deprecated in + 0.3.6) + :param str,float max_time: The maximum time value in the current + scaling system (see ``time_scale``). Can be either a string or a + numeric value. If equal to ``"tree"``, the maximum time is set to be that of the oldest root in the tree. If equal to ``"ts"`` the maximum - height is set to be the height of the oldest root in the tree sequence; + time is set to be the time of the oldest root in the tree sequence; this is useful when drawing trees from the same tree sequence as it ensures that node heights are consistent. If a numeric value, this is used as the - maximum tree height by which to scale other nodes. This parameter + maximum time by which to scale other nodes. This parameter is not currently supported for text output. + :param str max_tree_height: Deprecated alias for max_time. (Deprecated in + 0.3.6) :param str order: The left-to-right ordering of child nodes in the drawn tree. This can be either: ``"minlex"``, which minimises the differences between adjacent trees (see also the ``"minlex_postorder"`` traversal @@ -1870,7 +1884,9 @@ def draw( mutation_labels=mutation_labels, mutation_colours=mutation_colours, edge_colours=edge_colours, + time_scale=time_scale, tree_height_scale=tree_height_scale, + max_time=max_time, max_tree_height=max_tree_height, order=order, ) @@ -5456,6 +5472,7 @@ def draw_svg( *, size=None, x_scale=None, + time_scale=None, tree_height_scale=None, node_labels=None, mutation_labels=None, @@ -5512,11 +5529,13 @@ def draw_svg( corresponds to a tree boundary, which are positioned evenly along the axis, so that the X axis is of variable scale, no background scaling is required, and site positions are not marked on the axis. - :param str tree_height_scale: Control how height values for nodes are computed. + :param str time_scale: Control how height values for nodes are computed. If this is equal to ``"time"``, node heights are proportional to their time values (this is the default). If this is equal to ``"log_time"``, node heights are proportional to their log(time) values. If it is equal to ``"rank"``, node heights are spaced equally according to their ranked times. + :param str tree_height_scale: Deprecated alias for time_scale. (Deprecated in + 0.3.6) :param node_labels: If specified, show custom labels for the nodes (specified by ID) that are present in this map; any nodes not present will not have a label. @@ -5552,7 +5571,7 @@ def draw_svg( not be shown. To force display of the entire tree sequence, including empty flanking regions, specify ``x_lim=[0, ts.sequence_length]``. :param bool y_axis: Should the plot have an Y axis line, showing time (or - ranked node time if ``tree_height_scale="rank"``. If ``None`` (default) + ranked node time if ``time_scale="rank"``. If ``None`` (default) do not plot a Y axis. :param str y_label: Place a label to the left of the plot. If ``None`` (default) and there is a Y axis, create and place an appropriate label. @@ -5580,6 +5599,7 @@ def draw_svg( self, size, x_scale=x_scale, + time_scale=time_scale, tree_height_scale=tree_height_scale, node_labels=node_labels, mutation_labels=mutation_labels,