diff --git a/CHANGELOG.md b/CHANGELOG.md index 04f2c4aa..b7f86600 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed --- +- Rainbow multi-band scaler didn't work with list inputs https://github.com/light-curve/light-curve-python/issues/492 https://github.com/light-curve/light-curve-python/pull/493 ### Security diff --git a/light-curve/light_curve/light_curve_py/features/rainbow/_scaler.py b/light-curve/light_curve/light_curve_py/features/rainbow/_scaler.py index 14b34a1a..8db11ce7 100644 --- a/light-curve/light_curve/light_curve_py/features/rainbow/_scaler.py +++ b/light-curve/light_curve/light_curve_py/features/rainbow/_scaler.py @@ -23,6 +23,11 @@ class Scaler: Either a single value or an array of the same shape as the input array """ + def __eq__(self, other): + if not isinstance(other, Scaler): + return False + return np.array_equal(self.shift, other.shift) and np.array_equal(self.scale, other.scale) + @classmethod def from_time(cls, t) -> "Scaler": """Create a Scaler from a time array @@ -55,6 +60,11 @@ class MultiBandScaler(Scaler): per_band_shift: Dict[str, float] """Shift to apply to each band""" + def __eq__(self, other): + if not isinstance(other, MultiBandScaler): + return False + return super().__eq__(other) and self.per_band_shift == other.per_band_shift + @classmethod def from_flux(cls, flux, band, *, with_baseline: bool) -> "MultiBandScaler": """Create a Scaler from a flux array. @@ -62,6 +72,9 @@ def from_flux(cls, flux, band, *, with_baseline: bool) -> "MultiBandScaler": It uses standard deviation for the scale. For the shift, it is either zero (`with_baseline=False`) or the mean of each band otherwise. """ + flux = np.asarray(flux) + band = np.asarray(band) + uniq_bands = np.unique(band) per_band_shift = dict.fromkeys(uniq_bands, 0.0) shift_array = np.zeros(len(flux)) diff --git a/light-curve/tests/light_curve_py/features/test_rainbow.py b/light-curve/tests/light_curve_py/features/test_rainbow.py index deadddc8..0017280b 100644 --- a/light-curve/tests/light_curve_py/features/test_rainbow.py +++ b/light-curve/tests/light_curve_py/features/test_rainbow.py @@ -1,6 +1,7 @@ import numpy as np from light_curve.light_curve_py import RainbowFit +from light_curve.light_curve_py.features.rainbow._scaler import MultiBandScaler def test_noisy_with_baseline(): @@ -113,3 +114,14 @@ def test_noisy_all_functions_combination(): # plt.show() np.testing.assert_allclose(actual[:-1], expected[:-1], rtol=0.1) + + +def test_scaler_from_flux_list_input(): + "https://github.com/light-curve/light-curve-python/issues/492" + # Was failing + scaler1 = MultiBandScaler.from_flux( + flux=[1.0, 2.0, 3.0, 4.0], band=np.array(["g", "r", "g", "r"]), with_baseline=True + ) + # Was not failing, but was wrong + scaler2 = MultiBandScaler.from_flux(flux=[1.0, 2.0, 3.0, 4.0], band=["g", "r", "g", "r"], with_baseline=True) + assert scaler1 == scaler2