diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index 0a0f456..071d73f 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -4,7 +4,7 @@ from dash.dependencies import Input, Output, State, ALL from dash_core_components import Graph, Slider, Store -from .utils import img_array_to_uri, get_thumbnail_size_from_shape +from .utils import img_array_to_uri, get_thumbnail_size_from_shape, shape3d_to_size2d class DashVolumeSlicer: @@ -12,8 +12,17 @@ class DashVolumeSlicer: Parameters: app (dash.Dash): the Dash application instance. - volume (ndarray): the 3D numpy array to slice through. + volume (ndarray): the 3D numpy array to slice through. The dimensions + are assumed to be in zyx order. If this is not the case, you can + use ``np.swapaxes`` to make it so. + spacing (tuple of floats): The distance between voxels for each dimension (zyx). + The spacing and origin are applied to make the slice drawn in + "scene space" rather than "voxel space". + origin (tuple of floats): The offset for each dimension (zyx). axis (int): the dimension to slice in. Default 0. + reverse_y (bool): Whether to reverse the y-axis, so that the origin of + the slice is in the top-left, rather than bottom-left. Default True. + (This sets the figure's yaxes ``autorange`` to either "reversed" or True.) scene_id (str): the scene that this slicer is part of. Slicers that have the same scene-id show each-other's positions with line indicators. By default this is a hash of ``id(volume)``. @@ -38,7 +47,18 @@ class DashVolumeSlicer: _global_slicer_counter = 0 - def __init__(self, app, volume, axis=0, scene_id=None): + def __init__( + self, + app, + volume, + *, + spacing=None, + origin=None, + axis=0, + reverse_y=True, + scene_id=None + ): + # todo: also implement xyz dim order? if not isinstance(app, Dash): raise TypeError("Expect first arg to be a Dash app.") self._app = app @@ -46,6 +66,10 @@ def __init__(self, app, volume, axis=0, scene_id=None): if not (isinstance(volume, np.ndarray) and volume.ndim == 3): raise TypeError("Expected volume to be a 3D numpy array") self._volume = volume + spacing = (1, 1, 1) if spacing is None else spacing + spacing = float(spacing[0]), float(spacing[1]), float(spacing[2]) + origin = (0, 0, 0) if origin is None else origin + origin = float(origin[0]), float(origin[1]), float(origin[2]) # Check and store axis if not (isinstance(axis, int) and 0 <= axis <= 2): raise ValueError("The given axis must be 0, 1, or 2.") @@ -60,20 +84,26 @@ def __init__(self, app, volume, axis=0, scene_id=None): DashVolumeSlicer._global_slicer_counter += 1 self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter) - # Get the slice size (width, height), and max index - arr_shape = list(volume.shape) - arr_shape.pop(self._axis) - self._slice_size = tuple(reversed(arr_shape)) - self._max_index = self._volume.shape[self._axis] - 1 + # Prepare slice info + info = { + "shape": tuple(volume.shape), + "axis": self._axis, + "size": shape3d_to_size2d(volume.shape, axis), + "origin": shape3d_to_size2d(origin, axis), + "spacing": shape3d_to_size2d(spacing, axis), + } # Prep low-res slices - thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32) + thumbnail_size = get_thumbnail_size_from_shape( + (info["size"][1], info["size"][0]), 32 + ) thumbnails = [ img_array_to_uri(self._slice(i), thumbnail_size) - for i in range(self._max_index + 1) + for i in range(info["size"][2]) ] + info["lowres_size"] = thumbnail_size - # Create a placeholder trace + # Create traces # todo: can add "%{z[0]}", but that would be the scaled value ... image_trace = Image( source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})" @@ -97,6 +127,7 @@ def __init__(self, app, volume, axis=0, scene_id=None): scaleanchor="x", showticklabels=False, zeroline=False, + autorange="reversed" if reverse_y else True, ) # Wrap the figure in a graph # todo: or should the user provide this? @@ -106,22 +137,20 @@ def __init__(self, app, volume, axis=0, scene_id=None): config={"scrollZoom": True}, ) # Create a slider object that the user can put in the layout (or not) - # todo: use tooltip to show current value? self.slider = Slider( id=self._subid("slider"), min=0, - max=self._max_index, + max=info["size"][2] - 1, step=1, - value=self._max_index // 2, + value=info["size"][2] // 2, tooltip={"always_visible": False, "placement": "left"}, updatemode="drag", ) # Create the stores that we need (these must be present in the layout) self.stores = [ - Store( - id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size - ), + Store(id=self._subid("info"), data=info), Store(id=self._subid("index"), data=volume.shape[self._axis] // 2), + Store(id=self._subid("position"), data=0), Store(id=self._subid("_requested-slice-index"), data=0), Store(id=self._subid("_slice-data"), data=""), Store(id=self._subid("_slice-data-lowres"), data=thumbnails), @@ -175,6 +204,17 @@ def _create_client_callbacks(self): [Input(self._subid("slider"), "value")], ) + app.clientside_callback( + """ + function update_position(index, info) { + return info.origin[2] + index * info.spacing[2]; + } + """, + Output(self._subid("position"), "data"), + [Input(self._subid("index"), "data")], + [State(self._subid("info"), "data")], + ) + app.clientside_callback( """ function handle_slice_index(index) { @@ -205,7 +245,7 @@ def _create_client_callbacks(self): app.clientside_callback( """ - function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) { + function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) { let new_index = index_and_data[0]; let new_data = index_and_data[1]; // Store data in cache @@ -214,18 +254,18 @@ def _create_client_callbacks(self): slice_cache[new_index] = new_data; // Get the data we need *now* let data = slice_cache[index]; - let x0 = 0, y0 = 0, dx = 1, dy = 1; + let x0 = info.origin[0], y0 = info.origin[1]; + let dx = info.spacing[0], dy = info.spacing[1]; //slice_cache[new_index] = undefined; // todo: disabled cache for now! // Maybe we do not need an update if (!data) { data = lowres[index]; // Scale the image to take the exact same space as the full-res // version. It's not correct, but it looks better ... - // slice_size = full_w, full_h, low_w, low_h - dx = slice_size[0] / slice_size[2]; - dy = slice_size[1] / slice_size[3]; - x0 = 0.5 * dx - 0.5; - y0 = 0.5 * dy - 0.5; + dx *= info.size[0] / info.lowres_size[0]; + dy *= info.size[1] / info.lowres_size[1]; + x0 += 0.5 * dx - 0.5 * info.spacing[0]; + y0 += 0.5 * dy - 0.5 * info.spacing[1]; } if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) { return window.dash_clientside.no_update; @@ -253,7 +293,7 @@ def _create_client_callbacks(self): [ State(self._subid("graph"), "figure"), State(self._subid("_slice-data-lowres"), "data"), - State(self._subid("_slice-size"), "data"), + State(self._subid("info"), "data"), ], ) @@ -266,18 +306,22 @@ def _create_client_callbacks(self): # * match any of the selected axii app.clientside_callback( """ - function handle_indicator(indices1, indices2, slice_size, current) { - let w = slice_size[0], h = slice_size[1]; - let dx = w / 20, dy = h / 20; + function handle_indicator(positions1, positions2, info, current) { + let x0 = info.origin[0], y0 = info.origin[1]; + let x1 = x0 + info.size[0] * info.spacing[0], y1 = y0 + info.size[1] * info.spacing[1]; + x0 = x0 - info.spacing[0], y0 = y0 - info.spacing[1]; + let d = ((x1 - x0) + (y1 - y0)) * 0.5 * 0.05; let version = (current.version || 0) + 1; let x = [], y = []; - for (let index of indices1) { - x.push(...[-dx, -1, null, w, w + dx, null]); - y.push(...[index, index, index, index, index, index]); + for (let pos of positions1) { + // x relative to our slice, y in scene-coords + x.push(...[x0 - d, x0, null, x1, x1 + d, null]); + y.push(...[pos, pos, pos, pos, pos, pos]); } - for (let index of indices2) { - x.push(...[index, index, index, index, index, index]); - y.push(...[-dy, -1, null, h, h + dy, null]); + for (let pos of positions2) { + // x in scene-coords, y relative to our slice + x.push(...[pos, pos, pos, pos, pos, pos]); + y.push(...[y0 - d, y0, null, y1, y1 + d, null]); } return { type: 'scatter', @@ -296,7 +340,7 @@ def _create_client_callbacks(self): { "scene": self.scene_id, "context": ALL, - "name": "index", + "name": "position", "axis": axis, }, "data", @@ -304,7 +348,7 @@ def _create_client_callbacks(self): for axis in axii ], [ - State(self._subid("_slice-size"), "data"), + State(self._subid("info"), "data"), State(self._subid("_indicators"), "data"), ], ) diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py index 583e8b2..3bb57a1 100644 --- a/dash_3d_viewer/utils.py +++ b/dash_3d_viewer/utils.py @@ -32,3 +32,14 @@ def get_thumbnail_size_from_shape(shape, base_size): img_pil = PIL.Image.fromarray(img_array) img_pil.thumbnail((base_size, base_size)) return img_pil.size + + +def shape3d_to_size2d(shape, axis): + """Turn a 3d shape (z, y, x) into a local (x', y', z'), + where z' represents the dimension indicated by axis. + """ + shape = list(shape) + axis_value = shape.pop(axis) + size = list(reversed(shape)) + size.append(axis_value) + return tuple(size) diff --git a/examples/slicer_with_1_plus_2_views.py b/examples/slicer_with_1_plus_2_views.py index 312adf3..e04d4cf 100644 --- a/examples/slicer_with_1_plus_2_views.py +++ b/examples/slicer_with_1_plus_2_views.py @@ -3,10 +3,14 @@ This demonstrates how multiple indicators can be shown per axis. Sharing the same scene_id is enough for the slicers to show each-others -position. If the same volume object is given, it works by default, +position. If the same volume object would be given, it works by default, because the default scene_id is a hash of the volume object. Specifying a scene_id provides slice position indicators even when slicing through different volumes. + +Further, this example has one slider showing data with different spacing. +Note how the indicators represent the actual position in "scene coordinates". + """ import dash @@ -17,22 +21,33 @@ app = dash.Dash(__name__) -vol = imageio.volread("imageio:stent.npz") -slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene") -slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene") -slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene") +vol1 = imageio.volread("imageio:stent.npz") + +vol2 = vol1[::3, ::2, :] +spacing = 3, 2, 1 +ori = 110, 120, 140 + + +slicer1 = DashVolumeSlicer( + app, vol1, axis=1, origin=ori, reverse_y=False, scene_id="scene1" +) +slicer2 = DashVolumeSlicer( + app, vol1, axis=0, origin=ori, reverse_y=False, scene_id="scene1" +) +slicer3 = DashVolumeSlicer( + app, vol2, axis=0, origin=ori, spacing=spacing, reverse_y=False, scene_id="scene1" +) app.layout = html.Div( style={ "display": "grid", - "grid-template-columns": "40% 40%", + "gridTemplateColumns": "40% 40%", }, children=[ html.Div( [ html.H1("Coronal"), slicer1.graph, - html.Br(), slicer1.slider, *slicer1.stores, ] @@ -41,7 +56,6 @@ [ html.H1("Transversal 1"), slicer2.graph, - html.Br(), slicer2.slider, *slicer2.stores, ] @@ -51,7 +65,6 @@ [ html.H1("Transversal 2"), slicer3.graph, - html.Br(), slicer3.slider, *slicer3.stores, ] diff --git a/examples/slicer_with_2_views.py b/examples/slicer_with_2_views.py index 73b829c..7019df0 100644 --- a/examples/slicer_with_2_views.py +++ b/examples/slicer_with_2_views.py @@ -17,7 +17,7 @@ app.layout = html.Div( style={ "display": "grid", - "grid-template-columns": "40% 40%", + "gridTemplateColumns": "40% 40%", }, children=[ html.Div( diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py index 54b1a3f..94fccbb 100644 --- a/examples/slicer_with_3_views.py +++ b/examples/slicer_with_3_views.py @@ -15,9 +15,9 @@ # Read volumes and create slicer objects vol = imageio.volread("imageio:stent.npz") -slicer1 = DashVolumeSlicer(app, vol, axis=0) -slicer2 = DashVolumeSlicer(app, vol, axis=1) -slicer3 = DashVolumeSlicer(app, vol, axis=2) +slicer1 = DashVolumeSlicer(app, vol, reverse_y=False, axis=0) +slicer2 = DashVolumeSlicer(app, vol, reverse_y=False, axis=1) +slicer3 = DashVolumeSlicer(app, vol, reverse_y=False, axis=2) # Calculate isosurface and create a figure with a mesh object verts, faces, _, _ = marching_cubes(vol, 300, step_size=2) @@ -30,7 +30,7 @@ app.layout = html.Div( style={ "display": "grid", - "grid-template-columns": "40% 40%", + "gridTemplateColumns": "40% 40%", }, children=[ html.Div( diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..4cdb860 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,14 @@ +from dash_3d_viewer.utils import shape3d_to_size2d + +from pytest import raises + + +def test_shape3d_to_size2d(): + # shape -> z, y, x + # size -> x, y, out-of-plane + assert shape3d_to_size2d((12, 13, 14), 0) == (14, 13, 12) + assert shape3d_to_size2d((12, 13, 14), 1) == (14, 12, 13) + assert shape3d_to_size2d((12, 13, 14), 2) == (13, 12, 14) + + with raises(IndexError): + shape3d_to_size2d((12, 13, 14), 3)