22
22
from __future__ import annotations
23
23
24
24
import logging
25
- from typing import Sequence , cast
25
+ from typing import Sequence
26
26
27
27
import numpy as np
28
- from lsst .scarlet .lite .detect_pybind11 import Footprint # type: ignore
28
+ from lsst .scarlet .lite .detect_pybind11 import Footprint , get_footprints # type: ignore
29
29
30
- from .bbox import Box
30
+ from .bbox import Box , overlapped_slices
31
31
from .image import Image
32
32
from .utils import continue_class
33
- from .wavelet import get_multiresolution_support , starlet_transform
33
+ from .wavelet import (
34
+ get_multiresolution_support ,
35
+ get_starlet_scales ,
36
+ multiband_starlet_reconstruction ,
37
+ starlet_transform ,
38
+ )
34
39
35
40
logger = logging .getLogger ("scarlet.detect" )
36
41
@@ -111,30 +116,34 @@ def union(self, other: Footprint) -> Image | None:
111
116
return footprint1 | footprint2
112
117
113
118
114
- def footprints_to_image (footprints : Sequence [Footprint ], shape : tuple [ int , int ] ) -> Image :
119
+ def footprints_to_image (footprints : Sequence [Footprint ], bbox : Box ) -> Image :
115
120
"""Convert a set of scarlet footprints to a pixelized image.
116
121
117
122
Parameters
118
123
----------
119
124
footprints:
120
125
The footprints to convert into an image.
121
- shape :
122
- The shape of the image that is created from the footprints.
126
+ box :
127
+ The full box of the image that will contain the footprints.
123
128
124
129
Returns
125
130
-------
126
131
result:
127
132
The image created from the footprints.
128
133
"""
129
- result = Image .from_box (Box ( shape ) , dtype = int )
134
+ result = Image .from_box (bbox , dtype = int )
130
135
for k , footprint in enumerate (footprints ):
131
- bbox = bounds_to_bbox (footprint .bounds )
132
- fp_image = Image (footprint .data , yx0 = cast (tuple [int , int ], bbox .origin ))
133
- result = result + fp_image * (k + 1 )
136
+ slices = overlapped_slices (result .bbox , footprint .bbox )
137
+ result .data [slices [0 ]] += footprint .data [slices [1 ]] * (k + 1 )
134
138
return result
135
139
136
140
137
- def get_wavelets (images : np .ndarray , variance : np .ndarray , scales : int | None = None ) -> np .ndarray :
141
+ def get_wavelets (
142
+ images : np .ndarray ,
143
+ variance : np .ndarray ,
144
+ scales : int | None = None ,
145
+ generation : int = 2 ,
146
+ ) -> np .ndarray :
138
147
"""Calculate wavelet coefficents given a set of images and their variances
139
148
140
149
Parameters
@@ -157,9 +166,10 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None =
157
166
"""
158
167
sigma = np .median (np .sqrt (variance ), axis = (1 , 2 ))
159
168
# Create the wavelet coefficients for the significant pixels
160
- coeffs = []
169
+ scales = get_starlet_scales (images [0 ].shape , scales )
170
+ coeffs = np .empty ((scales + 1 ,) + images .shape , dtype = images .dtype )
161
171
for b , image in enumerate (images ):
162
- _coeffs = starlet_transform (image , scales = scales )
172
+ _coeffs = starlet_transform (image , scales = scales , generation = generation )
163
173
support = get_multiresolution_support (
164
174
image = image ,
165
175
starlets = _coeffs ,
@@ -168,8 +178,8 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None =
168
178
epsilon = 1e-1 ,
169
179
max_iter = 20 ,
170
180
)
171
- coeffs . append (( support * _coeffs ).astype (images .dtype ) )
172
- return np . array ( coeffs )
181
+ coeffs [:, b ] = ( support . support * _coeffs ).astype (images .dtype )
182
+ return coeffs
173
183
174
184
175
185
def get_detect_wavelets (images : np .ndarray , variance : np .ndarray , scales : int = 3 ) -> np .ndarray :
@@ -206,4 +216,96 @@ def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int =
206
216
epsilon = 1e-1 ,
207
217
max_iter = 20 ,
208
218
)
209
- return (support * _coeffs ).astype (images .dtype )
219
+ return (support .support * _coeffs ).astype (images .dtype )
220
+
221
+
222
+ def detect_footprints (
223
+ images : np .ndarray ,
224
+ variance : np .ndarray ,
225
+ scales : int = 2 ,
226
+ generation : int = 2 ,
227
+ origin : tuple [int , int ] | None = None ,
228
+ min_separation : float = 4 ,
229
+ min_area : int = 4 ,
230
+ peak_thresh : float = 5 ,
231
+ footprint_thresh : float = 5 ,
232
+ find_peaks : bool = True ,
233
+ remove_high_freq : bool = True ,
234
+ min_pixel_detect : int = 1 ,
235
+ ) -> list [Footprint ]:
236
+ """Detect footprints in an image
237
+
238
+ Parameters
239
+ ----------
240
+ images:
241
+ The array of images with shape `(bands, Ny, Nx)` for which to
242
+ calculate wavelet coefficients.
243
+ variance:
244
+ An array of variances with the same shape as `images`.
245
+ scales:
246
+ The maximum number of wavelet scales to use.
247
+ If `remove_high_freq` is `False`, then this argument is ignored.
248
+ generation:
249
+ The generation of the starlet transform to use.
250
+ If `remove_high_freq` is `False`, then this argument is ignored.
251
+ origin:
252
+ The location (y, x) of the lower corner of the image.
253
+ min_separation:
254
+ The minimum separation between peaks in pixels.
255
+ min_area:
256
+ The minimum area of a footprint in pixels.
257
+ peak_thresh:
258
+ The threshold for peak detection.
259
+ footprint_thresh:
260
+ The threshold for footprint detection.
261
+ find_peaks:
262
+ If `True`, then detect peaks in the detection image,
263
+ otherwise only the footprints are returned.
264
+ remove_high_freq:
265
+ If `True`, then remove high frequency wavelet coefficients
266
+ before detecting peaks.
267
+ min_pixel_detect:
268
+ The minimum number of bands that must be above the
269
+ detection threshold for a pixel to be included in a footprint.
270
+ """
271
+
272
+ if origin is None :
273
+ origin = (0 , 0 )
274
+ if remove_high_freq :
275
+ # Build the wavelet coefficients
276
+ wavelets = get_wavelets (
277
+ images ,
278
+ variance ,
279
+ scales = scales ,
280
+ generation = generation ,
281
+ )
282
+ # Remove the high frequency wavelets.
283
+ # This has the effect of preventing high frequency noise
284
+ # from interfering with the detection of peak positions.
285
+ wavelets [0 ] = 0
286
+ # Reconstruct the image from the remaining wavelet coefficients
287
+ _images = multiband_starlet_reconstruction (
288
+ wavelets ,
289
+ generation = generation ,
290
+ )
291
+ else :
292
+ _images = images
293
+ # Build a SNR weighted detection image
294
+ sigma = np .median (np .sqrt (variance ), axis = (1 , 2 )) / 2
295
+ detection = np .sum (_images / sigma [:, None , None ], axis = 0 )
296
+ if min_pixel_detect > 1 :
297
+ mask = np .sum (images > 0 , axis = 0 ) >= min_pixel_detect
298
+ detection [~ mask ] = 0
299
+ # Detect peaks on the detection image
300
+ footprints = get_footprints (
301
+ detection ,
302
+ min_separation ,
303
+ min_area ,
304
+ peak_thresh ,
305
+ footprint_thresh ,
306
+ find_peaks ,
307
+ origin [0 ],
308
+ origin [1 ],
309
+ )
310
+
311
+ return footprints
0 commit comments