diff --git a/.dockerignore b/.dockerignore index 5173d0657..bdd66ba7a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,4 @@ +tests docker-compose.yml **/__pycache__ *.pyc diff --git a/docker-compose.yml b/docker-compose.yml index e41968aaa..fc716165c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,9 +9,9 @@ services: dockerfile: Dockerfile platform: linux/amd64 volumes: - - ./stac_fastapi:/app/stac_fastapi - - ./scripts:/app/scripts - - ./tests:/app/tests + - ./stac_fastapi:/opt/src/stac_fastapi + - ./scripts:/opt/src/scripts + - ./tests:/opt/src/tests command: - bash - -c diff --git a/stac_fastapi/api/app.py b/stac_fastapi/api/app.py index d18844e5c..70ec259d1 100644 --- a/stac_fastapi/api/app.py +++ b/stac_fastapi/api/app.py @@ -143,7 +143,7 @@ def register_landing_page(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.landing_page, EmptyRequest, self.response_class + self.client.handle_landing_page, EmptyRequest, self.response_class ), ) @@ -164,7 +164,7 @@ def register_conformance_classes(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.conformance, EmptyRequest, self.response_class + self.client.handle_conformance, EmptyRequest, self.response_class ), ) @@ -183,7 +183,7 @@ def register_get_item(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.get_item, ItemUri, self.response_class + self.client.handle_get_item, ItemUri, self.response_class ), ) @@ -205,7 +205,9 @@ def register_post_search(self): response_model_exclude_none=True, methods=["POST"], endpoint=self._create_endpoint( - self.client.post_search, self.search_post_request_model, GeoJSONResponse + self.client.handle_post_search, + self.search_post_request_model, + GeoJSONResponse, ), ) @@ -227,7 +229,9 @@ def register_get_search(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.get_search, self.search_get_request_model, GeoJSONResponse + self.client.handle_get_search, + self.search_get_request_model, + GeoJSONResponse, ), ) @@ -248,7 +252,7 @@ def register_get_collections(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.all_collections, EmptyRequest, self.response_class + self.client.handle_all_collections, EmptyRequest, self.response_class ), ) @@ -267,7 +271,7 @@ def register_get_collection(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.get_collection, CollectionUri, self.response_class + self.client.handle_get_collection, CollectionUri, self.response_class ), ) @@ -298,7 +302,7 @@ def register_get_item_collection(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.item_collection, request_model, self.response_class + self.client.handle_collection_items, request_model, self.response_class ), ) diff --git a/stac_fastapi/extensions/core/transaction.py b/stac_fastapi/extensions/core/transaction.py index 5967e7128..b93cccd14 100644 --- a/stac_fastapi/extensions/core/transaction.py +++ b/stac_fastapi/extensions/core/transaction.py @@ -91,7 +91,7 @@ def register_create_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["POST"], - endpoint=self._create_endpoint(self.client.create_item, PostItem), + endpoint=self._create_endpoint(self.client.handle_create_item, PostItem), ) def register_update_item(self): @@ -104,7 +104,7 @@ def register_update_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["PUT"], - endpoint=self._create_endpoint(self.client.update_item, PutItem), + endpoint=self._create_endpoint(self.client.handle_update_item, PutItem), ) def register_delete_item(self): @@ -117,7 +117,7 @@ def register_delete_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["DELETE"], - endpoint=self._create_endpoint(self.client.delete_item, ItemUri), + endpoint=self._create_endpoint(self.client.handle_delete_item, ItemUri), ) def register_create_collection(self): @@ -131,7 +131,7 @@ def register_create_collection(self): response_model_exclude_none=True, methods=["POST"], endpoint=self._create_endpoint( - self.client.create_collection, stac_types.Collection + self.client.handle_create_collection, stac_types.Collection ), ) @@ -146,7 +146,7 @@ def register_update_collection(self): response_model_exclude_none=True, methods=["PUT"], endpoint=self._create_endpoint( - self.client.update_collection, stac_types.Collection + self.client.handle_update_collection, stac_types.Collection ), ) @@ -161,7 +161,7 @@ def register_delete_collection(self): response_model_exclude_none=True, methods=["DELETE"], endpoint=self._create_endpoint( - self.client.delete_collection, CollectionUri + self.client.handle_delete_collection, CollectionUri ), ) diff --git a/stac_fastapi/extensions/third_party/bulk_transactions.py b/stac_fastapi/extensions/third_party/bulk_transactions.py index 3fe25c9d1..74e79168e 100644 --- a/stac_fastapi/extensions/third_party/bulk_transactions.py +++ b/stac_fastapi/extensions/third_party/bulk_transactions.py @@ -36,7 +36,7 @@ def _chunks(lst, n): yield lst[i : i + n] @abc.abstractmethod - def bulk_item_insert( + def handle_bulk_item_insert( self, items: Items, chunk_size: Optional[int] = None, **kwargs ) -> str: """Bulk creation of items. @@ -57,7 +57,7 @@ class AsyncBaseBulkTransactionsClient(abc.ABC): """BulkTransactionsClient.""" @abc.abstractmethod - async def bulk_item_insert(self, items: Items, **kwargs) -> str: + async def handle_bulk_item_insert(self, items: Items, **kwargs) -> str: """Bulk creation of items. Args: @@ -125,7 +125,7 @@ def register(self, app: FastAPI) -> None: response_model_exclude_none=True, methods=["POST"], endpoint=self._create_endpoint( - self.client.bulk_item_insert, items_request_model + self.client.handle_bulk_item_insert, items_request_model ), ) app.include_router(router, tags=["Bulk Transaction Extension"]) diff --git a/stac_fastapi/types/core.py b/stac_fastapi/types/core.py index bce7ca2a2..fc932f9b1 100644 --- a/stac_fastapi/types/core.py +++ b/stac_fastapi/types/core.py @@ -11,12 +11,14 @@ from stac_pydantic.version import STAC_VERSION from starlette.responses import Response +from stac_fastapi.api.errors import NotFoundError from stac_fastapi.types import stac as stac_types from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.extension import ApiExtension +from stac_fastapi.types.links import CollectionLinks from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.search import BaseSearchPostRequest -from stac_fastapi.types.stac import Conformance +from stac_fastapi.types.stac import Collection, Conformance NumType = Union[float, int] StacType = Dict[str, Any] @@ -27,7 +29,7 @@ class BaseTransactionsClient(abc.ABC): """Defines a pattern for implementing the STAC API Transaction Extension.""" @abc.abstractmethod - def create_item( + def handle_create_item( self, collection_id: str, item: stac_types.Item, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Create a new item. @@ -45,7 +47,7 @@ def create_item( ... @abc.abstractmethod - def update_item( + def handle_update_item( self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Perform a complete update on an existing item. @@ -64,7 +66,7 @@ def update_item( ... @abc.abstractmethod - def delete_item( + def handle_delete_item( self, item_id: str, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Delete an item from a collection. @@ -81,7 +83,7 @@ def delete_item( ... @abc.abstractmethod - def create_collection( + def handle_create_collection( self, collection: stac_types.Collection, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Create a new collection. @@ -97,7 +99,7 @@ def create_collection( ... @abc.abstractmethod - def update_collection( + def handle_update_collection( self, collection: stac_types.Collection, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Perform a complete update on an existing collection. @@ -116,7 +118,7 @@ def update_collection( ... @abc.abstractmethod - def delete_collection( + def handle_delete_collection( self, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Delete a collection. @@ -137,7 +139,7 @@ class AsyncBaseTransactionsClient(abc.ABC): """Defines a pattern for implementing the STAC transaction extension.""" @abc.abstractmethod - async def create_item( + async def handle_create_item( self, collection_id: str, item: stac_types.Item, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Create a new item. @@ -154,7 +156,7 @@ async def create_item( ... @abc.abstractmethod - async def update_item( + async def handle_update_item( self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Perform a complete update on an existing item. @@ -172,7 +174,7 @@ async def update_item( ... @abc.abstractmethod - async def delete_item( + async def handle_delete_item( self, item_id: str, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Delete an item from a collection. @@ -189,7 +191,7 @@ async def delete_item( ... @abc.abstractmethod - async def create_collection( + async def handle_create_collection( self, collection: stac_types.Collection, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Create a new collection. @@ -205,7 +207,7 @@ async def create_collection( ... @abc.abstractmethod - async def update_collection( + async def handle_update_collection( self, collection: stac_types.Collection, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Perform a complete update on an existing collection. @@ -223,7 +225,7 @@ async def update_collection( ... @abc.abstractmethod - async def delete_collection( + async def handle_delete_collection( self, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Delete a collection. @@ -341,7 +343,18 @@ def list_conformance_classes(self): return base_conformance - def landing_page(self, **kwargs) -> stac_types.LandingPage: + @abc.abstractmethod + def list_all_collections(self) -> List[stac_types.Collection]: + """Return a list of all available collections. + + This method MUST be defined in the backend implementation. + + Returns: + List of STAC Collection-like dictionaries + """ + ... + + def handle_landing_page(self, **kwargs) -> stac_types.LandingPage: """Landing page. Called with `GET /`. @@ -361,8 +374,8 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: ) # Add Collections links - collections = self.all_collections(request=kwargs["request"]) - for collection in collections["collections"]: + collections = self.list_all_collections() + for collection in collections: landing_page["links"].append( { "rel": Relations.child.value, @@ -398,7 +411,7 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: return landing_page - def conformance(self, **kwargs) -> stac_types.Conformance: + def handle_conformance(self, **kwargs) -> stac_types.Conformance: """Conformance classes. Called with `GET /conformance`. @@ -409,7 +422,7 @@ def conformance(self, **kwargs) -> stac_types.Conformance: return Conformance(conformsTo=self.conformance_classes()) @abc.abstractmethod - def post_search( + def handle_post_search( self, search_request: BaseSearchPostRequest, **kwargs ) -> stac_types.ItemCollection: """Cross catalog search (POST). @@ -425,7 +438,7 @@ def post_search( ... @abc.abstractmethod - def get_search( + def handle_get_search( self, collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, @@ -448,7 +461,9 @@ def get_search( ... @abc.abstractmethod - def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_types.Item: + def handle_get_item( + self, item_id: str, collection_id: str, **kwargs + ) -> stac_types.Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -462,8 +477,7 @@ def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_types.Ite """ ... - @abc.abstractmethod - def all_collections(self, **kwargs) -> stac_types.Collections: + def handle_all_collections(self, **kwargs) -> stac_types.Collections: """Get all available collections. Called with `GET /collections`. @@ -471,10 +485,45 @@ def all_collections(self, **kwargs) -> stac_types.Collections: Returns: A list of collections. """ - ... + request: Request = kwargs["request"] + base_url = get_base_url(request) + collections = self.list_all_collections() + linked_collections: List[stac_types.Collection] = [] + if collections is not None and len(collections) > 0: + for c in collections: + coll = stac_types.Collection(**c) + coll["links"] = CollectionLinks( + collection_id=coll["id"], request=request + ).get_links(extra_links=coll.get("links")) + + linked_collections.append(coll) + + links = [ + { + "rel": Relations.root.value, + "type": MimeTypes.json, + "href": base_url, + }, + { + "rel": Relations.parent.value, + "type": MimeTypes.json, + "href": base_url, + }, + { + "rel": Relations.self.value, + "type": MimeTypes.json, + "href": urljoin(base_url, "collections"), + }, + ] + collection_list = stac_types.Collections( + collections=linked_collections or [], links=links + ) + return collection_list @abc.abstractmethod - def get_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: + def handle_get_collection( + self, collection_id: str, **kwargs + ) -> stac_types.Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -488,7 +537,7 @@ def get_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: ... @abc.abstractmethod - def item_collection( + def handle_collection_items( self, collection_id: str, limit: int = 10, token: str = None, **kwargs ) -> stac_types.ItemCollection: """Get all items from a specific collection. @@ -534,7 +583,33 @@ def extension_is_enabled(self, extension: str) -> bool: """Check if an api extension is enabled.""" return any([type(ext).__name__ == extension for ext in self.extensions]) - async def landing_page(self, **kwargs) -> stac_types.LandingPage: + @abc.abstractmethod + async def fetch_all_collections( + self, request: Request + ) -> List[stac_types.Collection]: + """Return a list of all available collections. + + This method MUST be defined in the backend implementation. + + Returns: + List of STAC Collection-like dictionaries + """ + ... + + @abc.abstractclassmethod + async def fetch_collection( + self, collection_id: str, request: Request + ) -> Optional[stac_types.Collection]: + """Return the STAC Collection with the given ID, or `None` if the Collection does not exist. + + This method MUST be defined in the backend implementation. + + Returns: + Dictionary representing the STAC Collection, or `None` if no Collection with the given ID is found + """ + ... + + async def handle_landing_page(self, **kwargs) -> stac_types.LandingPage: """Landing page. Called with `GET /`. @@ -552,7 +627,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: conformance_classes=self.conformance_classes(), extension_schemas=extension_schemas, ) - collections = await self.all_collections(request=kwargs["request"]) + collections = await self.handle_all_collections(request=kwargs["request"]) for collection in collections["collections"]: landing_page["links"].append( { @@ -589,7 +664,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: return landing_page - async def conformance(self, **kwargs) -> stac_types.Conformance: + async def handle_conformance(self, **kwargs) -> stac_types.Conformance: """Conformance classes. Called with `GET /conformance`. @@ -600,7 +675,7 @@ async def conformance(self, **kwargs) -> stac_types.Conformance: return Conformance(conformsTo=self.conformance_classes()) @abc.abstractmethod - async def post_search( + async def handle_post_search( self, search_request: BaseSearchPostRequest, **kwargs ) -> stac_types.ItemCollection: """Cross catalog search (POST). @@ -616,7 +691,7 @@ async def post_search( ... @abc.abstractmethod - async def get_search( + async def handle_get_search( self, collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, @@ -639,7 +714,7 @@ async def get_search( ... @abc.abstractmethod - async def get_item( + async def handle_get_item( self, item_id: str, collection_id: str, **kwargs ) -> stac_types.Item: """Get item by id. @@ -655,8 +730,7 @@ async def get_item( """ ... - @abc.abstractmethod - async def all_collections(self, **kwargs) -> stac_types.Collections: + async def handle_all_collections(self, **kwargs) -> stac_types.Collections: """Get all available collections. Called with `GET /collections`. @@ -664,10 +738,42 @@ async def all_collections(self, **kwargs) -> stac_types.Collections: Returns: A list of collections. """ - ... + request: Request = kwargs["request"] + base_url = get_base_url(request) + collections = await self.fetch_all_collections(request) + linked_collections: List[stac_types.Collection] = [] + if collections is not None and len(collections) > 0: + for c in collections: + coll = stac_types.Collection(**c) + coll["links"] = CollectionLinks( + collection_id=coll["id"], base_url=base_url + ).get_links(extra_links=coll.get("links")) + + linked_collections.append(coll) + + links = [ + { + "rel": Relations.root.value, + "type": MimeTypes.json, + "href": base_url, + }, + { + "rel": Relations.parent.value, + "type": MimeTypes.json, + "href": base_url, + }, + { + "rel": Relations.self.value, + "type": MimeTypes.json, + "href": urljoin(base_url, "collections"), + }, + ] + collection_list = stac_types.Collections( + collections=linked_collections or [], links=links + ) + return collection_list - @abc.abstractmethod - async def get_collection( + async def handle_get_collection( self, collection_id: str, **kwargs ) -> stac_types.Collection: """Get collection by id. @@ -680,10 +786,21 @@ async def get_collection( Returns: Collection. """ - ... + request: Request = kwargs["request"] + base_url = get_base_url(request) + + collection = await self.fetch_collection(collection_id, request) + if collection is None: + raise NotFoundError(f"Collection {collection_id} does not exist.") + + collection["links"] = CollectionLinks( + collection_id=collection_id, base_url=base_url + ).get_links(extra_links=collection.get("links")) + + return Collection(**collection) @abc.abstractmethod - async def item_collection( + async def handle_collection_items( self, collection_id: str, limit: int = 10, token: str = None, **kwargs ) -> stac_types.ItemCollection: """Get all items from a specific collection. diff --git a/stac_fastapi/types/links.py b/stac_fastapi/types/links.py index 0349984b1..3aa1225af 100644 --- a/stac_fastapi/types/links.py +++ b/stac_fastapi/types/links.py @@ -1,6 +1,6 @@ """link helpers.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from urllib.parse import urljoin import attr @@ -36,6 +36,10 @@ def root(self) -> Dict[str, Any]: """Return the catalog root.""" return dict(rel=Relations.root, type=MimeTypes.json, href=self.base_url) + def resolve(self, url): + """Resolve url to the current request url.""" + return urljoin(str(self.base_url), str(url)) + @attr.s class CollectionLinks(BaseLinks): @@ -65,6 +69,36 @@ def create_links(self) -> List[Dict[str, Any]]: """Return all inferred links.""" return [self.self(), self.parent(), self.items(), self.root()] + def get_links( + self, extra_links: Optional[List[Dict[str, Any]]] = None + ) -> List[Dict[str, Any]]: + """ + Generate all the links. + + Get the links object for a stac resource by iterating through + available methods on this class that start with link_. + """ + # join passed in links with generated links + # and update relative paths + links = self.create_links() + + if extra_links: + # For extra links passed in, + # add links modified with a resolved href. + # Drop any links that are dynamically + # determined by the server (e.g. self, parent, etc.) + # Resolving the href allows for relative paths + # to be stored in pgstac and for the hrefs in the + # links of response STAC objects to be resolved + # to the request url. + links += [ + {**link, "href": self.resolve(link["href"])} + for link in extra_links + if link["rel"] not in INFERRED_LINK_RELS + ] + + return links + @attr.s class ItemLinks(BaseLinks): diff --git a/tests/api/test_api.py b/tests/api/test_api.py index ab5a304d4..36e6495ba 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -1,9 +1,12 @@ -from fastapi import Depends, HTTPException, security, status +from typing import List + +from fastapi import Depends, HTTPException, Request, security, status from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi from stac_fastapi.extensions.core import TokenPaginationExtension, TransactionExtension from stac_fastapi.types import config, core +from stac_fastapi.types import stac as stac_types class TestRouteDependencies: @@ -77,44 +80,44 @@ def test_add_route_dependencies_after_building_api(self): class DummyCoreClient(core.BaseCoreClient): - def all_collections(self, *args, **kwargs): + def list_all_collections(self, request: Request) -> List[stac_types.Collection]: ... - def get_collection(self, *args, **kwargs): + def handle_get_collection(self, *args, **kwargs): ... - def get_item(self, *args, **kwargs): + def handle_get_item(self, *args, **kwargs): ... - def get_search(self, *args, **kwargs): + def handle_get_search(self, *args, **kwargs): ... - def post_search(self, *args, **kwargs): + def handle_post_search(self, *args, **kwargs): ... - def item_collection(self, *args, **kwargs): + def handle_collection_items(self, *args, **kwargs): ... class DummyTransactionsClient(core.BaseTransactionsClient): """Defines a pattern for implementing the STAC transaction extension.""" - def create_item(self, *args, **kwargs): + def handle_create_item(self, *args, **kwargs): return "dummy response" - def update_item(self, *args, **kwargs): + def handle_update_item(self, *args, **kwargs): return "dummy response" - def delete_item(self, *args, **kwargs): + def handle_delete_item(self, *args, **kwargs): return "dummy response" - def create_collection(self, *args, **kwargs): + def handle_create_collection(self, *args, **kwargs): return "dummy response" - def update_collection(self, *args, **kwargs): + def handle_update_collection(self, *args, **kwargs): return "dummy response" - def delete_collection(self, *args, **kwargs): + def handle_delete_collection(self, *args, **kwargs): return "dummy response"