Skip to content

Commit d9d9f90

Browse files
committed
fixup! fixup! fixup! fixup! fixup! fixup! Issue #114/#141 convert inline GeoJSON in aggregate_spatial to VectorCube
1 parent 842b3d7 commit d9d9f90

File tree

6 files changed

+131
-36
lines changed

6 files changed

+131
-36
lines changed

openeo_driver/ProcessGraphDeserializer.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -740,32 +740,24 @@ def fit_class_random_forest(args: dict, env: EvalEnv) -> DriverMlModel:
740740
return DriverMlModel()
741741

742742
predictors = extract_arg(args, 'predictors')
743-
if not isinstance(predictors, AggregatePolygonSpatialResult):
744-
# TODO #114 EP-3981 add support for real vector cubes.
743+
if not isinstance(predictors, (AggregatePolygonSpatialResult, DriverVectorCube)):
744+
# TODO #114 EP-3981 drop AggregatePolygonSpatialResult support.
745745
raise ProcessParameterInvalidException(
746-
parameter="predictors", process="fit_class_random_forest",
747-
reason=f"should be non-temporal vector-cube (got `{type(predictors)}`)."
748-
)
749-
target = extract_arg(args, 'target')
750-
if not (
751-
isinstance(target, dict)
752-
and target.get("type") == "FeatureCollection"
753-
and isinstance(target.get("features"), list)
754-
):
755-
# TODO #114 EP-3981 vector cube support
756-
raise ProcessParameterInvalidException(
757-
parameter="target", process="fit_class_random_forest",
758-
reason='only GeoJSON FeatureCollection is currently supported.',
746+
parameter="predictors",
747+
process="fit_class_random_forest",
748+
reason=f"should be non-temporal vector-cube, got {type(predictors)}.",
759749
)
760-
if any(
761-
# TODO: allow string based target labels too?
762-
not isinstance(deep_get(f, "properties", "target", default=None), int)
763-
for f in target.get("features", [])
764-
):
750+
751+
target = extract_arg(args, "target")
752+
if isinstance(target, DriverVectorCube):
753+
pass
754+
elif isinstance(target, dict) and target.get("type") == "FeatureCollection":
755+
target = env.backend_implementation.DriverVectorCube.from_geojson(target)
756+
else:
765757
raise ProcessParameterInvalidException(
766758
parameter="target",
767759
process="fit_class_random_forest",
768-
reason="Each feature (from target feature collection) should have an integer 'target' property.",
760+
reason=f"expected vector-cube like value but got {type(target)}.",
769761
)
770762

771763
# TODO: get defaults from process spec?
@@ -1056,7 +1048,7 @@ def aggregate_spatial(args: dict, env: EvalEnv) -> DriverDataCube:
10561048
if isinstance(geoms, DriverVectorCube):
10571049
geoms = geoms
10581050
elif isinstance(geoms, dict):
1059-
geoms = DriverVectorCube.from_geojson(geoms)
1051+
geoms = env.backend_implementation.DriverVectorCube.from_geojson(geoms)
10601052
elif isinstance(geoms, DelayedVector):
10611053
geoms = geoms.path
10621054
else:
@@ -1559,7 +1551,7 @@ def to_vector_cube(args: Dict, env: EvalEnv):
15591551
# TODO: standardization of something like this? https://github.com/Open-EO/openeo-processes/issues/346
15601552
data = extract_arg(args, "data", process_id="to_vector_cube")
15611553
if isinstance(data, dict) and data.get("type") in {"Polygon", "MultiPolygon", "Feature", "FeatureCollection"}:
1562-
return DriverVectorCube.from_geojson(data)
1554+
return env.backend_implementation.DriverVectorCube.from_geojson(data)
15631555
# TODO: support more inputs: string with geojson, string with WKT, list of WKT, string with URL to GeoJSON, ...
15641556
raise FeatureUnsupportedException(f"Converting {type(data)} to vector cube is not supported")
15651557

openeo_driver/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.20.4a1'
1+
__version__ = "0.21.0a1"

openeo_driver/backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from openeo.capabilities import ComparableVersion
2424
from openeo.internal.process_graph_visitor import ProcessGraphVisitor
2525
from openeo.util import rfc3339, dict_no_none
26-
from openeo_driver.datacube import DriverDataCube, DriverMlModel
26+
from openeo_driver.datacube import DriverDataCube, DriverMlModel, DriverVectorCube
2727
from openeo_driver.datastructs import SarBackscatterArgs
2828
from openeo_driver.dry_run import SourceConstraint
2929
from openeo_driver.errors import CollectionNotFoundException, ServiceUnsupportedException, FeatureUnsupportedException
@@ -570,6 +570,9 @@ class OpenEoBackendImplementation:
570570
enable_basic_auth = True
571571
enable_oidc_auth = True
572572

