From 327e7467a3f680063d2a1efaeee8979ad1b1c6e1 Mon Sep 17 00:00:00 2001 From: Pete Gadomski Date: Mon, 16 Jun 2025 10:16:06 -0600 Subject: [PATCH 1/2] feat: add iter_search --- Cargo.lock | 2 + Cargo.toml | 2 + python/rustac/rustac.pyi | 66 ++++++++++++++++++++++++++- src/lib.rs | 1 + src/search.rs | 98 ++++++++++++++++++++++++++++++++++++++-- tests/test_search.py | 9 ++++ 6 files changed, 174 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6877702..8ea72fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3506,6 +3506,8 @@ dependencies = [ "cargo-lock", "clap", "duckdb", + "futures-core", + "futures-util", "geoarrow-array 0.1.0-dev", "geojson", "parquet", diff --git a/Cargo.toml b/Cargo.toml index 82917e6..ee0a3bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,8 @@ pyo3-log = "0.12.1" tracing = "0.1.41" pyo3-object_store = "0.2.0" parquet = "55.1.0" +futures-core = "0.3.31" +futures-util = "0.3.31" [build-dependencies] cargo-lock = "10" diff --git a/python/rustac/rustac.pyi b/python/rustac/rustac.pyi index 65f5fb3..670f10b 100644 --- a/python/rustac/rustac.pyi +++ b/python/rustac/rustac.pyi @@ -284,7 +284,7 @@ async def search( **kwargs: str, ) -> list[dict[str, Any]]: """ - Searches a STAC API server. + Searches a STAC API server or a stac-geoparquet file. Args: href: The STAC API to search. @@ -333,6 +333,70 @@ async def search( ... ) """ +async def iter_search( + href: str, + *, + intersects: str | dict[str, Any] | None = None, + ids: str | list[str] | None = None, + collections: str | list[str] | None = None, + max_items: int | None = None, + limit: int | None = None, + bbox: list[float] | None = None, + datetime: str | None = None, + include: str | list[str] | None = None, + exclude: str | list[str] | None = None, + sortby: str | list[str | dict[str, str]] | None = None, + filter: str | dict[str, Any] | None = None, + query: dict[str, Any] | None = None, + use_duckdb: bool | None = None, + **kwargs: str, +) -> AsyncIterator[dict[str, Any]]: + """ + Searches a STAC API server and iterates over its items. + + Args: + href: The STAC API to search. + intersects: Searches items + by performing intersection between their geometry and provided GeoJSON + geometry. + ids: Array of Item ids to return. + collections: Array of one or more Collection IDs that + each matching Item must be in. + limit: The page size returned from the server. + bbox: Requested bounding box. + datetime: Single date+time, or a range (`/` separator), + formatted to RFC 3339, section 5.6. Use double dots .. for open + date ranges. + include: fields to include in the response (see [the + extension + docs](https://github.com/stac-api-extensions/fields?tab=readme-ov-file#includeexclude-semantics)) + for more on the semantics). + exclude: fields to exclude from the response (see [the + extension + docs](https://github.com/stac-api-extensions/fields?tab=readme-ov-file#includeexclude-semantics)) + for more on the semantics). + sortby: Fields by which to sort results (use `-field` to sort descending). + filter: CQL2 filter expression. Strings + will be interpreted as cql2-text, dictionaries as cql2-json. + query: Additional filtering based on properties. + It is recommended to use filter instead, if possible. + kwargs: Additional parameters to pass in to the search. + + Returns: + An iterator over STAC items + + Examples: + >>> search = await rustac.iter_search( + ... "https://landsatlook.usgs.gov/stac-server", + ... collections=["landsat-c2l2-sr"], + ... intersects={"type": "Point", "coordinates": [-105.119, 40.173]}, + ... sortby="-properties.datetime", + ... ) + >>> async for item in search: + ... items.append(item) + ... + """ + async def search_to( outfile: str, href: str, diff --git a/src/lib.rs b/src/lib.rs index 4b73d1e..84b8ba6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,7 @@ fn rustac(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { )?)?; m.add_function(wrap_pyfunction!(migrate::migrate, m)?)?; m.add_function(wrap_pyfunction!(read::read, m)?)?; + m.add_function(wrap_pyfunction!(search::iter_search, m)?)?; m.add_function(wrap_pyfunction!(search::search, m)?)?; m.add_function(wrap_pyfunction!(search::search_to, m)?)?; m.add_function(wrap_pyfunction!(version::sha, m)?)?; diff --git a/src/search.rs b/src/search.rs index a8f0723..db91630 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,11 +1,79 @@ use crate::{Error, Json, Result}; +use futures_core::Stream; +use futures_core::stream::BoxStream; +use futures_util::StreamExt; use geojson::Geometry; use pyo3::prelude::*; use pyo3::{Bound, FromPyObject, PyErr, PyResult, exceptions::PyValueError, types::PyDict}; use pyo3_object_store::AnyObjectStore; +use serde_json::{Map, Value}; use stac::Bbox; -use stac_api::{Fields, Filter, Items, Search, Sortby}; +use stac_api::{Client, Fields, Filter, Items, Search, Sortby}; use stac_io::{Format, StacStore}; +use std::sync::Arc; +use tokio::{pin, sync::Mutex}; + +#[pyclass] +struct SearchIterator(Arc>>>>); + +#[pymethods] +impl SearchIterator { + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __anext__<'py>(&self, py: Python<'py>) -> PyResult> { + let stream = self.0.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut stream = stream.lock().await; + if let Some(result) = stream.next().await { + let item = result.map_err(Error::from)?; + Ok(Some(Json(item))) + } else { + Ok(None) + } + }) + } +} + +#[pyfunction] +#[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, **kwargs))] +#[allow(clippy::too_many_arguments)] +pub fn iter_search<'py>( + py: Python<'py>, + href: String, + intersects: Option, + ids: Option, + collections: Option, + limit: Option, + bbox: Option>, + datetime: Option, + include: Option, + exclude: Option, + sortby: Option>, + filter: Option, + query: Option>, + kwargs: Option>, +) -> PyResult> { + let search = build( + intersects, + ids, + collections, + limit, + bbox, + datetime, + include, + exclude, + sortby, + filter, + query, + kwargs, + )?; + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let stream = iter_search_api(href, search).await?; + Ok(SearchIterator(Arc::new(Mutex::new(Box::pin(stream))))) + }) +} #[pyfunction] #[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, use_duckdb=None, **kwargs))] @@ -165,8 +233,32 @@ async fn search_api( search: Search, max_items: Option, ) -> Result { - let value = stac_api::client::search(&href, search, max_items).await?; - Ok(value) + let stream = iter_search_api(href, search).await?; + pin!(stream); + let mut items = if let Some(max_items) = max_items { + Vec::with_capacity(max_items) + } else { + Vec::new() + }; + while let Some(result) = stream.next().await { + let item = result?; + items.push(item); + if let Some(max_items) = max_items { + if items.len() >= max_items { + break; + } + } + } + Ok(items.into()) +} + +async fn iter_search_api( + href: String, + search: Search, +) -> Result>>> { + let client = Client::new(&href)?; + let stream = client.search(search).await?; + Ok(stream) } /// Creates a [Search] from Python arguments. diff --git a/tests/test_search.py b/tests/test_search.py index 665129c..62acd2b 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -113,3 +113,12 @@ async def test_cql(data: Path) -> None: }, max_items=1, ) + + +async def test_iter_search() -> None: + items = [] + search = await rustac.iter_search("https://landsatlook.usgs.gov/stac-server") + async for item in search: + items.append(item) + if len(items) >= 10: + break From 4e54a238f31b6f2b1f0e2095d733d80ec6c51978 Mon Sep 17 00:00:00 2001 From: Pete Gadomski Date: Mon, 16 Jun 2025 10:17:13 -0600 Subject: [PATCH 2/2] chore: update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ecbb5a9..0d40e9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - `type` field to geoparquet writes ([#136](https://github.com/stac-utils/rustac-py/pull/136), ) - `parquet_compression` argument to `write` and `search_to` ([#150](https://github.com/stac-utils/rustac-py/pull/150)) +- `iter_search` ([#151](https://github.com/stac-utils/rustac-py/pull/151)) ### Fixed