Skip to content

Add errors option to curvefit #7891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 16, 2023
Merged
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ New Features

- Added support for multidimensional initial guess and bounds in :py:meth:`DataArray.curvefit` (:issue:`7768`, :pull:`7821`).
By `András Gunyhó <https://github.com/mgunyho>`_.
- Add a ``allow_failures`` flag to :py:meth:`Dataset.curve_fit` that allows
returning NaN for the parameters and covariances of failed fits, rather than
failing the whole series of fits (:issue:`6317`, :pull:`7891`).
By `Dominik Stańczak <https://github.com/StanczakDominik>`_ and `András Gunyhó <https://github.com/mgunyho>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
7 changes: 7 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6162,6 +6162,7 @@ def curvefit(
p0: dict[str, float | DataArray] | None = None,
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
allow_failures: bool = False,
kwargs: dict[str, Any] | None = None,
) -> Dataset:
"""
Expand Down Expand Up @@ -6206,6 +6207,11 @@ def curvefit(
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
allow_failures: bool, default: False
If True and the underlying `scipy.optimize_curve_fit` optimization fails for
any of the fits, return NaN in coefficients and covariances for those
coordinates. Helpful when fitting multiple curves and some of the data just
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the Helpful... Part. It just adds clutter for little benefit.

If you want to add such tips do it above the argument list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in d3decbb. I was also wondering if the name could be something different, like allow_errors for example. Are there any other examples of similar args in xarray, so we could be consistent with those?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

In other places we use errors = "raise" | "ignore" | "warn" (the warn part is optional)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll change it to that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in efce0b3. I was considering adding the option to also warn, but looks like scipy has its own internal mechanism for producing warnings, see here: https://github.com/scipy/scipy/blob/v1.10.1/scipy/optimize/_minpack_py.py#L491. I'm not sure how to control it and didn't want to get into that now. Maybe a warning is now always issued? Although I didn't see any warnings in pytest on my setup at least.

doesn't fit your model.
**kwargs : optional
Additional keyword arguments to passed to scipy curve_fit.

Expand Down Expand Up @@ -6310,6 +6316,7 @@ def curvefit(
p0=p0,
bounds=bounds,
param_names=param_names,
allow_failures=allow_failures,
kwargs=kwargs,
)

Expand Down
16 changes: 15 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8631,6 +8631,7 @@ def curvefit(
p0: dict[str, float | DataArray] | None = None,
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
allow_failures: bool = False,
kwargs: dict[str, Any] | None = None,
) -> T_Dataset:
"""
Expand Down Expand Up @@ -8675,6 +8676,11 @@ def curvefit(
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
allow_failures: bool, default: False
If True and the underlying `scipy.optimize_curve_fit` optimization fails for
any of the fits, return NaN in coefficients and covariances for those
coordinates. Helpful when fitting multiple curves and some of the data just
doesn't fit your model.
**kwargs : optional
Additional keyword arguments to passed to scipy curve_fit.

Expand Down Expand Up @@ -8793,7 +8799,15 @@ def _wrapper(Y, *args, **kwargs):
pcov = np.full([n_params, n_params], np.nan)
return popt, pcov
x = np.squeeze(x)
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)

try:
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
except RuntimeError:
if not allow_failures:
raise
popt = np.full([n_params], np.nan)
pcov = np.full([n_params, n_params], np.nan)

return popt, pcov

result = type(self)()
Expand Down
42 changes: 42 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4571,6 +4571,48 @@ def sine(t, a, f, p):
bounds={"a": (0, DataArray([1], coords={"foo": [1]}))},
)

@requires_scipy
@pytest.mark.parametrize("use_dask", [True, False])
def test_curvefit_allow_failures(self, use_dask: bool) -> None:
if use_dask and not has_dask:
pytest.skip("requires dask")

# nonsense function to make the optimization fail
def line(x, a, b):
if a > 10:
return 0
return a * x + b

da = DataArray(
[[1, 3, 5], [0, 20, 40]],
coords={"i": [1, 2], "x": [0.0, 1.0, 2.0]},
)

if use_dask:
da = da.chunk({"i": 1})

expected = DataArray(
[[2, 1], [np.nan, np.nan]], coords={"i": [1, 2], "param": ["a", "b"]}
)

with pytest.raises(RuntimeError):
da.curvefit(
coords="x",
func=line,
# limit maximum number of calls so the optimization fails
kwargs=dict(maxfev=5),
)

fit = da.curvefit(
coords="x",
func=line,
allow_failures=True,
# limit maximum number of calls so the optimization fails
kwargs=dict(maxfev=5),
)

assert_allclose(fit.curvefit_coefficients, expected)


class TestReduce:
@pytest.fixture(autouse=True)
Expand Down