diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 381db6e8..0faf5924 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -6,6 +6,13 @@ Release Notes v0.3.0 (Unreleased) ------------------- +Enhancements +~~~~~~~~~~~~ + +- Ensure that there is no ``intent`` conflict between the variables + declared in a model. This check is explicit at Model creation and a + more meaningful error message is shown when it fails (:issue:`57`). + v0.2.1 (7 November 2018) ------------------------ diff --git a/xsimlab/model.py b/xsimlab/model.py index a43d198c..e829b0b6 100644 --- a/xsimlab/model.py +++ b/xsimlab/model.py @@ -145,6 +145,34 @@ def set_process_keys(self): if od_key is not None: p_obj.__xsimlab_od_keys__[var.name] = od_key + def ensure_no_intent_conflict(self): + """Raise an error if more than one variable with + intent='out' targets the same variable. + + """ + filter_out = lambda var: ( + var.metadata['intent'] == VarIntent.OUT and + var.metadata['var_type'] != VarType.ON_DEMAND + ) + + targets = defaultdict(list) + + for p_name, p_obj in self._processes_obj.items(): + for var in filter_variables(p_obj, func=filter_out).values(): + target_key = p_obj.__xsimlab_store_keys__.get(var.name) + targets[target_key].append((p_name, var.name)) + + conflicts = {k: v for k, v in targets.items() if len(v) > 1} + + if conflicts: + conflicts_str = {k: ' and '.join(["'{}.{}'".format(*i) for i in v]) + for k, v in conflicts.items()} + msg = '\n'.join(["'{}.{}' set by: {}".format(*k, v) + for k, v in conflicts_str.items()]) + + raise ValueError( + "Conflict(s) found in given variable intents:\n" + msg) + def get_all_variables(self): """Get all variables in the model as a list of ``(process_name, var_name)`` tuples. @@ -364,6 +392,8 @@ def __init__(self, processes): self._all_vars = builder.get_all_variables() self._all_vars_dict = None + builder.ensure_no_intent_conflict() + self._input_vars = builder.get_input_variables() self._input_vars_dict = None diff --git a/xsimlab/tests/test_model.py b/xsimlab/tests/test_model.py index 908e94f3..7685c9bc 100644 --- a/xsimlab/tests/test_model.py +++ b/xsimlab/tests/test_model.py @@ -1,7 +1,7 @@ import pytest import xsimlab as xs -from xsimlab.tests.fixture_model import AddOnDemand, InitProfile +from xsimlab.tests.fixture_model import AddOnDemand, InitProfile, Profile class TestModelBuilder(object): @@ -52,6 +52,15 @@ def test_get_all_variables(self, model): assert all([p_name in model for p_name, _ in model.all_vars]) assert ('profile', 'u') in model.all_vars + def test_ensure_no_intent_conflict(self, model): + @xs.process + class Foo(object): + u = xs.foreign(Profile, 'u', intent='out') + + with pytest.raises(ValueError) as excinfo: + invalid_model = model.update_processes({'foo': Foo}) + assert "Conflict(s)" in str(excinfo.value) + def test_get_input_variables(self, model): expected = {('init_profile', 'n_points'), ('roll', 'shift'),