573+
# Overridable vector cube implementation
574+
DriverVectorCube = DriverVectorCube
575+
573576
def __init__(
574577
self,
575578
secondary_services: Optional[SecondaryServices] = None,

openeo_driver/datacube.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ def __init__(
177177
def with_cube(self, cube: xarray.DataArray, flatten_prefix: str = FLATTEN_PREFIX) -> "DriverVectorCube":
178178
"""Create new vector cube with same geometries but new cube"""
179179
log.info(f"Creating vector cube with new cube {cube.name!r}")
180-
return DriverVectorCube(geometries=self._geometries, cube=cube, flatten_prefix=flatten_prefix)
180+
return type(self)(
181+
geometries=self._geometries, cube=cube, flatten_prefix=flatten_prefix
182+
)
181183

182184
@classmethod
183185
def from_fiona(cls, paths: List[str], driver: str, options: dict) -> "DriverVectorCube":
@@ -333,6 +335,15 @@ def __eq__(self, other):
333335
return (isinstance(other, DriverVectorCube)
334336
and np.array_equal(self._as_geopandas_df().values, other._as_geopandas_df().values))
335337

338+
def fit_class_random_forest(
339+
self,
340+
target: "DriverVectorCube",
341+
num_trees: int = 100,
342+
max_variables: Optional[Union[int, str]] = None,
343+
seed: Optional[int] = None,
344+
) -> "DriverMlModel":
345+
raise NotImplementedError
346+
336347

337348
class DriverMlModel:
338349
"""Base class for driver-side 'ml-model' data structures"""

openeo_driver/dummy/dummy_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,25 @@ def fit_class_random_forest(
284284
)
285285

286286

287+
class DummyVectorCube(DriverVectorCube):
288+
def fit_class_random_forest(
289+
self,
290+
target: DriverVectorCube,
291+
num_trees: int = 100,
292+
max_variables: Optional[Union[int, str]] = None,
293+
seed: Optional[int] = None,
294+
) -> "DriverMlModel":
295+
return DummyMlModel(
296+
process_id="fit_class_random_forest",
297+
# TODO: handle `to_geojson` in `DummyMlModel.write_assets` instead of here?
298+
data=self.to_geojson(),
299+
target=target.to_geojson(),
300+
num_trees=num_trees,
301+
max_variables=max_variables,
302+
seed=seed,
303+
)
304+
305+
287306
class DummyMlModel(DriverMlModel):
288307

289308
def __init__(self, **kwargs):
@@ -608,6 +627,9 @@ def delete(self, user_id: str, process_id: str) -> None:
608627

609628

610629
class DummyBackendImplementation(OpenEoBackendImplementation):
630+
631+
DriverVectorCube = DummyVectorCube
632+
611633
def __init__(self, processing: Optional[Processing] = None):
612634
super(DummyBackendImplementation, self).__init__(
613635
secondary_services=DummySecondaryServices(),

tests/test_views_execute.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2699,17 +2699,84 @@ def test_chunk_polygon(api100):
26992699

27002700
def test_fit_class_random_forest(api100):
27012701
res = api100.check_result("fit_class_random_forest.json")
2702-
assert res.json == DictSubSet({
2703-
"type": "DummyMlModel",
2704-
"creation_data": {
2705-
"process_id": "fit_class_random_forest",
2706-
"data": [[100.0, 100.1, 100.2, 100.3], [101.0, 101.1, 101.2, 101.3]],
2707-
"target": DictSubSet({"type": "FeatureCollection"}),
2708-
"num_trees": 200,
2709-
"max_variables": None,
2710-
"seed": None,
2702+
2703+
geom1 = {
2704+
"type": "Polygon",
2705+
"coordinates": [[[3.0, 5.0], [4.0, 5.0], [4.0, 6.0], [3.0, 6.0], [3.0, 5.0]]],
2706+
}
2707+
geom2 = {
2708+
"type": "Polygon",
2709+
"coordinates": [[[8.0, 1.0], [9.0, 1.0], [9.0, 2.0], [8.0, 2.0], [8.0, 1.0]]],
2710+
}
2711+
assert res.json == DictSubSet(
2712+
{
2713+
"type": "DummyMlModel",
2714+
"creation_data": {
2715+
"process_id": "fit_class_random_forest",
2716+
"data": DictSubSet(
2717+
{
2718+
"type": "FeatureCollection",
2719+
"features": [
2720+
DictSubSet(
2721+
{
2722+
"type": "Feature",
2723+
"id": "0",
2724+
"geometry": geom1,
2725+
"properties": {
2726+
"agg~B02": 2.345,
2727+
"agg~B03": None,
2728+
"agg~B04": 2.0,
2729+
"agg~B08": 3.0,
2730+
"target": 0,
2731+
},
2732+
}
2733+
),
2734+
DictSubSet(
2735+
{
2736+
"type": "Feature",
2737+
"id": "1",
2738+
"geometry": geom2,
2739+
"properties": {
2740+
"agg~B02": 4.0,
2741+
"agg~B03": 5.0,
2742+
"agg~B04": 6.0,
2743+
"agg~B08": 7.0,
2744+
"target": 1,
2745+
},
2746+
}
2747+
),
2748+
],
2749+
}
2750+
),
2751+
"target": DictSubSet(
2752+
{
2753+
"type": "FeatureCollection",
2754+
"features": [
2755+
DictSubSet(
2756+
{
2757+
"type": "Feature",
2758+
"id": "0",
2759+
"geometry": geom1,
2760+
"properties": {"target": 0},
2761+
}
2762+
),
2763+
DictSubSet(
2764+
{
2765+
"type": "Feature",
2766+
"id": "1",
2767+
"geometry": geom2,
2768+
"properties": {"target": 1},
2769+
}
2770+
),
2771+
],
2772+
}
2773+
),
2774+
"max_variables": None,
2775+
"num_trees": 200,
2776+
"seed": None,
2777+
},
27112778
}
2712-
})
2779+
)
27132780

27142781

27152782
def test_if_merge_cubes(api100):

0 commit comments

Comments
 (0)