Skip to content

Commit 0f10972

Browse files
committed
Issue #401 Improve automatic adding of save_result
1 parent d505757 commit 0f10972

File tree

3 files changed

+148
-38
lines changed

3 files changed

+148
-38
lines changed

openeo/rest/datacube.py

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class DataCube(_ProcessGraphAbstraction):
5858
and this process graph can be "grown" to a desired workflow by calling the appropriate methods.
5959
"""
6060

61+
# TODO: set this based on back-end or user preference?
62+
_DEFAULT_RASTER_FORMAT = "GTiff"
63+
6164
def __init__(self, graph: PGNode, connection: 'openeo.Connection', metadata: CollectionMetadata = None):
6265
super().__init__(pgnode=graph, connection=connection)
6366
self.metadata = CollectionMetadata.get_or_create(metadata)
@@ -1810,36 +1813,41 @@ def atmospheric_correction(
18101813
})
18111814

18121815
@openeo_process
1813-
def save_result(self, format: str = "GTiff", options: dict = None) -> 'DataCube':
1816+
def save_result(
1817+
self, format: str = _DEFAULT_RASTER_FORMAT, options: Optional[dict] = None
1818+
) -> "DataCube":
18141819
formats = set(self._connection.list_output_formats().keys())
1820+
# TODO: map format to correct casing too?
18151821
if format.lower() not in {f.lower() for f in formats}:
18161822
raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats))
18171823
return self.process(
18181824
process_id="save_result",
18191825
arguments={
18201826
"data": THIS,
18211827
"format": format,
1828+
# TODO: leave out options if unset?
18221829
"options": options or {}
18231830
}
18241831
)
18251832

1826-
def download(
1827-
self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None,
1828-
options: Optional[dict] = None
1829-
):
1833+
def _ensure_save_result(
1834+
self, format: Optional[str] = None, options: Optional[dict] = None
1835+
) -> "DataCube":
18301836
"""
1831-
Download image collection, e.g. as GeoTIFF.
1832-
If outputfile is provided, the result is stored on disk locally, otherwise, a bytes object is returned.
1833-
The bytes object can be passed on to a suitable decoder for decoding.
1837+
Make sure there is a (final) `save_result` node in the process graph.
1838+
If there is already one: check if it is consistent with the given format/options (if any)
1839+
and add a new one otherwise.
18341840
1835-
:param outputfile: Optional, an output file if the result needs to be stored on disk.
1836-
:param format: Optional, an output format supported by the backend.
1837-
:param options: Optional, file format options
1838-
:return: None if the result is stored to disk, or a bytes object returned by the backend.
1841+
:param format: (optional) desired `save_result` file format
1842+
:param options: (optional) desired `save_result` file format parameters
1843+
:return:
18391844
"""
1840-
if self.result_node().process_id == "save_result":
1841-
# There is already a `save_result` node: check if it is consistent with given format/options
1842-
args = self.result_node().arguments
1845+
# TODO: move to generic data cube parent class (not only for raster cubes, but also vector cubes)
1846+
result_node = self.result_node()
1847+
if result_node.process_id == "save_result":
1848+
# There is already a `save_result` node:
1849+
# check if it is consistent with given format/options (if any)
1850+
args = result_node.arguments
18431851
if format is not None and format.lower() != args["format"].lower():
18441852
raise ValueError(
18451853
f"Existing `save_result` node with different format {args['format']!r} != {format!r}"
@@ -1851,10 +1859,28 @@ def download(
18511859
cube = self
18521860
else:
18531861
# No `save_result` node yet: automatically add it.
1854-
if not format:
1855-
format = guess_format(outputfile) if outputfile else "GTiff"
1856-
cube = self.save_result(format=format, options=options)
1862+
cube = self.save_result(
1863+
format=format or self._DEFAULT_RASTER_FORMAT, options=options
1864+
)
1865+
return cube
18571866

1867+
def download(
1868+
self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None,
1869+
options: Optional[dict] = None
1870+
):
1871+
"""
1872+
Download image collection, e.g. as GeoTIFF.
1873+
If outputfile is provided, the result is stored on disk locally, otherwise, a bytes object is returned.
1874+
The bytes object can be passed on to a suitable decoder for decoding.
1875+
1876+
:param outputfile: Optional, an output file if the result needs to be stored on disk.
1877+
:param format: Optional, an output format supported by the backend.
1878+
:param options: Optional, file format options
1879+
:return: None if the result is stored to disk, or a bytes object returned by the backend.
1880+
"""
1881+
if format is None and outputfile is not None:
1882+
format = guess_format(outputfile)
1883+
cube = self._ensure_save_result(format=format, options=options)
18581884
return self._connection.download(cube.flat_graph(), outputfile)
18591885

18601886
def validate(self) -> List[dict]:
@@ -1869,27 +1895,36 @@ def tiled_viewing_service(self, type: str, **kwargs) -> Service:
18691895
return self._connection.create_service(self.flat_graph(), type=type, **kwargs)
18701896

18711897
def execute_batch(
1872-
self,
1873-
outputfile: Union[str, pathlib.Path] = None, out_format: str = None,
1874-
print=print, max_poll_interval=60, connection_retry_interval=30,
1875-
job_options=None, **format_options) -> BatchJob:
1898+
self,
1899+
outputfile: Optional[Union[str, pathlib.Path]] = None,
1900+
out_format: Optional[str] = None,
1901+
*,
1902+
print: typing.Callable[[str], None] = print,
1903+
max_poll_interval: float = 60,
1904+
connection_retry_interval: float = 30,
1905+
job_options: Optional[dict] = None,
1906+
# TODO: avoid `format_options` as keyword arguments
1907+
**format_options,
1908+
) -> BatchJob:
18761909
"""
18771910
Evaluate the process graph by creating a batch job, and retrieving the results when it is finished.
18781911
This method is mostly recommended if the batch job is expected to run in a reasonable amount of time.
18791912
18801913
For very long-running jobs, you probably do not want to keep the client running.
18811914
1882-
:param job_options:
18831915
:param outputfile: The path of a file to which a result can be written
1884-
:param out_format: (optional) Format of the job result.
1885-
:param format_options: String Parameters for the job result format
1886-
1916+
:param out_format: (optional) File format to use for the job result.
1917+
:param job_options:
1918+
:param format_options: output file format parameters.
18871919
"""
18881920
if "format" in format_options and not out_format:
18891921
out_format = format_options["format"] # align with 'download' call arg name
1890-
if not out_format:
1891-
out_format = guess_format(outputfile) if outputfile else "GTiff"
1892-
job = self.create_job(out_format, job_options=job_options, **format_options)
1922+
if not out_format and outputfile:
1923+
out_format = guess_format(outputfile)
1924+
1925+
job = self.create_job(
1926+
format=out_format, job_options=job_options, format_options=format_options
1927+
)
18931928
return job.run_synchronous(
18941929
outputfile=outputfile,
18951930
print=print, max_poll_interval=max_poll_interval, connection_retry_interval=connection_retry_interval
@@ -1904,6 +1939,7 @@ def create_job(
19041939
plan: Optional[str] = None,
19051940
budget: Optional[float] = None,
19061941
job_options: Optional[dict] = None,
1942+
# TODO: avoid `format_options` as keyword arguments
19071943
**format_options,
19081944
) -> BatchJob:
19091945
"""
@@ -1914,22 +1950,19 @@ def create_job(
19141950
it still needs to be started and tracked explicitly.
19151951
Use :py:meth:`execute_batch` instead to have the openEO Python client take care of that job management.
19161952
1917-
:param out_format: String Format of the job result.
1953+
:param out_format: output file format.
19181954
:param title: job title
19191955
:param description: job description
19201956
:param plan: billing plan
19211957
:param budget: maximum cost the request is allowed to produce
1922-
:param job_options: A dictionary containing (custom) job options
1923-
:param format_options: String Parameters for the job result format
1958+
:param job_options: custom job options.
1959+
:param format_options: output file format parameters.
19241960
:return: Created job.
19251961
"""
19261962
# TODO: add option to also automatically start the job?
19271963
# TODO: avoid using all kwargs as format_options
19281964
# TODO: centralize `create_job` for `DataCube`, `VectorCube`, `MlModel`, ...
1929-
cube = self
1930-
if out_format:
1931-
# add `save_result` node
1932-
cube = cube.save_result(format=out_format, options=format_options)
1965+
cube = self._ensure_save_result(format=out_format, options=format_options)
19331966
return self._connection.create_job(
19341967
process_graph=cube.flat_graph(),
19351968
title=title,

openeo/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def deep_set(data: dict, *keys, value):
437437
raise ValueError("No keys given")
438438

439439

440-
def guess_format(filename: Union[str, Path]):
440+
def guess_format(filename: Union[str, Path]) -> str:
441441
"""
442442
Guess the output format from a given filename and return the corrected format.
443443
Any names not in the dict get passed through.

tests/rest/datacube/test_datacube.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44
- 1.0.0-style DataCube
55
66
"""
7-
7+
import functools
88
from datetime import date, datetime
99
import pathlib
1010

11+
import mock
1112
import numpy as np
1213
import pytest
1314
import shapely
1415
import shapely.geometry
1516

1617
from openeo.capabilities import ComparableVersion
18+
from openeo.internal.warnings import UserDeprecationWarning
1719
from openeo.rest import BandMathException
1820
from openeo.rest.datacube import DataCube
1921
from .conftest import API_URL
@@ -446,3 +448,78 @@ def result_callback(request, context):
446448
requests_mock.post(API_URL + '/result', content=result_callback)
447449
result = connection.load_collection("S2").download(format=format)
448450
assert result == b"data"
451+
452+
453+
class TestExecuteBatch:
454+
@pytest.fixture
455+
def get_create_job_pg(self, connection):
456+
"""Fixture to help intercepting the process graph that was passed to Connection.create_job"""
457+
with mock.patch.object(connection, "create_job") as create_job:
458+
459+
def get() -> dict:
460+
assert create_job.call_count == 1
461+
return create_job.call_args.kwargs["process_graph"]
462+
463+
yield get
464+
465+
def test_basic(self, connection, s2cube, get_create_job_pg, recwarn, caplog):
466+
s2cube.execute_batch()
467+
pg = get_create_job_pg()
468+
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
469+
assert pg["saveresult1"] == {
470+
"process_id": "save_result",
471+
"arguments": {
472+
"data": {"from_node": "loadcollection1"},
473+
"format": "GTiff",
474+
"options": {},
475+
},
476+
"result": True,
477+
}
478+
assert recwarn.list == []
479+
assert caplog.records == []
480+
481+
@pytest.mark.parametrize(
482+
["format", "expected"],
483+
[(None, "GTiff"), ("GTiff", "GTiff"), ("gtiff", "gtiff"), ("NetCDF", "NetCDF")],
484+
)
485+
def test_format(
486+
self, connection, s2cube, get_create_job_pg, format, expected, recwarn, caplog
487+
):
488+
s2cube.execute_batch(format=format)
489+
pg = get_create_job_pg()
490+
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
491+
assert pg["saveresult1"] == {
492+
"process_id": "save_result",
493+
"arguments": {
494+
"data": {"from_node": "loadcollection1"},
495+
"format": expected,
496+
"options": {},
497+
},
498+
"result": True,
499+
}
500+
assert recwarn.list == []
501+
assert caplog.records == []
502+
503+
@pytest.mark.parametrize(
504+
["out_format", "expected"],
505+
[("GTiff", "GTiff"), ("NetCDF", "NetCDF")],
506+
)
507+
def test_out_format(
508+
self, connection, s2cube, get_create_job_pg, out_format, expected
509+
):
510+
with pytest.warns(
511+
UserDeprecationWarning,
512+
match="`out_format`.*is deprecated.*use `format` instead",
513+
):
514+
s2cube.execute_batch(out_format=out_format)
515+
pg = get_create_job_pg()
516+
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
517+
assert pg["saveresult1"] == {
518+
"process_id": "save_result",
519+
"arguments": {
520+
"data": {"from_node": "loadcollection1"},
521+
"format": expected,
522+
"options": {},
523+
},
524+
"result": True,
525+
}

0 commit comments

Comments
 (0)