Skip to content

Commit 334b819

Browse files
committed
Issue #391/#651 more review tweaks
1 parent 2e58607 commit 334b819

File tree

3 files changed

+76
-10
lines changed

3 files changed

+76
-10
lines changed

openeo/internal/graph_building.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,8 @@ class MultiLeafGraph(FlatGraphableMixin):
458458
Container for process graphs with multiple leaf/result nodes.
459459
"""
460460

461+
__slots__ = ["_leaves"]
462+
461463
def __init__(self, leaves: Iterable[FlatGraphableMixin]):
462464
self._leaves = list(leaves)
463465

openeo/rest/multiresult.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class MultiResult(FlatGraphableMixin):
3434
.. versionadded:: 0.35.0
3535
"""
3636

37+
__slots__ = ("_multi_leaf_graph", "_connection")
38+
3739
def __init__(self, leaves: List[FlatGraphableMixin], connection: Optional[Connection] = None):
3840
"""
3941
Build a :py:class:`MultiResult` instance from multiple leaf nodes
@@ -47,11 +49,14 @@ def __init__(self, leaves: List[FlatGraphableMixin], connection: Optional[Connec
4749
are not already associated with a connection.
4850
"""
4951
self._multi_leaf_graph = MultiLeafGraph(leaves=leaves)
50-
self._connection = self._common_connection(leaves=leaves, connection=connection)
52+
self._connection = self._extract_connection(leaves=leaves, connection=connection)
5153

5254
@staticmethod
53-
def _common_connection(leaves: List[FlatGraphableMixin], connection: Optional[Connection] = None) -> Connection:
54-
"""Find common connection. Fails if there are multiple or none."""
55+
def _extract_connection(leaves: List[FlatGraphableMixin], connection: Optional[Connection] = None) -> Connection:
56+
"""
57+
Extract common connection from leaves and/or explicitly provided connection.
58+
Fails if there are multiple or none.
59+
"""
5560
connections = set()
5661
if connection:
5762
connections.add(connection)

tests/rest/test_multiresult.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
import pytest
22

3+
from openeo import BatchJob
34
from openeo.rest._testing import DummyBackend
45
from openeo.rest.multiresult import MultiResult
56

67

78
class TestMultiResultHandling:
8-
9-
def test_create_job_method(self, dummy_backend):
10-
con = dummy_backend.connection
11-
cube = con.load_collection("S2")
9+
def test_flat_graph(self, dummy_backend):
10+
cube = dummy_backend.connection.load_collection("S2")
1211
save1 = cube.save_result(format="GTiff")
1312
save2 = cube.save_result(format="netCDF")
1413
multi_result = MultiResult([save1, save2])
15-
multi_result.create_job()
16-
assert dummy_backend.get_batch_pg() == {
14+
assert multi_result.flat_graph() == {
1715
"loadcollection1": {
1816
"process_id": "load_collection",
1917
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
@@ -29,7 +27,37 @@ def test_create_job_method(self, dummy_backend):
2927
},
3028
}
3129

32-
def test_create_job_on_connection(self, con120, dummy_backend):
30+
def test_create_job_method(self, dummy_backend):
31+
con = dummy_backend.connection
32+
cube = con.load_collection("S2")
33+
save1 = cube.save_result(format="GTiff")
34+
save2 = cube.save_result(format="netCDF")
35+
multi_result = MultiResult([save1, save2])
36+
multi_result.create_job(title="multi result test")
37+
assert dummy_backend.batch_jobs == {
38+
"job-000": {
39+
"job_id": "job-000",
40+
"pg": {
41+
"loadcollection1": {
42+
"process_id": "load_collection",
43+
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
44+
},
45+
"saveresult1": {
46+
"process_id": "save_result",
47+
"arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}},
48+
},
49+
"saveresult2": {
50+
"process_id": "save_result",
51+
"arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}},
52+
"result": True,
53+
},
54+
},
55+
"status": "created",
56+
"title": "multi result test",
57+
}
58+
}
59+
60+
def test_create_job_through_connection(self, con120, dummy_backend):
3361
con = dummy_backend.connection
3462
cube = con120.load_collection("S2")
3563
save1 = cube.save_result(format="GTiff")
@@ -51,3 +79,34 @@ def test_create_job_on_connection(self, con120, dummy_backend):
5179
"result": True,
5280
},
5381
}
82+
83+
def test_execute_batch(self, dummy_backend):
84+
con = dummy_backend.connection
85+
cube = con.load_collection("S2")
86+
save1 = cube.save_result(format="GTiff")
87+
save2 = cube.save_result(format="netCDF")
88+
multi_result = MultiResult([save1, save2])
89+
job = multi_result.execute_batch(title="multi result test")
90+
assert isinstance(job, BatchJob)
91+
assert dummy_backend.batch_jobs == {
92+
"job-000": {
93+
"job_id": "job-000",
94+
"pg": {
95+
"loadcollection1": {
96+
"process_id": "load_collection",
97+
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
98+
},
99+
"saveresult1": {
100+
"process_id": "save_result",
101+
"arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}},
102+
},
103+
"saveresult2": {
104+
"process_id": "save_result",
105+
"arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}},
106+
"result": True,
107+
},
108+
},
109+
"status": "finished",
110+
"title": "multi result test",
111+
}
112+
}

0 commit comments

Comments
 (0)