Skip to content

Commit 99f9559

Browse files
mgunyhoStanczakDominikdcherian
authored
Add errors option to curvefit (#7891)
* Add allow_failures flag to Dataset.curve_fit * Reword docstring * Add allow_failures flag also to DataArray * Add unit test for curvefit with allow_failures * Update whats-new Co-authored-by: Dominik Stańczak <[email protected]> * Add PR to whats-new * Update docstring * Rename allow_failures to errors to be consistent with other methods * Compute array so test passes also with dask * Check error message * Update whats-new --------- Co-authored-by: Dominik Stańczak <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 71defdd commit 99f9559

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ New Features
2525

2626
- Added support for multidimensional initial guess and bounds in :py:meth:`DataArray.curvefit` (:issue:`7768`, :pull:`7821`).
2727
By `András Gunyhó <https://github.com/mgunyho>`_.
28+
- Add an ``errors`` option to :py:meth:`Dataset.curve_fit` that allows
29+
returning NaN for the parameters and covariances of failed fits, rather than
30+
failing the whole series of fits (:issue:`6317`, :pull:`7891`).
31+
By `Dominik Stańczak <https://github.com/StanczakDominik>`_ and `András Gunyhó <https://github.com/mgunyho>`_.
2832

2933
Breaking changes
3034
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6162,6 +6162,7 @@ def curvefit(
61626162
p0: dict[str, float | DataArray] | None = None,
61636163
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
61646164
param_names: Sequence[str] | None = None,
6165+
errors: ErrorOptions = "raise",
61656166
kwargs: dict[str, Any] | None = None,
61666167
) -> Dataset:
61676168
"""
@@ -6206,6 +6207,10 @@ def curvefit(
62066207
this will be automatically determined by arguments of `func`. `param_names`
62076208
should be manually supplied when fitting a function that takes a variable
62086209
number of parameters.
6210+
errors : {"raise", "ignore"}, default: "raise"
6211+
If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will
6212+
raise an exception. If 'ignore', the coefficients and covariances for the
6213+
coordinates where the fitting failed will be NaN.
62096214
**kwargs : optional
62106215
Additional keyword arguments to passed to scipy curve_fit.
62116216
@@ -6312,6 +6317,7 @@ def curvefit(
63126317
p0=p0,
63136318
bounds=bounds,
63146319
param_names=param_names,
6320+
errors=errors,
63156321
kwargs=kwargs,
63166322
)
63176323

xarray/core/dataset.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8631,6 +8631,7 @@ def curvefit(
86318631
p0: dict[str, float | DataArray] | None = None,
86328632
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
86338633
param_names: Sequence[str] | None = None,
8634+
errors: ErrorOptions = "raise",
86348635
kwargs: dict[str, Any] | None = None,
86358636
) -> T_Dataset:
86368637
"""
@@ -8675,6 +8676,10 @@ def curvefit(
86758676
this will be automatically determined by arguments of `func`. `param_names`
86768677
should be manually supplied when fitting a function that takes a variable
86778678
number of parameters.
8679+
errors : {"raise", "ignore"}, default: "raise"
8680+
If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will
8681+
raise an exception. If 'ignore', the coefficients and covariances for the
8682+
coordinates where the fitting failed will be NaN.
86788683
**kwargs : optional
86798684
Additional keyword arguments to passed to scipy curve_fit.
86808685
@@ -8757,6 +8762,9 @@ def curvefit(
87578762
f"dimensions {preserved_dims}."
87588763
)
87598764

8765+
if errors not in ["raise", "ignore"]:
8766+
raise ValueError('errors must be either "raise" or "ignore"')
8767+
87608768
# Broadcast all coords with each other
87618769
coords_ = broadcast(*coords_)
87628770
coords_ = [
@@ -8793,7 +8801,15 @@ def _wrapper(Y, *args, **kwargs):
87938801
pcov = np.full([n_params, n_params], np.nan)
87948802
return popt, pcov
87958803
x = np.squeeze(x)
8796-
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
8804+
8805+
try:
8806+
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
8807+
except RuntimeError:
8808+
if errors == "raise":
8809+
raise
8810+
popt = np.full([n_params], np.nan)
8811+
pcov = np.full([n_params, n_params], np.nan)
8812+
87978813
return popt, pcov
87988814

87998815
result = type(self)()

xarray/tests/test_dataarray.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4584,6 +4584,48 @@ def sine(t, a, f, p):
45844584
bounds={"a": (0, DataArray([1], coords={"foo": [1]}))},
45854585
)
45864586

4587+
@requires_scipy
4588+
@pytest.mark.parametrize("use_dask", [True, False])
4589+
def test_curvefit_ignore_errors(self, use_dask: bool) -> None:
4590+
if use_dask and not has_dask:
4591+
pytest.skip("requires dask")
4592+
4593+
# nonsense function to make the optimization fail
4594+
def line(x, a, b):
4595+
if a > 10:
4596+
return 0
4597+
return a * x + b
4598+
4599+
da = DataArray(
4600+
[[1, 3, 5], [0, 20, 40]],
4601+
coords={"i": [1, 2], "x": [0.0, 1.0, 2.0]},
4602+
)
4603+
4604+
if use_dask:
4605+
da = da.chunk({"i": 1})
4606+
4607+
expected = DataArray(
4608+
[[2, 1], [np.nan, np.nan]], coords={"i": [1, 2], "param": ["a", "b"]}
4609+
)
4610+
4611+
with pytest.raises(RuntimeError, match="calls to function has reached maxfev"):
4612+
da.curvefit(
4613+
coords="x",
4614+
func=line,
4615+
# limit maximum number of calls so the optimization fails
4616+
kwargs=dict(maxfev=5),
4617+
).compute() # have to compute to raise the error
4618+
4619+
fit = da.curvefit(
4620+
coords="x",
4621+
func=line,
4622+
errors="ignore",
4623+
# limit maximum number of calls so the optimization fails
4624+
kwargs=dict(maxfev=5),
4625+
).compute()
4626+
4627+
assert_allclose(fit.curvefit_coefficients, expected)
4628+
45874629

45884630
class TestReduce:
45894631
@pytest.fixture(autouse=True)

0 commit comments

Comments
 (0)