Skip to content

Commit e501851

Browse files
committed
Issue #391 Initial implementation of MultiResult
1 parent 7c321be commit e501851

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

openeo/internal/graph_building.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import sys
1515
from contextlib import nullcontext
1616
from pathlib import Path
17-
from typing import Any, Dict, Iterator, Optional, Tuple, Union
17+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
1818

1919
from openeo.api.process import Parameter
2020
from openeo.internal.process_graph_visitor import (
@@ -438,3 +438,24 @@ def _process_from_parameter(self, name: str) -> Any:
438438
if name not in self._parameters:
439439
raise ProcessGraphVisitException("No substitution value for parameter {p!r}.".format(p=name))
440440
return self._parameters[name]
441+
442+
443+
class MultiResult(FlatGraphableMixin):
444+
"""
445+
Handler of use cases where there are multiple result nodes
446+
(or other leaf nodes) in a process graph.
447+
"""
448+
449+
def __init__(self, leaves: List[FlatGraphableMixin]):
450+
self._leaves = leaves
451+
452+
def flat_graph(self) -> Dict[str, dict]:
453+
result = {}
454+
for leaf in self._leaves:
455+
leaf_graph = leaf.flat_graph()
456+
existing = set(leaf_graph.keys()).intersection(result.keys())
457+
if existing:
458+
# TODO: automatic renaming of duplicate node ids?
459+
raise ValueError(f"Duplicate node ids while building multi-result process graph: {sorted(existing)}")
460+
result.update(leaf_graph)
461+
return result

tests/internal/test_graphbuilding.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import re
23
import textwrap
34

45
import pytest
@@ -7,6 +8,7 @@
78
from openeo.api.process import Parameter
89
from openeo.internal.graph_building import (
910
FlatGraphNodeIdGenerator,
11+
MultiResult,
1012
PGNode,
1113
PGNodeGraphUnflattener,
1214
ReduceNode,
@@ -412,3 +414,19 @@ def test_walk_nodes_nested():
412414
walk = list(node.walk_nodes())
413415
assert all(isinstance(n, PGNode) for n in walk)
414416
assert set(n.process_id for n in walk) == {"load1", "max", "foo", "load2", "add", "five"}
417+
418+
419+
class TestMultiResult:
420+
def test_simple(self):
421+
multi = MultiResult([PGNode("foo"), PGNode("bar")])
422+
assert multi.flat_graph() == {
423+
"foo1": {"process_id": "foo", "arguments": {}, "result": True},
424+
"bar1": {"process_id": "bar", "arguments": {}, "result": True},
425+
}
426+
427+
def test_simple_duplicates(self):
428+
multi = MultiResult([PGNode("foo"), PGNode("foo")])
429+
with pytest.raises(
430+
ValueError, match=re.escape("Duplicate node ids while building multi-result process graph: ['foo1']")
431+
):
432+
multi.flat_graph()

0 commit comments

Comments
 (0)