Skip to content

Commit 168dd56

Browse files
authored
Merge pull request #7 from lsst/tickets/DM-41840
tickets/DM-41840: Implement improved detection algorithms
2 parents ae00866 + a843edc commit 168dd56

18 files changed

+600
-192
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
22
requires = [
3-
"setuptools<65",
3+
"setuptools",
44
"lsst-versions >= 1.3.0",
55
"wheel",
66
"pybind11 >= 2.5.0",

python/lsst/scarlet/lite/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
except ImportError:
1111
pass
1212

13-
from . import initialization, io, measure, models, operators, utils
13+
from . import initialization, io, measure, models, operators, utils, wavelet
1414
from .fft import *
1515
from .image import *
1616
from .observation import *

python/lsst/scarlet/lite/bbox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def from_data(x: np.ndarray, threshold: float = 0) -> Box:
228228
nonzero = np.where(sel)
229229
bounds = []
230230
for dim in range(len(x.shape)):
231-
bounds.append((nonzero[dim].min(), nonzero[dim].max() + 1))
231+
bounds.append((int(nonzero[dim].min()), int(nonzero[dim].max() + 1)))
232232
else:
233233
bounds = [(0, 0)] * len(x.shape)
234234
return Box.from_bounds(*bounds)

python/lsst/scarlet/lite/blend.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,25 @@ def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
113113
return self.observation.convolve(model)
114114
return model
115115

116-
def _grad_log_likelihood(self) -> Image:
117-
"""Gradient of the likelihood wrt the unconvolved model"""
116+
def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
117+
"""Gradient of the likelihood wrt the unconvolved model
118+
119+
Returns
120+
-------
121+
result:
122+
The gradient of the likelihood wrt the model
123+
model_data:
124+
The convol model data used to calculate the gradient.
125+
This can be useful for debugging but is not used in
126+
production.
127+
"""
118128
model = self.get_model(convolve=True)
119129
# Update the loss
120130
self.loss.append(self.observation.log_likelihood(model))
121131
# Calculate the gradient wrt the model d(logL)/d(model)
122132
result = self.observation.weights * (model - self.observation.images)
123133
result = self.observation.convolve(result, grad=True)
124-
return result
134+
return result, model.data
125135

126136
@property
127137
def log_likelihood(self) -> float:
@@ -244,7 +254,7 @@ def fit(
244254
# Update each component given the current gradient
245255
for component in self.components:
246256
overlap = component.bbox & self.bbox
247-
component.update(self.it, grad_log_likelihood[overlap].data)
257+
component.update(self.it, grad_log_likelihood[0][overlap].data)
248258
# Check to see if any components need to be resized
249259
if do_resize:
250260
component.resize(self.bbox)

python/lsst/scarlet/lite/detect.py

Lines changed: 119 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,20 @@
2222
from __future__ import annotations
2323

2424
import logging
25-
from typing import Sequence, cast
25+
from typing import Sequence
2626

2727
import numpy as np
28-
from lsst.scarlet.lite.detect_pybind11 import Footprint # type: ignore
28+
from lsst.scarlet.lite.detect_pybind11 import Footprint, get_footprints # type: ignore
2929

30-
from .bbox import Box
30+
from .bbox import Box, overlapped_slices
3131
from .image import Image
3232
from .utils import continue_class
33-
from .wavelet import get_multiresolution_support, starlet_transform
33+
from .wavelet import (
34+
get_multiresolution_support,
35+
get_starlet_scales,
36+
multiband_starlet_reconstruction,
37+
starlet_transform,
38+
)
3439

3540
logger = logging.getLogger("scarlet.detect")
3641

@@ -111,30 +116,34 @@ def union(self, other: Footprint) -> Image | None:
111116
return footprint1 | footprint2
112117

113118

114-
def footprints_to_image(footprints: Sequence[Footprint], shape: tuple[int, int]) -> Image:
119+
def footprints_to_image(footprints: Sequence[Footprint], bbox: Box) -> Image:
115120
"""Convert a set of scarlet footprints to a pixelized image.
116121
117122
Parameters
118123
----------
119124
footprints:
120125
The footprints to convert into an image.
121-
shape:
122-
The shape of the image that is created from the footprints.
126+
box:
127+
The full box of the image that will contain the footprints.
123128
124129
Returns
125130
-------
126131
result:
127132
The image created from the footprints.
128133
"""
129-
result = Image.from_box(Box(shape), dtype=int)
134+
result = Image.from_box(bbox, dtype=int)
130135
for k, footprint in enumerate(footprints):
131-
bbox = bounds_to_bbox(footprint.bounds)
132-
fp_image = Image(footprint.data, yx0=cast(tuple[int, int], bbox.origin))
133-
result = result + fp_image * (k + 1)
136+
slices = overlapped_slices(result.bbox, footprint.bbox)
137+
result.data[slices[0]] += footprint.data[slices[1]] * (k + 1)
134138
return result
135139

136140

137-
def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None = None) -> np.ndarray:
141+
def get_wavelets(
142+
images: np.ndarray,
143+
variance: np.ndarray,
144+
scales: int | None = None,
145+
generation: int = 2,
146+
) -> np.ndarray:
138147
"""Calculate wavelet coefficents given a set of images and their variances
139148
140149
Parameters
@@ -157,9 +166,10 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None =
157166
"""
158167
sigma = np.median(np.sqrt(variance), axis=(1, 2))
159168
# Create the wavelet coefficients for the significant pixels
160-
coeffs = []
169+
scales = get_starlet_scales(images[0].shape, scales)
170+
coeffs = np.empty((scales + 1,) + images.shape, dtype=images.dtype)
161171
for b, image in enumerate(images):
162-
_coeffs = starlet_transform(image, scales=scales)
172+
_coeffs = starlet_transform(image, scales=scales, generation=generation)
163173
support = get_multiresolution_support(
164174
image=image,
165175
starlets=_coeffs,
@@ -168,8 +178,8 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None =
168178
epsilon=1e-1,
169179
max_iter=20,
170180
)
171-
coeffs.append((support * _coeffs).astype(images.dtype))
172-
return np.array(coeffs)
181+
coeffs[:, b] = (support.support * _coeffs).astype(images.dtype)
182+
return coeffs
173183

174184

175185
def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int = 3) -> np.ndarray:
@@ -206,4 +216,96 @@ def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int =
206216
epsilon=1e-1,
207217
max_iter=20,
208218
)
209-
return (support * _coeffs).astype(images.dtype)
219+
return (support.support * _coeffs).astype(images.dtype)
220+
221+
222+
def detect_footprints(
223+
images: np.ndarray,
224+
variance: np.ndarray,
225+
scales: int = 2,
226+
generation: int = 2,
227+
origin: tuple[int, int] | None = None,
228+
min_separation: float = 4,
229+
min_area: int = 4,
230+
peak_thresh: float = 5,
231+
footprint_thresh: float = 5,
232+
find_peaks: bool = True,
233+
remove_high_freq: bool = True,
234+
min_pixel_detect: int = 1,
235+
) -> list[Footprint]:
236+
"""Detect footprints in an image
237+
238+
Parameters
239+
----------
240+
images:
241+
The array of images with shape `(bands, Ny, Nx)` for which to
242+
calculate wavelet coefficients.
243+
variance:
244+
An array of variances with the same shape as `images`.
245+
scales:
246+
The maximum number of wavelet scales to use.
247+
If `remove_high_freq` is `False`, then this argument is ignored.
248+
generation:
249+
The generation of the starlet transform to use.
250+
If `remove_high_freq` is `False`, then this argument is ignored.
251+
origin:
252+
The location (y, x) of the lower corner of the image.
253+
min_separation:
254+
The minimum separation between peaks in pixels.
255+
min_area:
256+
The minimum area of a footprint in pixels.
257+
peak_thresh:
258+
The threshold for peak detection.
259+
footprint_thresh:
260+
The threshold for footprint detection.
261+
find_peaks:
262+
If `True`, then detect peaks in the detection image,
263+
otherwise only the footprints are returned.
264+
remove_high_freq:
265+
If `True`, then remove high frequency wavelet coefficients
266+
before detecting peaks.
267+
min_pixel_detect:
268+
The minimum number of bands that must be above the
269+
detection threshold for a pixel to be included in a footprint.
270+
"""
271+
272+
if origin is None:
273+
origin = (0, 0)
274+
if remove_high_freq:
275+
# Build the wavelet coefficients
276+
wavelets = get_wavelets(
277+
images,
278+
variance,
279+
scales=scales,
280+
generation=generation,
281+
)
282+
# Remove the high frequency wavelets.
283+
# This has the effect of preventing high frequency noise
284+
# from interfering with the detection of peak positions.
285+
wavelets[0] = 0
286+
# Reconstruct the image from the remaining wavelet coefficients
287+
_images = multiband_starlet_reconstruction(
288+
wavelets,
289+
generation=generation,
290+
)
291+
else:
292+
_images = images
293+
# Build a SNR weighted detection image
294+
sigma = np.median(np.sqrt(variance), axis=(1, 2)) / 2
295+
detection = np.sum(_images / sigma[:, None, None], axis=0)
296+
if min_pixel_detect > 1:
297+
mask = np.sum(images > 0, axis=0) >= min_pixel_detect
298+
detection[~mask] = 0
299+
# Detect peaks on the detection image
300+
footprints = get_footprints(
301+
detection,
302+
min_separation,
303+
min_area,
304+
peak_thresh,
305+
footprint_thresh,
306+
find_peaks,
307+
origin[0],
308+
origin[1],
309+
)
310+
311+
return footprints

0 commit comments

Comments
 (0)