Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

--
- Rainbow multi-band scaler didn't work with list inputs https://github.com/light-curve/light-curve-python/issues/492 https://github.com/light-curve/light-curve-python/pull/493

### Security

Expand Down
13 changes: 13 additions & 0 deletions light-curve/light_curve/light_curve_py/features/rainbow/_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ class Scaler:
Either a single value or an array of the same shape as the input array
"""

def __eq__(self, other):
if not isinstance(other, Scaler):
return False
return np.array_equal(self.shift, other.shift) and np.array_equal(self.scale, other.scale)

@classmethod
def from_time(cls, t) -> "Scaler":
"""Create a Scaler from a time array
Expand Down Expand Up @@ -55,13 +60,21 @@ class MultiBandScaler(Scaler):
per_band_shift: Dict[str, float]
"""Shift to apply to each band"""

def __eq__(self, other):
if not isinstance(other, MultiBandScaler):
return False
return super().__eq__(other) and self.per_band_shift == other.per_band_shift

@classmethod
def from_flux(cls, flux, band, *, with_baseline: bool) -> "MultiBandScaler":
"""Create a Scaler from a flux array.

It uses standard deviation for the scale. For the shift, it is either
zero (`with_baseline=False`) or the mean of each band otherwise.
"""
flux = np.asarray(flux)
band = np.asarray(band)

uniq_bands = np.unique(band)
per_band_shift = dict.fromkeys(uniq_bands, 0.0)
shift_array = np.zeros(len(flux))
Expand Down
12 changes: 12 additions & 0 deletions light-curve/tests/light_curve_py/features/test_rainbow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from light_curve.light_curve_py import RainbowFit
from light_curve.light_curve_py.features.rainbow._scaler import MultiBandScaler


def test_noisy_with_baseline():
Expand Down Expand Up @@ -113,3 +114,14 @@ def test_noisy_all_functions_combination():
# plt.show()

np.testing.assert_allclose(actual[:-1], expected[:-1], rtol=0.1)


def test_scaler_from_flux_list_input():
"https://github.com/light-curve/light-curve-python/issues/492"
# Was failing
scaler1 = MultiBandScaler.from_flux(
flux=[1.0, 2.0, 3.0, 4.0], band=np.array(["g", "r", "g", "r"]), with_baseline=True
)
# Was not failing, but was wrong
scaler2 = MultiBandScaler.from_flux(flux=[1.0, 2.0, 3.0, 4.0], band=["g", "r", "g", "r"], with_baseline=True)
assert scaler1 == scaler2
Loading