Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import warnings
from collections.abc import Callable, Iterable, Sequence
from typing import ClassVar

import numpy as np
import torch
Expand All @@ -38,7 +39,14 @@
remove_small_objects,
)
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option
from monai.utils import (
TransformBackends,
convert_data_type,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
)
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
Expand All @@ -54,6 +62,7 @@
"SobelGradients",
"VoteEnsemble",
"Invert",
"GenerateHeatmap",
"DistanceTransformEDT",
]

Expand Down Expand Up @@ -742,6 +751,174 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
return self.post_convert(out_pt, img)


class GenerateHeatmap(Transform):
"""
Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.

Notes:
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
- Output shape:
- Non-batched points (N, D): (N, H, W[, D])
- Batched points (B, N, D): (B, N, H, W[, D])
- Each channel corresponds to one landmark.

Args:
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).

Raises:
ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.

"""

backend: ClassVar[list] = [TransformBackends.NUMPY, TransformBackends.TORCH]

def __init__(
self,
sigma: Sequence[float] | float = 5.0,
spatial_shape: Sequence[int] | None = None,
truncated: float = 4.0,
normalize: bool = True,
dtype: np.dtype | torch.dtype | type = np.float32,
) -> None:
if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):
if any(s <= 0 for s in sigma):
raise ValueError("sigma values must be positive.")
self._sigma = tuple(float(s) for s in sigma)
else:
if float(sigma) <= 0:
raise ValueError("sigma must be positive.")
self._sigma = (float(sigma),)
if truncated <= 0:
raise ValueError("truncated must be positive.")
self.truncated = float(truncated)
self.normalize = normalize
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)

def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
original_points = points
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)

is_batched = points_t.ndim == 3
if not is_batched:
if points_t.ndim != 2:
raise ValueError(
"points must be a 2D or 3D array with shape (num_points, spatial_dims) or (B, num_points, spatial_dims)."
)
points_t = points_t.unsqueeze(0) # Add a batch dimension

if points_t.shape[-1] not in (2, 3):
raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.")

device = points_t.device
batch_size, num_points, spatial_dims = points_t.shape

target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
sigma = self._resolve_sigma(spatial_dims)
radius = tuple(int(np.ceil(self.truncated * s)) for s in sigma)

heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device)
image_bounds = tuple(int(s) for s in target_shape)
for b_idx in range(batch_size):
for idx, center in enumerate(points_t[b_idx]):
center_vals = center.tolist()
if not np.all(np.isfinite(center_vals)):
continue
if not self._is_inside(center_vals, image_bounds):
continue
window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device)
if window_slices is None:
continue
region = heatmap[b_idx, idx][window_slices]
gaussian = self._evaluate_gaussian(coord_shifts, sigma)
updated = torch.maximum(region, gaussian)
# write back
region.copy_(updated)
if self.normalize:
peak = updated.max()
if peak.item() > 0:
heatmap[b_idx, idx] /= peak

if not is_batched:
heatmap = heatmap.squeeze(0)

target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype
converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)
return converted

def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:
shape = call_shape if call_shape is not None else self.spatial_shape
if shape is None:
raise ValueError("spatial_shape must be provided either at construction time or call time.")
shape_tuple = ensure_tuple(shape)
if len(shape_tuple) != spatial_dims:
if len(shape_tuple) == 1:
shape_tuple = shape_tuple * spatial_dims # type: ignore
else:
raise ValueError(
"spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
)
return tuple(int(s) for s in shape_tuple)

def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
if len(self._sigma) == spatial_dims:
return self._sigma
if len(self._sigma) == 1:
return self._sigma * spatial_dims
raise ValueError("sigma sequence length must equal the number of spatial dimensions.")

@staticmethod
def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
for c, size in zip(center, bounds):
if not (0 <= c < size):
return False
return True

def _make_window(
self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device
) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]:
slices: list[slice] = []
coord_shifts: list[torch.Tensor] = []
for _dim, (c, r, size) in enumerate(zip(center, radius, bounds)):
start = max(int(np.floor(c - r)), 0)
stop = min(int(np.ceil(c + r)) + 1, size)
if start >= stop:
return None, ()
slices.append(slice(start, stop))
coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c))
return tuple(slices), tuple(coord_shifts)

def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
"""
Evaluate Gaussian at given coordinate shifts with specified sigmas.

Args:
coord_shifts: Per-dimension coordinate offsets from center.
sigma: Per-dimension standard deviations.

Returns:
Gaussian values at the specified coordinates.
"""
device = coord_shifts[0].device
shape = tuple(len(axis) for axis in coord_shifts)
if 0 in shape:
return torch.zeros(shape, dtype=self.torch_dtype, device=device)
exponent = torch.zeros(shape, dtype=torch.float32, device=device)
for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)):
shift32 = shift.to(torch.float32)
scaled = (shift32 / float(sig)) ** 2
reshape_shape = [1] * len(coord_shifts)
reshape_shape[dim] = shift.numel()
exponent += scaled.reshape(reshape_shape)
gauss = torch.exp(-0.5 * exponent)
return gauss.to(dtype=self.torch_dtype)


