Skip to content

Commit 1bf0850

Browse files
committed
fixes
Signed-off-by: sewon.jeon <[email protected]>
1 parent aaf2833 commit 1bf0850

File tree

3 files changed

+37
-25
lines changed

3 files changed

+37
-25
lines changed

monai/transforms/post/array.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,14 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
753753

754754
class GenerateHeatmap(Transform):
755755
"""
756-
Generate per-landmark gaussian response maps for 2D or 3D coordinates.
756+
Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.
757+
758+
Notes:
759+
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
760+
- Output shape:
761+
- Non-batched points (N, D): (N, H, W[, D])
762+
- Batched points (B, N, D): (B, N, H, W[, D])
763+
- Each channel corresponds to one landmark.
757764
758765
Args:
759766
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
@@ -829,11 +836,13 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
829836
continue
830837
region = heatmap[b_idx, idx][window_slices]
831838
gaussian = self._evaluate_gaussian(coord_shifts, sigma)
832-
torch.maximum(region, gaussian, out=region)
839+
updated = torch.maximum(region, gaussian)
840+
# write back
841+
region.copy_(updated)
833842
if self.normalize:
834-
max_val = heatmap[b_idx, idx].max()
835-
if max_val.item() > 0:
836-
heatmap[b_idx, idx] /= max_val
843+
peak = updated.max()
844+
if peak.item() > 0:
845+
heatmap[b_idx, idx] /= peak
837846

838847
if not is_batched:
839848
heatmap = heatmap.squeeze(0)
@@ -851,7 +860,9 @@ def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims:
851860
if len(shape_tuple) == 1:
852861
shape_tuple = shape_tuple * spatial_dims # type: ignore
853862
else:
854-
raise ValueError("spatial_shape length must match spatial dimension of the landmarks.")
863+
raise ValueError(
864+
"spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
865+
)
855866
return tuple(int(s) for s in shape_tuple)
856867

857868
def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
@@ -879,7 +890,7 @@ def _make_window(
879890
if start >= stop:
880891
return None, ()
881892
slices.append(slice(start, stop))
882-
coord_shifts.append(torch.arange(start, stop, device=device, dtype=self.torch_dtype) - float(c))
893+
coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c))
883894
return tuple(slices), tuple(coord_shifts)
884895

885896
def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
@@ -897,13 +908,15 @@ def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tupl
897908
shape = tuple(len(axis) for axis in coord_shifts)
898909
if 0 in shape:
899910
return torch.zeros(shape, dtype=self.torch_dtype, device=device)
900-
exponent = torch.zeros(shape, dtype=self.torch_dtype, device=device)
911+
exponent = torch.zeros(shape, dtype=torch.float32, device=device)
901912
for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)):
902-
scaled = (shift / float(sig)) ** 2
913+
shift32 = shift.to(torch.float32)
914+
scaled = (shift32 / float(sig)) ** 2
903915
reshape_shape = [1] * len(coord_shifts)
904916
reshape_shape[dim] = shift.numel()
905917
exponent += scaled.reshape(reshape_shape)
906-
return torch.exp(-0.5 * exponent)
918+
gauss = torch.exp(-0.5 * exponent)
919+
return gauss.to(dtype=self.torch_dtype)
907920

908921

909922
class ProbNMS(Transform):

monai/transforms/post/dictionary.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,13 @@ class GenerateHeatmapd(MapTransform):
517517
"""
518518
Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`.
519519
Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image.
520+
521+
Notes:
522+
- Default heatmap_keys are generated as "{key}_heatmap" for each input key
523+
- Shape inference precedence: static spatial_shape > ref_image
524+
- Output shapes:
525+
- Non-batched points (N, D): (N, H, W[, D])
526+
- Batched points (B, N, D): (B, N, H, W[, D])
520527
"""
521528

522529
backend = GenerateHeatmap.backend
@@ -538,7 +545,7 @@ def __init__(
538545
spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None,
539546
truncated: float = 4.0,
540547
normalize: bool = True,
541-
dtype: np.dtype | type = np.float32,
548+
dtype: np.dtype | torch.dtype | type = np.float32,
542549
allow_missing_keys: bool = False,
543550
) -> None:
544551
super().__init__(keys, allow_missing_keys)
@@ -567,6 +574,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
567574
)
568575
# Copy metadata if reference is MetaTensor
569576
if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):
577+
heatmap.affine = reference.affine
570578
self._update_spatial_metadata(heatmap, reference)
571579
d[out_key] = heatmap
572580
return d
@@ -640,18 +648,8 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
640648

641649
def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None:
642650
"""Update spatial metadata of heatmap based on its dimensions."""
643-
# Determine if batched based on reference's batch dimension
644-
ref_spatial_shape = reference.meta.get("spatial_shape", [])
645-
ref_is_batched = len(reference.shape) > len(ref_spatial_shape) + 1
646-
647-
if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D)
648-
spatial_shape = heatmap.shape[2:]
649-
elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
650-
# Disambiguate: 2D batched vs 3D non-batched
651-
spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:]
652-
else: # 2D non-batched: (C, H, W)
653-
spatial_shape = heatmap.shape[1:]
654-
651+
# trailing dims after channel are spatial regardless of batch presence
652+
spatial_shape = heatmap.shape[-(reference.ndim - 1) :]
655653
heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)
656654

657655

tests/transforms/test_generate_heatmapd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
f"dict_static_shape_{len(shape)}d",
5454
np.array([[1.0] * len(shape)], dtype=np.float32),
5555
{"spatial_shape": shape},
56-
(1,) + shape,
56+
(1, *shape),
5757
np.float32,
5858
]
5959
)
@@ -165,7 +165,8 @@ def test_dict_batched_with_ref(self, _, points, params, expected_shape, _expecte
165165
assert_allclose(heatmap.affine, image.affine, type_test=False)
166166

167167
# Check max values
168-
max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0]
168+
hm2 = heatmap.reshape(heatmap.shape[0], heatmap.shape[1], -1)
169+
max_vals = hm2.max(dim=2)[0]
169170
np.testing.assert_allclose(
170171
max_vals.cpu().numpy(), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5
171172
)

0 commit comments

Comments
 (0)