diff --git a/datatree/mapping.py b/datatree/mapping.py index 6f8c65ae..bd41cdbd 100644 --- a/datatree/mapping.py +++ b/datatree/mapping.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import sys from itertools import repeat from textwrap import dedent from typing import TYPE_CHECKING, Callable, Tuple @@ -202,10 +203,15 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: ], ) ) + func_with_error_context = _handle_errors_with_path_context( + node_of_first_tree.path + )(func) # Now we can call func on the data in this particular set of corresponding nodes results = ( - func(*node_args_as_datasets, **node_kwargs_as_datasets) + func_with_error_context( + *node_args_as_datasets, **node_kwargs_as_datasets + ) if node_of_first_tree.has_data else None ) @@ -251,6 +257,34 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: return _map_over_subtree +def _handle_errors_with_path_context(path): + """Wraps given function so that if it fails it also raises path to node on which it failed.""" + + def decorator(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if sys.version_info >= (3, 11): + # Add the context information to the error message + e.add_note( + f"Raised whilst mapping function over node with path {path}" + ) + raise + + return wrapper + + return decorator + + +def add_note(err: BaseException, msg: str) -> None: + # TODO: remove once python 3.10 can be dropped + if sys.version_info < (3, 11): + err.__notes__ = getattr(err, "__notes__", []) + [msg] + else: + err.add_note(msg) + + def _check_single_set_return_values(path_to_node, obj): """Check types returned from single evaluation of func, and return number of return values received from func.""" if isinstance(obj, (Dataset, DataArray)): diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py index 894b0480..79cb1378 100644 --- a/datatree/tests/test_mapping.py +++ b/datatree/tests/test_mapping.py @@ -264,6 +264,23 @@ def check_for_data(ds): dt.map_over_subtree(check_for_data) + @pytest.mark.xfail( + reason="probably some bug in pytests handling of exception notes" + ) + def test_error_contains_path_of_offending_node(self, create_test_datatree): + dt = create_test_datatree() + dt["set1"]["bad_var"] = 0 + print(dt) + + def fail_on_specific_node(ds): + if "bad_var" in ds: + raise ValueError("Failed because 'bar_var' present in dataset") + + with pytest.raises( + ValueError, match="Raised whilst mapping function over node /set1" + ): + dt.map_over_subtree(fail_on_specific_node) + class TestMutableOperations: def test_construct_using_type(self): diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index 607af509..d82353b8 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -23,6 +23,10 @@ v0.0.13 (unreleased) New Features ~~~~~~~~~~~~ +- Indicate which node caused the problem if error encountered while applying user function using :py:func:`map_over_subtree` + (:issue:`190`, :pull:`264`). Only works when using python 3.11 or later. + By `Tom Nicholas `_. + Breaking changes ~~~~~~~~~~~~~~~~