class ProbNMS(Transform):
"""
Performs probability based non-maximum suppression (NMS) on the probabilities map via
Expand Down
148 changes: 148 additions & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AsDiscrete,
DistanceTransformEDT,
FillHoles,
GenerateHeatmap,
KeepLargestConnectedComponent,
LabelFilter,
LabelToContour,
Expand All @@ -48,6 +49,7 @@
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode
from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
"ActivationsD",
Expand Down Expand Up @@ -95,6 +97,9 @@
"DistanceTransformEDTd",
"DistanceTransformEDTD",
"DistanceTransformEDTDict",
"GenerateHeatmapd",
"GenerateHeatmapD",
"GenerateHeatmapDict",
]

DEFAULT_POST_FIX = PostFix.meta()
Expand Down Expand Up @@ -508,6 +513,149 @@ def __init__(self, keys: KeysCollection, output_key: str | None = None, num_clas
super().__init__(keys, ensemble, output_key)


class GenerateHeatmapd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`.
Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image.

Notes:
- Default heatmap_keys are generated as "{key}_heatmap" for each input key
- Shape inference precedence: static spatial_shape > ref_image
- Output shapes:
- Non-batched points (N, D): (N, H, W[, D])
- Batched points (B, N, D): (B, N, H, W[, D])
"""

backend = GenerateHeatmap.backend

# Error messages
_ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length."
_ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided."
_ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes."
_ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
_ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)."
_ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."

def __init__(
self,
keys: KeysCollection,
sigma: Sequence[float] | float = 5.0,
heatmap_keys: KeysCollection | None = None,
ref_image_keys: KeysCollection | None = None,
spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None,
truncated: float = 4.0,
normalize: bool = True,
dtype: np.dtype | torch.dtype | type = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys)
self.ref_image_keys = self._prepare_optional_keys(ref_image_keys)
self.static_shapes = self._prepare_shapes(spatial_shape)
self.generator = GenerateHeatmap(
sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype
)

def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
d = dict(data)
for key, out_key, ref_key, static_shape in self.key_iterator(
d, self.heatmap_keys, self.ref_image_keys, self.static_shapes
):
points = d[key]
shape = self._determine_shape(points, static_shape, d, ref_key)
# The GenerateHeatmap transform will handle type conversion based on input points
heatmap = self.generator(points, spatial_shape=shape)
# If there's a reference image and we need to match its type/device
reference = d.get(ref_key) if ref_key is not None and ref_key in d else None
if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)):
# Convert to match reference type and device while preserving heatmap's dtype
heatmap, _, _ = convert_to_dst_type(
heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None)
)
# Copy metadata if reference is MetaTensor
if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):
heatmap.affine = reference.affine
self._update_spatial_metadata(heatmap, reference)
d[out_key] = heatmap
return d

def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]:
if heatmap_keys is None:
return tuple(f"{key}_heatmap" for key in self.keys)
keys_tuple = ensure_tuple(heatmap_keys)
if len(keys_tuple) == 1 and len(self.keys) > 1:
keys_tuple = keys_tuple * len(self.keys)
if len(keys_tuple) != len(self.keys):
raise ValueError(self._ERR_HEATMAP_KEYS_LEN)
return keys_tuple

def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]:
if maybe_keys is None:
return (None,) * len(self.keys)
keys_tuple = ensure_tuple(maybe_keys)
if len(keys_tuple) == 1 and len(self.keys) > 1:
keys_tuple = keys_tuple * len(self.keys)
if len(keys_tuple) != len(self.keys):
raise ValueError(self._ERR_REF_KEYS_LEN)
return tuple(keys_tuple)

def _prepare_shapes(
self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None
) -> tuple[tuple[int, ...] | None, ...]:
if spatial_shape is None:
return (None,) * len(self.keys)
shape_tuple = ensure_tuple(spatial_shape)
if shape_tuple and all(isinstance(v, (int, np.integer)) for v in shape_tuple):
shape = tuple(int(v) for v in shape_tuple)
return (shape,) * len(self.keys)
if len(shape_tuple) == 1 and len(self.keys) > 1:
shape_tuple = shape_tuple * len(self.keys)
if len(shape_tuple) != len(self.keys):
raise ValueError(self._ERR_SHAPE_LEN)
prepared: list[tuple[int, ...] | None] = []
for item in shape_tuple:
if item is None:
prepared.append(None)
else:
dims = ensure_tuple(item)
prepared.append(tuple(int(v) for v in dims))
return tuple(prepared)

def _determine_shape(
self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None
) -> tuple[int, ...]:
if static_shape is not None:
return static_shape
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
if points_t.ndim not in (2, 3):
raise ValueError(self._ERR_INVALID_POINTS)
spatial_dims = int(points_t.shape[-1])
if ref_key is not None and ref_key in data:
return self._shape_from_reference(data[ref_key], spatial_dims)
raise ValueError(self._ERR_NO_SHAPE)

def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, ...]:
if isinstance(reference, MetaTensor):
meta_shape = reference.meta.get("spatial_shape")
if meta_shape is not None:
dims = ensure_tuple(meta_shape)
if len(dims) == spatial_dims:
return tuple(int(v) for v in dims)
return tuple(int(v) for v in reference.shape[-spatial_dims:])
if hasattr(reference, "shape"):
return tuple(int(v) for v in reference.shape[-spatial_dims:])
raise ValueError(self._ERR_REF_NO_SHAPE)

def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None:
"""Update spatial metadata of heatmap based on its dimensions."""
# trailing dims after channel are spatial regardless of batch presence
spatial_shape = heatmap.shape[-(reference.ndim - 1) :]
heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)


GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd


class ProbNMSd(MapTransform):
"""
Performs probability based non-maximum suppression (NMS) on the probabilities map via
Expand Down
Loading
Loading