Skip to content

Commit 31766cb

Browse files
committed
Issue #401/#449 support format guessing in VectorCube.execute_batch
1 parent 8d1d947 commit 31766cb

File tree

4 files changed

+154
-57
lines changed

4 files changed

+154
-57
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010

11-
- Add support in `VectoCube.download()` to guess output format from extension of a given filename
11+
- Add support in `VectoCube.download()` and `VectorCube.execute_batch()` to guess output format from extension of a given filename
1212
([#401](https://github.com/Open-EO/openeo-python-client/issues/401), [#449](https://github.com/Open-EO/openeo-python-client/issues/449))
1313
- Added `load_stac` for Client Side Processing, based on the [openeo-processes-dask implementation](https://github.com/Open-EO/openeo-processes-dask/pull/127)
1414

openeo/rest/datacube.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,7 +1942,7 @@ def download(
19421942
:param options: Optional, file format options
19431943
:return: None if the result is stored to disk, or a bytes object returned by the backend.
19441944
"""
1945-
if format is None and outputfile is not None:
1945+
if format is None and outputfile:
19461946
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
19471947
format = guess_format(outputfile)
19481948
cube = self._ensure_save_result(format=format, options=options)
@@ -2062,7 +2062,7 @@ def execute_batch(
20622062
"""
20632063
if "format" in format_options and not out_format:
20642064
out_format = format_options["format"] # align with 'download' call arg name
2065-
if not out_format and outputfile:
2065+
if out_format is None and outputfile:
20662066
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
20672067
out_format = guess_format(outputfile)
20682068

openeo/rest/vectorcube.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,22 @@ def execute(self) -> dict:
139139
def download(self, outputfile: Union[str, pathlib.Path], format: Optional[str] = None, options: dict = None):
140140
# TODO #401 make outputfile optional (See DataCube.download)
141141
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
142-
if format is None and outputfile is not None:
142+
if format is None and outputfile:
143143
format = guess_format(outputfile)
144144
cube = self._ensure_save_result(format=format, options=options)
145145
return self._connection.download(cube.flat_graph(), outputfile)
146146

147147
def execute_batch(
148-
self,
149-
outputfile: Union[str, pathlib.Path] = None, out_format: str = None,
150-
print=print, max_poll_interval=60, connection_retry_interval=30,
151-
job_options=None, **format_options) -> BatchJob:
148+
self,
149+
outputfile: Optional[Union[str, pathlib.Path]] = None,
150+
out_format: Optional[str] = None,
151+
print=print,
152+
max_poll_interval: float = 60,
153+
connection_retry_interval: float = 30,
154+
job_options: Optional[dict] = None,
155+
# TODO: avoid using kwargs as format options
156+
**format_options,
157+
) -> BatchJob:
152158
"""
153159
Evaluate the process graph by creating a batch job, and retrieving the results when it is finished.
154160
This method is mostly recommended if the batch job is expected to run in a reasonable amount of time.
@@ -159,8 +165,11 @@ def execute_batch(
159165
:param outputfile: The path of a file to which a result can be written
160166
:param out_format: (optional) Format of the job result.
161167
:param format_options: String Parameters for the job result format
162-
163168
"""
169+
if out_format is None and outputfile:
170+
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
171+
out_format = guess_format(outputfile)
172+
164173
job = self.create_job(out_format, job_options=job_options, **format_options)
165174
return job.run_synchronous(
166175
# TODO #135 support multi file result sets too
@@ -193,11 +202,7 @@ def create_job(
193202
"""
194203
# TODO: avoid using all kwargs as format_options
195204
# TODO: centralize `create_job` for `DataCube`, `VectorCube`, `MlModel`, ...
196-
cube = self
197-
if out_format:
198-
# add `save_result` node
199-
# TODO #401: avoid duplicate save_result
200-
cube = cube.save_result(format=out_format, options=format_options)
205+
cube = self._ensure_save_result(format=out_format, options=format_options or None)
201206
return self._connection.create_job(
202207
process_graph=cube.flat_graph(),
203208
title=title,

tests/rest/datacube/test_vectorcube.py

Lines changed: 135 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import re
12
from pathlib import Path
2-
from typing import List
33

44
import pytest
55

6+
from openeo import Connection
67
from openeo.internal.graph_building import PGNode
78
from openeo.rest.vectorcube import VectorCube
89

@@ -13,42 +14,101 @@ def vector_cube(con100) -> VectorCube:
1314
return VectorCube(graph=pgnode, connection=con100)
1415

1516

16-
class DownloadSpy:
17+
class DummyBackend:
1718
"""
18-
Test helper to track download requests and optionally override next response to return.
19+
Dummy backend that handles sync/batch execution requests
20+
and allows inspection of posted process graphs
1921
"""
2022

21-
__slots__ = ["requests", "next_response"]
23+
def __init__(self, requests_mock, connection: Connection):
24+
self.connection = connection
25+
self.sync_requests = []
26+
self.batch_jobs = {}
27+
self.next_result = b"Result data"
28+
requests_mock.post(connection.build_url("/result"), content=self._handle_post_result)
29+
requests_mock.post(connection.build_url("/jobs"), content=self._handle_post_jobs)
30+
requests_mock.post(
31+
re.compile(connection.build_url(r"/jobs/(job-\d+)/results$")), content=self._handle_post_job_results
32+
)
33+
requests_mock.get(re.compile(connection.build_url(r"/jobs/(job-\d+)$")), json=self._handle_get_job)
34+
requests_mock.get(
35+
re.compile(connection.build_url(r"/jobs/(job-\d+)/results$")), json=self._handle_get_job_results
36+
)
37+
requests_mock.get(
38+
re.compile(connection.build_url("/jobs/(.*?)/results/result.data$")),
39+
content=self._handle_get_job_result_asset,
40+
)
2241

23-
def __init__(self):
24-
self.requests: List[dict] = []
25-
self.next_response: bytes = b"Spy data"
42+
def _handle_post_result(self, request, context):
43+
"""handler of `POST /result` (synchronous execute)"""
44+
pg = request.json()["process"]["process_graph"]
45+
self.sync_requests.append(pg)
46+
return self.next_result
2647

27-
@property
28-
def only_request(self) -> dict:
29-
"""Get progress graph of only request done"""
30-
assert len(self.requests) == 1
31-
return self.requests[-1]
48+
def _handle_post_jobs(self, request, context):
49+
"""handler of `POST /jobs` (create batch job)"""
50+
pg = request.json()["process"]["process_graph"]
51+
job_id = f"job-{len(self.batch_jobs):03d}"
52+
self.batch_jobs[job_id] = {"job_id": job_id, "pg": pg, "status": "created"}
53+
context.status_code = 201
54+
context.headers["openeo-identifier"] = job_id
3255

33-
@property
34-
def last_request(self) -> dict:
35-
"""Get last progress graph"""
36-
assert len(self.requests) > 0
37-
return self.requests[-1]
56+
def _get_job_id(self, request) -> str:
57+
match = re.match(r"^/jobs/(job-\d+)(/|$)", request.path)
58+
if not match:
59+
raise ValueError(f"Failed to extract job_id from {request.path}")
60+
job_id = match.group(1)
61+
assert job_id in self.batch_jobs
62+
return job_id
3863

64+
def _handle_post_job_results(self, request, context):
65+
"""Handler of `POST /job/{job_id}/results` (start batch job)."""
66+
job_id = self._get_job_id(request)
67+
assert self.batch_jobs[job_id]["status"] == "created"
68+
# TODO: support custom status sequence (instead of directly going to status "finished")?
69+
self.batch_jobs[job_id]["status"] = "finished"
70+
context.status_code = 202
3971

40-
@pytest.fixture
41-
def download_spy(requests_mock, con100) -> DownloadSpy:
42-
"""Test fixture to spy on (and mock) `POST /result` (download) requests."""
43-
spy = DownloadSpy()
72+
def _handle_get_job(self, request, context):
73+
"""Handler of `GET /job/{job_id}` (get batch job status and metadata)."""
74+
job_id = self._get_job_id(request)
75+
return {"id": job_id, "status": self.batch_jobs[job_id]["status"]}
4476

45-
def post_result(request, context):
46-
pg = request.json()["process"]["process_graph"]
47-
spy.requests.append(pg)
48-
return spy.next_response
77+
def _handle_get_job_results(self, request, context):
78+
"""Handler of `GET /job/{job_id}/results` (list batch job results)."""
79+
job_id = self._get_job_id(request)
80+
assert self.batch_jobs[job_id]["status"] == "finished"
81+
return {
82+
"id": job_id,
83+
"assets": {"result.data": {"href": self.connection.build_url(f"/jobs/{job_id}/results/result.data")}},
84+
}
85+
86+
def _handle_get_job_result_asset(self, request, context):
87+
"""Handler of `GET /job/{job_id}/results/result.data` (get batch job result asset)."""
88+
job_id = self._get_job_id(request)
89+
assert self.batch_jobs[job_id]["status"] == "finished"
90+
return self.next_result
91+
92+
def get_sync_pg(self) -> dict:
93+
"""Get one and only synchronous process graph"""
94+
assert len(self.sync_requests) == 1
95+
return self.sync_requests[0]
96+
97+
def get_batch_pg(self) -> dict:
98+
"""Get one and only batch process graph"""
99+
assert len(self.batch_jobs) == 1
100+
return self.batch_jobs[max(self.batch_jobs.keys())]["pg"]
49101

50-
requests_mock.post(con100.build_url("/result"), content=post_result)
51-
yield spy
102+
def get_pg(self) -> dict:
103+
"""Get one and only batch process graph (sync or batch)"""
104+
pgs = self.sync_requests + [b["pg"] for b in self.batch_jobs.values()]
105+
assert len(pgs) == 1
106+
return pgs[0]
107+
108+
109+
@pytest.fixture
110+
def dummy_backend(requests_mock, con100) -> DummyBackend:
111+
yield DummyBackend(requests_mock=requests_mock, connection=con100)
52112

53113

54114
def test_raster_to_vector(con100):
@@ -91,13 +151,19 @@ def test_raster_to_vector(con100):
91151
],
92152
)
93153
@pytest.mark.parametrize("path_class", [str, Path])
154+
@pytest.mark.parametrize("exec_mode", ["sync", "batch"])
94155
def test_download_auto_save_result_only_file(
95-
vector_cube, download_spy, tmp_path, filename, expected_format, path_class
156+
vector_cube, dummy_backend, tmp_path, filename, expected_format, path_class, exec_mode
96157
):
97158
output_path = tmp_path / filename
98-
vector_cube.download(path_class(output_path))
159+
if exec_mode == "sync":
160+
vector_cube.download(path_class(output_path))
161+
elif exec_mode == "batch":
162+
vector_cube.execute_batch(outputfile=path_class(output_path))
163+
else:
164+
raise ValueError(exec_mode)
99165

100-
assert download_spy.only_request == {
166+
assert dummy_backend.get_pg() == {
101167
"createvectorcube1": {"process_id": "create_vector_cube", "arguments": {}},
102168
"saveresult1": {
103169
"process_id": "save_result",
@@ -109,7 +175,7 @@ def test_download_auto_save_result_only_file(
109175
"result": True,
110176
},
111177
}
112-
assert output_path.read_bytes() == b"Spy data"
178+
assert output_path.read_bytes() == b"Result data"
113179

114180

115181
@pytest.mark.parametrize(
@@ -126,11 +192,19 @@ def test_download_auto_save_result_only_file(
126192
# TODO #449 more formats to autodetect?
127193
],
128194
)
129-
def test_download_auto_save_result_with_format(vector_cube, download_spy, tmp_path, filename, format, expected_format):
195+
@pytest.mark.parametrize("exec_mode", ["sync", "batch"])
196+
def test_download_auto_save_result_with_format(
197+
vector_cube, dummy_backend, tmp_path, filename, format, expected_format, exec_mode
198+
):
130199
output_path = tmp_path / filename
131-
vector_cube.download(output_path, format=format)
200+
if exec_mode == "sync":
201+
vector_cube.download(output_path, format=format)
202+
elif exec_mode == "batch":
203+
vector_cube.execute_batch(outputfile=output_path, out_format=format)
204+
else:
205+
raise ValueError(exec_mode)
132206

133-
assert download_spy.only_request == {
207+
assert dummy_backend.get_pg() == {
134208
"createvectorcube1": {"process_id": "create_vector_cube", "arguments": {}},
135209
"saveresult1": {
136210
"process_id": "save_result",
@@ -142,14 +216,23 @@ def test_download_auto_save_result_with_format(vector_cube, download_spy, tmp_pa
142216
"result": True,
143217
},
144218
}
145-
assert output_path.read_bytes() == b"Spy data"
219+
assert output_path.read_bytes() == b"Result data"
146220

147221

148-
def test_download_auto_save_result_with_options(vector_cube, download_spy, tmp_path):
222+
@pytest.mark.parametrize("exec_mode", ["sync", "batch"])
223+
def test_download_auto_save_result_with_options(vector_cube, dummy_backend, tmp_path, exec_mode):
149224
output_path = tmp_path / "result.json"
150-
vector_cube.download(output_path, format="GeoJSON", options={"precision": 7})
225+
format = "GeoJSON"
226+
options = {"precision": 7}
151227

152-
assert download_spy.only_request == {
228+
if exec_mode == "sync":
229+
vector_cube.download(output_path, format=format, options=options)
230+
elif exec_mode == "batch":
231+
vector_cube.execute_batch(outputfile=output_path, out_format=format, **options)
232+
else:
233+
raise ValueError(exec_mode)
234+
235+
assert dummy_backend.get_pg() == {
153236
"createvectorcube1": {"process_id": "create_vector_cube", "arguments": {}},
154237
"saveresult1": {
155238
"process_id": "save_result",
@@ -161,7 +244,7 @@ def test_download_auto_save_result_with_options(vector_cube, download_spy, tmp_p
161244
"result": True,
162245
},
163246
}
164-
assert output_path.read_bytes() == b"Spy data"
247+
assert output_path.read_bytes() == b"Result data"
165248

166249

167250
@pytest.mark.parametrize(
@@ -173,17 +256,26 @@ def test_download_auto_save_result_with_options(vector_cube, download_spy, tmp_p
173256
("result.nc", "netCDF", "netCDF"),
174257
],
175258
)
176-
def test_save_result_and_download(vector_cube, download_spy, tmp_path, output_file, format, expected_format):
259+
@pytest.mark.parametrize("exec_mode", ["sync", "batch"])
260+
def test_save_result_and_download(
261+
vector_cube, dummy_backend, tmp_path, output_file, format, expected_format, exec_mode
262+
):
177263
"""e.g. https://github.com/Open-EO/openeo-geopyspark-driver/issues/477"""
178264
vector_cube = vector_cube.save_result(format=format)
179265
output_path = tmp_path / output_file
180-
vector_cube.download(output_path)
181-
assert download_spy.only_request == {
266+
if exec_mode == "sync":
267+
vector_cube.download(output_path)
268+
elif exec_mode == "batch":
269+
vector_cube.execute_batch(outputfile=output_path)
270+
else:
271+
raise ValueError(exec_mode)
272+
273+
assert dummy_backend.get_pg() == {
182274
"createvectorcube1": {"process_id": "create_vector_cube", "arguments": {}},
183275
"saveresult1": {
184276
"process_id": "save_result",
185277
"arguments": {"data": {"from_node": "createvectorcube1"}, "format": expected_format, "options": {}},
186278
"result": True,
187279
},
188280
}
189-
assert output_path.read_bytes() == b"Spy data"
281+
assert output_path.read_bytes() == b"Result data"

0 commit comments

Comments
 (0)