diff --git a/CHANGELOG.md b/CHANGELOG.md index e45649ed..6973baa3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,10 +8,17 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [0.2.9] - tbd + +### Fixed + +- Transformations of Points and Shapes are now applied before rendering with datashader (#378) + ## [0.2.8] - 2024-11-26 ### Changed -- Support for `xarray.DataTree` (which moved from `datatree.DataTree`) #380 + +- Support for `xarray.DataTree` (which moved from `datatree.DataTree`) (#380) ## [0.2.7] - 2024-10-24 @@ -45,10 +52,6 @@ and this project adheres to [Semantic Versioning][]. ## [0.2.5] - 2024-08-23 -### Added - -- - ### Changed - Replaced `outline` parameter in `render_labels` with alpha-based logic (#323) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 78c59436..8513dba9 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -162,7 +162,7 @@ def render_shapes( palette: list[str] | str | None = None, na_color: ColorLike | None = "default", outline_width: float | int = 1.5, - outline_color: str | list[float] = "#000000ff", + outline_color: str | list[float] = "#000000", outline_alpha: float | int = 0.0, cmap: Colormap | str | None = None, norm: Normalize | None = None, @@ -208,9 +208,11 @@ def render_shapes( won't be shown. outline_width : float | int, default 1.5 Width of the border. - outline_color : str | list[float], default "#000000ff" - Color of the border. Can either be a named color ("red"), a hex representation ("#000000ff") or a list of - floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). + outline_color : str | list[float], default "#000000" + Color of the border. Can either be a named color ("red"), a hex representation ("#000000") or a list of + floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). If the hex representation includes alpha, e.g. + "#000000ff", the last two positions are ignored, since the alpha of the outlines is solely controlled by + `outline_alpha`. outline_alpha : float | int, default 0.0 Alpha value for the outline of shapes. Invisible by default. cmap : Colormap | str | None, optional diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index a3a6c1a1..6ced5b27 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -18,10 +18,9 @@ from matplotlib.colors import ListedColormap, Normalize from scanpy._settings import settings as sc_settings from spatialdata import get_extent -from spatialdata.models import PointsModel, get_table_keys -from spatialdata.transformations import ( - set_transformation, -) +from spatialdata.models import PointsModel, ShapesModel, get_table_keys +from spatialdata.transformations import get_transformation, set_transformation +from spatialdata.transformations.transformations import Identity from xarray import DataTree from spatialdata_plot._logging import logger @@ -44,6 +43,7 @@ _get_colors_for_categorical_obs, _get_extent_and_range_for_datashader_canvas, _get_linear_colormap, + _get_transformation_matrix_for_datashader, _is_coercable_to_float, _map_color_seg, _maybe_set_colors, @@ -148,7 +148,7 @@ def _render_shapes( colorbar = False if col_for_color is None else legend_params.colorbar # Apply the transformation to the PatchCollection's paths - trans, _ = _prepare_transformation(sdata_filt.shapes[element], coordinate_system) + trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system) shapes = gpd.GeoDataFrame(shapes, geometry="geometry") @@ -168,14 +168,6 @@ def _render_shapes( ) if method == "datashader": - trans += ax.transData - - plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas( - sdata_filt.shapes[element], coordinate_system, ax, fig_params - ) - - cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext) - _geometry = shapes["geometry"] is_point = _geometry.type == "Point" @@ -184,19 +176,31 @@ def _render_shapes( scale = shapes[is_point]["radius"] * render_params.scale sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) + # apply transformations to the individual points + element_trans = get_transformation(sdata_filt.shapes[element]) + tm = _get_transformation_matrix_for_datashader(element_trans) + transformed_element = sdata_filt.shapes[element].transform( + lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2] + ) + transformed_element = ShapesModel.parse( + gpd.GeoDataFrame(data=sdata_filt.shapes[element].drop("geometry", axis=1), geometry=transformed_element) + ) + + plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas( + transformed_element, coordinate_system, ax, fig_params + ) + + cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext) + # in case we are coloring by a column in table - if col_for_color is not None and col_for_color not in sdata_filt.shapes[element].columns: - sdata_filt.shapes[element][col_for_color] = ( - color_vector if color_source_vector is None else color_source_vector - ) + if col_for_color is not None and col_for_color not in transformed_element.columns: + transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector # Render shapes with datashader color_by_categorical = col_for_color is not None and color_source_vector is not None aggregate_with_reduction = None if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1): if color_by_categorical: - agg = cvs.polygons( - sdata_filt.shapes[element], geometry="geometry", agg=ds.by(col_for_color, ds.count()) - ) + agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.by(col_for_color, ds.count())) else: reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean" logger.info( @@ -204,16 +208,16 @@ def _render_shapes( "to the matplotlib result." ) agg = _datashader_aggregate_with_function( - render_params.ds_reduction, cvs, sdata_filt.shapes[element], col_for_color, "shapes" + render_params.ds_reduction, cvs, transformed_element, col_for_color, "shapes" ) # save min and max values for drawing the colorbar aggregate_with_reduction = (agg.min(), agg.max()) else: - agg = cvs.polygons(sdata_filt.shapes[element], geometry="geometry", agg=ds.count()) + agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.count()) # render outlines if needed if (render_outlines := render_params.outline_alpha) > 0: agg_outlines = cvs.line( - sdata_filt.shapes[element], + transformed_element, geometry="geometry", line_width=render_params.outline_params.linewidth, ) @@ -287,13 +291,23 @@ def _render_shapes( rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) _cax = _ax_show_and_transform( - rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.fill_alpha + rgba_image, + trans_data, + ax, + zorder=render_params.zorder, + alpha=render_params.fill_alpha, + extent=x_ext + y_ext, ) # render outline image if needed if render_outlines: rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax) _ax_show_and_transform( - rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.outline_alpha + rgba_image, + trans_data, + ax, + zorder=render_params.zorder, + alpha=render_params.outline_alpha, + extent=x_ext + y_ext, ) cax = None @@ -330,7 +344,7 @@ def _render_shapes( if not values_are_categorical: # If the user passed a Normalize object with vmin/vmax we'll use those, - # # if not we'll use the min/max of the color_vector + # if not we'll use the min/max of the color_vector _cax.set_clim( vmin=render_params.cmap_params.norm.vmin or min(color_vector), vmax=render_params.cmap_params.norm.vmax or max(color_vector), @@ -468,7 +482,7 @@ def _render_points( if color_source_vector is None and render_params.transfunc is not None: color_vector = render_params.transfunc(color_vector) - _, trans_data = _prepare_transformation(sdata.points[element], coordinate_system, ax) + trans, trans_data = _prepare_transformation(sdata.points[element], coordinate_system, ax) norm = copy(render_params.cmap_params.norm) @@ -491,8 +505,15 @@ def _render_points( # use dpi/100 as a factor for cases where dpi!=100 px = int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100))) + # apply transformations + transformed_element = PointsModel.parse( + trans.transform(sdata_filt.points[element][["x", "y"]]), + annotation=sdata_filt.points[element][sdata_filt.points[element].columns.drop(["x", "y"])], + transformations={coordinate_system: Identity()}, + ) + plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas( - sdata_filt.points[element], coordinate_system, ax, fig_params + transformed_element, coordinate_system, ax, fig_params ) # use datashader for the visualization of points @@ -502,7 +523,7 @@ def _render_points( aggregate_with_reduction = None if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1): if color_by_categorical: - agg = cvs.points(sdata_filt.points[element], "x", "y", agg=ds.by(col_for_color, ds.count())) + agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count())) else: reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum" logger.info( @@ -510,12 +531,12 @@ def _render_points( "to the matplotlib result." ) agg = _datashader_aggregate_with_function( - render_params.ds_reduction, cvs, sdata_filt.points[element], col_for_color, "points" + render_params.ds_reduction, cvs, transformed_element, col_for_color, "points" ) # save min and max values for drawing the colorbar aggregate_with_reduction = (agg.min(), agg.max()) else: - agg = cvs.points(sdata_filt.points[element], "x", "y", agg=ds.count()) + agg = cvs.points(transformed_element, "x", "y", agg=ds.count()) if norm.vmin is not None or norm.vmax is not None: norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin @@ -573,7 +594,14 @@ def _render_points( ) rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) - _ax_show_and_transform(rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.alpha) + _ax_show_and_transform( + rgba_image, + trans_data, + ax, + zorder=render_params.zorder, + alpha=render_params.alpha, + extent=x_ext + y_ext, + ) cax = None if aggregate_with_reduction is not None: diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index dd3d5ec0..2fe377cc 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -19,6 +19,7 @@ import matplotlib.transforms as mtransforms import numpy as np import numpy.ma as ma +import numpy.typing as npt import pandas as pd import shapely import spatialdata as sd @@ -58,8 +59,11 @@ from spatialdata._core.query.relational_query import _locate_value, _ValueOrigin from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, SpatialElement, get_model + +# from spatialdata.transformations.transformations import Scale +from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation +from spatialdata.transformations import Sequence as SDSequence from spatialdata.transformations.operations import get_transformation -from spatialdata.transformations.transformations import Scale from xarray import DataArray, DataTree from spatialdata_plot._logging import logger @@ -1977,12 +1981,29 @@ def _ax_show_and_transform( alpha: float | None = None, cmap: ListedColormap | LinearSegmentedColormap | None = None, zorder: int = 0, + extent: list[float] | None = None, ) -> matplotlib.image.AxesImage: + # default extent in mpl: + image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5] + if extent is not None: + # make sure extent is [x_min, x_max, y_min, y_max] + if extent[3] < extent[2]: + extent[2], extent[3] = extent[3], extent[2] + if extent[0] < 0: + x_factor = array.shape[1] / (extent[1] - extent[0]) + image_extent[0] = image_extent[0] + (extent[0] * x_factor) + image_extent[1] = image_extent[1] + (extent[0] * x_factor) + if extent[2] < 0: + y_factor = array.shape[0] / (extent[3] - extent[2]) + image_extent[2] = image_extent[2] + (extent[2] * y_factor) + image_extent[3] = image_extent[3] + (extent[2] * y_factor) + if not cmap and alpha is not None: im = ax.imshow( array, alpha=alpha, zorder=zorder, + extent=tuple(image_extent), ) im.set_transform(trans_data) else: @@ -1990,6 +2011,7 @@ def _ax_show_and_transform( array, cmap=cmap, zorder=zorder, + extent=tuple(image_extent), ) im.set_transform(trans_data) return im @@ -2055,7 +2077,7 @@ def _get_extent_and_range_for_datashader_canvas( def _create_image_from_datashader_result( ds_result: ds.transfer_functions.Image, factor: float, ax: Axes -) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.CompositeGenericTransform]: +) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]: # create SpatialImage from datashader output to get it back to original size rgba_image_data = ds_result.to_numpy().base rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1)) @@ -2187,3 +2209,34 @@ def _prepare_transformation( trans_data = trans + ax.transData if ax is not None else None return trans, trans_data + + +def _get_datashader_trans_matrix_of_single_element( + trans: Identity | Scale | Affine | MapAxis | Translation, +) -> npt.NDArray[Any]: + flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) + tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y")) + + if isinstance(trans, Identity): + return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + if isinstance(trans, (Scale | Affine)): + # idea: "flip the y-axis", apply transformation, flip back + flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix + return flip_and_transform + if isinstance(trans, MapAxis): + # no flipping needed + return tm + # for a Translation, we need the transposed transformation matrix + return tm.T + + +def _get_transformation_matrix_for_datashader( + trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence, +) -> npt.NDArray[Any]: + """Get the affine matrix needed to transform shapes for rendering with datashader.""" + if isinstance(trans, SDSequence): + tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + for x in trans.transformations: + tm = tm @ _get_datashader_trans_matrix_of_single_element(x) + return tm + return _get_datashader_trans_matrix_of_single_element(trans) diff --git a/tests/_images/Points_datashader_can_transform_points.png b/tests/_images/Points_datashader_can_transform_points.png new file mode 100644 index 00000000..201b2db1 Binary files /dev/null and b/tests/_images/Points_datashader_can_transform_points.png differ diff --git a/tests/_images/Points_points_transformed_ds_agrees_with_mpl.png b/tests/_images/Points_points_transformed_ds_agrees_with_mpl.png new file mode 100644 index 00000000..c26c62ce Binary files /dev/null and b/tests/_images/Points_points_transformed_ds_agrees_with_mpl.png differ diff --git a/tests/_images/Shapes_datashader_can_transform_circles.png b/tests/_images/Shapes_datashader_can_transform_circles.png new file mode 100644 index 00000000..60cde073 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_transform_circles.png differ diff --git a/tests/_images/Shapes_datashader_can_transform_multipolygons.png b/tests/_images/Shapes_datashader_can_transform_multipolygons.png new file mode 100644 index 00000000..09e56c63 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_transform_multipolygons.png differ diff --git a/tests/_images/Shapes_datashader_can_transform_polygons.png b/tests/_images/Shapes_datashader_can_transform_polygons.png new file mode 100644 index 00000000..fb2552ff Binary files /dev/null and b/tests/_images/Shapes_datashader_can_transform_polygons.png differ diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index b2dc9179..e3e99099 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -1,3 +1,5 @@ +import math + import dask.dataframe import matplotlib import matplotlib.pyplot as plt @@ -7,6 +9,8 @@ from anndata import AnnData from spatialdata import SpatialData, deepcopy from spatialdata.models import PointsModel, TableModel +from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation +from spatialdata.transformations._utils import _set_transformations import spatialdata_plot # noqa: F401 from tests.conftest import DPI, PlotTester, PlotTesterMeta @@ -173,3 +177,38 @@ def test_plot_mpl_and_datashader_point_sizes_agree_after_altered_dpi(self, sdata sdata_blobs.pl.render_points(element="blobs_points", size=400, color="blue").pl.render_points( element="blobs_points", size=400, color="yellow", method="datashader", alpha=0.8 ).pl.show(dpi=200) + + def test_plot_points_transformed_ds_agrees_with_mpl(self): + sdata = SpatialData( + points={ + "points1": PointsModel.parse( + pd.DataFrame({"y": [0, 0, 10, 10, 4, 6, 4, 6], "x": [0, 10, 10, 0, 4, 6, 6, 4]}), + transformations={"global": Scale([2, 2], ("y", "x"))}, + ) + }, + ) + sdata.pl.render_points("points1", method="matplotlib", size=50, color="lightgrey").pl.render_points( + "points1", method="datashader", size=10, color="red" + ).pl.show() + + def test_plot_datashader_can_transform_points(self, sdata_blobs: SpatialData): + theta = math.pi / 1.7 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + + scale = Scale([-1.3, 1.8], axes=("x", "y")) + identity = Identity() + mapaxis = MapAxis({"x": "y", "y": "x"}) + translation = Translation([20, -65], ("x", "y")) + seq = Sequence([mapaxis, scale, identity, translation, rotation]) + + _set_transformations(sdata_blobs["blobs_points"], {"global": seq}) + + sdata_blobs.pl.render_points("blobs_points", method="datashader", color="black", size=5).pl.show() diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index ff86a4f0..d683189a 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -1,3 +1,5 @@ +import math + import anndata import geopandas as gpd import matplotlib @@ -10,6 +12,8 @@ from shapely.geometry import MultiPolygon, Point, Polygon from spatialdata import SpatialData, deepcopy from spatialdata.models import ShapesModel, TableModel +from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation +from spatialdata.transformations._utils import _set_transformations import spatialdata_plot # noqa: F401 from tests.conftest import DPI, PlotTester, PlotTesterMeta @@ -377,3 +381,69 @@ def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes( "blobs_circles", color="dummy_gene_expression", norm=norm, table_name="new_table" ).pl.show() + + def test_plot_datashader_can_transform_polygons(self, sdata_blobs: SpatialData): + theta = math.pi / 1.7 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + + scale = Scale([-1.3, 1.8], axes=("x", "y")) + identity = Identity() + mapaxis = MapAxis({"x": "y", "y": "x"}) + translation = Translation([20, -65], ("x", "y")) + seq = Sequence([mapaxis, scale, identity, translation, rotation]) + + _set_transformations(sdata_blobs["blobs_polygons"], {"global": seq}) + + sdata_blobs.pl.render_shapes("blobs_polygons", method="datashader", outline_alpha=1.0).pl.show() + + def test_plot_datashader_can_transform_multipolygons(self, sdata_blobs: SpatialData): + theta = math.pi / 1.7 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + + scale = Scale([-1.3, 1.8], axes=("x", "y")) + identity = Identity() + mapaxis = MapAxis({"x": "y", "y": "x"}) + translation = Translation([20, -65], ("x", "y")) + seq = Sequence([mapaxis, scale, identity, translation, rotation]) + + _set_transformations(sdata_blobs["blobs_multipolygons"], {"global": seq}) + + sdata_blobs.pl.render_shapes("blobs_multipolygons", method="datashader", outline_alpha=1.0).pl.show() + + def test_plot_datashader_can_transform_circles(self, sdata_blobs: SpatialData): + theta = math.pi / 1.7 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + + scale = Scale([-1.3, 1.8], axes=("x", "y")) + identity = Identity() + mapaxis = MapAxis({"x": "y", "y": "x"}) + translation = Translation([20, -65], ("x", "y")) + seq = Sequence([mapaxis, scale, identity, translation, rotation]) + + _set_transformations(sdata_blobs["blobs_circles"], {"global": seq}) + + sdata_blobs.pl.render_shapes("blobs_circles", method="datashader", outline_alpha=1.0).pl.show()