From af43883575a87a9e612f9de855bc621c09f51de0 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Wed, 20 Aug 2014 13:22:41 -0400 Subject: [PATCH 01/24] NF: add version of Paul Ivanov's slice viewer Thanks Paul... --- nibabel/viewers.py | 215 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 nibabel/viewers.py diff --git a/nibabel/viewers.py b/nibabel/viewers.py new file mode 100644 index 0000000000..7ae4627d9a --- /dev/null +++ b/nibabel/viewers.py @@ -0,0 +1,215 @@ +""" Utilities for viewing images + +Includes version of OrthoSlicer3D code by our own Paul Ivanov +""" +from __future__ import division, print_function + +import numpy as np + +from .optpkg import optional_package + +plt, _, _ = optional_package('matplotlib.pyplot') +mpl_img, _, _ = optional_package('matplotlib.image') + +# Assumes the following layout +# +# ^ +---------+ ^ +---------+ +# | | | | | | +# | | | | +# z | 2 | z | 3 | +# | | | | +# | | | | | | +# v +---------+ v +---------+ +# <-- x --> <-- y --> +# ^ +---------+ +# | | | +# | | +# y | 1 | +# | | +# | | | +# v +---------+ +# <-- x --> + +class OrthoSlicer3D(object): + """Orthogonal-plane slicer. + + OrthoSlicer3d expects 3-dimensional data, and by default it creates a + figure with 3 axes, one for each slice orientation. + + There are two modes, "following on" and "following off". In "following on" + mode, moving the mouse in any one axis will select out the corresponding + slices in the other two. The mode is "following off" when the figure is + first created. Clicking the left mouse button toggles mouse following and + triggers a full redraw (to update the ticks, for example). Scrolling up and + down moves the slice up and down in the current axis. + + Example + ------- + import numpy as np + a = np.sin(np.linspace(0,np.pi,20)) + b = np.sin(np.linspace(0,np.pi*5,20)) + data = np.outer(a,b)[..., np.newaxis]*a + OrthoSlicer3D(data).show() + """ + def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', + pcnt_range=None): + """ + Parameters + ---------- + data : 3 dimensional ndarray + The data that will be displayed by the slicer + axes : None or length 3 sequence of mpl.Axes, optional + 3 axes instances for the X, Y, and Z slices, or None (default) + aspect_ratio : float or length 3 sequence, optional + stretch factors for X, Y, Z directions + cmap : colormap identifier, optional + String or cmap instance specifying colormap. Will be passed as + ``cmap`` argument to ``plt.imshow``. + pcnt_range : length 2 sequence, optional + Percentile range over which to scale image for display. If None, + scale between image mean and max. If sequence, min and max + percentile over which to scale image. + """ + data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension + aspect_ratio = np.array(aspect_ratio) + if axes is None: # make the axes + # ^ +---------+ ^ +---------+ + # | | | | | | + # | | | | + # z | 2 | z | 3 | + # | | | | + # | | | | | | + # v +---------+ v +---------+ + # <-- x --> <-- y --> + # ^ +---------+ + # | | | + # | | + # y | 1 | + # | | + # | | | + # v +---------+ + # <-- x --> + fig = plt.figure() + x, y, z = data_shape * aspect_ratio + maxw = float(x + y) + maxh = float(y + z) + yh = y / maxh + xw = x / maxw + yw = y / maxw + zh = z / maxh + # z slice (if usual transverse acquisition => axial slice) + ax1 = fig.add_axes((0., 0., xw, yh)) + # y slice (usually coronal) + ax2 = fig.add_axes((0, yh, xw, zh)) + # x slice (usually sagittal) + ax3 = fig.add_axes((xw, yh, yw, zh)) + axes = (ax1, ax2, ax3) + else: + if not np.all(aspect_ratio == 1): + raise ValueError('Aspect ratio must be 1 for external axes') + ax1, ax2, ax3 = axes + + self.data = data + + if pcnt_range is None: + vmin, vmax = data.min(), data.max() + else: + vmin, vmax = np.percentile(data, pcnt_range) + + kw = dict(vmin=vmin, + vmax=vmax, + aspect='auto', + interpolation='nearest', + cmap=cmap, + origin='lower') + # Start midway through each axis + st_x, st_y, st_z = (data_shape - 1) / 2. + n_x, n_y, n_z = data_shape + z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T + y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T + x_get_slice = lambda i: self.data[min(i, n_x-1), :, :].T + im1 = ax1.imshow(z_get_slice(st_z), **kw) + im2 = ax2.imshow(y_get_slice(st_y), **kw) + im3 = ax3.imshow(x_get_slice(st_x), **kw) + im1.get_slice, im2.get_slice, im3.get_slice = ( + z_get_slice, y_get_slice, x_get_slice) + # idx is the current slice number for each panel + im1.idx, im2.idx, im3.idx = st_z, st_y, st_x + # set the maximum dimensions for indexing + im1.size, im2.size, im3.size = n_z, n_y, n_x + # setup pairwise connections between the slice dimensions + im1.imx = im3 # x move in panel 1 (usually axial) + im1.imy = im2 # y move in panel 1 + im2.imx = im3 # x move in panel 2 (usually coronal) + im2.imy = im1 + im3.imx = im2 # x move in panel 3 (usually sagittal) + im3.imy = im1 + + self.follow = False + self.figs = set([ax.figure for ax in axes]) + for fig in self.figs: + fig.canvas.mpl_connect('button_press_event', self.on_click) + fig.canvas.mpl_connect('scroll_event', self.on_scroll) + fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) + + def show(self): + """ Show the slicer; convenience for ``plt.show()`` + """ + plt.show() + + def _axis_artist(self, event): + """ Return artist if within axes, and is an image, else None + """ + if not getattr(event, 'inaxes'): + return None + artist = event.inaxes.images[0] + return artist if isinstance(artist, mpl_img.AxesImage) else None + + def on_click(self, event): + if event.button == 1: + self.follow = not self.follow + plt.draw() + + def on_scroll(self, event): + assert event.button in ('up', 'down') + im = self._axis_artist(event) + if im is None: + return + im.idx += 1 if event.button == 'up' else -1 + im.idx %= im.size + im.set_data(im.get_slice(im.idx)) + ax = im.axes + ax.draw_artist(im) + ax.figure.canvas.blit(ax.bbox) + + def on_mousemove(self, event): + if not self.follow: + return + im = self._axis_artist(event) + if im is None: + return + ax = im.axes + imx, imy = im.imx, im.imy + x, y = np.round((event.xdata, event.ydata)).astype(int) + imx.set_data(imx.get_slice(x)) + imy.set_data(imy.get_slice(y)) + imx.idx = x + imy.idx = y + for i in imx, imy: + ax = i.axes + ax.draw_artist(i) + ax.figure.canvas.blit(ax.bbox) + + +if __name__ == '__main__': + a = np.sin(np.linspace(0,np.pi,20)) + b = np.sin(np.linspace(0,np.pi*5,20)) + data = np.outer(a,b)[..., np.newaxis]*a + # all slices + OrthoSlicer3D(data).show() + + # broken out into three separate figures + f, ax1 = plt.subplots() + f, ax2 = plt.subplots() + f, ax3 = plt.subplots() + OrthoSlicer3D(data, axes=(ax1, ax2, ax3)).show() From 8bb7fb04b8e79d19553bc79483202438fd85fc87 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sun, 19 Oct 2014 11:59:03 -0700 Subject: [PATCH 02/24] ENH: Add crosshairs, modify mode --- .travis.yml | 1 + nibabel/__init__.py | 1 + nibabel/tests/test_viewers.py | 51 ++++++++++++ nibabel/viewers.py | 144 ++++++++++++++++++++-------------- 4 files changed, 138 insertions(+), 59 deletions(-) create mode 100644 nibabel/tests/test_viewers.py diff --git a/.travis.yml b/.travis.yml index 2f010533e2..210d56129e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,7 @@ # munges each line before executing it to print out the exit status. It's okay # for it to be on multiple physical lines, so long as you remember: - There # can't be any leading "-"s - All newlines will be removed, so use ";"s + language: python # Run jobs on container-based infrastructure, can be overridden per job diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 779f6e8587..4d8791d7d9 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -64,6 +64,7 @@ from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis from . import mriutils +from . import viewers # be friendly on systems with ancient numpy -- no tests, but at least # importable diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py new file mode 100644 index 0000000000..ce0713b4dc --- /dev/null +++ b/nibabel/tests/test_viewers.py @@ -0,0 +1,51 @@ +# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +# +# See COPYING file distributed along with the NiBabel package for the +# copyright and license terms. +# +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## + +import numpy as np +from collections import namedtuple as nt + +from ..optpkg import optional_package +from ..viewers import OrthoSlicer3D + +from numpy.testing.decorators import skipif + +from nose.tools import assert_raises + +plt, has_mpl = optional_package('matplotlib.pyplot')[:2] +needs_mpl = skipif(not has_mpl, 'These tests need matplotlib') + + +@needs_mpl +def test_viewer(): + # Test viewer + a = np.sin(np.linspace(0, np.pi, 20)) + b = np.sin(np.linspace(0, np.pi*5, 30)) + data = np.outer(a, b)[..., np.newaxis] * a + viewer = OrthoSlicer3D(data) + plt.draw() + + # fake some events + viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes + viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes + # tracking on + viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + None, 1)) + viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + plt.gca(), 1)) + # tracking off + viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + None, None)) + viewer.close() + + # other cases + fig, axes = plt.subplots(1, 3) + plt.close(fig) + OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) + assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2, 3], + axes=axes) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 7ae4627d9a..bfdea2a489 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -1,10 +1,12 @@ """ Utilities for viewing images -Includes version of OrthoSlicer3D code by our own Paul Ivanov +Includes version of OrthoSlicer3D code originally written by our own +Paul Ivanov. """ from __future__ import division, print_function import numpy as np +from functools import partial from .optpkg import optional_package @@ -30,26 +32,32 @@ # v +---------+ # <-- x --> + +def _set_viewer_slice(idx, im): + """Helper to set a viewer slice number""" + im.idx = idx + im.set_data(im.get_slice(im.idx)) + for fun in im.cross_setters: + fun([idx] * 2) + + class OrthoSlicer3D(object): """Orthogonal-plane slicer. OrthoSlicer3d expects 3-dimensional data, and by default it creates a figure with 3 axes, one for each slice orientation. - There are two modes, "following on" and "following off". In "following on" - mode, moving the mouse in any one axis will select out the corresponding - slices in the other two. The mode is "following off" when the figure is - first created. Clicking the left mouse button toggles mouse following and - triggers a full redraw (to update the ticks, for example). Scrolling up and + Clicking and dragging the mouse in any one axis will select out the + corresponding slices in the other two. Scrolling up and down moves the slice up and down in the current axis. Example ------- - import numpy as np - a = np.sin(np.linspace(0,np.pi,20)) - b = np.sin(np.linspace(0,np.pi*5,20)) - data = np.outer(a,b)[..., np.newaxis]*a - OrthoSlicer3D(data).show() + >>> import numpy as np + >>> a = np.sin(np.linspace(0,np.pi,20)) + >>> b = np.sin(np.linspace(0,np.pi*5,20)) + >>> data = np.outer(a,b)[..., np.newaxis]*a + >>> OrthoSlicer3D(data).show() """ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', pcnt_range=None): @@ -70,9 +78,9 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', scale between image mean and max. If sequence, min and max percentile over which to scale image. """ - data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension - aspect_ratio = np.array(aspect_ratio) - if axes is None: # make the axes + data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension + aspect_ratio = np.array(aspect_ratio, float) + if axes is None: # make the axes # ^ +---------+ ^ +---------+ # | | | | | | # | | | | @@ -122,8 +130,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', interpolation='nearest', cmap=cmap, origin='lower') + # Start midway through each axis st_x, st_y, st_z = (data_shape - 1) / 2. + sts = (st_x, st_y, st_z) n_x, n_y, n_z = data_shape z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T @@ -133,22 +143,51 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', im3 = ax3.imshow(x_get_slice(st_x), **kw) im1.get_slice, im2.get_slice, im3.get_slice = ( z_get_slice, y_get_slice, x_get_slice) + self._ims = (im1, im2, im3) + # idx is the current slice number for each panel im1.idx, im2.idx, im3.idx = st_z, st_y, st_x + # set the maximum dimensions for indexing im1.size, im2.size, im3.size = n_z, n_y, n_x + + # set up axis crosshairs + colors = ['r', 'g', 'b'] + for ax, im, idx_1, idx_2 in zip(axes, self._ims, [0, 0, 1], [1, 2, 2]): + im.x_line = ax.plot([sts[idx_1]] * 2, + [-0.5, data.shape[idx_2] - 0.5], + color=colors[idx_1], linestyle='-', + alpha=0.25)[0] + im.y_line = ax.plot([-0.5, data.shape[idx_1] - 0.5], + [sts[idx_2]] * 2, + color=colors[idx_2], linestyle='-', + alpha=0.25)[0] + ax.axis('tight') + ax.patch.set_visible(False) + ax.set_frame_on(False) + ax.axes.get_yaxis().set_visible(False) + ax.axes.get_xaxis().set_visible(False) + + # monkey-patch some functions + im1.set_viewer_slice = partial(_set_viewer_slice, im=im1) + im2.set_viewer_slice = partial(_set_viewer_slice, im=im2) + im3.set_viewer_slice = partial(_set_viewer_slice, im=im3) + # setup pairwise connections between the slice dimensions - im1.imx = im3 # x move in panel 1 (usually axial) - im1.imy = im2 # y move in panel 1 - im2.imx = im3 # x move in panel 2 (usually coronal) - im2.imy = im1 - im3.imx = im2 # x move in panel 3 (usually sagittal) - im3.imy = im1 - - self.follow = False + im1.x_im = im3 # x move in panel 1 (usually axial) + im1.y_im = im2 # y move in panel 1 + im2.x_im = im3 # x move in panel 2 (usually coronal) + im2.y_im = im1 # y move in panel 2 + im3.x_im = im2 # x move in panel 3 (usually sagittal) + im3.y_im = im1 # y move in panel 3 + + # when an index changes, which crosshairs need to be updated + im1.cross_setters = [im2.y_line.set_ydata, im3.y_line.set_ydata] + im2.cross_setters = [im1.y_line.set_ydata, im3.x_line.set_xdata] + im3.cross_setters = [im1.x_line.set_xdata, im2.x_line.set_xdata] + self.figs = set([ax.figure for ax in axes]) for fig in self.figs: - fig.canvas.mpl_connect('button_press_event', self.on_click) fig.canvas.mpl_connect('scroll_event', self.on_scroll) fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) @@ -157,59 +196,46 @@ def show(self): """ plt.show() + def close(self): + """Close the viewer figures + """ + for f in self.figs: + plt.close(f) + def _axis_artist(self, event): - """ Return artist if within axes, and is an image, else None + """Return artist if within axes, and is an image, else None """ if not getattr(event, 'inaxes'): return None artist = event.inaxes.images[0] return artist if isinstance(artist, mpl_img.AxesImage) else None - def on_click(self, event): - if event.button == 1: - self.follow = not self.follow - plt.draw() - def on_scroll(self, event): assert event.button in ('up', 'down') im = self._axis_artist(event) if im is None: return - im.idx += 1 if event.button == 'up' else -1 - im.idx %= im.size - im.set_data(im.get_slice(im.idx)) - ax = im.axes - ax.draw_artist(im) - ax.figure.canvas.blit(ax.bbox) + idx = (im.idx + (1 if event.button == 'up' else -1)) + idx = max(min(idx, im.size - 1), 0) + im.set_viewer_slice(idx) + self._draw_ims() def on_mousemove(self, event): - if not self.follow: + if event.button != 1: # only enabled while dragging return im = self._axis_artist(event) if im is None: return - ax = im.axes - imx, imy = im.imx, im.imy + x_im, y_im = im.x_im, im.y_im x, y = np.round((event.xdata, event.ydata)).astype(int) - imx.set_data(imx.get_slice(x)) - imy.set_data(imy.get_slice(y)) - imx.idx = x - imy.idx = y - for i in imx, imy: - ax = i.axes - ax.draw_artist(i) + for i, idx in zip((x_im, y_im), (x, y)): + i.set_viewer_slice(idx) + self._draw_ims() + + def _draw_ims(self): + for im in self._ims: + ax = im.axes + ax.draw_artist(im) + ax.draw_artist(im.x_line) + ax.draw_artist(im.y_line) ax.figure.canvas.blit(ax.bbox) - - -if __name__ == '__main__': - a = np.sin(np.linspace(0,np.pi,20)) - b = np.sin(np.linspace(0,np.pi*5,20)) - data = np.outer(a,b)[..., np.newaxis]*a - # all slices - OrthoSlicer3D(data).show() - - # broken out into three separate figures - f, ax1 = plt.subplots() - f, ax2 = plt.subplots() - f, ax3 = plt.subplots() - OrthoSlicer3D(data, axes=(ax1, ax2, ax3)).show() From 50975718234ac83acdc40d0cf53f548c0d401cad Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sun, 19 Oct 2014 12:14:43 -0700 Subject: [PATCH 03/24] FIX: Minor fixes --- nibabel/viewers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index bfdea2a489..53ecfd9f79 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -57,8 +57,9 @@ class OrthoSlicer3D(object): >>> a = np.sin(np.linspace(0,np.pi,20)) >>> b = np.sin(np.linspace(0,np.pi*5,20)) >>> data = np.outer(a,b)[..., np.newaxis]*a - >>> OrthoSlicer3D(data).show() + >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ + # Skip doctest above b/c not all systems have mpl installed def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', pcnt_range=None): """ From 6fda4d89f867b886f9e545de0d070b0078a8e82d Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sun, 19 Oct 2014 19:22:17 -0700 Subject: [PATCH 04/24] ENH: Add set_indices method --- nibabel/tests/test_viewers.py | 10 +++--- nibabel/viewers.py | 57 ++++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index ce0713b4dc..ec61c8dbca 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -33,14 +33,16 @@ def test_viewer(): # fake some events viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes - # tracking on + # "click" outside axes, then once in each axis, then move without click viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, 1)) - viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - plt.gca(), 1)) - # tracking off + for im in viewer._ims: + viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + im.axes, + 1)) viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) + viewer.set_indices(0, 1, 2) viewer.close() # other cases diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 53ecfd9f79..d71a76976e 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -35,10 +35,10 @@ def _set_viewer_slice(idx, im): """Helper to set a viewer slice number""" - im.idx = idx + im.idx = max(min(int(round(idx)), im.size - 1), 0) im.set_data(im.get_slice(im.idx)) for fun in im.cross_setters: - fun([idx] * 2) + fun([im.idx] * 2) class OrthoSlicer3D(object): @@ -133,24 +133,21 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', origin='lower') # Start midway through each axis - st_x, st_y, st_z = (data_shape - 1) / 2. - sts = (st_x, st_y, st_z) - n_x, n_y, n_z = data_shape - z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T - y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T - x_get_slice = lambda i: self.data[min(i, n_x-1), :, :].T - im1 = ax1.imshow(z_get_slice(st_z), **kw) - im2 = ax2.imshow(y_get_slice(st_y), **kw) - im3 = ax3.imshow(x_get_slice(st_x), **kw) + z_get_slice = lambda i: self.data[:, :, i].T + y_get_slice = lambda i: self.data[:, i, :].T + x_get_slice = lambda i: self.data[i, :, :].T + sts = (data_shape - 1) // 2 + im1 = ax1.imshow(z_get_slice(sts[2]), **kw) + im2 = ax2.imshow(y_get_slice(sts[1]), **kw) + im3 = ax3.imshow(x_get_slice(sts[0]), **kw) + # idx is the current slice number for each panel + im1.idx, im2.idx, im3.idx = sts + self._ims = (im1, im2, im3) im1.get_slice, im2.get_slice, im3.get_slice = ( z_get_slice, y_get_slice, x_get_slice) - self._ims = (im1, im2, im3) - - # idx is the current slice number for each panel - im1.idx, im2.idx, im3.idx = st_z, st_y, st_x # set the maximum dimensions for indexing - im1.size, im2.size, im3.size = n_z, n_y, n_x + im1.size, im2.size, im3.size = data_shape # set up axis crosshairs colors = ['r', 'g', 'b'] @@ -191,6 +188,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', for fig in self.figs: fig.canvas.mpl_connect('scroll_event', self.on_scroll) fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) + fig.canvas.mpl_connect('button_press_event', self.on_mousemove) def show(self): """ Show the slicer; convenience for ``plt.show()`` @@ -203,6 +201,26 @@ def close(self): for f in self.figs: plt.close(f) + def set_indices(self, x=None, y=None, z=None): + """Set current displayed slice indices + + Parameters + ---------- + x : int | None + Index to use. If None, do not change. + y : int | None + Index to use. If None, do not change. + z : int | None + Index to use. If None, do not change. + """ + draw = False + for im, val in zip(self._ims, (z, y, x)): + if val is not None: + im.set_viewer_slice(val) + draw = True + if draw: + self._draw_ims() + def _axis_artist(self, event): """Return artist if within axes, and is an image, else None """ @@ -216,8 +234,7 @@ def on_scroll(self, event): im = self._axis_artist(event) if im is None: return - idx = (im.idx + (1 if event.button == 'up' else -1)) - idx = max(min(idx, im.size - 1), 0) + idx = im.idx + (1 if event.button == 'up' else -1) im.set_viewer_slice(idx) self._draw_ims() @@ -227,9 +244,7 @@ def on_mousemove(self, event): im = self._axis_artist(event) if im is None: return - x_im, y_im = im.x_im, im.y_im - x, y = np.round((event.xdata, event.ydata)).astype(int) - for i, idx in zip((x_im, y_im), (x, y)): + for i, idx in zip((im.x_im, im.y_im), (event.xdata, event.ydata)): i.set_viewer_slice(idx) self._draw_ims() From e55a15e7d56ff1d72ec2c54aadcc3539d19e900b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 22 Oct 2014 17:08:03 -0700 Subject: [PATCH 05/24] ENH: Allow time dimension --- nibabel/spatialimages.py | 19 +++ nibabel/viewers.py | 358 +++++++++++++++++++++++---------------- 2 files changed, 228 insertions(+), 149 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index 23ab9c9c2d..d3ea9d3c33 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -139,6 +139,7 @@ from .filebasedimages import FileBasedHeader, FileBasedImage from .filebasedimages import ImageFileError # flake8: noqa; for back-compat +from .viewers import OrthoSlicer3D from .volumeutils import shape_zoom_affine @@ -661,3 +662,21 @@ def __getitem__(self): raise TypeError("Cannot slice image objects; consider slicing image " "array data with `img.dataobj[slice]` or " "`img.get_data()[slice]`") + + def plot(self, show=True): + """Plot the image using OrthoSlicer3D + + Parameters + ---------- + show : bool + If True, the viewer will be shown. + + Returns + ------- + viewer : instance of OrthoSlicer3D + The viewer. + """ + out = OrthoSlicer3D(self.get_data()) + if show: + out.show() + return out diff --git a/nibabel/viewers.py b/nibabel/viewers.py index d71a76976e..00fea018c2 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -6,40 +6,12 @@ from __future__ import division, print_function import numpy as np -from functools import partial from .optpkg import optional_package plt, _, _ = optional_package('matplotlib.pyplot') mpl_img, _, _ = optional_package('matplotlib.image') -# Assumes the following layout -# -# ^ +---------+ ^ +---------+ -# | | | | | | -# | | | | -# z | 2 | z | 3 | -# | | | | -# | | | | | | -# v +---------+ v +---------+ -# <-- x --> <-- y --> -# ^ +---------+ -# | | | -# | | -# y | 1 | -# | | -# | | | -# v +---------+ -# <-- x --> - - -def _set_viewer_slice(idx, im): - """Helper to set a viewer slice number""" - im.idx = max(min(int(round(idx)), im.size - 1), 0) - im.set_data(im.get_slice(im.idx)) - for fun in im.cross_setters: - fun([im.idx] * 2) - class OrthoSlicer3D(object): """Orthogonal-plane slicer. @@ -61,26 +33,42 @@ class OrthoSlicer3D(object): """ # Skip doctest above b/c not all systems have mpl installed def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', - pcnt_range=None): + pcnt_range=(1., 99.), figsize=(8, 8)): """ Parameters ---------- - data : 3 dimensional ndarray - The data that will be displayed by the slicer - axes : None or length 3 sequence of mpl.Axes, optional - 3 axes instances for the X, Y, and Z slices, or None (default) - aspect_ratio : float or length 3 sequence, optional - stretch factors for X, Y, Z directions - cmap : colormap identifier, optional + data : ndarray + The data that will be displayed by the slicer. Should have 3+ + dimensions. + axes : tuple of mpl.Axes | None, optional + 3 or 4 axes instances for the X, Y, Z slices plus volumes, + or None (default). + aspect_ratio : array-like, optional + Stretch factors for X, Y, Z directions. + cmap : str | instance of cmap, optional String or cmap instance specifying colormap. Will be passed as ``cmap`` argument to ``plt.imshow``. - pcnt_range : length 2 sequence, optional + pcnt_range : array-like, optional Percentile range over which to scale image for display. If None, scale between image mean and max. If sequence, min and max percentile over which to scale image. + figsize : tuple + Figure size (in inches) to use if axes are None. """ - data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension - aspect_ratio = np.array(aspect_ratio, float) + ar = np.array(aspect_ratio, float) + if ar.shape != (3,) or np.any(ar <= 0): + raise ValueError('aspect ratio must have exactly 3 elements >= 0') + aspect_ratio = dict(x=ar[0], y=ar[1], z=ar[2]) + data = np.asanyarray(data) + if data.ndim < 3: + raise RuntimeError('data must have at least 3 dimensions') + self._volume_dims = data.shape[3:] + self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data + self._data = data + pcnt_range = (0, 100) if pcnt_range is None else pcnt_range + vmin, vmax = np.percentile(data, pcnt_range) + del data + if axes is None: # make the axes # ^ +---------+ ^ +---------+ # | | | | | | @@ -90,105 +78,110 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', # | | | | | | # v +---------+ v +---------+ # <-- x --> <-- y --> - # ^ +---------+ - # | | | - # | | - # y | 1 | - # | | - # | | | - # v +---------+ - # <-- x --> - fig = plt.figure() - x, y, z = data_shape * aspect_ratio - maxw = float(x + y) - maxh = float(y + z) - yh = y / maxh - xw = x / maxw - yw = y / maxw - zh = z / maxh - # z slice (if usual transverse acquisition => axial slice) - ax1 = fig.add_axes((0., 0., xw, yh)) - # y slice (usually coronal) - ax2 = fig.add_axes((0, yh, xw, zh)) - # x slice (usually sagittal) - ax3 = fig.add_axes((xw, yh, yw, zh)) - axes = (ax1, ax2, ax3) - else: - if not np.all(aspect_ratio == 1): - raise ValueError('Aspect ratio must be 1 for external axes') - ax1, ax2, ax3 = axes - - self.data = data + # ^ +---------+ ^ +---------+ + # | | | | | | + # | | | | + # y | 1 | A | 4 | + # | | | | + # | | | | | | + # v +---------+ v +---------+ + # <-- x --> <-- t --> - if pcnt_range is None: - vmin, vmax = data.min(), data.max() + fig, axes = plt.subplots(2, 2) + fig.set_size_inches(figsize, forward=True) + self._axes = dict(x=axes[0, 1], y=axes[0, 0], z=axes[1, 0], + v=axes[1, 1]) + plt.tight_layout(pad=0.1) + if not self.multi_volume: + fig.delaxes(self._axes['v']) + del self._axes['v'] else: - vmin, vmax = np.percentile(data, pcnt_range) + self._axes = dict(z=axes[0], y=axes[1], x=axes[2]) + if len(axes) > 3: + self._axes['v'] = axes[3] - kw = dict(vmin=vmin, - vmax=vmax, - aspect='auto', - interpolation='nearest', - cmap=cmap, - origin='lower') + kw = dict(vmin=vmin, vmax=vmax, aspect=1, interpolation='nearest', + cmap=cmap, origin='lower') - # Start midway through each axis - z_get_slice = lambda i: self.data[:, :, i].T - y_get_slice = lambda i: self.data[:, i, :].T - x_get_slice = lambda i: self.data[i, :, :].T - sts = (data_shape - 1) // 2 - im1 = ax1.imshow(z_get_slice(sts[2]), **kw) - im2 = ax2.imshow(y_get_slice(sts[1]), **kw) - im3 = ax3.imshow(x_get_slice(sts[0]), **kw) - # idx is the current slice number for each panel - im1.idx, im2.idx, im3.idx = sts - self._ims = (im1, im2, im3) - im1.get_slice, im2.get_slice, im3.get_slice = ( - z_get_slice, y_get_slice, x_get_slice) - - # set the maximum dimensions for indexing - im1.size, im2.size, im3.size = data_shape + # Start midway through each axis, idx is current slice number + self._ims, self._sizes, self._idx = dict(), dict(), dict() + colors = dict() + for k, size in zip('xyz', self._data.shape[:3]): + self._idx[k] = size // 2 + self._ims[k] = self._axes[k].imshow(self._get_slice(k), **kw) + self._sizes[k] = size + colors[k] = (0, 1, 0) + self._idx['v'] = 0 + labels = dict(z='ILSR', y='ALPR', x='AIPS') # set up axis crosshairs - colors = ['r', 'g', 'b'] - for ax, im, idx_1, idx_2 in zip(axes, self._ims, [0, 0, 1], [1, 2, 2]): - im.x_line = ax.plot([sts[idx_1]] * 2, - [-0.5, data.shape[idx_2] - 0.5], - color=colors[idx_1], linestyle='-', - alpha=0.25)[0] - im.y_line = ax.plot([-0.5, data.shape[idx_1] - 0.5], - [sts[idx_2]] * 2, - color=colors[idx_2], linestyle='-', - alpha=0.25)[0] - ax.axis('tight') + for type_, i_1, i_2 in zip('zyx', 'xxy', 'yzz'): + ax = self._axes[type_] + im = self._ims[type_] + label = labels[type_] + # add slice lines + im.vert_line = ax.plot([self._idx[i_1]] * 2, + [-0.5, self._sizes[i_2] - 0.5], + color=colors[i_1], linestyle='-')[0] + im.horiz_line = ax.plot([-0.5, self._sizes[i_1] - 0.5], + [self._idx[i_2]] * 2, + color=colors[i_2], linestyle='-')[0] + # add text labels (top, right, bottom, left) + lims = [0, self._sizes[i_1], 0, self._sizes[i_2]] + bump = 0.01 + poss = [[lims[1] / 2., lims[3]], + [(1 + bump) * lims[1], lims[3] / 2.], + [lims[1] / 2., 0], + [lims[0] - bump * lims[1], lims[3] / 2.]] + anchors = [['center', 'bottom'], ['left', 'center'], + ['center', 'top'], ['right', 'center']] + im.texts = [ax.text(pos[0], pos[1], lab, + horizontalalignment=anchor[0], + verticalalignment=anchor[1]) + for pos, anchor, lab in zip(poss, anchors, label)] + ax.axis(lims) + ax.set_aspect(aspect_ratio[type_]) ax.patch.set_visible(False) ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False) - # monkey-patch some functions - im1.set_viewer_slice = partial(_set_viewer_slice, im=im1) - im2.set_viewer_slice = partial(_set_viewer_slice, im=im2) - im3.set_viewer_slice = partial(_set_viewer_slice, im=im3) + # Set up volumes axis + if self.multi_volume: + ax = self._axes['v'] + ax.set_axis_bgcolor('k') + ax.set_title('Volumes') + n_vols = np.prod(self._volume_dims) + print(n_vols) + y = np.mean(np.mean(np.mean(self._data, 0), 0), 0).ravel() + y = np.concatenate((y, [y[-1]])) + x = np.arange(n_vols + 1) - 0.5 + step = ax.step(x, y, where='post', color='y')[0] + ax.set_xticks(np.unique(np.linspace(0, n_vols - 1, 5).astype(int))) + ax.set_xlim(x[0], x[-1]) + line = ax.plot([0, 0], ax.get_ylim(), color=(0, 1, 0))[0] + self._time_lines = [line, step] # setup pairwise connections between the slice dimensions - im1.x_im = im3 # x move in panel 1 (usually axial) - im1.y_im = im2 # y move in panel 1 - im2.x_im = im3 # x move in panel 2 (usually coronal) - im2.y_im = im1 # y move in panel 2 - im3.x_im = im2 # x move in panel 3 (usually sagittal) - im3.y_im = im1 # y move in panel 3 + self._click_update_keys = dict(x='yz', y='xz', z='xy') # when an index changes, which crosshairs need to be updated - im1.cross_setters = [im2.y_line.set_ydata, im3.y_line.set_ydata] - im2.cross_setters = [im1.y_line.set_ydata, im3.x_line.set_xdata] - im3.cross_setters = [im1.x_line.set_xdata, im2.x_line.set_xdata] + self._cross_setters = dict( + x=[self._ims['z'].vert_line.set_xdata, + self._ims['y'].vert_line.set_xdata], + y=[self._ims['z'].horiz_line.set_ydata, + self._ims['x'].vert_line.set_xdata], + z=[self._ims['y'].horiz_line.set_ydata, + self._ims['x'].horiz_line.set_ydata]) - self.figs = set([ax.figure for ax in axes]) - for fig in self.figs: - fig.canvas.mpl_connect('scroll_event', self.on_scroll) - fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) - fig.canvas.mpl_connect('button_press_event', self.on_mousemove) + self._figs = set([a.figure for a in self._axes.values()]) + for fig in self._figs: + fig.canvas.mpl_connect('scroll_event', self._on_scroll) + fig.canvas.mpl_connect('motion_notify_event', self._on_mousemove) + fig.canvas.mpl_connect('button_press_event', self._on_mousemove) + fig.canvas.mpl_connect('key_press_event', self._on_keypress) + plt.draw() + self._draw() def show(self): """ Show the slicer; convenience for ``plt.show()`` @@ -198,10 +191,15 @@ def show(self): def close(self): """Close the viewer figures """ - for f in self.figs: + for f in self._figs: plt.close(f) - def set_indices(self, x=None, y=None, z=None): + @property + def multi_volume(self): + """Whether or not the displayed data is multi-volume""" + return len(self._volume_dims) > 0 + + def set_indices(self, x=None, y=None, z=None, v=None): """Set current displayed slice indices Parameters @@ -212,46 +210,108 @@ def set_indices(self, x=None, y=None, z=None): Index to use. If None, do not change. z : int | None Index to use. If None, do not change. + v : int | None + Volume index to use. If None, do not change. """ + x = int(x) if x is not None else None + y = int(y) if y is not None else None + z = int(z) if z is not None else None + v = int(v) if v is not None else None draw = False - for im, val in zip(self._ims, (z, y, x)): + if v is not None: + if not self.multi_volume: + raise RuntimeError('cannot change volume index of ' + 'single-volume image') + self._set_vol_idx(v, draw=False) # delay draw + draw = True + for key, val in zip('zyx', (z, y, x)): if val is not None: - im.set_viewer_slice(val) + self._set_viewer_slice(key, val) draw = True if draw: - self._draw_ims() + self._draw() - def _axis_artist(self, event): - """Return artist if within axes, and is an image, else None - """ - if not getattr(event, 'inaxes'): + def _set_vol_idx(self, idx, draw=True): + """Helper to change which volume is shown""" + max_ = np.prod(self._volume_dims) + self._idx['v'] = max(min(int(round(idx)), max_ - 1), 0) + # Must reset what is shown + self._current_vol_data = self._data[:, :, :, self._idx['v']] + for key in 'xyz': + self._ims[key].set_data(self._get_slice(key)) + self._time_lines[0].set_xdata([self._idx['v']] * 2) + if draw: + self._draw() + + def _get_slice(self, key): + """Helper to get the current slice image""" + ii = dict(x=0, y=1, z=2)[key] + return np.take(self._current_vol_data, self._idx[key], axis=ii).T + + def _set_viewer_slice(self, key, idx): + """Helper to set a viewer slice number""" + self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0) + self._ims[key].set_data(self._get_slice(key)) + for fun in self._cross_setters[key]: + fun([self._idx[key]] * 2) + + def _in_axis(self, event): + """Return axis key if within one of our axes, else None""" + if getattr(event, 'inaxes') is None: return None - artist = event.inaxes.images[0] - return artist if isinstance(artist, mpl_img.AxesImage) else None + for key, ax in self._axes.items(): + if event.inaxes is ax: + return key + return None - def on_scroll(self, event): + def _on_scroll(self, event): assert event.button in ('up', 'down') - im = self._axis_artist(event) - if im is None: + key = self._in_axis(event) + if key is None: return - idx = im.idx + (1 if event.button == 'up' else -1) - im.set_viewer_slice(idx) - self._draw_ims() + delta = 10 if event.key is not None and 'control' in event.key else 1 + if event.key is not None and 'shift' in event.key: + if not self.multi_volume: + return + key = 'v' # shift: change volume in any axis + idx = self._idx[key] + (delta if event.button == 'up' else -delta) + if key == 'v': + self._set_vol_idx(idx) + else: + self._set_viewer_slice(key, idx) + self._draw() - def on_mousemove(self, event): + def _on_mousemove(self, event): if event.button != 1: # only enabled while dragging return - im = self._axis_artist(event) - if im is None: + key = self._in_axis(event) + if key is None: return - for i, idx in zip((im.x_im, im.y_im), (event.xdata, event.ydata)): - i.set_viewer_slice(idx) - self._draw_ims() + if key == 'v': + self._set_vol_idx(event.xdata) + else: + for sub_key, idx in zip(self._click_update_keys[key], + (event.xdata, event.ydata)): + self._set_viewer_slice(sub_key, idx) + self._draw() - def _draw_ims(self): - for im in self._ims: + def _on_keypress(self, event): + if event.key is not None and 'escape' in event.key: + self.close() + + def _draw(self): + for im in self._ims.values(): ax = im.axes + ax.draw_artist(ax.patch) ax.draw_artist(im) - ax.draw_artist(im.x_line) - ax.draw_artist(im.y_line) + ax.draw_artist(im.vert_line) + ax.draw_artist(im.horiz_line) + ax.figure.canvas.blit(ax.bbox) + for t in im.texts: + ax.draw_artist(t) + if self.multi_volume: + ax = self._axes['v'] + ax.draw_artist(ax.patch) + for artist in self._time_lines: + ax.draw_artist(artist) ax.figure.canvas.blit(ax.bbox) From f9c6ae09c81c4319356d8640412e65311e6f84d7 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Thu, 23 Oct 2014 00:30:01 -0700 Subject: [PATCH 06/24] ENH: Better tests --- nibabel/spatialimages.py | 18 +++++++-------- nibabel/tests/test_viewers.py | 43 ++++++++++++++++++++++------------- nibabel/viewers.py | 27 +++++++++++----------- 3 files changed, 49 insertions(+), 39 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index d3ea9d3c33..d798352950 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -663,20 +663,18 @@ def __getitem__(self): "array data with `img.dataobj[slice]` or " "`img.get_data()[slice]`") - def plot(self, show=True): + def plot(self): """Plot the image using OrthoSlicer3D - Parameters - ---------- - show : bool - If True, the viewer will be shown. - Returns ------- viewer : instance of OrthoSlicer3D The viewer. + + Notes + ----- + This requires matplotlib. If a non-interactive backend is used, + consider using viewer.show() (equivalently plt.show()) to show + the figure. """ - out = OrthoSlicer3D(self.get_data()) - if show: - out.show() - return out + return OrthoSlicer3D(self.get_data()) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index ec61c8dbca..4cae4211d2 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -26,28 +26,39 @@ def test_viewer(): # Test viewer a = np.sin(np.linspace(0, np.pi, 20)) b = np.sin(np.linspace(0, np.pi*5, 30)) - data = np.outer(a, b)[..., np.newaxis] * a + data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] viewer = OrthoSlicer3D(data) plt.draw() - # fake some events - viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes - viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes + # fake some events, inside and outside axes + viewer._on_scroll(nt('event', 'button inaxes key')('up', None, None)) + for ax in (viewer._axes['x'], viewer._axes['v']): + viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, None)) + viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift')) # "click" outside axes, then once in each axis, then move without click - viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - None, 1)) - for im in viewer._ims: - viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - im.axes, - 1)) - viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - None, None)) + viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + None, 1)) + for ax in viewer._axes.values(): + viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + ax, 1)) + viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + None, None)) viewer.set_indices(0, 1, 2) + viewer.set_indices(v=10) viewer.close() + # non-multi-volume + viewer = OrthoSlicer3D(data[:, :, :, 0]) + assert_raises(ValueError, viewer.set_indices, v=10) # not multi-volume + viewer._on_scroll(nt('event', 'button inaxes key')('up', viewer._axes['x'], + 'shift')) + viewer._on_keypress(nt('event', 'key')('escape')) + # other cases - fig, axes = plt.subplots(1, 3) + fig, axes = plt.subplots(1, 4) plt.close(fig) - OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) - assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2, 3], - axes=axes) + OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes, + aspect_ratio=[1, 2, 3]) + OrthoSlicer3D(data, axes=axes[:3]) + assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2]) + assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 00fea018c2..834b23f4e9 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -11,6 +11,7 @@ plt, _, _ = optional_package('matplotlib.pyplot') mpl_img, _, _ = optional_package('matplotlib.image') +mpl_patch, _, _ = optional_package('matplotlib.patches') class OrthoSlicer3D(object): @@ -61,7 +62,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', aspect_ratio = dict(x=ar[0], y=ar[1], z=ar[2]) data = np.asanyarray(data) if data.ndim < 3: - raise RuntimeError('data must have at least 3 dimensions') + raise ValueError('data must have at least 3 dimensions') self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data @@ -147,20 +148,23 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', ax.axes.get_xaxis().set_visible(False) # Set up volumes axis - if self.multi_volume: + if self.multi_volume and 'v' in self._axes: ax = self._axes['v'] ax.set_axis_bgcolor('k') ax.set_title('Volumes') n_vols = np.prod(self._volume_dims) - print(n_vols) y = np.mean(np.mean(np.mean(self._data, 0), 0), 0).ravel() y = np.concatenate((y, [y[-1]])) x = np.arange(n_vols + 1) - 0.5 step = ax.step(x, y, where='post', color='y')[0] ax.set_xticks(np.unique(np.linspace(0, n_vols - 1, 5).astype(int))) ax.set_xlim(x[0], x[-1]) - line = ax.plot([0, 0], ax.get_ylim(), color=(0, 1, 0))[0] - self._time_lines = [line, step] + lims = ax.get_ylim() + patch = mpl_patch.Rectangle([-0.5, lims[0]], 1., np.diff(lims)[0], + fill=True, facecolor=(0, 1, 0), + edgecolor=(0, 1, 0), alpha=0.25) + ax.add_patch(patch) + self._time_lines = [patch, step] # setup pairwise connections between the slice dimensions self._click_update_keys = dict(x='yz', y='xz', z='xy') @@ -180,11 +184,9 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', fig.canvas.mpl_connect('motion_notify_event', self._on_mousemove) fig.canvas.mpl_connect('button_press_event', self._on_mousemove) fig.canvas.mpl_connect('key_press_event', self._on_keypress) - plt.draw() - self._draw() def show(self): - """ Show the slicer; convenience for ``plt.show()`` + """ Show the slicer in blocking mode; convenience for ``plt.show()`` """ plt.show() @@ -220,8 +222,8 @@ def set_indices(self, x=None, y=None, z=None, v=None): draw = False if v is not None: if not self.multi_volume: - raise RuntimeError('cannot change volume index of ' - 'single-volume image') + raise ValueError('cannot change volume index of single-volume ' + 'image') self._set_vol_idx(v, draw=False) # delay draw draw = True for key, val in zip('zyx', (z, y, x)): @@ -239,7 +241,7 @@ def _set_vol_idx(self, idx, draw=True): self._current_vol_data = self._data[:, :, :, self._idx['v']] for key in 'xyz': self._ims[key].set_data(self._get_slice(key)) - self._time_lines[0].set_xdata([self._idx['v']] * 2) + self._time_lines[0].set_x(self._idx['v'] - 0.5) if draw: self._draw() @@ -262,7 +264,6 @@ def _in_axis(self, event): for key, ax in self._axes.items(): if event.inaxes is ax: return key - return None def _on_scroll(self, event): assert event.button in ('up', 'down') @@ -309,7 +310,7 @@ def _draw(self): ax.figure.canvas.blit(ax.bbox) for t in im.texts: ax.draw_artist(t) - if self.multi_volume: + if self.multi_volume and 'v' in self._axes: # user might only pass 3 ax = self._axes['v'] ax.draw_artist(ax.patch) for artist in self._time_lines: From 11ff239af481349247f0cf35cdaea22a38a90976 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sat, 25 Oct 2014 22:46:52 -0700 Subject: [PATCH 07/24] FIX: Update volume plot --- nibabel/tests/test_viewers.py | 1 + nibabel/viewers.py | 78 +++++++++++++++++++++-------------- 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 4cae4211d2..64ea1df514 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -27,6 +27,7 @@ def test_viewer(): a = np.sin(np.linspace(0, np.pi, 20)) b = np.sin(np.linspace(0, np.pi*5, 30)) data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] + data = data * np.array([1., 2.]) # give it a # of volumes > 1 viewer = OrthoSlicer3D(data) plt.draw() diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 834b23f4e9..540a9b6e67 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -93,7 +93,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', self._axes = dict(x=axes[0, 1], y=axes[0, 0], z=axes[1, 0], v=axes[1, 1]) plt.tight_layout(pad=0.1) - if not self.multi_volume: + if self.n_volumes <= 1: fig.delaxes(self._axes['v']) del self._axes['v'] else: @@ -109,7 +109,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', colors = dict() for k, size in zip('xyz', self._data.shape[:3]): self._idx[k] = size // 2 - self._ims[k] = self._axes[k].imshow(self._get_slice(k), **kw) + self._ims[k] = self._axes[k].imshow(self._get_slice_data(k), **kw) self._sizes[k] = size colors[k] = (0, 1, 0) self._idx['v'] = 0 @@ -148,23 +148,24 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', ax.axes.get_xaxis().set_visible(False) # Set up volumes axis - if self.multi_volume and 'v' in self._axes: + if self.n_volumes > 1 and 'v' in self._axes: ax = self._axes['v'] ax.set_axis_bgcolor('k') ax.set_title('Volumes') - n_vols = np.prod(self._volume_dims) - y = np.mean(np.mean(np.mean(self._data, 0), 0), 0).ravel() - y = np.concatenate((y, [y[-1]])) - x = np.arange(n_vols + 1) - 0.5 + y = self._get_voxel_levels() + x = np.arange(self.n_volumes + 1) - 0.5 step = ax.step(x, y, where='post', color='y')[0] - ax.set_xticks(np.unique(np.linspace(0, n_vols - 1, 5).astype(int))) + ax.set_xticks(np.unique(np.linspace(0, self.n_volumes - 1, + 5).astype(int))) ax.set_xlim(x[0], x[-1]) - lims = ax.get_ylim() - patch = mpl_patch.Rectangle([-0.5, lims[0]], 1., np.diff(lims)[0], - fill=True, facecolor=(0, 1, 0), - edgecolor=(0, 1, 0), alpha=0.25) + yl = [self._data.min(), self._data.max()] + yl = [l + s * np.diff(lims)[0] for l, s in zip(yl, [-1.01, 1.01])] + patch = mpl_patch.Rectangle([-0.5, yl[0]], 1., np.diff(yl)[0], + fill=True, facecolor=(0, 1, 0), + edgecolor=(0, 1, 0), alpha=0.25) ax.add_patch(patch) - self._time_lines = [patch, step] + ax.set_ylim(yl) + self._volume_ax_objs = dict(step=step, patch=patch) # setup pairwise connections between the slice dimensions self._click_update_keys = dict(x='yz', y='xz', z='xy') @@ -197,9 +198,9 @@ def close(self): plt.close(f) @property - def multi_volume(self): - """Whether or not the displayed data is multi-volume""" - return len(self._volume_dims) > 0 + def n_volumes(self): + """Number of volumes in the data""" + return int(np.prod(self._volume_dims)) def set_indices(self, x=None, y=None, z=None, v=None): """Set current displayed slice indices @@ -221,31 +222,43 @@ def set_indices(self, x=None, y=None, z=None, v=None): v = int(v) if v is not None else None draw = False if v is not None: - if not self.multi_volume: + if self.n_volumes <= 1: raise ValueError('cannot change volume index of single-volume ' 'image') - self._set_vol_idx(v, draw=False) # delay draw + self._set_vol_idx(v) draw = True for key, val in zip('zyx', (z, y, x)): if val is not None: self._set_viewer_slice(key, val) draw = True if draw: + self._update_voxel_levels() self._draw() - def _set_vol_idx(self, idx, draw=True): - """Helper to change which volume is shown""" + def _get_voxel_levels(self): + """Get levels of the current voxel as a function of volume""" + y = self._data[self._idx['x'], + self._idx['y'], + self._idx['z'], :].ravel() + y = np.concatenate((y, [y[-1]])) + return y + + def _update_voxel_levels(self): + """Update voxel levels in time plot""" + if self.n_volumes > 1: + self._volume_ax_objs['step'].set_ydata(self._get_voxel_levels()) + + def _set_vol_idx(self, idx): + """Change which volume is shown""" max_ = np.prod(self._volume_dims) self._idx['v'] = max(min(int(round(idx)), max_ - 1), 0) # Must reset what is shown self._current_vol_data = self._data[:, :, :, self._idx['v']] for key in 'xyz': - self._ims[key].set_data(self._get_slice(key)) - self._time_lines[0].set_x(self._idx['v'] - 0.5) - if draw: - self._draw() + self._ims[key].set_data(self._get_slice_data(key)) + self._volume_ax_objs['patch'].set_x(self._idx['v'] - 0.5) - def _get_slice(self, key): + def _get_slice_data(self, key): """Helper to get the current slice image""" ii = dict(x=0, y=1, z=2)[key] return np.take(self._current_vol_data, self._idx[key], axis=ii).T @@ -253,7 +266,7 @@ def _get_slice(self, key): def _set_viewer_slice(self, key, idx): """Helper to set a viewer slice number""" self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0) - self._ims[key].set_data(self._get_slice(key)) + self._ims[key].set_data(self._get_slice_data(key)) for fun in self._cross_setters[key]: fun([self._idx[key]] * 2) @@ -272,7 +285,7 @@ def _on_scroll(self, event): return delta = 10 if event.key is not None and 'control' in event.key else 1 if event.key is not None and 'shift' in event.key: - if not self.multi_volume: + if self.n_volumes <= 1: return key = 'v' # shift: change volume in any axis idx = self._idx[key] + (delta if event.button == 'up' else -delta) @@ -280,6 +293,7 @@ def _on_scroll(self, event): self._set_vol_idx(idx) else: self._set_viewer_slice(key, idx) + self._update_voxel_levels() self._draw() def _on_mousemove(self, event): @@ -294,6 +308,7 @@ def _on_mousemove(self, event): for sub_key, idx in zip(self._click_update_keys[key], (event.xdata, event.ydata)): self._set_viewer_slice(sub_key, idx) + self._update_voxel_levels() self._draw() def _on_keypress(self, event): @@ -303,16 +318,15 @@ def _on_keypress(self, event): def _draw(self): for im in self._ims.values(): ax = im.axes - ax.draw_artist(ax.patch) ax.draw_artist(im) ax.draw_artist(im.vert_line) ax.draw_artist(im.horiz_line) ax.figure.canvas.blit(ax.bbox) for t in im.texts: ax.draw_artist(t) - if self.multi_volume and 'v' in self._axes: # user might only pass 3 + if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3 ax = self._axes['v'] - ax.draw_artist(ax.patch) - for artist in self._time_lines: - ax.draw_artist(artist) + ax.draw_artist(ax.patch) # axis bgcolor to erase old lines + for key in ('step', 'patch'): + ax.draw_artist(self._volume_ax_objs[key]) ax.figure.canvas.blit(ax.bbox) From bf54667489b76a1af087aa79949f90a3cc20b50a Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sat, 25 Oct 2014 23:13:25 -0700 Subject: [PATCH 08/24] FIX: Remove monkey patching --- nibabel/tests/test_viewers.py | 37 ++++++++--------- nibabel/viewers.py | 78 +++++++++++++++++++---------------- 2 files changed, 59 insertions(+), 56 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 64ea1df514..8be12f104d 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -28,32 +28,28 @@ def test_viewer(): b = np.sin(np.linspace(0, np.pi*5, 30)) data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] data = data * np.array([1., 2.]) # give it a # of volumes > 1 - viewer = OrthoSlicer3D(data) + v = OrthoSlicer3D(data) plt.draw() # fake some events, inside and outside axes - viewer._on_scroll(nt('event', 'button inaxes key')('up', None, None)) - for ax in (viewer._axes['x'], viewer._axes['v']): - viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, None)) - viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift')) + v._on_scroll(nt('event', 'button inaxes key')('up', None, None)) + for ax in (v._axes['x'], v._axes['v']): + v._on_scroll(nt('event', 'button inaxes key')('up', ax, None)) + v._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift')) # "click" outside axes, then once in each axis, then move without click - viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - None, 1)) - for ax in viewer._axes.values(): - viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - ax, 1)) - viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - None, None)) - viewer.set_indices(0, 1, 2) - viewer.set_indices(v=10) - viewer.close() + v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, 1)) + for ax in v._axes.values(): + v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) + v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) + v.set_indices(0, 1, 2) + v.set_indices(v=10) + v.close() # non-multi-volume - viewer = OrthoSlicer3D(data[:, :, :, 0]) - assert_raises(ValueError, viewer.set_indices, v=10) # not multi-volume - viewer._on_scroll(nt('event', 'button inaxes key')('up', viewer._axes['x'], - 'shift')) - viewer._on_keypress(nt('event', 'key')('escape')) + v = OrthoSlicer3D(data[:, :, :, 0]) + assert_raises(ValueError, v.set_indices, v=10) # not multi-volume + v._on_scroll(nt('event', 'button inaxes key')('up', v._axes['x'], 'shift')) + v._on_keypress(nt('event', 'key')('escape')) # other cases fig, axes = plt.subplots(1, 4) @@ -63,3 +59,4 @@ def test_viewer(): OrthoSlicer3D(data, axes=axes[:3]) assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2]) assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) + assert_raises(ValueError, OrthoSlicer3D, data, affine=np.eye(3)) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 540a9b6e67..d73a0ec547 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -33,8 +33,8 @@ class OrthoSlicer3D(object): >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ # Skip doctest above b/c not all systems have mpl installed - def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', - pcnt_range=(1., 99.), figsize=(8, 8)): + def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, + cmap='gray', pcnt_range=(1., 99.), figsize=(8, 8)): """ Parameters ---------- @@ -46,13 +46,14 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', or None (default). aspect_ratio : array-like, optional Stretch factors for X, Y, Z directions. + affine : array-like | None + Affine transform for the data. This is used to determine + how the data should be sliced for plotting into the X, Y, + and Z view axes. If None, identity is assumed. cmap : str | instance of cmap, optional - String or cmap instance specifying colormap. Will be passed as - ``cmap`` argument to ``plt.imshow``. + String or cmap instance specifying colormap. pcnt_range : array-like, optional - Percentile range over which to scale image for display. If None, - scale between image mean and max. If sequence, min and max - percentile over which to scale image. + Percentile range over which to scale image for display. figsize : tuple Figure size (in inches) to use if axes are None. """ @@ -63,6 +64,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', data = np.asanyarray(data) if data.ndim < 3: raise ValueError('data must have at least 3 dimensions') + affine = np.array(affine, float) if affine is not None else np.eye(4) + if affine.ndim != 2 or affine.shape != (4, 4): + raise ValueError('affine must be a 4x4 matrix') + self._affine = affine self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data @@ -116,17 +121,16 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', labels = dict(z='ILSR', y='ALPR', x='AIPS') # set up axis crosshairs + self._crosshairs = dict() for type_, i_1, i_2 in zip('zyx', 'xxy', 'yzz'): - ax = self._axes[type_] - im = self._ims[type_] - label = labels[type_] - # add slice lines - im.vert_line = ax.plot([self._idx[i_1]] * 2, - [-0.5, self._sizes[i_2] - 0.5], - color=colors[i_1], linestyle='-')[0] - im.horiz_line = ax.plot([-0.5, self._sizes[i_1] - 0.5], - [self._idx[i_2]] * 2, - color=colors[i_2], linestyle='-')[0] + ax, label = self._axes[type_], labels[type_] + vert = ax.plot([self._idx[i_1]] * 2, + [-0.5, self._sizes[i_2] - 0.5], + color=colors[i_1], linestyle='-')[0] + horiz = ax.plot([-0.5, self._sizes[i_1] - 0.5], + [self._idx[i_2]] * 2, + color=colors[i_2], linestyle='-')[0] + self._crosshairs[type_] = dict(vert=vert, horiz=horiz) # add text labels (top, right, bottom, left) lims = [0, self._sizes[i_1], 0, self._sizes[i_2]] bump = 0.01 @@ -136,10 +140,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', [lims[0] - bump * lims[1], lims[3] / 2.]] anchors = [['center', 'bottom'], ['left', 'center'], ['center', 'top'], ['right', 'center']] - im.texts = [ax.text(pos[0], pos[1], lab, - horizontalalignment=anchor[0], - verticalalignment=anchor[1]) - for pos, anchor, lab in zip(poss, anchors, label)] + for pos, anchor, lab in zip(poss, anchors, label): + ax.text(pos[0], pos[1], lab, + horizontalalignment=anchor[0], + verticalalignment=anchor[1]) ax.axis(lims) ax.set_aspect(aspect_ratio[type_]) ax.patch.set_visible(False) @@ -172,18 +176,18 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', # when an index changes, which crosshairs need to be updated self._cross_setters = dict( - x=[self._ims['z'].vert_line.set_xdata, - self._ims['y'].vert_line.set_xdata], - y=[self._ims['z'].horiz_line.set_ydata, - self._ims['x'].vert_line.set_xdata], - z=[self._ims['y'].horiz_line.set_ydata, - self._ims['x'].horiz_line.set_ydata]) + x=[self._crosshairs['z']['vert'].set_xdata, + self._crosshairs['y']['vert'].set_xdata], + y=[self._crosshairs['z']['horiz'].set_ydata, + self._crosshairs['x']['vert'].set_xdata], + z=[self._crosshairs['y']['horiz'].set_ydata, + self._crosshairs['x']['horiz'].set_ydata]) self._figs = set([a.figure for a in self._axes.values()]) for fig in self._figs: fig.canvas.mpl_connect('scroll_event', self._on_scroll) - fig.canvas.mpl_connect('motion_notify_event', self._on_mousemove) - fig.canvas.mpl_connect('button_press_event', self._on_mousemove) + fig.canvas.mpl_connect('motion_notify_event', self._on_mouse) + fig.canvas.mpl_connect('button_press_event', self._on_mouse) fig.canvas.mpl_connect('key_press_event', self._on_keypress) def show(self): @@ -279,6 +283,7 @@ def _in_axis(self, event): return key def _on_scroll(self, event): + """Handle mpl scroll wheel event""" assert event.button in ('up', 'down') key = self._in_axis(event) if key is None: @@ -296,7 +301,8 @@ def _on_scroll(self, event): self._update_voxel_levels() self._draw() - def _on_mousemove(self, event): + def _on_mouse(self, event): + """Handle mpl mouse move and button press events""" if event.button != 1: # only enabled while dragging return key = self._in_axis(event) @@ -312,18 +318,18 @@ def _on_mousemove(self, event): self._draw() def _on_keypress(self, event): + """Handle mpl keypress events""" if event.key is not None and 'escape' in event.key: self.close() def _draw(self): - for im in self._ims.values(): - ax = im.axes + """Update all four (or three) plots""" + for key in 'xyz': + ax, im = self._axes[key], self._ims[key] ax.draw_artist(im) - ax.draw_artist(im.vert_line) - ax.draw_artist(im.horiz_line) + for line in self._crosshairs[key].values(): + ax.draw_artist(line) ax.figure.canvas.blit(ax.bbox) - for t in im.texts: - ax.draw_artist(t) if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3 ax = self._axes['v'] ax.draw_artist(ax.patch) # axis bgcolor to erase old lines From 62bce405739982f8d67bc5151365ee00f303de49 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 27 Oct 2014 16:58:05 -0700 Subject: [PATCH 09/24] WIP --- nibabel/spatialimages.py | 2 +- nibabel/viewers.py | 32 ++++++++++++++++---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index d798352950..a3926bc91b 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -677,4 +677,4 @@ def plot(self): consider using viewer.show() (equivalently plt.show()) to show the figure. """ - return OrthoSlicer3D(self.get_data()) + return OrthoSlicer3D(self.get_data(), self.get_affine()) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index d73a0ec547..10edd14d09 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -8,6 +8,7 @@ import numpy as np from .optpkg import optional_package +from .orientations import aff2axcodes, axcodes2ornt plt, _, _ = optional_package('matplotlib.pyplot') mpl_img, _, _ = optional_package('matplotlib.image') @@ -33,23 +34,22 @@ class OrthoSlicer3D(object): >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ # Skip doctest above b/c not all systems have mpl installed - def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, - cmap='gray', pcnt_range=(1., 99.), figsize=(8, 8)): + def __init__(self, data, affine=None, axes=None, cmap='gray', + pcnt_range=(1., 99.), figsize=(8, 8)): """ Parameters ---------- data : ndarray The data that will be displayed by the slicer. Should have 3+ dimensions. - axes : tuple of mpl.Axes | None, optional - 3 or 4 axes instances for the X, Y, Z slices plus volumes, - or None (default). - aspect_ratio : array-like, optional - Stretch factors for X, Y, Z directions. affine : array-like | None Affine transform for the data. This is used to determine how the data should be sliced for plotting into the X, Y, - and Z view axes. If None, identity is assumed. + and Z view axes. If None, identity is assumed. The aspect + ratio of the data are inferred from the affine transform. + axes : tuple of mpl.Axes | None, optional + 3 or 4 axes instances for the X, Y, Z slices plus volumes, + or None (default). cmap : str | instance of cmap, optional String or cmap instance specifying colormap. pcnt_range : array-like, optional @@ -57,17 +57,17 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, figsize : tuple Figure size (in inches) to use if axes are None. """ - ar = np.array(aspect_ratio, float) - if ar.shape != (3,) or np.any(ar <= 0): - raise ValueError('aspect ratio must have exactly 3 elements >= 0') - aspect_ratio = dict(x=ar[0], y=ar[1], z=ar[2]) data = np.asanyarray(data) if data.ndim < 3: raise ValueError('data must have at least 3 dimensions') affine = np.array(affine, float) if affine is not None else np.eye(4) if affine.ndim != 2 or affine.shape != (4, 4): raise ValueError('affine must be a 4x4 matrix') - self._affine = affine + self._affine = affine.copy() + self._codes = axcodes2ornt(aff2axcodes(self._affine)) # XXX USE FOR ORDERING + print(self._codes) + self._scalers = np.abs(self._affine).max(axis=0)[:3] + self._inv_affine = np.linalg.inv(affine) self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data @@ -122,7 +122,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, # set up axis crosshairs self._crosshairs = dict() - for type_, i_1, i_2 in zip('zyx', 'xxy', 'yzz'): + for type_, i_1, i_2 in zip('xyz', 'yxx', 'zzy'): ax, label = self._axes[type_], labels[type_] vert = ax.plot([self._idx[i_1]] * 2, [-0.5, self._sizes[i_2] - 0.5], @@ -145,7 +145,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, horizontalalignment=anchor[0], verticalalignment=anchor[1]) ax.axis(lims) - ax.set_aspect(aspect_ratio[type_]) + # ax.set_aspect(aspect_ratio[type_]) # XXX FIX ax.patch.set_visible(False) ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) @@ -206,7 +206,7 @@ def n_volumes(self): """Number of volumes in the data""" return int(np.prod(self._volume_dims)) - def set_indices(self, x=None, y=None, z=None, v=None): + def set_position(self, x=None, y=None, z=None, v=None): """Set current displayed slice indices Parameters From e5169328085d86c2862a0b457f3508deac217935 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 28 Oct 2014 12:50:18 -0700 Subject: [PATCH 10/24] WIP: Closer to correcting orientation --- nibabel/tests/test_viewers.py | 7 +-- nibabel/viewers.py | 88 ++++++++++++++++++++++------------- 2 files changed, 57 insertions(+), 38 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 8be12f104d..6ec5c1b1ef 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -41,22 +41,17 @@ def test_viewer(): for ax in v._axes.values(): v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) - v.set_indices(0, 1, 2) - v.set_indices(v=10) v.close() # non-multi-volume v = OrthoSlicer3D(data[:, :, :, 0]) - assert_raises(ValueError, v.set_indices, v=10) # not multi-volume v._on_scroll(nt('event', 'button inaxes key')('up', v._axes['x'], 'shift')) v._on_keypress(nt('event', 'key')('escape')) # other cases fig, axes = plt.subplots(1, 4) plt.close(fig) - OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes, - aspect_ratio=[1, 2, 3]) + OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) OrthoSlicer3D(data, axes=axes[:3]) - assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2]) assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) assert_raises(ValueError, OrthoSlicer3D, data, affine=np.eye(3)) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 10edd14d09..8893d67467 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -29,7 +29,7 @@ class OrthoSlicer3D(object): ------- >>> import numpy as np >>> a = np.sin(np.linspace(0,np.pi,20)) - >>> b = np.sin(np.linspace(0,np.pi*5,20)) + >>> b = np.sin(np.linspace(0,np.pi*5,20))asa >>> data = np.outer(a,b)[..., np.newaxis]*a >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ @@ -44,11 +44,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', dimensions. affine : array-like | None Affine transform for the data. This is used to determine - how the data should be sliced for plotting into the X, Y, - and Z view axes. If None, identity is assumed. The aspect - ratio of the data are inferred from the affine transform. + how the data should be sliced for plotting into the saggital, + coronal, and axial view axes. If None, identity is assumed. + The aspect ratio of the data are inferred from the affine + transform. axes : tuple of mpl.Axes | None, optional - 3 or 4 axes instances for the X, Y, Z slices plus volumes, + 3 or 4 axes instances for the 3 slices plus volumes, or None (default). cmap : str | instance of cmap, optional String or cmap instance specifying colormap. @@ -63,39 +64,43 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', affine = np.array(affine, float) if affine is not None else np.eye(4) if affine.ndim != 2 or affine.shape != (4, 4): raise ValueError('affine must be a 4x4 matrix') + # determine our orientation self._affine = affine.copy() - self._codes = axcodes2ornt(aff2axcodes(self._affine)) # XXX USE FOR ORDERING - print(self._codes) + codes = axcodes2ornt(aff2axcodes(self._affine)) + order = np.argsort([c[0] for c in codes]) + flips = np.array([c[1] for c in codes])[order] + self._order = dict(x=int(order[0]), y=int(order[1]), z=int(order[2])) + self._flips = dict(x=flips[0], y=flips[1], z=flips[2]) self._scalers = np.abs(self._affine).max(axis=0)[:3] self._inv_affine = np.linalg.inv(affine) + # current volume info self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data - pcnt_range = (0, 100) if pcnt_range is None else pcnt_range vmin, vmax = np.percentile(data, pcnt_range) del data if axes is None: # make the axes # ^ +---------+ ^ +---------+ # | | | | | | + # | Sag | | Cor | + # S | 1 | S | 2 | # | | | | - # z | 2 | z | 3 | # | | | | - # | | | | | | - # v +---------+ v +---------+ - # <-- x --> <-- y --> - # ^ +---------+ ^ +---------+ - # | | | | | | + # +---------+ +---------+ + # A --> <-- R + # ^ +---------+ +---------+ + # | | | | | + # | Axial | | | + # A | 3 | | 4 | # | | | | - # y | 1 | A | 4 | # | | | | - # | | | | | | - # v +---------+ v +---------+ - # <-- x --> <-- t --> + # +---------+ +---------+ + # <-- R <-- t --> fig, axes = plt.subplots(2, 2) fig.set_size_inches(figsize, forward=True) - self._axes = dict(x=axes[0, 1], y=axes[0, 0], z=axes[1, 0], + self._axes = dict(x=axes[0, 0], y=axes[0, 1], z=axes[1, 0], v=axes[1, 1]) plt.tight_layout(pad=0.1) if self.n_volumes <= 1: @@ -111,14 +116,15 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # Start midway through each axis, idx is current slice number self._ims, self._sizes, self._idx = dict(), dict(), dict() + self._vol = 0 colors = dict() - for k, size in zip('xyz', self._data.shape[:3]): + for k in 'xyz': + size = self._data.shape[self._order[k]] self._idx[k] = size // 2 self._ims[k] = self._axes[k].imshow(self._get_slice_data(k), **kw) self._sizes[k] = size colors[k] = (0, 1, 0) - self._idx['v'] = 0 - labels = dict(z='ILSR', y='ALPR', x='AIPS') + labels = dict(x='SAIP', y='SLIR', z='ALPR') # set up axis crosshairs self._crosshairs = dict() @@ -231,7 +237,7 @@ def set_position(self, x=None, y=None, z=None, v=None): 'image') self._set_vol_idx(v) draw = True - for key, val in zip('zyx', (z, y, x)): + for key, val in zip('xyz', (x, y, z)): if val is not None: self._set_viewer_slice(key, val) draw = True @@ -241,9 +247,11 @@ def set_position(self, x=None, y=None, z=None, v=None): def _get_voxel_levels(self): """Get levels of the current voxel as a function of volume""" - y = self._data[self._idx['x'], - self._idx['y'], - self._idx['z'], :].ravel() + # XXX THIS IS WRONG + #y = self._data[self._idx['x'], + # self._idx['y'], + # self._idx['z'], :].ravel() + y = self._data[0, 0, 0, :].ravel() y = np.concatenate((y, [y[-1]])) return y @@ -255,20 +263,34 @@ def _update_voxel_levels(self): def _set_vol_idx(self, idx): """Change which volume is shown""" max_ = np.prod(self._volume_dims) - self._idx['v'] = max(min(int(round(idx)), max_ - 1), 0) + self._vol = max(min(int(round(idx)), max_ - 1), 0) # Must reset what is shown - self._current_vol_data = self._data[:, :, :, self._idx['v']] + self._current_vol_data = self._data[:, :, :, self._vol] for key in 'xyz': self._ims[key].set_data(self._get_slice_data(key)) - self._volume_ax_objs['patch'].set_x(self._idx['v'] - 0.5) + self._volume_ax_objs['patch'].set_x(self._vol - 0.5) def _get_slice_data(self, key): """Helper to get the current slice image""" - ii = dict(x=0, y=1, z=2)[key] - return np.take(self._current_vol_data, self._idx[key], axis=ii).T + assert key in ['x', 'y', 'z'] + data = np.take(self._current_vol_data, self._idx[key], + axis=self._order[key]) + # saggital: get to S/A + # coronal: get to S/L + # axial: get to A/L + xaxes = dict(x='y', y='x', z='x') + yaxes = dict(x='z', y='z', z='y') + if self._order[xaxes[key]] < self._order[yaxes[key]]: + data = data.T + if self._flips[xaxes[key]]: + data = data[:, ::-1] + if self._flips[yaxes[key]]: + data = data[::-1] + return data def _set_viewer_slice(self, key, idx): """Helper to set a viewer slice number""" + assert key in ['x', 'y', 'z'] self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0) self._ims[key].set_data(self._get_slice_data(key)) for fun in self._cross_setters[key]: @@ -293,7 +315,9 @@ def _on_scroll(self, event): if self.n_volumes <= 1: return key = 'v' # shift: change volume in any axis - idx = self._idx[key] + (delta if event.button == 'up' else -delta) + assert key in ['x', 'y', 'z', 'v'] + idx = self._idx[key] if key != 'v' else self._vol + idx += delta if event.button == 'up' else -delta if key == 'v': self._set_vol_idx(idx) else: From b58a5d14439452b2f58dfcbcfe923b44a44f8aef Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 28 Oct 2014 12:58:34 -0700 Subject: [PATCH 11/24] WIP: Fixed ratio --- nibabel/viewers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 8893d67467..f5b611aafb 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -68,7 +68,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', self._affine = affine.copy() codes = axcodes2ornt(aff2axcodes(self._affine)) order = np.argsort([c[0] for c in codes]) - flips = np.array([c[1] for c in codes])[order] + flips = np.array([c[1] < 0 for c in codes])[order] self._order = dict(x=int(order[0]), y=int(order[1]), z=int(order[2])) self._flips = dict(x=flips[0], y=flips[1], z=flips[2]) self._scalers = np.abs(self._affine).max(axis=0)[:3] @@ -128,7 +128,10 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # set up axis crosshairs self._crosshairs = dict() - for type_, i_1, i_2 in zip('xyz', 'yxx', 'zzy'): + r = [self._scalers[self._order['z']] / self._scalers[self._order['y']], + self._scalers[self._order['z']] / self._scalers[self._order['x']], + self._scalers[self._order['y']] / self._scalers[self._order['x']]] + for type_, i_1, i_2, ratio in zip('xyz', 'yxx', 'zzy', r): ax, label = self._axes[type_], labels[type_] vert = ax.plot([self._idx[i_1]] * 2, [-0.5, self._sizes[i_2] - 0.5], @@ -151,7 +154,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', horizontalalignment=anchor[0], verticalalignment=anchor[1]) ax.axis(lims) - # ax.set_aspect(aspect_ratio[type_]) # XXX FIX + ax.set_aspect(ratio) ax.patch.set_visible(False) ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) From e54b5368d5a13547df437fa2beced7e2c4a01981 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Tue, 28 Oct 2014 14:48:18 -0700 Subject: [PATCH 12/24] FIX: FIX time plot --- nibabel/viewers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index f5b611aafb..8be3345172 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -250,11 +250,10 @@ def set_position(self, x=None, y=None, z=None, v=None): def _get_voxel_levels(self): """Get levels of the current voxel as a function of volume""" - # XXX THIS IS WRONG - #y = self._data[self._idx['x'], - # self._idx['y'], - # self._idx['z'], :].ravel() - y = self._data[0, 0, 0, :].ravel() + idx = [0] * 3 + for key in 'xyz': + idx[self._order[key]] = self._idx[key] + y = self._data[idx[0], idx[1], idx[2], :].ravel() y = np.concatenate((y, [y[-1]])) return y From 64aa8495bd17f5f015e24d34f71355c14b6ee409 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 16:18:48 -0700 Subject: [PATCH 13/24] FIX: Fix orientations and interactions --- nibabel/tests/test_viewers.py | 16 +- nibabel/viewers.py | 320 +++++++++++++++++++++------------- 2 files changed, 208 insertions(+), 128 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 6ec5c1b1ef..fa4a336a30 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -14,6 +14,7 @@ from ..viewers import OrthoSlicer3D from numpy.testing.decorators import skipif +from numpy.testing import assert_array_equal from nose.tools import assert_raises @@ -29,7 +30,7 @@ def test_viewer(): data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] data = data * np.array([1., 2.]) # give it a # of volumes > 1 v = OrthoSlicer3D(data) - plt.draw() + assert_array_equal(v.position, (0, 0, 0)) # fake some events, inside and outside axes v._on_scroll(nt('event', 'button inaxes key')('up', None, None)) @@ -41,6 +42,8 @@ def test_viewer(): for ax in v._axes.values(): v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) + v.set_volume_idx(1) + v.set_volume_idx(1) # should just pass v.close() # non-multi-volume @@ -51,7 +54,14 @@ def test_viewer(): # other cases fig, axes = plt.subplots(1, 4) plt.close(fig) - OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) - OrthoSlicer3D(data, axes=axes[:3]) + v1 = OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) + aff = np.array([[0, 1, 0, 3], [-1, 0, 0, 2], [0, 0, 2, 1], [0, 0, 0, 1]], + float) + v2 = OrthoSlicer3D(data, affine=aff, axes=axes[:3]) assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) assert_raises(ValueError, OrthoSlicer3D, data, affine=np.eye(3)) + assert_raises(TypeError, v2.link_to, 1) + v2.link_to(v1) + v2.link_to(v1) # shouldn't do anything + v1.close() + v2.close() diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 8be3345172..f2d43231a8 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -6,6 +6,7 @@ from __future__ import division, print_function import numpy as np +import weakref from .optpkg import optional_package from .orientations import aff2axcodes, axcodes2ornt @@ -91,7 +92,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # A --> <-- R # ^ +---------+ +---------+ # | | | | | - # | Axial | | | + # | Axial | | Vol | # A | 3 | | 4 | # | | | | # | | | | @@ -111,37 +112,31 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', if len(axes) > 3: self._axes['v'] = axes[3] - kw = dict(vmin=vmin, vmax=vmax, aspect=1, interpolation='nearest', - cmap=cmap, origin='lower') - # Start midway through each axis, idx is current slice number - self._ims, self._sizes, self._idx = dict(), dict(), dict() - self._vol = 0 - colors = dict() - for k in 'xyz': - size = self._data.shape[self._order[k]] - self._idx[k] = size // 2 - self._ims[k] = self._axes[k].imshow(self._get_slice_data(k), **kw) - self._sizes[k] = size - colors[k] = (0, 1, 0) - labels = dict(x='SAIP', y='SLIR', z='ALPR') + self._ims, self._sizes, self._data_idx = dict(), dict(), dict() # set up axis crosshairs self._crosshairs = dict() r = [self._scalers[self._order['z']] / self._scalers[self._order['y']], self._scalers[self._order['z']] / self._scalers[self._order['x']], self._scalers[self._order['y']] / self._scalers[self._order['x']]] - for type_, i_1, i_2, ratio in zip('xyz', 'yxx', 'zzy', r): - ax, label = self._axes[type_], labels[type_] - vert = ax.plot([self._idx[i_1]] * 2, - [-0.5, self._sizes[i_2] - 0.5], - color=colors[i_1], linestyle='-')[0] - horiz = ax.plot([-0.5, self._sizes[i_1] - 0.5], - [self._idx[i_2]] * 2, - color=colors[i_2], linestyle='-')[0] - self._crosshairs[type_] = dict(vert=vert, horiz=horiz) + for k in 'xyz': + self._sizes[k] = self._data.shape[self._order[k]] + for k, xax, yax, ratio, label in zip('xyz', 'yxx', 'zzy', r, + ('SAIP', 'SLIR', 'ALPR')): + ax = self._axes[k] + d = np.zeros((self._sizes[yax], self._sizes[xax])) + self._ims[k] = self._axes[k].imshow(d, vmin=vmin, vmax=vmax, + aspect=1, cmap=cmap, + interpolation='nearest', + origin='lower') + vert = ax.plot([0] * 2, [-0.5, self._sizes[yax] - 0.5], + color=(0, 1, 0), linestyle='-')[0] + horiz = ax.plot([-0.5, self._sizes[xax] - 0.5], [0] * 2, + color=(0, 1, 0), linestyle='-')[0] + self._crosshairs[k] = dict(vert=vert, horiz=horiz) # add text labels (top, right, bottom, left) - lims = [0, self._sizes[i_1], 0, self._sizes[i_2]] + lims = [0, self._sizes[xax], 0, self._sizes[yax]] bump = 0.01 poss = [[lims[1] / 2., lims[3]], [(1 + bump) * lims[1], lims[3] / 2.], @@ -159,13 +154,15 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False) + self._data_idx[k] = 0 + self._data_idx['v'] = -1 # Set up volumes axis if self.n_volumes > 1 and 'v' in self._axes: ax = self._axes['v'] ax.set_axis_bgcolor('k') ax.set_title('Volumes') - y = self._get_voxel_levels() + y = np.zeros(self.n_volumes + 1) x = np.arange(self.n_volumes + 1) - 0.5 step = ax.step(x, y, where='post', color='y')[0] ax.set_xticks(np.unique(np.linspace(0, self.n_volumes - 1, @@ -180,18 +177,6 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', ax.set_ylim(yl) self._volume_ax_objs = dict(step=step, patch=patch) - # setup pairwise connections between the slice dimensions - self._click_update_keys = dict(x='yz', y='xz', z='xy') - - # when an index changes, which crosshairs need to be updated - self._cross_setters = dict( - x=[self._crosshairs['z']['vert'].set_xdata, - self._crosshairs['y']['vert'].set_xdata], - y=[self._crosshairs['z']['horiz'].set_ydata, - self._crosshairs['x']['vert'].set_xdata], - z=[self._crosshairs['y']['horiz'].set_ydata, - self._crosshairs['x']['horiz'].set_ydata]) - self._figs = set([a.figure for a in self._axes.values()]) for fig in self._figs: fig.canvas.mpl_connect('scroll_event', self._on_scroll) @@ -199,8 +184,20 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', fig.canvas.mpl_connect('button_press_event', self._on_mouse) fig.canvas.mpl_connect('key_press_event', self._on_keypress) + # actually set data meaningfully + self._position = np.zeros(4) + self._position[3] = 1. # convenience for affine multn + self._changing = False # keep track of status to avoid loops + self._links = [] # other viewers this one is linked to + for fig in self._figs: + fig.canvas.draw() + self._set_volume_index(0, update_slices=False) + self._set_position(0., 0., 0.) + self._draw() + + # User-level functions ################################################### def show(self): - """ Show the slicer in blocking mode; convenience for ``plt.show()`` + """Show the slicer in blocking mode; convenience for ``plt.show()`` """ plt.show() @@ -209,95 +206,156 @@ def close(self): """ for f in self._figs: plt.close(f) + for link in self._links: + link()._unlink(self) @property def n_volumes(self): """Number of volumes in the data""" return int(np.prod(self._volume_dims)) - def set_position(self, x=None, y=None, z=None, v=None): + @property + def position(self): + """The current coordinates""" + return self._position[:3].copy() + + def link_to(self, other): + """Link positional changes between two canvases + + Parameters + ---------- + other : instance of OrthoSlicer3D + Other viewer to use to link movements. + """ + if not isinstance(other, self.__class__): + raise TypeError('other must be an instance of %s, not %s' + % (self.__class__.__name__, type(other))) + self._link(other, is_primary=True) + + def _link(self, other, is_primary): + """Link a viewer""" + ref = weakref.ref(other) + if ref in self._links: + return + self._links.append(ref) + if is_primary: + other._link(self, is_primary=False) + other.set_position(*self.position) + + def _unlink(self, other): + """Unlink a viewer""" + ref = weakref.ref(other) + if ref in self._links: + self._links.pop(self._links.index(ref)) + ref()._unlink(self) + + def _notify_links(self): + """Notify linked canvases of a position change""" + for link in self._links: + link().set_position(*self.position[:3]) + + def set_position(self, x=None, y=None, z=None): """Set current displayed slice indices Parameters ---------- - x : int | None - Index to use. If None, do not change. - y : int | None - Index to use. If None, do not change. - z : int | None - Index to use. If None, do not change. - v : int | None - Volume index to use. If None, do not change. + x : float | None + X coordinate to use. If None, do not change. + y : float | None + Y coordinate to use. If None, do not change. + z : float | None + Z coordinate to use. If None, do not change. """ - x = int(x) if x is not None else None - y = int(y) if y is not None else None - z = int(z) if z is not None else None - v = int(v) if v is not None else None - draw = False - if v is not None: - if self.n_volumes <= 1: - raise ValueError('cannot change volume index of single-volume ' - 'image') - self._set_vol_idx(v) - draw = True - for key, val in zip('xyz', (x, y, z)): - if val is not None: - self._set_viewer_slice(key, val) - draw = True - if draw: - self._update_voxel_levels() - self._draw() - - def _get_voxel_levels(self): - """Get levels of the current voxel as a function of volume""" - idx = [0] * 3 - for key in 'xyz': - idx[self._order[key]] = self._idx[key] - y = self._data[idx[0], idx[1], idx[2], :].ravel() - y = np.concatenate((y, [y[-1]])) - return y - - def _update_voxel_levels(self): - """Update voxel levels in time plot""" - if self.n_volumes > 1: - self._volume_ax_objs['step'].set_ydata(self._get_voxel_levels()) - - def _set_vol_idx(self, idx): - """Change which volume is shown""" + self._set_position(x, y, z) + self._draw() + + def set_volume_idx(self, v): + """Set current displayed volume index + + Parameters + ---------- + v : int + Volume index. + """ + self._set_volume_index(v) + self._draw() + + def _set_volume_index(self, v, update_slices=True): + """Set the plot data using a volume index""" + v = self._data_idx['v'] if v is None else int(round(v)) + if v == self._data_idx['v']: + return max_ = np.prod(self._volume_dims) - self._vol = max(min(int(round(idx)), max_ - 1), 0) - # Must reset what is shown - self._current_vol_data = self._data[:, :, :, self._vol] + self._data_idx['v'] = max(min(int(round(v)), max_ - 1), 0) + idx = (slice(None), slice(None), slice(None)) + if self._data.ndim > 3: + idx = idx + tuple(np.unravel_index(self._data_idx['v'], + self._volume_dims)) + self._current_vol_data = self._data[idx] + # update all of our slice plots + if update_slices: + self._set_position(None, None, None, notify=False) + + def _set_position(self, x, y, z, notify=True): + """Set the plot data using a physical position""" + # deal with volume first + if self._changing: + return + self._changing = True + x = self._position[0] if x is None else float(x) + y = self._position[1] if y is None else float(y) + z = self._position[2] if z is None else float(z) + + # deal with slicing appropriately + self._position[:3] = [x, y, z] + idxs = np.dot(self._inv_affine, self._position)[:3] + for key, idx in zip('xyz', idxs): + self._data_idx[key] = max(min(int(round(idx)), + self._sizes[key] - 1), 0) for key in 'xyz': - self._ims[key].set_data(self._get_slice_data(key)) - self._volume_ax_objs['patch'].set_x(self._vol - 0.5) - - def _get_slice_data(self, key): - """Helper to get the current slice image""" - assert key in ['x', 'y', 'z'] - data = np.take(self._current_vol_data, self._idx[key], - axis=self._order[key]) - # saggital: get to S/A - # coronal: get to S/L - # axial: get to A/L - xaxes = dict(x='y', y='x', z='x') - yaxes = dict(x='z', y='z', z='y') - if self._order[xaxes[key]] < self._order[yaxes[key]]: - data = data.T - if self._flips[xaxes[key]]: - data = data[:, ::-1] - if self._flips[yaxes[key]]: - data = data[::-1] - return data - - def _set_viewer_slice(self, key, idx): - """Helper to set a viewer slice number""" - assert key in ['x', 'y', 'z'] - self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0) - self._ims[key].set_data(self._get_slice_data(key)) - for fun in self._cross_setters[key]: - fun([self._idx[key]] * 2) - + # saggital: get to S/A + # coronal: get to S/L + # axial: get to A/L + data = np.take(self._current_vol_data, self._data_idx[key], + axis=self._order[key]) + xax = dict(x='y', y='x', z='x')[key] + yax = dict(x='z', y='z', z='y')[key] + if self._order[xax] < self._order[yax]: + data = data.T + if self._flips[xax]: + data = data[:, ::-1] + if self._flips[yax]: + data = data[::-1] + self._ims[key].set_data(data) + # deal with crosshairs + loc = self._data_idx[key] + if self._flips[key]: + loc = self._sizes[key] - loc + loc = [loc] * 2 + if key == 'x': + self._crosshairs['z']['vert'].set_xdata(loc) + self._crosshairs['y']['vert'].set_xdata(loc) + elif key == 'y': + self._crosshairs['z']['horiz'].set_ydata(loc) + self._crosshairs['x']['vert'].set_xdata(loc) + else: # key == 'z' + self._crosshairs['y']['horiz'].set_ydata(loc) + self._crosshairs['x']['horiz'].set_ydata(loc) + + # Update volume trace + if self.n_volumes > 1 and 'v' in self._axes: + idx = [0] * 3 + for key in 'xyz': + idx[self._order[key]] = self._data_idx[key] + vdata = self._data[idx[0], idx[1], idx[2], :].ravel() + vdata = np.concatenate((vdata, [vdata[-1]])) + self._volume_ax_objs['patch'].set_x(self._data_idx['v'] - 0.5) + self._volume_ax_objs['step'].set_ydata(vdata) + if notify: + self._notify_links() + self._changing = False + + # Matplotlib handlers #################################################### def _in_axis(self, event): """Return axis key if within one of our axes, else None""" if getattr(event, 'inaxes') is None: @@ -312,19 +370,25 @@ def _on_scroll(self, event): key = self._in_axis(event) if key is None: return - delta = 10 if event.key is not None and 'control' in event.key else 1 if event.key is not None and 'shift' in event.key: if self.n_volumes <= 1: return key = 'v' # shift: change volume in any axis assert key in ['x', 'y', 'z', 'v'] - idx = self._idx[key] if key != 'v' else self._vol - idx += delta if event.button == 'up' else -delta + dv = 10. if event.key is not None and 'control' in event.key else 1. + dv *= 1. if event.button == 'up' else -1. + dv *= -1 if self._flips.get(key, False) else 1 + val = self._data_idx[key] + dv if key == 'v': - self._set_vol_idx(idx) + self._set_volume_index(val) else: - self._set_viewer_slice(key, idx) - self._update_voxel_levels() + coords = {key: val} + for k in 'xyz': + if k not in coords: + coords[k] = self._data_idx[k] + coords = np.array([coords['x'], coords['y'], coords['z'], 1.]) + coords = np.dot(self._affine, coords)[:3] + self._set_position(coords[0], coords[1], coords[2]) self._draw() def _on_mouse(self, event): @@ -335,12 +399,18 @@ def _on_mouse(self, event): if key is None: return if key == 'v': - self._set_vol_idx(event.xdata) + # volume plot directly translates + self._set_volume_index(event.xdata) else: - for sub_key, idx in zip(self._click_update_keys[key], - (event.xdata, event.ydata)): - self._set_viewer_slice(sub_key, idx) - self._update_voxel_levels() + # translate click xdata/ydata to physical position + xax, yax = dict(x='yz', y='xz', z='xy')[key] + x, y = event.xdata, event.ydata + x = self._sizes[xax] - x if self._flips[xax] else x + y = self._sizes[yax] - y if self._flips[yax] else y + idxs = {xax: x, yax: y, key: self._data_idx[key]} + idxs = np.array([idxs['x'], idxs['y'], idxs['z'], 1.]) + pos = np.dot(self._affine, idxs)[:3] + self._set_position(*pos) self._draw() def _on_keypress(self, event): From 4bf9d5c558c9f63d09a88c6472b9357e73342c3d Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 16:26:15 -0700 Subject: [PATCH 14/24] FIX: Minor fixes --- nibabel/viewers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index f2d43231a8..3870b6f9b2 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -30,7 +30,7 @@ class OrthoSlicer3D(object): ------- >>> import numpy as np >>> a = np.sin(np.linspace(0,np.pi,20)) - >>> b = np.sin(np.linspace(0,np.pi*5,20))asa + >>> b = np.sin(np.linspace(0,np.pi*5,20)) >>> data = np.outer(a,b)[..., np.newaxis]*a >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ @@ -189,6 +189,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', self._position[3] = 1. # convenience for affine multn self._changing = False # keep track of status to avoid loops self._links = [] # other viewers this one is linked to + plt.draw() for fig in self._figs: fig.canvas.draw() self._set_volume_index(0, update_slices=False) From 8b3aaa9eedf0336290489ab0a8a7446242123932 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 16:34:22 -0700 Subject: [PATCH 15/24] FIX: Fix test --- nibabel/tests/test_viewers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index fa4a336a30..a5d0ca4709 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -10,6 +10,12 @@ import numpy as np from collections import namedtuple as nt +try: + import matplotlib + matplotlib.use('agg') +except Exception: + pass + from ..optpkg import optional_package from ..viewers import OrthoSlicer3D From 9b18cd38b45d28287aebd421e55a0b51dacfa9ac Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 16:42:40 -0700 Subject: [PATCH 16/24] FIX: Better testing --- nibabel/tests/test_viewers.py | 10 ++++------ nibabel/viewers.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index a5d0ca4709..e639c5e38c 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -10,11 +10,6 @@ import numpy as np from collections import namedtuple as nt -try: - import matplotlib - matplotlib.use('agg') -except Exception: - pass from ..optpkg import optional_package from ..viewers import OrthoSlicer3D @@ -24,13 +19,16 @@ from nose.tools import assert_raises -plt, has_mpl = optional_package('matplotlib.pyplot')[:2] +matplotlib, has_mpl = optional_package('matplotlib')[:2] needs_mpl = skipif(not has_mpl, 'These tests need matplotlib') +if has_mpl: + matplotlib.use('Agg') @needs_mpl def test_viewer(): # Test viewer + plt = optional_package('matplotlib.pyplot')[0] a = np.sin(np.linspace(0, np.pi, 20)) b = np.sin(np.linspace(0, np.pi*5, 30)) data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 3870b6f9b2..12d4f581e7 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -11,10 +11,6 @@ from .optpkg import optional_package from .orientations import aff2axcodes, axcodes2ornt -plt, _, _ = optional_package('matplotlib.pyplot') -mpl_img, _, _ = optional_package('matplotlib.image') -mpl_patch, _, _ = optional_package('matplotlib.patches') - class OrthoSlicer3D(object): """Orthogonal-plane slicer. @@ -59,6 +55,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', figsize : tuple Figure size (in inches) to use if axes are None. """ + # Nest imports so that matplotlib.use() has the appropriate + # effect in testing + plt, _, _ = optional_package('matplotlib.pyplot') + mpl_img, _, _ = optional_package('matplotlib.image') + mpl_patch, _, _ = optional_package('matplotlib.patches') + data = np.asanyarray(data) if data.ndim < 3: raise ValueError('data must have at least 3 dimensions') @@ -200,11 +202,13 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', def show(self): """Show the slicer in blocking mode; convenience for ``plt.show()`` """ + plt, _, _ = optional_package('matplotlib.pyplot') plt.show() def close(self): """Close the viewer figures """ + plt, _, _ = optional_package('matplotlib.pyplot') for f in self._figs: plt.close(f) for link in self._links: From 34cef801c8c75ae39ced426c00567790d495d4e6 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 17:17:04 -0700 Subject: [PATCH 17/24] STY: Remove dicts in favor of lists --- nibabel/tests/test_viewers.py | 6 +- nibabel/viewers.py | 181 ++++++++++++++++------------------ 2 files changed, 90 insertions(+), 97 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index e639c5e38c..e0bdfae814 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -38,12 +38,12 @@ def test_viewer(): # fake some events, inside and outside axes v._on_scroll(nt('event', 'button inaxes key')('up', None, None)) - for ax in (v._axes['x'], v._axes['v']): + for ax in (v._axes[0], v._axes[3]): v._on_scroll(nt('event', 'button inaxes key')('up', ax, None)) v._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift')) # "click" outside axes, then once in each axis, then move without click v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, 1)) - for ax in v._axes.values(): + for ax in v._axes: v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) v.set_volume_idx(1) @@ -52,7 +52,7 @@ def test_viewer(): # non-multi-volume v = OrthoSlicer3D(data[:, :, :, 0]) - v._on_scroll(nt('event', 'button inaxes key')('up', v._axes['x'], 'shift')) + v._on_scroll(nt('event', 'button inaxes key')('up', v._axes[0], 'shift')) v._on_keypress(nt('event', 'key')('escape')) # other cases diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 12d4f581e7..bfde6bfbfc 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -70,10 +70,9 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # determine our orientation self._affine = affine.copy() codes = axcodes2ornt(aff2axcodes(self._affine)) - order = np.argsort([c[0] for c in codes]) - flips = np.array([c[1] < 0 for c in codes])[order] - self._order = dict(x=int(order[0]), y=int(order[1]), z=int(order[2])) - self._flips = dict(x=flips[0], y=flips[1], z=flips[2]) + self._order = np.argsort([c[0] for c in codes]) + self._flips = np.array([c[1] < 0 for c in codes])[self._order] + self._flips = list(self._flips) + [False] # add volume dim self._scalers = np.abs(self._affine).max(axis=0)[:3] self._inv_affine = np.linalg.inv(affine) # current volume info @@ -87,7 +86,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # ^ +---------+ ^ +---------+ # | | | | | | # | Sag | | Cor | - # S | 1 | S | 2 | + # S | 0 | S | 1 | # | | | | # | | | | # +---------+ +---------+ @@ -95,7 +94,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # ^ +---------+ +---------+ # | | | | | # | Axial | | Vol | - # A | 3 | | 4 | + # A | 2 | | 3 | # | | | | # | | | | # +---------+ +---------+ @@ -103,40 +102,38 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', fig, axes = plt.subplots(2, 2) fig.set_size_inches(figsize, forward=True) - self._axes = dict(x=axes[0, 0], y=axes[0, 1], z=axes[1, 0], - v=axes[1, 1]) + self._axes = [axes[0, 0], axes[0, 1], axes[1, 0], axes[1, 1]] plt.tight_layout(pad=0.1) if self.n_volumes <= 1: - fig.delaxes(self._axes['v']) - del self._axes['v'] + fig.delaxes(self._axes[3]) + self._axes.pop(-1) else: - self._axes = dict(z=axes[0], y=axes[1], x=axes[2]) + self._axes = [axes[0], axes[1], axes[2]] if len(axes) > 3: - self._axes['v'] = axes[3] + self._axes.append(axes[3]) # Start midway through each axis, idx is current slice number - self._ims, self._sizes, self._data_idx = dict(), dict(), dict() + self._ims, self._data_idx = list(), list() # set up axis crosshairs - self._crosshairs = dict() - r = [self._scalers[self._order['z']] / self._scalers[self._order['y']], - self._scalers[self._order['z']] / self._scalers[self._order['x']], - self._scalers[self._order['y']] / self._scalers[self._order['x']]] - for k in 'xyz': - self._sizes[k] = self._data.shape[self._order[k]] - for k, xax, yax, ratio, label in zip('xyz', 'yxx', 'zzy', r, - ('SAIP', 'SLIR', 'ALPR')): - ax = self._axes[k] + self._crosshairs = [None] * 3 + r = [self._scalers[self._order[2]] / self._scalers[self._order[1]], + self._scalers[self._order[2]] / self._scalers[self._order[0]], + self._scalers[self._order[1]] / self._scalers[self._order[0]]] + self._sizes = [self._data.shape[o] for o in self._order] + for ii, xax, yax, ratio, label in zip([0, 1, 2], [1, 0, 0], [2, 2, 1], + r, ('SAIP', 'SLIR', 'ALPR')): + ax = self._axes[ii] d = np.zeros((self._sizes[yax], self._sizes[xax])) - self._ims[k] = self._axes[k].imshow(d, vmin=vmin, vmax=vmax, - aspect=1, cmap=cmap, - interpolation='nearest', - origin='lower') + im = self._axes[ii].imshow(d, vmin=vmin, vmax=vmax, aspect=1, + cmap=cmap, interpolation='nearest', + origin='lower') + self._ims.append(im) vert = ax.plot([0] * 2, [-0.5, self._sizes[yax] - 0.5], color=(0, 1, 0), linestyle='-')[0] horiz = ax.plot([-0.5, self._sizes[xax] - 0.5], [0] * 2, color=(0, 1, 0), linestyle='-')[0] - self._crosshairs[k] = dict(vert=vert, horiz=horiz) + self._crosshairs[ii] = dict(vert=vert, horiz=horiz) # add text labels (top, right, bottom, left) lims = [0, self._sizes[xax], 0, self._sizes[yax]] bump = 0.01 @@ -156,12 +153,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False) - self._data_idx[k] = 0 - self._data_idx['v'] = -1 + self._data_idx.append(0) + self._data_idx.append(-1) # volume # Set up volumes axis - if self.n_volumes > 1 and 'v' in self._axes: - ax = self._axes['v'] + if self.n_volumes > 1 and len(self._axes) > 3: + ax = self._axes[3] ax.set_axis_bgcolor('k') ax.set_title('Volumes') y = np.zeros(self.n_volumes + 1) @@ -179,7 +176,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', ax.set_ylim(yl) self._volume_ax_objs = dict(step=step, patch=patch) - self._figs = set([a.figure for a in self._axes.values()]) + self._figs = set([a.figure for a in self._axes]) for fig in self._figs: fig.canvas.mpl_connect('scroll_event', self._on_scroll) fig.canvas.mpl_connect('motion_notify_event', self._on_mouse) @@ -287,14 +284,14 @@ def set_volume_idx(self, v): def _set_volume_index(self, v, update_slices=True): """Set the plot data using a volume index""" - v = self._data_idx['v'] if v is None else int(round(v)) - if v == self._data_idx['v']: + v = self._data_idx[3] if v is None else int(round(v)) + if v == self._data_idx[3]: return max_ = np.prod(self._volume_dims) - self._data_idx['v'] = max(min(int(round(v)), max_ - 1), 0) + self._data_idx[3] = max(min(int(round(v)), max_ - 1), 0) idx = (slice(None), slice(None), slice(None)) if self._data.ndim > 3: - idx = idx + tuple(np.unravel_index(self._data_idx['v'], + idx = idx + tuple(np.unravel_index(self._data_idx[3], self._volume_dims)) self._current_vol_data = self._data[idx] # update all of our slice plots @@ -314,47 +311,46 @@ def _set_position(self, x, y, z, notify=True): # deal with slicing appropriately self._position[:3] = [x, y, z] idxs = np.dot(self._inv_affine, self._position)[:3] - for key, idx in zip('xyz', idxs): - self._data_idx[key] = max(min(int(round(idx)), - self._sizes[key] - 1), 0) - for key in 'xyz': + for ii, (size, idx) in enumerate(zip(self._sizes, idxs)): + self._data_idx[ii] = max(min(int(round(idx)), size - 1), 0) + for ii in range(3): # saggital: get to S/A # coronal: get to S/L # axial: get to A/L - data = np.take(self._current_vol_data, self._data_idx[key], - axis=self._order[key]) - xax = dict(x='y', y='x', z='x')[key] - yax = dict(x='z', y='z', z='y')[key] + data = np.take(self._current_vol_data, self._data_idx[ii], + axis=self._order[ii]) + xax = [1, 0, 0][ii] + yax = [2, 2, 1][ii] if self._order[xax] < self._order[yax]: data = data.T if self._flips[xax]: data = data[:, ::-1] if self._flips[yax]: data = data[::-1] - self._ims[key].set_data(data) + self._ims[ii].set_data(data) # deal with crosshairs - loc = self._data_idx[key] - if self._flips[key]: - loc = self._sizes[key] - loc + loc = self._data_idx[ii] + if self._flips[ii]: + loc = self._sizes[ii] - loc loc = [loc] * 2 - if key == 'x': - self._crosshairs['z']['vert'].set_xdata(loc) - self._crosshairs['y']['vert'].set_xdata(loc) - elif key == 'y': - self._crosshairs['z']['horiz'].set_ydata(loc) - self._crosshairs['x']['vert'].set_xdata(loc) - else: # key == 'z' - self._crosshairs['y']['horiz'].set_ydata(loc) - self._crosshairs['x']['horiz'].set_ydata(loc) + if ii == 0: + self._crosshairs[2]['vert'].set_xdata(loc) + self._crosshairs[1]['vert'].set_xdata(loc) + elif ii == 1: + self._crosshairs[2]['horiz'].set_ydata(loc) + self._crosshairs[0]['vert'].set_xdata(loc) + else: # ii == 2 + self._crosshairs[1]['horiz'].set_ydata(loc) + self._crosshairs[0]['horiz'].set_ydata(loc) # Update volume trace - if self.n_volumes > 1 and 'v' in self._axes: - idx = [0] * 3 - for key in 'xyz': - idx[self._order[key]] = self._data_idx[key] - vdata = self._data[idx[0], idx[1], idx[2], :].ravel() + if self.n_volumes > 1 and len(self._axes) > 3: + idx = [None, Ellipsis] * 3 + for ii in range(3): + idx[self._order[ii]] = self._data_idx[ii] + vdata = self._data[idx].ravel() vdata = np.concatenate((vdata, [vdata[-1]])) - self._volume_ax_objs['patch'].set_x(self._data_idx['v'] - 0.5) + self._volume_ax_objs['patch'].set_x(self._data_idx[3] - 0.5) self._volume_ax_objs['step'].set_ydata(vdata) if notify: self._notify_links() @@ -362,60 +358,57 @@ def _set_position(self, x, y, z, notify=True): # Matplotlib handlers #################################################### def _in_axis(self, event): - """Return axis key if within one of our axes, else None""" + """Return axis index if within one of our axes, else None""" if getattr(event, 'inaxes') is None: return None - for key, ax in self._axes.items(): + for ii, ax in enumerate(self._axes): if event.inaxes is ax: - return key + return ii def _on_scroll(self, event): """Handle mpl scroll wheel event""" assert event.button in ('up', 'down') - key = self._in_axis(event) - if key is None: + ii = self._in_axis(event) + if ii is None: return if event.key is not None and 'shift' in event.key: if self.n_volumes <= 1: return - key = 'v' # shift: change volume in any axis - assert key in ['x', 'y', 'z', 'v'] + ii = 3 # shift: change volume in any axis + assert ii in range(4) dv = 10. if event.key is not None and 'control' in event.key else 1. dv *= 1. if event.button == 'up' else -1. - dv *= -1 if self._flips.get(key, False) else 1 - val = self._data_idx[key] + dv - if key == 'v': + dv *= -1 if self._flips[ii] else 1 + val = self._data_idx[ii] + dv + if ii == 3: self._set_volume_index(val) else: - coords = {key: val} - for k in 'xyz': - if k not in coords: - coords[k] = self._data_idx[k] - coords = np.array([coords['x'], coords['y'], coords['z'], 1.]) - coords = np.dot(self._affine, coords)[:3] - self._set_position(coords[0], coords[1], coords[2]) + coords = [self._data_idx[k] for k in range(3)] + [1.] + coords[ii] = val + self._set_position(*np.dot(self._affine, coords)[:3]) self._draw() def _on_mouse(self, event): """Handle mpl mouse move and button press events""" if event.button != 1: # only enabled while dragging return - key = self._in_axis(event) - if key is None: + ii = self._in_axis(event) + if ii is None: return - if key == 'v': + if ii == 3: # volume plot directly translates self._set_volume_index(event.xdata) else: # translate click xdata/ydata to physical position - xax, yax = dict(x='yz', y='xz', z='xy')[key] + xax, yax = [[1, 2], [0, 2], [0, 1]][ii] x, y = event.xdata, event.ydata x = self._sizes[xax] - x if self._flips[xax] else x y = self._sizes[yax] - y if self._flips[yax] else y - idxs = {xax: x, yax: y, key: self._data_idx[key]} - idxs = np.array([idxs['x'], idxs['y'], idxs['z'], 1.]) - pos = np.dot(self._affine, idxs)[:3] - self._set_position(*pos) + idxs = [None, None, None, 1.] + idxs[xax] = x + idxs[yax] = y + idxs[ii] = self._data_idx[ii] + self._set_position(*np.dot(self._affine, idxs)[:3]) self._draw() def _on_keypress(self, event): @@ -425,14 +418,14 @@ def _on_keypress(self, event): def _draw(self): """Update all four (or three) plots""" - for key in 'xyz': - ax, im = self._axes[key], self._ims[key] - ax.draw_artist(im) - for line in self._crosshairs[key].values(): + for ii in range(3): + ax = self._axes[ii] + ax.draw_artist(self._ims[ii]) + for line in self._crosshairs[ii].values(): ax.draw_artist(line) ax.figure.canvas.blit(ax.bbox) - if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3 - ax = self._axes['v'] + if self.n_volumes > 1 and len(self._axes) > 3: + ax = self._axes[3] ax.draw_artist(ax.patch) # axis bgcolor to erase old lines for key in ('step', 'patch'): ax.draw_artist(self._volume_ax_objs[key]) From 485ecda5fbe8bb9ccca1891884ac03ac4f7bec3b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 31 Oct 2014 17:02:37 -0700 Subject: [PATCH 18/24] FIX: Cleanup on close --- nibabel/viewers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index bfde6bfbfc..055227c87b 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -182,6 +182,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', fig.canvas.mpl_connect('motion_notify_event', self._on_mouse) fig.canvas.mpl_connect('button_press_event', self._on_mouse) fig.canvas.mpl_connect('key_press_event', self._on_keypress) + fig.canvas.mpl_connect('close_event', self._cleanup) # actually set data meaningfully self._position = np.zeros(4) @@ -205,9 +206,13 @@ def show(self): def close(self): """Close the viewer figures """ + self._cleanup() plt, _, _ = optional_package('matplotlib.pyplot') for f in self._figs: plt.close(f) + + def _cleanup(self): + """Clean up before closing""" for link in self._links: link()._unlink(self) From 59cd10f412de4b4c17186647c16abc1a2b83696a Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Fri, 31 Oct 2014 22:34:54 -0700 Subject: [PATCH 19/24] FIX: Better unlinking --- nibabel/spatialimages.py | 3 ++- nibabel/tests/test_viewers.py | 4 +++- nibabel/viewers.py | 21 ++++++++++++++++++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index a3926bc91b..0f0a421e9a 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -677,4 +677,5 @@ def plot(self): consider using viewer.show() (equivalently plt.show()) to show the figure. """ - return OrthoSlicer3D(self.get_data(), self.get_affine()) + return OrthoSlicer3D(self.get_data(), self.get_affine(), + title=self.get_filename()) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index e0bdfae814..cfaf925647 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -17,7 +17,7 @@ from numpy.testing.decorators import skipif from numpy.testing import assert_array_equal -from nose.tools import assert_raises +from nose.tools import assert_raises, assert_true matplotlib, has_mpl = optional_package('matplotlib')[:2] needs_mpl = skipif(not has_mpl, 'These tests need matplotlib') @@ -35,6 +35,7 @@ def test_viewer(): data = data * np.array([1., 2.]) # give it a # of volumes > 1 v = OrthoSlicer3D(data) assert_array_equal(v.position, (0, 0, 0)) + assert_true('OrthoSlicer3D' in repr(v)) # fake some events, inside and outside axes v._on_scroll(nt('event', 'button inaxes key')('up', None, None)) @@ -49,6 +50,7 @@ def test_viewer(): v.set_volume_idx(1) v.set_volume_idx(1) # should just pass v.close() + v._draw() # should be safe # non-multi-volume v = OrthoSlicer3D(data[:, :, :, 0]) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 055227c87b..735d1e8eef 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -32,7 +32,7 @@ class OrthoSlicer3D(object): """ # Skip doctest above b/c not all systems have mpl installed def __init__(self, data, affine=None, axes=None, cmap='gray', - pcnt_range=(1., 99.), figsize=(8, 8)): + pcnt_range=(1., 99.), figsize=(8, 8), title=None): """ Parameters ---------- @@ -60,6 +60,8 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', plt, _, _ = optional_package('matplotlib.pyplot') mpl_img, _, _ = optional_package('matplotlib.image') mpl_patch, _, _ = optional_package('matplotlib.patches') + self._title = title + self._closed = False data = np.asanyarray(data) if data.ndim < 3: @@ -107,6 +109,8 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', if self.n_volumes <= 1: fig.delaxes(self._axes[3]) self._axes.pop(-1) + if self._title is not None: + fig.canvas.set_window_title(str(title)) else: self._axes = [axes[0], axes[1], axes[2]] if len(axes) > 3: @@ -196,6 +200,14 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', self._set_position(0., 0., 0.) self._draw() + def __repr__(self): + title = '' if self._title is None else ('%s ' % self._title) + vol = '' if self.n_volumes <= 1 else (', %s' % self.n_volumes) + r = ('<%s: %s(%s, %s, %s%s)>' + % (self.__class__.__name__, title, self._sizes[0], self._sizes[1], + self._sizes[2], vol)) + return r + # User-level functions ################################################### def show(self): """Show the slicer in blocking mode; convenience for ``plt.show()`` @@ -213,8 +225,9 @@ def close(self): def _cleanup(self): """Clean up before closing""" - for link in self._links: - link()._unlink(self) + self._closed = True + for link in list(self._links): # make a copy before iterating + self._unlink(link()) @property def n_volumes(self): @@ -423,6 +436,8 @@ def _on_keypress(self, event): def _draw(self): """Update all four (or three) plots""" + if self._closed: # make sure we don't draw when we shouldn't + return for ii in range(3): ax = self._axes[ii] ax.draw_artist(self._ims[ii]) From f08525fb0516fa7bc6ae76781d9ea94cb02548bb Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 11 Feb 2016 10:36:46 -0500 Subject: [PATCH 20/24] FIX: Minor fix --- nibabel/viewers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 735d1e8eef..641aa3039e 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -163,7 +163,10 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # Set up volumes axis if self.n_volumes > 1 and len(self._axes) > 3: ax = self._axes[3] - ax.set_axis_bgcolor('k') + try: + ax.set_facecolor('k') + except AttributeError: # old mpl + ax.set_axis_bgcolor('k') ax.set_title('Volumes') y = np.zeros(self.n_volumes + 1) x = np.arange(self.n_volumes + 1) - 0.5 @@ -363,7 +366,7 @@ def _set_position(self, x, y, z, notify=True): # Update volume trace if self.n_volumes > 1 and len(self._axes) > 3: - idx = [None, Ellipsis] * 3 + idx = [slice(None)] * len(self._axes) for ii in range(3): idx[self._order[ii]] = self._data_idx[ii] vdata = self._data[idx].ravel() From 416f7845bdcd7609e27c451ca57b96cd58f7ff44 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Fri, 12 Feb 2016 10:54:01 -0500 Subject: [PATCH 21/24] ENH: Add +/- key support for incrementing/decrementing the volume for 4D data sets Add tests and update the OrthoSlicer3D docstring to describe 4D support. --- nibabel/tests/test_viewers.py | 12 +++++++++++- nibabel/viewers.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index cfaf925647..4d6fd64380 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -15,7 +15,7 @@ from ..viewers import OrthoSlicer3D from numpy.testing.decorators import skipif -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_equal from nose.tools import assert_raises, assert_true @@ -48,7 +48,17 @@ def test_viewer(): v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) v.set_volume_idx(1) + + # decrement/increment volume numbers via keypress v.set_volume_idx(1) # should just pass + v._on_keypress(nt('event', 'key')('-')) # decrement + assert_equal(v._data_idx[3], 0) + v._on_keypress(nt('event', 'key')('+')) # increment + assert_equal(v._data_idx[3], 1) + v._on_keypress(nt('event', 'key')('-')) + v._on_keypress(nt('event', 'key')('=')) # alternative increment key + assert_equal(v._data_idx[3], 1) + v.close() v._draw() # should be safe diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 641aa3039e..d2a19a7585 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -22,6 +22,12 @@ class OrthoSlicer3D(object): corresponding slices in the other two. Scrolling up and down moves the slice up and down in the current axis. + OrthoSlicer3d also supports 4-dimensional data, where multiple + 3-dimensional volumes are stacked along the last axis. For 4-dimensional + data the fourth figure axis can be used to control which 3-dimensional + volume is displayed. Alternatively, the - key can be used to decrement the + displayed volume and the + or = keys can be used to increment it. + Example ------- >>> import numpy as np @@ -436,6 +442,16 @@ def _on_keypress(self, event): """Handle mpl keypress events""" if event.key is not None and 'escape' in event.key: self.close() + elif event.key in ["=", '+']: + # increment volume index + new_idx = min(self._data_idx[3]+1, self.n_volumes) + self._set_volume_index(new_idx, update_slices=True) + self._draw() + elif event.key == '-': + # decrement volume index + new_idx = max(self._data_idx[3]-1, 0) + self._set_volume_index(new_idx, update_slices=True) + self._draw() def _draw(self): """Update all four (or three) plots""" From 4459cef8af7e2aefd519e579209108bf366b3583 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Fri, 12 Feb 2016 11:05:22 -0500 Subject: [PATCH 22/24] FIX: raise TypeError on complex input prior to figure creation --- nibabel/tests/test_viewers.py | 5 +++++ nibabel/viewers.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 4d6fd64380..7d1812762c 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -66,6 +66,11 @@ def test_viewer(): v = OrthoSlicer3D(data[:, :, :, 0]) v._on_scroll(nt('event', 'button inaxes key')('up', v._axes[0], 'shift')) v._on_keypress(nt('event', 'key')('escape')) + v.close() + + # complex input should raise a TypeError prior to figure creation + assert_raises(TypeError, OrthoSlicer3D, + data[:, :, :, 0].astype(np.complex64)) # other cases fig, axes = plt.subplots(1, 4) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index d2a19a7585..4505fea3eb 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -72,6 +72,8 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', data = np.asanyarray(data) if data.ndim < 3: raise ValueError('data must have at least 3 dimensions') + if np.iscomplexobj(data): + raise TypeError("Complex data not supported") affine = np.array(affine, float) if affine is not None else np.eye(4) if affine.ndim != 2 or affine.shape != (4, 4): raise ValueError('affine must be a 4x4 matrix') From ba1243e4e3f4f72b6a053fd6d7514a84690576a2 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 17 Feb 2016 09:07:01 -0500 Subject: [PATCH 23/24] FIX: Fix operators --- nibabel/viewers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 4505fea3eb..7ab14c0b27 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -446,12 +446,12 @@ def _on_keypress(self, event): self.close() elif event.key in ["=", '+']: # increment volume index - new_idx = min(self._data_idx[3]+1, self.n_volumes) + new_idx = min(self._data_idx[3] + 1, self.n_volumes) self._set_volume_index(new_idx, update_slices=True) self._draw() elif event.key == '-': # decrement volume index - new_idx = max(self._data_idx[3]-1, 0) + new_idx = max(self._data_idx[3] - 1, 0) self._set_volume_index(new_idx, update_slices=True) self._draw() From 2a8a73c6b2f931ab62d9ead070a89312903c7cb9 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 29 Feb 2016 10:13:07 -0500 Subject: [PATCH 24/24] FIX: Clean up code --- nibabel/spatialimages.py | 4 +- nibabel/tests/test_viewers.py | 8 +++- nibabel/viewers.py | 74 +++++++++++++++++++++++++---------- 3 files changed, 63 insertions(+), 23 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index 0f0a421e9a..1d3f8b6a34 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -663,7 +663,7 @@ def __getitem__(self): "array data with `img.dataobj[slice]` or " "`img.get_data()[slice]`") - def plot(self): + def orthoview(self): """Plot the image using OrthoSlicer3D Returns @@ -677,5 +677,5 @@ def plot(self): consider using viewer.show() (equivalently plt.show()) to show the figure. """ - return OrthoSlicer3D(self.get_data(), self.get_affine(), + return OrthoSlicer3D(self.dataobj, self.affine, title=self.get_filename()) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 7d1812762c..e78a0ecb22 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -48,6 +48,10 @@ def test_viewer(): v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) v.set_volume_idx(1) + v.cmap = 'hot' + v.clim = (0, 3) + assert_raises(ValueError, OrthoSlicer3D.clim.fset, v, (0.,)) # bad limits + assert_raises(ValueError, OrthoSlicer3D.cmap.fset, v, 'foo') # wrong cmap # decrement/increment volume numbers via keypress v.set_volume_idx(1) # should just pass @@ -75,11 +79,13 @@ def test_viewer(): # other cases fig, axes = plt.subplots(1, 4) plt.close(fig) - v1 = OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) + v1 = OrthoSlicer3D(data, axes=axes) aff = np.array([[0, 1, 0, 3], [-1, 0, 0, 2], [0, 0, 2, 1], [0, 0, 0, 1]], float) v2 = OrthoSlicer3D(data, affine=aff, axes=axes[:3]) + # bad data (not 3+ dim) assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) + # bad affine (not 4x4) assert_raises(ValueError, OrthoSlicer3D, data, affine=np.eye(3)) assert_raises(TypeError, v2.link_to, 1) v2.link_to(v1) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 7ab14c0b27..718879e317 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -8,6 +8,7 @@ import numpy as np import weakref +from .affines import voxel_sizes from .optpkg import optional_package from .orientations import aff2axcodes, axcodes2ornt @@ -37,29 +38,25 @@ class OrthoSlicer3D(object): >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ # Skip doctest above b/c not all systems have mpl installed - def __init__(self, data, affine=None, axes=None, cmap='gray', - pcnt_range=(1., 99.), figsize=(8, 8), title=None): + def __init__(self, data, affine=None, axes=None, title=None): """ Parameters ---------- - data : ndarray + data : array-like The data that will be displayed by the slicer. Should have 3+ dimensions. - affine : array-like | None + affine : array-like or None Affine transform for the data. This is used to determine how the data should be sliced for plotting into the saggital, coronal, and axial view axes. If None, identity is assumed. The aspect ratio of the data are inferred from the affine transform. - axes : tuple of mpl.Axes | None, optional + axes : tuple of mpl.Axes or None, optional 3 or 4 axes instances for the 3 slices plus volumes, or None (default). - cmap : str | instance of cmap, optional - String or cmap instance specifying colormap. - pcnt_range : array-like, optional - Percentile range over which to scale image for display. - figsize : tuple - Figure size (in inches) to use if axes are None. + title : str or None + The title to display. Can be None (default) to display no + title. """ # Nest imports so that matplotlib.use() has the appropriate # effect in testing @@ -75,21 +72,21 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', if np.iscomplexobj(data): raise TypeError("Complex data not supported") affine = np.array(affine, float) if affine is not None else np.eye(4) - if affine.ndim != 2 or affine.shape != (4, 4): + if affine.shape != (4, 4): raise ValueError('affine must be a 4x4 matrix') # determine our orientation - self._affine = affine.copy() + self._affine = affine codes = axcodes2ornt(aff2axcodes(self._affine)) self._order = np.argsort([c[0] for c in codes]) self._flips = np.array([c[1] < 0 for c in codes])[self._order] self._flips = list(self._flips) + [False] # add volume dim - self._scalers = np.abs(self._affine).max(axis=0)[:3] + self._scalers = voxel_sizes(self._affine) self._inv_affine = np.linalg.inv(affine) # current volume info self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data - vmin, vmax = np.percentile(data, pcnt_range) + self._clim = np.percentile(data, (1., 99.)) del data if axes is None: # make the axes @@ -111,7 +108,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # <-- R <-- t --> fig, axes = plt.subplots(2, 2) - fig.set_size_inches(figsize, forward=True) + fig.set_size_inches((8, 8), forward=True) self._axes = [axes[0, 0], axes[0, 1], axes[1, 0], axes[1, 1]] plt.tight_layout(pad=0.1) if self.n_volumes <= 1: @@ -132,14 +129,14 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', r = [self._scalers[self._order[2]] / self._scalers[self._order[1]], self._scalers[self._order[2]] / self._scalers[self._order[0]], self._scalers[self._order[1]] / self._scalers[self._order[0]]] - self._sizes = [self._data.shape[o] for o in self._order] + self._sizes = [self._data.shape[order] for order in self._order] for ii, xax, yax, ratio, label in zip([0, 1, 2], [1, 0, 0], [2, 2, 1], r, ('SAIP', 'SLIR', 'ALPR')): ax = self._axes[ii] d = np.zeros((self._sizes[yax], self._sizes[xax])) - im = self._axes[ii].imshow(d, vmin=vmin, vmax=vmax, aspect=1, - cmap=cmap, interpolation='nearest', - origin='lower') + im = self._axes[ii].imshow( + d, vmin=self._clim[0], vmax=self._clim[1], aspect=1, + cmap='gray', interpolation='nearest', origin='lower') self._ims.append(im) vert = ax.plot([0] * 2, [-0.5, self._sizes[yax] - 0.5], color=(0, 1, 0), linestyle='-')[0] @@ -240,6 +237,11 @@ def _cleanup(self): for link in list(self._links): # make a copy before iterating self._unlink(link()) + def draw(self): + """Redraw the current image""" + for fig in self._figs: + fig.canvas.draw() + @property def n_volumes(self): """Number of volumes in the data""" @@ -250,6 +252,38 @@ def position(self): """The current coordinates""" return self._position[:3].copy() + @property + def figs(self): + """A tuple of the figure(s) containing the axes""" + return tuple(self._figs) + + @property + def cmap(self): + """The current colormap""" + return self._cmap + + @cmap.setter + def cmap(self, cmap): + for im in self._ims: + im.set_cmap(cmap) + self._cmap = cmap + self.draw() + + @property + def clim(self): + """The current color limits""" + return self._clim + + @clim.setter + def clim(self, clim): + clim = np.array(clim, float) + if clim.shape != (2,): + raise ValueError('clim must be a 2-element array-like') + for im in self._ims: + im.set_clim(clim) + self._clim = tuple(clim) + self.draw() + def link_to(self, other): """Link positional changes between two canvases