diff --git a/.mypy.ini b/.mypy.ini index 77bf7465..0eee2044 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,6 +1,5 @@ [mypy] python_version = 3.10 -plugins = numpy.typing.mypy_plugin ignore_errors = False warn_redundant_casts = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0feb1bd8..3d1fb890 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,17 +9,17 @@ ci: skip: [] repos: - repo: https://github.com/rbubley/mirrors-prettier - rev: v3.5.3 + rev: v3.6.2 hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.9 + rev: v0.12.8 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.17.1 hooks: - id: mypy additional_dependencies: [numpy, types-requests] diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 2dd9019f..546b71b1 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -895,6 +895,7 @@ def show( cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist() ) ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i] + assert isinstance(ax, Axes) wants_images = False wants_labels = False diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 52297005..f126c5c0 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -340,6 +340,8 @@ def _render_shapes( vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: + assert norm.vmin is not None + assert norm.vmax is not None # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) vmin = norm.vmin - 0.5 @@ -693,6 +695,8 @@ def _render_points( vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: + assert norm.vmin is not None + assert norm.vmax is not None # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) vmin = norm.vmin - 0.5 diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a2e8f767..4e4f0b5f 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -526,6 +526,8 @@ def _prepare_cmap_norm( cmap = copy(cmap) + assert isinstance(cmap, Colormap), f"Invalid type of `cmap`: {type(cmap)}, expected `Colormap`." + if norm is None: norm = Normalize(vmin=None, vmax=None, clip=False) @@ -2045,11 +2047,20 @@ def _validate_image_render_params( spatial_element_ch = ( spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c ) - if (channel := param_dict["channel"]) is not None and ( - (isinstance(channel[0], int) and max([abs(ch) for ch in channel]) <= len(spatial_element_ch)) - or all(ch in spatial_element_ch for ch in channel) + + channel = param_dict["channel"] + channel_list: list[str] | list[int] | None + if isinstance(channel, list): + type_ = type(channel[0]) + assert all(isinstance(ch, type_) for ch in channel), "All channels must be of the same type." + # mypy complains that channel_list can be also of type list[str | int] + channel_list = [channel] if isinstance(channel, int | str) else channel # type: ignore[assignment] + + if channel_list is not None and ( + (isinstance(channel_list[0], int) and max([abs(ch) for ch in channel_list]) <= len(spatial_element_ch)) # type: ignore[arg-type] + or all(ch in spatial_element_ch for ch in channel_list) ): - element_params[el]["channel"] = channel + element_params[el]["channel"] = channel_list else: element_params[el]["channel"] = None @@ -2057,18 +2068,20 @@ def _validate_image_render_params( if isinstance(palette := param_dict["palette"], list): if len(palette) == 1: - palette_length = len(channel) if channel is not None else len(spatial_element_ch) + palette_length = len(channel_list) if channel_list is not None else len(spatial_element_ch) palette = palette * palette_length - if (channel is not None and len(palette) != len(channel)) and len(palette) != len(spatial_element_ch): + if (channel_list is not None and len(palette) != len(channel_list)) and len(palette) != len( + spatial_element_ch + ): palette = None element_params[el]["palette"] = palette element_params[el]["na_color"] = param_dict["na_color"] if (cmap := param_dict["cmap"]) is not None: if len(cmap) == 1: - cmap_length = len(channel) if channel is not None else len(spatial_element_ch) + cmap_length = len(channel_list) if channel_list is not None else len(spatial_element_ch) cmap = cmap * cmap_length - if (channel is not None and len(cmap) != len(channel)) or len(cmap) != len(spatial_element_ch): + if (channel_list is not None and len(cmap) != len(channel_list)) or len(cmap) != len(spatial_element_ch): cmap = None element_params[el]["cmap"] = cmap element_params[el]["norm"] = param_dict["norm"] @@ -2364,7 +2377,9 @@ def _get_datashader_trans_matrix_of_single_element( # no flipping needed return tm # for a Translation, we need the transposed transformation matrix - return tm.T + tm_T = tm.T + assert isinstance(tm_T, np.ndarray) + return tm_T def _get_transformation_matrix_for_datashader(