@@ -753,7 +753,14 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
753
753
754
754
class GenerateHeatmap (Transform ):
755
755
"""
756
- Generate per-landmark gaussian response maps for 2D or 3D coordinates.
756
+ Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.
757
+
758
+ Notes:
759
+ - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
760
+ - Output shape:
761
+ - Non-batched points (N, D): (N, H, W[, D])
762
+ - Batched points (B, N, D): (B, N, H, W[, D])
763
+ - Each channel corresponds to one landmark.
757
764
758
765
Args:
759
766
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
@@ -829,11 +836,13 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
829
836
continue
830
837
region = heatmap [b_idx , idx ][window_slices ]
831
838
gaussian = self ._evaluate_gaussian (coord_shifts , sigma )
832
- torch .maximum (region , gaussian , out = region )
839
+ updated = torch .maximum (region , gaussian )
840
+ # write back
841
+ region .copy_ (updated )
833
842
if self .normalize :
834
- max_val = heatmap [ b_idx , idx ] .max ()
835
- if max_val .item () > 0 :
836
- heatmap [b_idx , idx ] /= max_val
843
+ peak = updated .max ()
844
+ if peak .item () > 0 :
845
+ heatmap [b_idx , idx ] /= peak
837
846
838
847
if not is_batched :
839
848
heatmap = heatmap .squeeze (0 )
@@ -851,7 +860,9 @@ def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims:
851
860
if len (shape_tuple ) == 1 :
852
861
shape_tuple = shape_tuple * spatial_dims # type: ignore
853
862
else :
854
- raise ValueError ("spatial_shape length must match spatial dimension of the landmarks." )
863
+ raise ValueError (
864
+ "spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
865
+ )
855
866
return tuple (int (s ) for s in shape_tuple )
856
867
857
868
def _resolve_sigma (self , spatial_dims : int ) -> tuple [float , ...]:
@@ -879,7 +890,7 @@ def _make_window(
879
890
if start >= stop :
880
891
return None , ()
881
892
slices .append (slice (start , stop ))
882
- coord_shifts .append (torch .arange (start , stop , device = device , dtype = self . torch_dtype ) - float (c ))
893
+ coord_shifts .append (torch .arange (start , stop , device = device , dtype = torch . float32 ) - float (c ))
883
894
return tuple (slices ), tuple (coord_shifts )
884
895
885
896
def _evaluate_gaussian (self , coord_shifts : tuple [torch .Tensor , ...], sigma : tuple [float , ...]) -> torch .Tensor :
@@ -897,13 +908,15 @@ def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tupl
897
908
shape = tuple (len (axis ) for axis in coord_shifts )
898
909
if 0 in shape :
899
910
return torch .zeros (shape , dtype = self .torch_dtype , device = device )
900
- exponent = torch .zeros (shape , dtype = self . torch_dtype , device = device )
911
+ exponent = torch .zeros (shape , dtype = torch . float32 , device = device )
901
912
for dim , (shift , sig ) in enumerate (zip (coord_shifts , sigma )):
902
- scaled = (shift / float (sig )) ** 2
913
+ shift32 = shift .to (torch .float32 )
914
+ scaled = (shift32 / float (sig )) ** 2
903
915
reshape_shape = [1 ] * len (coord_shifts )
904
916
reshape_shape [dim ] = shift .numel ()
905
917
exponent += scaled .reshape (reshape_shape )
906
- return torch .exp (- 0.5 * exponent )
918
+ gauss = torch .exp (- 0.5 * exponent )
919
+ return gauss .to (dtype = self .torch_dtype )
907
920
908
921
909
922
class ProbNMS (Transform ):
0 commit comments