Skip to content

Commit d63d93b

Browse files
committed
Improve auto-adding of save_result (#623, #401, #583, #391)
- Check whole process graph for pre-existing `save_result` nodes, not just final node - Disallow ambiguity of combining explicit `save_result` and download/create_job with format
1 parent 26bef79 commit d63d93b

File tree

5 files changed

+126
-55
lines changed

5 files changed

+126
-55
lines changed

openeo/internal/graph_building.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import sys
1515
from contextlib import nullcontext
1616
from pathlib import Path
17-
from typing import Any, Dict, Optional, Tuple, Union
17+
from typing import Any, Dict, Iterator, Optional, Tuple, Union
1818

1919
from openeo.api.process import Parameter
2020
from openeo.internal.process_graph_visitor import (
@@ -225,6 +225,24 @@ def from_flat_graph(flat_graph: dict, parameters: Optional[dict] = None) -> PGNo
225225
return PGNodeGraphUnflattener.unflatten(flat_graph=flat_graph, parameters=parameters)
226226

227227

228+
def walk_nodes(self) -> Iterator[PGNode]:
229+
"""Walk this node and all it's parents"""
230+
# TODO: option to do deep walk (walk through child graphs too)?
231+
yield self
232+
233+
def walk(x) -> Iterator[PGNode]:
234+
if isinstance(x, PGNode):
235+
yield from x.walk_nodes()
236+
elif isinstance(x, dict):
237+
for v in x.values():
238+
yield from walk(v)
239+
elif isinstance(x, (list, tuple)):
240+
for v in x:
241+
yield from walk(v)
242+
243+
yield from walk(self.arguments)
244+
245+
228246
def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, dict]:
229247
"""
230248
Convert given object to a internal flat dict graph representation.

openeo/rest/datacube.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,8 +2097,11 @@ def save_result(
20972097

20982098
def _ensure_save_result(
20992099
self,
2100+
*,
21002101
format: Optional[str] = None,
21012102
options: Optional[dict] = None,
2103+
weak_format: Optional[str] = None,
2104+
method: str,
21022105
) -> DataCube:
21032106
"""
21042107
Make sure there is a (final) `save_result` node in the process graph.
@@ -2110,25 +2113,19 @@ def _ensure_save_result(
21102113
:return:
21112114
"""
21122115
# TODO #401 Unify with VectorCube._ensure_save_result and move to generic data cube parent class (not only for raster cubes, but also vector cubes)
2113-
result_node = self.result_node()
2114-
if result_node.process_id == "save_result":
2115-
# There is already a `save_result` node:
2116-
# check if it is consistent with given format/options (if any)
2117-
args = result_node.arguments
2118-
if format is not None and format.lower() != args["format"].lower():
2119-
raise ValueError(
2120-
f"Existing `save_result` node with different format {args['format']!r} != {format!r}"
2121-
)
2122-
if options is not None and options != args["options"]:
2123-
raise ValueError(
2124-
f"Existing `save_result` node with different options {args['options']!r} != {options!r}"
2125-
)
2126-
cube = self
2127-
else:
2116+
save_result_nodes = [n for n in self.result_node().walk_nodes() if n.process_id == "save_result"]
2117+
2118+
cube = self
2119+
if not save_result_nodes:
21282120
# No `save_result` node yet: automatically add it.
2129-
cube = self.save_result(
2130-
format=format or self._DEFAULT_RASTER_FORMAT, options=options
2121+
cube = cube.save_result(format=format or weak_format or self._DEFAULT_RASTER_FORMAT, options=options)
2122+
elif format or options:
2123+
raise OpenEoClientException(
2124+
f"{method} with explicit output {'format' if format else 'options'} {format or options!r},"
2125+
f" but the process graph already has `save_result` node(s)"
2126+
f" which is ambiguous and should not be combined."
21312127
)
2128+
21322129
return cube
21332130

21342131
def download(
@@ -2152,10 +2149,8 @@ def download(
21522149
(overruling the connection's ``auto_validate`` setting).
21532150
:return: None if the result is stored to disk, or a bytes object returned by the backend.
21542151
"""
2155-
if format is None and outputfile:
2156-
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
2157-
format = guess_format(outputfile)
2158-
cube = self._ensure_save_result(format=format, options=options)
2152+
weak_format = guess_format(outputfile) if outputfile else None
2153+
cube = self._ensure_save_result(format=format, options=options, weak_format=weak_format, method="Download")
21592154
return self._connection.download(cube.flat_graph(), outputfile, validate=validate)
21602155

21612156
def validate(self) -> List[dict]:
@@ -2321,7 +2316,7 @@ def create_job(
23212316
# TODO: add option to also automatically start the job?
23222317
# TODO: avoid using all kwargs as format_options
23232318
# TODO: centralize `create_job` for `DataCube`, `VectorCube`, `MlModel`, ...
2324-
cube = self._ensure_save_result(format=out_format, options=format_options or None)
2319+
cube = self._ensure_save_result(format=out_format, options=format_options or None, method="Creating job")
23252320
return self._connection.create_job(
23262321
process_graph=cube.flat_graph(),
23272322
title=title,

tests/internal/test_graphbuilding.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,36 @@ def test_parameter_substitution_undefined(self):
379379
}
380380
with pytest.raises(ProcessGraphVisitException, match="No substitution value for parameter 'increment'"):
381381
_ = PGNodeGraphUnflattener.unflatten(flat_graph, parameters={"other": 100})
382+
383+
384+
def test_walk_nodes_basic():
385+
node = PGNode("foo")
386+
walk = node.walk_nodes()
387+
assert next(walk) is node
388+
with pytest.raises(StopIteration):
389+
next(walk)
390+
391+
392+
def test_walk_nodes_args():
393+
data = PGNode("load")
394+
geometry = PGNode("vector")
395+
node = PGNode("foo", data=data, geometry=geometry)
396+
397+
walk = node.walk_nodes()
398+
assert next(walk) is node
399+
rest = list(walk)
400+
assert rest == [data, geometry] or rest == [geometry, data]
401+
402+
403+
def test_walk_nodes_nested():
404+
node = PGNode(
405+
"foo",
406+
cubes=[PGNode("load1"), PGNode("load2")],
407+
size={
408+
"x": PGNode("add", x=PGNode("five"), y=3),
409+
"y": PGNode("max"),
410+
},
411+
)
412+
walk = list(node.walk_nodes())
413+
assert all(isinstance(n, PGNode) for n in walk)
414+
assert set(n.process_id for n in walk) == {"load1", "max", "foo", "load2", "add", "five"}

tests/rest/datacube/test_datacube.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import contextlib
99
import pathlib
10+
import re
1011
from datetime import date, datetime
1112
from unittest import mock
1213

@@ -710,21 +711,25 @@ def test_create_job_out_format(
710711
@pytest.mark.parametrize(
711712
["save_result_format", "execute_format", "expected"],
712713
[
713-
("GTiff", "GTiff", "GTiff"),
714+
(None, None, "GTiff"),
715+
(None, "GTiff", "GTiff"),
714716
("GTiff", None, "GTiff"),
715-
("NetCDF", "NetCDF", "NetCDF"),
717+
(None, "NetCDF", "NetCDF"),
716718
("NetCDF", None, "NetCDF"),
717719
],
718720
)
719-
def test_create_job_existing_save_result(
721+
def test_save_result_and_create_job_at_most_one_with_format(
720722
self,
721723
s2cube,
722724
get_create_job_pg,
723725
save_result_format,
724726
execute_format,
725727
expected,
726728
):
727-
cube = s2cube.save_result(format=save_result_format)
729+
cube = s2cube
730+
if save_result_format:
731+
cube = cube.save_result(format=save_result_format)
732+
728733
cube.create_job(out_format=execute_format)
729734
pg = get_create_job_pg()
730735
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
@@ -740,13 +745,21 @@ def test_create_job_existing_save_result(
740745

741746
@pytest.mark.parametrize(
742747
["save_result_format", "execute_format"],
743-
[("NetCDF", "GTiff"), ("GTiff", "NetCDF")],
748+
[
749+
("NetCDF", "NetCDF"),
750+
("GTiff", "NetCDF"),
751+
],
744752
)
745-
def test_create_job_existing_save_result_incompatible(
746-
self, s2cube, save_result_format, execute_format
747-
):
753+
def test_save_result_and_create_job_both_with_format(self, s2cube, save_result_format, execute_format):
748754
cube = s2cube.save_result(format=save_result_format)
749-
with pytest.raises(ValueError):
755+
with pytest.raises(
756+
OpenEoClientException,
757+
match=re.escape(
758+
"Creating job with explicit output format 'NetCDF',"
759+
" but the process graph already has `save_result` node(s)"
760+
" which is ambiguous and should not be combined."
761+
),
762+
):
750763
cube.create_job(out_format=execute_format)
751764

752765
def test_execute_batch_defaults(self, s2cube, get_create_job_pg, recwarn, caplog):
@@ -808,21 +821,24 @@ def test_execute_batch_out_format_from_output_file(
808821
@pytest.mark.parametrize(
809822
["save_result_format", "execute_format", "expected"],
810823
[
811-
("GTiff", "GTiff", "GTiff"),
824+
(None, None, "GTiff"),
825+
(None, "GTiff", "GTiff"),
812826
("GTiff", None, "GTiff"),
813-
("NetCDF", "NetCDF", "NetCDF"),
827+
(None, "NetCDF", "NetCDF"),
814828
("NetCDF", None, "NetCDF"),
815829
],
816830
)
817-
def test_execute_batch_existing_save_result(
831+
def test_save_result_and_execute_batch_at_most_one_with_format(
818832
self,
819833
s2cube,
820834
get_create_job_pg,
821835
save_result_format,
822836
execute_format,
823837
expected,
824838
):
825-
cube = s2cube.save_result(format=save_result_format)
839+
cube = s2cube
840+
if save_result_format:
841+
cube = cube.save_result(format=save_result_format)
826842
cube.execute_batch(out_format=execute_format)
827843
pg = get_create_job_pg()
828844
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
@@ -838,13 +854,23 @@ def test_execute_batch_existing_save_result(
838854

839855
@pytest.mark.parametrize(
840856
["save_result_format", "execute_format"],
841-
[("NetCDF", "GTiff"), ("GTiff", "NetCDF")],
857+
[
858+
("NetCDF", "NetCDF"),
859+
("GTiff", "NetCDF"),
860+
],
842861
)
843862
def test_execute_batch_existing_save_result_incompatible(
844863
self, s2cube, save_result_format, execute_format
845864
):
846865
cube = s2cube.save_result(format=save_result_format)
847-
with pytest.raises(ValueError):
866+
with pytest.raises(
867+
OpenEoClientException,
868+
match=re.escape(
869+
"Creating job with explicit output format 'NetCDF',"
870+
" but the process graph already has `save_result` node(s)"
871+
" which is ambiguous and should not be combined."
872+
),
873+
):
848874
cube.execute_batch(out_format=execute_format)
849875

850876
def test_save_result_format_options_vs_create_job(elf, s2cube, get_create_job_pg):

tests/rest/datacube/test_datacube100.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3257,50 +3257,49 @@ def test_apply_append_math_keep_context(con100):
32573257
({}, "result.nc", {}, b"this is netCDF data"),
32583258
({"format": "GTiff"}, "result.tiff", {}, b"this is GTiff data"),
32593259
({"format": "GTiff"}, "result.tif", {}, b"this is GTiff data"),
3260-
(
3261-
{"format": "GTiff"},
3262-
"result.nc",
3263-
{},
3264-
ValueError(
3265-
"Existing `save_result` node with different format 'GTiff' != 'netCDF'"
3266-
),
3267-
),
3260+
({"format": "GTiff"}, "result.nc", {}, b"this is GTiff data"),
32683261
({}, "result.tiff", {"format": "GTiff"}, b"this is GTiff data"),
32693262
({}, "result.nc", {"format": "netCDF"}, b"this is netCDF data"),
32703263
({}, "result.meh", {"format": "netCDF"}, b"this is netCDF data"),
32713264
(
32723265
{"format": "GTiff"},
32733266
"result.tiff",
32743267
{"format": "GTiff"},
3275-
b"this is GTiff data",
3268+
OpenEoClientException(
3269+
"Download with explicit output format 'GTiff', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
3270+
),
32763271
),
32773272
(
32783273
{"format": "netCDF"},
32793274
"result.tiff",
32803275
{"format": "NETCDF"},
3281-
b"this is netCDF data",
3276+
OpenEoClientException(
3277+
"Download with explicit output format 'NETCDF', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
3278+
),
32823279
),
32833280
(
32843281
{"format": "netCDF"},
32853282
"result.json",
32863283
{"format": "JSON"},
3287-
ValueError(
3288-
"Existing `save_result` node with different format 'netCDF' != 'JSON'"
3284+
OpenEoClientException(
3285+
"Download with explicit output format 'JSON', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
32893286
),
32903287
),
32913288
({"options": {}}, "result.tiff", {}, b"this is GTiff data"),
32923289
(
32933290
{"options": {"quality": "low"}},
32943291
"result.tiff",
32953292
{"options": {"quality": "low"}},
3296-
b"this is GTiff data",
3293+
OpenEoClientException(
3294+
"Download with explicit output options {'quality': 'low'}, but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
3295+
),
32973296
),
32983297
(
32993298
{"options": {"colormap": "jet"}},
33003299
"result.tiff",
33013300
{"options": {"quality": "low"}},
3302-
ValueError(
3303-
"Existing `save_result` node with different options {'colormap': 'jet'} != {'quality': 'low'}"
3301+
OpenEoClientException(
3302+
"Download with explicit output options {'quality': 'low'}, but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
33043303
),
33053304
),
33063305
],
@@ -3328,8 +3327,8 @@ def post_result(request, context):
33283327
cube = cube.save_result(**save_result_kwargs)
33293328

33303329
path = tmp_path / download_filename
3331-
if isinstance(expected, ValueError):
3332-
with pytest.raises(ValueError, match=str(expected)):
3330+
if isinstance(expected, Exception):
3331+
with pytest.raises(type(expected), match=re.escape(str(expected))):
33333332
cube.download(str(path), **download_kwargs)
33343333
assert post_result_mock.call_count == 0
33353334
else:

0 commit comments

Comments
 (0)