Skip to content

Add errors option to curvefit #7891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 16, 2023
Merged
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ New Features

- Added support for multidimensional initial guess and bounds in :py:meth:`DataArray.curvefit` (:issue:`7768`, :pull:`7821`).
By `András Gunyhó <https://github.com/mgunyho>`_.
- Add an ``errors`` option to :py:meth:`Dataset.curve_fit` that allows
returning NaN for the parameters and covariances of failed fits, rather than
failing the whole series of fits (:issue:`6317`, :pull:`7891`).
By `Dominik Stańczak <https://github.com/StanczakDominik>`_ and `András Gunyhó <https://github.com/mgunyho>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
6 changes: 6 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6162,6 +6162,7 @@ def curvefit(
p0: dict[str, float | DataArray] | None = None,
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
) -> Dataset:
"""
Expand Down Expand Up @@ -6206,6 +6207,10 @@ def curvefit(
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will
raise an exception. If 'ignore', the coefficients and covariances for the
coordinates where the fitting failed will be NaN.
**kwargs : optional
Additional keyword arguments to passed to scipy curve_fit.

Expand Down Expand Up @@ -6312,6 +6317,7 @@ def curvefit(
p0=p0,
bounds=bounds,
param_names=param_names,
errors=errors,
kwargs=kwargs,
)

Expand Down
18 changes: 17 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8631,6 +8631,7 @@ def curvefit(
p0: dict[str, float | DataArray] | None = None,
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
) -> T_Dataset:
"""
Expand Down Expand Up @@ -8675,6 +8676,10 @@ def curvefit(
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will
raise an exception. If 'ignore', the coefficients and covariances for the
coordinates where the fitting failed will be NaN.
**kwargs : optional
Additional keyword arguments to passed to scipy curve_fit.

Expand Down Expand Up @@ -8757,6 +8762,9 @@ def curvefit(
f"dimensions {preserved_dims}."
)

if errors not in ["raise", "ignore"]:
raise ValueError('errors must be either "raise" or "ignore"')

# Broadcast all coords with each other
coords_ = broadcast(*coords_)
coords_ = [
Expand Down Expand Up @@ -8793,7 +8801,15 @@ def _wrapper(Y, *args, **kwargs):
pcov = np.full([n_params, n_params], np.nan)
return popt, pcov
x = np.squeeze(x)
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)

try:
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
except RuntimeError:
if errors == "raise":
raise
popt = np.full([n_params], np.nan)
pcov = np.full([n_params, n_params], np.nan)

return popt, pcov

result = type(self)()
Expand Down
42 changes: 42 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4571,6 +4571,48 @@ def sine(t, a, f, p):
bounds={"a": (0, DataArray([1], coords={"foo": [1]}))},
)

@requires_scipy
@pytest.mark.parametrize("use_dask", [True, False])
def test_curvefit_ignore_errors(self, use_dask: bool) -> None:
if use_dask and not has_dask:
pytest.skip("requires dask")

# nonsense function to make the optimization fail
def line(x, a, b):
if a > 10:
return 0
return a * x + b

da = DataArray(
[[1, 3, 5], [0, 20, 40]],
coords={"i": [1, 2], "x": [0.0, 1.0, 2.0]},
)

if use_dask:
da = da.chunk({"i": 1})

expected = DataArray(
[[2, 1], [np.nan, np.nan]], coords={"i": [1, 2], "param": ["a", "b"]}
)

with pytest.raises(RuntimeError, match="calls to function has reached maxfev"):
da.curvefit(
coords="x",
func=line,
# limit maximum number of calls so the optimization fails
kwargs=dict(maxfev=5),
).compute() # have to compute to raise the error

fit = da.curvefit(
coords="x",
func=line,
errors="ignore",
# limit maximum number of calls so the optimization fails
kwargs=dict(maxfev=5),
).compute()

assert_allclose(fit.curvefit_coefficients, expected)


class TestReduce:
@pytest.fixture(autouse=True)
Expand Down