From 5b384a515012645a6f002086edd077b89ff24453 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 19 Mar 2021 09:00:19 +0000 Subject: [PATCH] Allow mutations to be plotted per edge not per tree --- python/CHANGELOG.rst | 3 + python/tests/data/svg/internal_sample_ts.svg | 27 ++- python/tests/data/svg/tree.svg | 26 +-- python/tests/data/svg/tree_both_axes.svg | 36 ++-- python/tests/data/svg/tree_muts.svg | 42 ++-- python/tests/data/svg/tree_muts_all_edge.svg | 120 +++++++++++ python/tests/data/svg/tree_timed_muts.svg | 42 ++-- python/tests/data/svg/tree_x_axis.svg | 52 +++-- python/tests/data/svg/tree_y_axis.svg | 59 +++--- python/tests/data/svg/ts.svg | 27 ++- python/tests/data/svg/ts_multiroot.svg | 2 +- python/tests/data/svg/ts_mut_highlight.svg | 27 ++- python/tests/data/svg/ts_mut_times.svg | 27 ++- .../tests/data/svg/ts_mut_times_logscale.svg | 27 ++- .../tests/data/svg/ts_mutations_no_edges.svg | 2 +- .../data/svg/ts_mutations_timed_no_edges.svg | 2 +- python/tests/data/svg/ts_no_axes.svg | 16 +- python/tests/data/svg/ts_nonbinary.svg | 2 +- python/tests/data/svg/ts_plain.svg | 16 +- python/tests/data/svg/ts_plain_no_xlab.svg | 16 +- python/tests/data/svg/ts_plain_y.svg | 16 +- python/tests/data/svg/ts_rank.svg | 27 ++- python/tests/data/svg/ts_xlabel.svg | 27 ++- python/tests/data/svg/ts_y_axis.svg | 27 ++- python/tests/data/svg/ts_y_axis_log.svg | 27 ++- python/tests/data/svg/ts_y_axis_regular.svg | 27 ++- python/tests/test_drawing.py | 40 ++++ python/tskit/drawing.py | 190 +++++++++++++----- python/tskit/trees.py | 10 + 29 files changed, 732 insertions(+), 230 deletions(-) create mode 100644 python/tests/data/svg/tree_muts_all_edge.svg diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 24bd64cb88..425f4e3147 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -6,6 +6,9 @@ **Features** +- SVG visualization of a single tree allows all mutations on an edge to be plotted + via the ``all_edge_mutations`` param (:user:`hyanwong`,:issue:`1253`, :pr:`1258`). + **Fixes** -------------------- diff --git a/python/tests/data/svg/internal_sample_ts.svg b/python/tests/data/svg/internal_sample_ts.svg index 4db49d6830..03da9e4685 100644 --- a/python/tests/data/svg/internal_sample_ts.svg +++ b/python/tests/data/svg/internal_sample_ts.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -184,8 +193,13 @@ 2 - + + + + + 6 + 3 @@ -282,8 +296,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/tree.svg b/python/tests/data/svg/tree.svg index ca28992362..deed14abe4 100644 --- a/python/tests/data/svg/tree.svg +++ b/python/tests/data/svg/tree.svg @@ -1,38 +1,38 @@ - + - - - + + + 0 - - + + 1 - + 4 - - - + + + 2 - - + + 3 - + 5 diff --git a/python/tests/data/svg/tree_both_axes.svg b/python/tests/data/svg/tree_both_axes.svg index 2ccad3904a..3058f1e007 100644 --- a/python/tests/data/svg/tree_both_axes.svg +++ b/python/tests/data/svg/tree_both_axes.svg @@ -1,7 +1,7 @@ - + @@ -24,23 +24,23 @@ - + Time - - + + 0.11 - + 0.00 - + 1.11 @@ -56,33 +56,33 @@ - - - + + + 0 - - + + 1 - + 4 - - - + + + 2 - - + + 3 - + 5 diff --git a/python/tests/data/svg/tree_muts.svg b/python/tests/data/svg/tree_muts.svg index a67d60cbcb..4757a9aacb 100644 --- a/python/tests/data/svg/tree_muts.svg +++ b/python/tests/data/svg/tree_muts.svg @@ -1,54 +1,54 @@ - + - - - - + + + + 0 - - + + 1 - + 4 - - - + + + 2 - - + + 3 - - - + + + 2 5 - - - + + + 0 - - + + 1 diff --git a/python/tests/data/svg/tree_muts_all_edge.svg b/python/tests/data/svg/tree_muts_all_edge.svg new file mode 100644 index 0000000000..fb9b58802a --- /dev/null +++ b/python/tests/data/svg/tree_muts_all_edge.svg @@ -0,0 +1,120 @@ + + + + + + + + + + + + + Genome position + + + + + + 0.06 + + + + + + 0.79 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 + + + + + 1 + + + + + + 3 + + + + + 4 + + + 4 + + + + + + 2 + + + + + + + 6 + + + + + 7 + + + 3 + + + + 5 + + + + + + 5 + + + 7 + + + + diff --git a/python/tests/data/svg/tree_timed_muts.svg b/python/tests/data/svg/tree_timed_muts.svg index da878b24b7..5e6c301732 100644 --- a/python/tests/data/svg/tree_timed_muts.svg +++ b/python/tests/data/svg/tree_timed_muts.svg @@ -1,54 +1,54 @@ - + - - - - + + + + 0 - - + + 1 - + 4 - - - + + + 2 - - + + 3 - - - + + + 2 5 - - - + + + 0 - - + + 1 diff --git a/python/tests/data/svg/tree_x_axis.svg b/python/tests/data/svg/tree_x_axis.svg index 0bf522b335..f23a085645 100644 --- a/python/tests/data/svg/tree_x_axis.svg +++ b/python/tests/data/svg/tree_x_axis.svg @@ -1,7 +1,7 @@ - + @@ -33,9 +33,12 @@ - + + + + @@ -43,50 +46,55 @@ - - - - + + + + 0 - - + + 1 - - - + + + 3 - - + + 4 4 - - - + + + 2 - - + + + + + + 6 + 3 - + 5 - - - + + + 5 diff --git a/python/tests/data/svg/tree_y_axis.svg b/python/tests/data/svg/tree_y_axis.svg index e37b7fdf08..8eed1aeff1 100644 --- a/python/tests/data/svg/tree_y_axis.svg +++ b/python/tests/data/svg/tree_y_axis.svg @@ -1,37 +1,37 @@ - + - + Time (relative steps) - - + + 0.00 - + 1.00 - + 2.00 - + @@ -41,50 +41,55 @@ - - - - + + + + 0 - - + + 1 - - - + + + 3 - - + + 4 4 - - - + + + 2 - - + + + + + + 6 + 3 - + 5 - - - + + + 5 diff --git a/python/tests/data/svg/ts.svg b/python/tests/data/svg/ts.svg index 3eed7e0685..92f7a96f8e 100644 --- a/python/tests/data/svg/ts.svg +++ b/python/tests/data/svg/ts.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -174,8 +183,13 @@ 2 - + + + + + 6 + 3 @@ -257,8 +271,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_multiroot.svg b/python/tests/data/svg/ts_multiroot.svg index 1140fa2b57..6b5c71ce43 100644 --- a/python/tests/data/svg/ts_multiroot.svg +++ b/python/tests/data/svg/ts_multiroot.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/ts_mut_highlight.svg b/python/tests/data/svg/ts_mut_highlight.svg index c7beab410e..c6a50dac35 100644 --- a/python/tests/data/svg/ts_mut_highlight.svg +++ b/python/tests/data/svg/ts_mut_highlight.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -174,8 +183,13 @@ 2 - + + + + + 6 + 3 @@ -257,8 +271,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_mut_times.svg b/python/tests/data/svg/ts_mut_times.svg index 064b278be0..b8c9b9a2a3 100644 --- a/python/tests/data/svg/ts_mut_times.svg +++ b/python/tests/data/svg/ts_mut_times.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -174,8 +183,13 @@ 2 - + + + + + 6 + 3 @@ -257,8 +271,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_mut_times_logscale.svg b/python/tests/data/svg/ts_mut_times_logscale.svg index 6b05dbe881..3abde3e824 100644 --- a/python/tests/data/svg/ts_mut_times_logscale.svg +++ b/python/tests/data/svg/ts_mut_times_logscale.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -174,8 +183,13 @@ 2 - + + + + + 6 + 3 @@ -257,8 +271,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_mutations_no_edges.svg b/python/tests/data/svg/ts_mutations_no_edges.svg index ca5236eaa0..1eb89ca4f5 100644 --- a/python/tests/data/svg/ts_mutations_no_edges.svg +++ b/python/tests/data/svg/ts_mutations_no_edges.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/ts_mutations_timed_no_edges.svg b/python/tests/data/svg/ts_mutations_timed_no_edges.svg index a05fa76538..6fe89cb001 100644 --- a/python/tests/data/svg/ts_mutations_timed_no_edges.svg +++ b/python/tests/data/svg/ts_mutations_timed_no_edges.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/ts_no_axes.svg b/python/tests/data/svg/ts_no_axes.svg index b2d87ef320..02a4987765 100644 --- a/python/tests/data/svg/ts_no_axes.svg +++ b/python/tests/data/svg/ts_no_axes.svg @@ -1,7 +1,7 @@ - + @@ -93,8 +93,13 @@ 2 - + + + + + 6 + 3 @@ -176,8 +181,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_nonbinary.svg b/python/tests/data/svg/ts_nonbinary.svg index 1e291aa863..90701c0953 100644 --- a/python/tests/data/svg/ts_nonbinary.svg +++ b/python/tests/data/svg/ts_nonbinary.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/ts_plain.svg b/python/tests/data/svg/ts_plain.svg index 1d83e8b1fc..ccc3f666f7 100644 --- a/python/tests/data/svg/ts_plain.svg +++ b/python/tests/data/svg/ts_plain.svg @@ -1,7 +1,7 @@ - + @@ -137,8 +137,13 @@ 2 - + + + + + 6 + 3 @@ -220,8 +225,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_plain_no_xlab.svg b/python/tests/data/svg/ts_plain_no_xlab.svg index 5923917ec2..05f0b883e7 100644 --- a/python/tests/data/svg/ts_plain_no_xlab.svg +++ b/python/tests/data/svg/ts_plain_no_xlab.svg @@ -1,7 +1,7 @@ - + @@ -134,8 +134,13 @@ 2 - + + + + + 6 + 3 @@ -217,8 +222,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_plain_y.svg b/python/tests/data/svg/ts_plain_y.svg index 4e62bf0b7f..6c8a10a47c 100644 --- a/python/tests/data/svg/ts_plain_y.svg +++ b/python/tests/data/svg/ts_plain_y.svg @@ -1,7 +1,7 @@ - + @@ -164,8 +164,13 @@ 2 - + + + + + 6 + 3 @@ -247,8 +252,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_rank.svg b/python/tests/data/svg/ts_rank.svg index ffaad8d8b0..f0a267e20b 100644 --- a/python/tests/data/svg/ts_rank.svg +++ b/python/tests/data/svg/ts_rank.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -222,8 +231,13 @@ 2 - + + + + + 6 + 3 @@ -305,8 +319,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_xlabel.svg b/python/tests/data/svg/ts_xlabel.svg index 15c54c80a1..46f676c06d 100644 --- a/python/tests/data/svg/ts_xlabel.svg +++ b/python/tests/data/svg/ts_xlabel.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -174,8 +183,13 @@ 2 - + + + + + 6 + 3 @@ -257,8 +271,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_y_axis.svg b/python/tests/data/svg/ts_y_axis.svg index 308945ba8d..700da37e91 100644 --- a/python/tests/data/svg/ts_y_axis.svg +++ b/python/tests/data/svg/ts_y_axis.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -222,8 +231,13 @@ 2 - + + + + + 6 + 3 @@ -305,8 +319,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_y_axis_log.svg b/python/tests/data/svg/ts_y_axis_log.svg index ab0732f08c..23d49a9db4 100644 --- a/python/tests/data/svg/ts_y_axis_log.svg +++ b/python/tests/data/svg/ts_y_axis_log.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -222,8 +231,13 @@ 2 - + + + + + 6 + 3 @@ -305,8 +319,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/data/svg/ts_y_axis_regular.svg b/python/tests/data/svg/ts_y_axis_regular.svg index 59130d2392..ca59d41490 100644 --- a/python/tests/data/svg/ts_y_axis_regular.svg +++ b/python/tests/data/svg/ts_y_axis_regular.svg @@ -1,7 +1,7 @@ - + @@ -76,13 +76,22 @@ - + + + + + + + + + + @@ -250,8 +259,13 @@ 2 - + + + + + 6 + 3 @@ -333,8 +347,13 @@ 2 - + + + + + 7 + 3 diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index 57aa268d38..16b883a9fb 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -197,6 +197,7 @@ def get_simple_ts(self, use_mutation_times=False): 0.06 0 0.3 Empty 0.5 XXX + 0.91 T """ ) muts = io.StringIO( @@ -208,6 +209,8 @@ def get_simple_ts(self, use_mutation_times=False): 1 4 C -1 1.6 1 4 G 3 1.5 2 7 G -1 10 + 2 3 C 5 1 + 4 3 G -1 1 """ ) ts = tskit.load_text(nodes, edges, sites=sites, mutations=muts, strict=False) @@ -1783,6 +1786,33 @@ def test_unplotted_mutation(self): svg_no_css = svg[svg.find("") :] assert svg_no_css.count("fill-opacity:0") == 1 + @pytest.mark.parametrize("all_muts", [False, True]) + @pytest.mark.parametrize("x_axis", [False, True]) + def test_extra_mutations(self, all_muts, x_axis): + # The simple_ts has 2 mutations on an edge which spans the whole ts + # One mut is within tree 1, the other within tree 3 + ts = self.get_simple_ts() + extra_mut_copies = 0 + if all_muts: + extra_mut_copies = 2 if x_axis else 1 + extra_right = ts.at_index(1) + svg = extra_right.draw_svg(all_edge_mutations=all_muts, x_axis=x_axis) + self.verify_basic_svg(svg) + svg_no_css = svg[svg.find("") :] + assert svg_no_css.count("extra") == 1 * extra_mut_copies + + extra_right_and_left = ts.at_index(2) + svg = extra_right_and_left.draw_svg(all_edge_mutations=all_muts, x_axis=x_axis) + self.verify_basic_svg(svg) + svg_no_css = svg[svg.find("") :] + assert svg_no_css.count("extra") == 2 * extra_mut_copies + + extra_left = ts.at_index(3) + svg = extra_left.draw_svg(all_edge_mutations=all_muts, x_axis=x_axis) + self.verify_basic_svg(svg) + svg_no_css = svg[svg.find("") :] + assert svg_no_css.count("extra") == 1 * extra_mut_copies + def test_max_tree_height(self): nodes = io.StringIO( """\ @@ -2096,6 +2126,16 @@ def test_known_svg_tree_root_mut(self, overwrite_viz, draw_plotbox): svg = tree.draw_svg(debug_box=draw_plotbox) self.verify_known_svg(svg, "tree_muts.svg", overwrite_viz) + def test_known_svg_tree_mut_all_edge(self, overwrite_viz, draw_plotbox): + tree = self.get_simple_ts().at_index(1) + size = (300, 400) + svg = tree.draw_svg( + size=size, debug_box=draw_plotbox, all_edge_mutations=True, x_axis=True + ) + self.verify_known_svg( + svg, "tree_muts_all_edge.svg", overwrite_viz, width=size[0], height=size[1] + ) + def test_known_svg_tree_timed_root_mut(self, overwrite_viz, draw_plotbox): tree = self.get_simple_ts(use_mutation_times=True).at_index(0) svg = tree.draw_svg(debug_box=draw_plotbox) diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index 87aa48d1ae..19eef87517 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -314,8 +314,8 @@ class SvgPlot: """ The base class for plotting either a tree or a tree sequence as an SVG file""" standard_style = ( - ".tree-sequence .background path {fill: #808080; fill-opacity:0}" - ".tree-sequence .background path:nth-child(odd) {fill-opacity:.1}" + ".background path {fill: #808080; fill-opacity:0}" + ".background path:nth-child(odd) {fill-opacity:.1}" ".axes {font-size: 14px}" ".x-axis .tick .lab {font-weight: bold}" ".axes, .tree {font-size: 14px; text-anchor:middle}" @@ -327,8 +327,10 @@ class SvgPlot: ".node > .sym {fill: black; stroke: none}" ".site > .sym {stroke: black}" ".mut text {fill: red; font-style: italic}" + ".mut.extra text {fill: hotpink}" ".mut line {fill: none; stroke: none}" # Default hide mut line to expose edges ".mut .sym {fill: none; stroke: red}" + ".mut.extra .sym {stroke: hotpink}" ".node .mut .sym {stroke-width: 1.5px}" ".tree text, .tree-sequence text {dominant-baseline: central}" ".plotbox .lab.lft {text-anchor: end}" @@ -390,6 +392,7 @@ def __init__( y_label = "Time" self.x_label = x_label self.y_label = y_label + self.mutations_outside_tree = set() # mutations in here get an additional class def get_plotbox(self): """ @@ -446,7 +449,7 @@ def draw_x_axis( tick_labels=None, # Tick labels below axis. If None, use the position value tick_length_lower=default_tick_length, tick_length_upper=None, # If None, use the same as tick_length_lower - sites=None, # An iterator over site objects to plot as ticks above the x axis + site_muts=None, # A dict of site id => mutation to plot as ticks on the x axis ): if not self.x_axis and not self.x_label: return @@ -474,7 +477,7 @@ def draw_x_axis( label_precision = 0 if integer_ticks else 2 tick_labels = [f"{lab:.{label_precision}f}" for lab in tick_labels] - upper_length = -tick_length_upper if sites is None else 0 + upper_length = -tick_length_upper if site_muts is None else 0 for pos, lab in itertools.zip_longest(tick_positions, tick_labels): tick = x_axis.add( dwg.g( @@ -488,9 +491,10 @@ def draw_x_axis( self.add_text_in_group( lab, tick, pos=(0, tick_length_lower), group_class="lab" ) - if sites is not None: - # Add sites as upper chevrons - for s in sites: + if site_muts is not None: + # Add sites as vertical lines with overlaid mutations as upper chevrons + for s_id, mutations in site_muts.items(): + s = self.ts.site(s_id) x = self.x_transform(s.position) site = x_axis.add( dwg.g( @@ -500,8 +504,11 @@ def draw_x_axis( site.add( dwg.line((0, 0), (0, rnd(-tick_length_upper)), class_="sym") ) - for i, m in enumerate(reversed(s.mutations)): - mut = dwg.g(class_=f"mut m{m.id}") + for i, m in enumerate(reversed(mutations)): + mutation_class = f"mut m{m.id}" + if m.id in self.mutations_outside_tree: + mutation_class += " extra" + mut = dwg.g(class_=mutation_class) h = -i * 4 - 1.5 w = tick_length_upper / 4 mut.add( @@ -567,6 +574,55 @@ def draw_y_axis( text_anchor="end", ) + def shade_background( + self, + breaks, + tick_length_lower, + tree_width=None, + bottom_padding=None, + ): + if not self.x_axis: + return + if tree_width is None: + tree_width = self.plotbox.width + if bottom_padding is None: + bottom_padding = self.plotbox.pad_bottom + plot_breaks = self.x_transform(np.array(breaks)) + dwg = self.drawing + + # For tree sequences, we need to add on the background shaded regions + self.root_groups["background"] = self.dwg_base.add(dwg.g(class_="background")) + y = self.image_size[1] - self.x_axis_offset + for i in range(1, len(breaks)): + break_x = plot_breaks[i] + prev_break_x = plot_breaks[i - 1] + tree_x = i * tree_width + self.plotbox.left + prev_tree_x = (i - 1) * tree_width + self.plotbox.left + # Shift diagonal lines between tree & axis into the treebox a little + diag_height = y - (self.image_size[1] - bottom_padding) + self.root_groups["background"].add( + # NB: the path below draws straight diagonal lines between the tree boxes + # and the X axis. An alternative implementation using bezier curves could + # substitute the following for lines 2 and 4 of the path spec string + # "l0,{box_h:g} c0,{diag_h} {rdiag_x},0 {rdiag_x},{diag_h} " + # "c0,-{diag_h} {ldiag_x},0 {ldiag_x},-{diag_h} l0,-{box_h:g}z" + dwg.path( + "M{start_x:g},0 l{box_w:g},0 " # Top left to top right of tree box + "l0,{box_h:g} l{rdiag_x:g},{diag_h:g} " # Down to axis + "l0,{tick_h:g} l{ax_x:g},0 l0,-{tick_h:g} " # Between axis ticks + "l{ldiag_x:g},-{diag_h:g} l0,-{box_h:g}z".format( # Up from axis + start_x=rnd(prev_tree_x), + box_w=rnd(tree_x - prev_tree_x), + box_h=rnd(y - diag_height), + rdiag_x=rnd(break_x - tree_x), + diag_h=rnd(diag_height), + tick_h=rnd(tick_length_lower), + ax_x=rnd(prev_break_x - break_x), + ldiag_x=rnd(prev_tree_x - prev_break_x), + ) + ) + ) + def x_transform(self, x): raise NotImplementedError( "No transform func defined for genome pos -> plot coords" @@ -641,7 +697,7 @@ def __init__( for tree in ts.trees() ) # TODO add general padding arguments following matplotlib's terminology. - self.set_spacing(top=0, left=20, bottom=15, right=20) + self.set_spacing(top=0, left=20, bottom=10, right=20) svg_trees = [ SvgTree( tree, @@ -728,44 +784,18 @@ def draw_x_axis( lambda x: self.plotbox.left + x / self.ts.sequence_length * self.plotbox.width ) - plot_breaks = self.x_transform(breaks) - dwg = self.drawing - - # For tree sequences, we need to add on the background shaded regions - self.root_groups["background"] = self.dwg_base.add( - dwg.g(class_="background") - ) - # plotbox_bottom_padding += 10 # extra space for the diagonal lines - y = self.image_size[1] - self.x_axis_offset - for i in range(1, len(breaks)): - break_x = plot_breaks[i] - prev_break_x = plot_breaks[i - 1] - tree_x = i * self.tree_plotbox.max_x + self.plotbox.left - prev_tree_x = (i - 1) * self.tree_plotbox.max_x + self.plotbox.left - # Shift diagonal lines between tree & axis into the treebox a little - diag_height = y - ( - self.plotbox.bottom - self.tree_plotbox.pad_bottom - ) - self.root_groups["background"].add( - dwg.path( - f"M{rnd(prev_tree_x):g},0 " - f"l{rnd(tree_x-prev_tree_x):g},0 " - f"l0,{rnd(y - diag_height):g} " - f"l{rnd(break_x-tree_x):g},{rnd(diag_height):g} " - # NB for curves try "c0,{1} {0},0 {0},{1}" instead of above - f"l0,{rnd(tick_length_lower):g} " - f"l{rnd(prev_break_x-break_x):g},0 " - f"l0,{rnd(-tick_length_lower):g} " - f"l{rnd(prev_tree_x-prev_break_x):g},{rnd(-diag_height):g} " - # NB for curves try "c0,{1} {0},0 {0},{1}" instead of above - f"l0,{rnd(diag_height - y):g}z", - ) - ) + self.shade_background( + breaks, + tick_length_lower, + self.tree_plotbox.max_x, + self.plotbox.pad_bottom + self.tree_plotbox.pad_bottom, + ) + site_muts = {s.id: s.mutations for s in self.ts.sites()} super().draw_x_axis( tick_positions=breaks, tick_length_lower=tick_length_lower, tick_length_upper=tick_length_upper, - sites=self.ts.sites(), + site_muts=site_muts, ) else: @@ -804,6 +834,7 @@ def __init__( y_label=None, y_ticks=None, y_gridlines=None, + all_edge_mutations=None, tree_height_scale=None, node_attrs=None, mutation_attrs=None, @@ -817,8 +848,9 @@ def __init__( if symbol_size is None: symbol_size = 6 self.symbol_size = symbol_size + ts = tree.tree_sequence super().__init__( - tree.tree_sequence, + ts, size, root_svg_attributes, style, @@ -857,6 +889,42 @@ def __init__( f"Mutations {unplotted} are above nodes which are not present in the " "displayed tree, so are not plotted on the topology." ) + self.left_extent = tree.interval.left + self.right_extent = tree.interval.right + if all_edge_mutations: + tree_left = tree.interval.left + tree_right = tree.interval.right + edge_left = ts.tables.edges.left + edge_right = ts.tables.edges.right + node_edges = tree._node_edges() + # whittle mutations down so we only need look at those above the tree nodes + mut_t = ts.tables.mutations + focal_mutations = np.isin(mut_t.node, np.fromiter(nodes, mut_t.node.dtype)) + mutation_nodes = mut_t.node[focal_mutations] + mutation_positions = ts.tables.sites.position[mut_t.site][focal_mutations] + mutation_ids = np.arange(ts.num_mutations, dtype=int)[focal_mutations] + for m_id, node, pos in zip( + mutation_ids, mutation_nodes, mutation_positions + ): + curr_edge = node_edges[node] + if curr_edge >= 0: + if ( + edge_left[curr_edge] <= pos < tree_left + ): # Mutation on this edge but to left of plotted tree + self.node_mutations[node].append(ts.mutation(m_id)) + self.mutations_outside_tree.add(m_id) + self.left_extent = min(self.left_extent, pos) + elif ( + tree_right <= pos < edge_right[curr_edge] + ): # Mutation on this edge but to right of plotted tree + self.node_mutations[node].append(ts.mutation(m_id)) + self.mutations_outside_tree.add(m_id) + self.right_extent = max(self.right_extent, pos) + if self.right_extent != tree.interval.right: + # Use nextafter so extent of plotting incorporates the mutation + self.right_extent = np.nextafter( + self.right_extent, self.right_extent + 1 + ) # attributes for symbols half_symbol_size = "{:g}".format(rnd(symbol_size / 2)) symbol_size = "{:g}".format(rnd(symbol_size)) @@ -885,8 +953,8 @@ def __init__( add_class(self.node_label_attrs[u], "lab") # class 'lab' for label if node_label_attrs is not None and u in node_label_attrs: self.node_label_attrs[u].update(node_label_attrs[u]) - for site in tree.sites(): - for mutation in site.mutations: + for _, mutations in self.node_mutations.items(): + for mutation in mutations: m = mutation.id # We need to offset the rectangle so that it's centred self.mutation_attrs[m] = { @@ -907,14 +975,28 @@ def __init__( self.mutation_label_attrs[m].update(mutation_label_attrs[m]) add_class(self.mutation_label_attrs[m], "lab") - self.set_spacing(top=10, left=20, bottom=10, right=20) + self.set_spacing(top=10, left=20, bottom=15, right=20) self.assign_y_coordinates(max_tree_height, force_root_branch) self.assign_x_coordinates() + tick_length_lower = self.default_tick_length # TODO - parameterize + tick_length_upper = self.default_tick_length_site # TODO - parameterize + if all_edge_mutations: + self.shade_background(tree.interval, tick_length_lower) + + first_site, last_site = np.searchsorted( + self.ts.tables.sites.position, [self.left_extent, self.right_extent] + ) + site_muts = {site_id: [] for site_id in range(first_site, last_site)} + # Only use mutations plotted on the tree (not necessarily all at the site) + for muts in self.node_mutations.values(): + for mut in muts: + site_muts[mut.site].append(mut) + self.draw_x_axis( tick_positions=np.array(tree.interval), - tick_length_lower=self.default_tick_length, # TODO - parameterize - tick_length_upper=self.default_tick_length_site, # TODO - parameterize - sites=tree.sites(), + tick_length_lower=tick_length_lower, + tick_length_upper=tick_length_upper, + site_muts=site_muts, ) self.draw_y_axis( lower=self.y_transform(0), @@ -1076,7 +1158,9 @@ def assign_x_coordinates(self): self.node_x_coord_map = node_x_coord_map # Transform is not for nodes but for genome positions self.x_transform = lambda x: ( - (x - self.tree.interval.left) / self.tree.interval.span * self.plotbox.width + (x - self.left_extent) + / (self.right_extent - self.left_extent) + * self.plotbox.width + self.plotbox.left ) @@ -1180,6 +1264,8 @@ def draw_tree(self): mutation_class = f"mut m{mutation.id} s{mutation.site}" if util.is_unknown_time(self.ts.mutation(mutation.id).time): mutation_class += " unknown_time" + if mutation.id in self.mutations_outside_tree: + mutation_class += " extra" mut_group = curr_svg_group.add( dwg.g(class_=mutation_class, transform=f"translate(0 {rnd(dy)})") ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 5c46ed1da1..0412ca836c 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1469,6 +1469,7 @@ def draw_svg( y_label=None, y_ticks=None, y_gridlines=None, + all_edge_mutations=None, **kwargs, ): """ @@ -1632,6 +1633,14 @@ def draw_svg( node value. :param bool y_gridlines: Whether to plot horizontal lines behind the tree at each y tickmark. + :param bool all_edge_mutations: The edge on which a mutation occurs may span + multiple trees. If ``False`` or ``None`` (default) mutations are only drawn + on an edge if their site position exists within the genomic interval covered + by this tree. If ``True``, all mutations on each edge of the tree are drawn, + even if the their genomic position is to the left or right of the tree + itself (by default these "extra" mutations are drawn in a different colour). + Note that this means that independent drawings of different trees + from the same tree sequence may share some plotted mutations. :return: An SVG representation of a tree. :rtype: str @@ -1654,6 +1663,7 @@ def draw_svg( y_label=y_label, y_ticks=y_ticks, y_gridlines=y_gridlines, + all_edge_mutations=all_edge_mutations, **kwargs, ) output = draw.drawing.tostring()