From e909b04c78c3015365092ebfe0e590a86b0e92a9 Mon Sep 17 00:00:00 2001 From: Tom Christian Date: Wed, 22 Jun 2022 12:09:40 -0700 Subject: [PATCH 1/7] Respect forwarding headers in pgstac --- CHANGES.md | 1 + .../pgstac/stac_fastapi/pgstac/core.py | 4 +- .../stac_fastapi/pgstac/models/links.py | 51 ++++++++++++- stac_fastapi/pgstac/tests/api/test_api.py | 71 +++++++++++++++++++ .../pgstac/tests/resources/test_collection.py | 65 +++++++++++++++++ .../pgstac/tests/resources/test_item.py | 51 ++++++++++++- 6 files changed, 238 insertions(+), 5 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index f28d116c6..8594c6a8f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,7 @@ * Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367)) * Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383)) * Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411)) +* Respect `Forwarded` or `X-Forwarded-*` request headers when building links to better accommodate load balancers and proxies. ### Changed diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index 811afbba6..e110759d2 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -18,7 +18,7 @@ from starlette.requests import Request from stac_fastapi.pgstac.config import Settings -from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks, PagingLinks +from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks, PagingLinks, get_base_url_from_request from stac_fastapi.pgstac.types.search import PgstacSearch from stac_fastapi.pgstac.utils import filter_fields from stac_fastapi.types.core import AsyncBaseCoreClient @@ -35,7 +35,7 @@ class CoreCrudClient(AsyncBaseCoreClient): async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" request: Request = kwargs["request"] - base_url = str(request.base_url) + base_url = get_base_url_from_request(request) pool = request.app.state.readpool async with pool.acquire() as conn: diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py index 4816c0969..c0e57b51c 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py @@ -1,5 +1,7 @@ """link helpers.""" +import re +from http.client import HTTP_PORT, HTTPS_PORT from typing import Any, Dict, List, Optional from urllib.parse import ParseResult, parse_qs, unquote, urlencode, urljoin, urlparse @@ -36,6 +38,51 @@ def merge_params(url: str, newparams: Dict) -> str: return href +def get_base_url_from_request(request: Request) -> str: + """ + Account for forwarding headers when deriving base URL. + Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. + Default to what can be derived from the URL if no headers provided. + """ + domain = request.url.hostname + proto = request.url.scheme + port_str = str(request.url.port) if request.url.port is not None else None + forwarded = request.headers.get("forwarded") + if forwarded is not None: + parts = forwarded.split(";") + for part in parts: + if len(part) > 0 and re.search("=", part): + key, value = part.split("=") + if key == "proto": + proto = value + elif key == "host": + host_parts = value.split(":") + domain = host_parts[0] + port_str = host_parts[1] if len(host_parts) == 2 else None + else: + proto = request.headers.get("x-forwarded-proto", proto) + port_str = request.headers.get("x-forwarded-port", port_str) + port_suffix = "" + if port_str is not None and port_str.isdigit(): + if (proto == "http" and port_str == str(HTTP_PORT)) or ( + proto == "https" and port_str == str(HTTPS_PORT) + ): + pass + else: + port_suffix = f":{port_str}" + # ensure url ends with slash + url = re.sub( + r"([^/])$", + r"\1/", + urljoin( + f"{proto}://{domain}{port_suffix}", + # ensure root path starts with slash + re.sub(r"^([^/])", r"/\1", request.scope.get("root_path")), + ), + ) + return url + + @attr.s class BaseLinks: """Create inferred links common to collections and items.""" @@ -45,12 +92,12 @@ class BaseLinks: @property def base_url(self): """Get the base url.""" - return str(self.request.base_url) + return get_base_url_from_request(self.request) @property def url(self): """Get the current request url.""" - return str(self.request.url) + return str(self.request.url).replace(str(self.request.base_url), self.base_url) def resolve(self, url): """Resolve url to the current request url.""" diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index f4d783b11..b805a9a37 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +import pytest STAC_CORE_ROUTES = [ "GET /", @@ -281,3 +282,73 @@ async def test_search_line_string_intersects( assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 + + +@pytest.mark.asyncio +async def test_search_forwarded_header( + load_test_data, app_client, load_test_collection +): + coll = load_test_collection + item = load_test_data("test_item.json") + await app_client.post(f"/collections/{coll.id}/items", json=item) + resp = await app_client.post( + "/search", + json={ + "collections": [item["collection"]], + }, + headers={"Forwarded": "proto=https;host=test:1234"}, + ) + features = resp.json()["features"] + assert len(features) > 0 + for feature in features: + for link in feature["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_search_x_forwarded_headers( + load_test_data, app_client, load_test_collection +): + coll = load_test_collection + item = load_test_data("test_item.json") + await app_client.post(f"/collections/{coll.id}/items", json=item) + resp = await app_client.post( + "/search", + json={ + "collections": [item["collection"]], + }, + headers={ + "X-Forwarded-Proto": "https", + "X-Forwarded-Port": "1234", + }, + ) + features = resp.json()["features"] + assert len(features) > 0 + for feature in features: + for link in feature["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_search_duplicate_forward_headers( + load_test_data, app_client, load_test_collection +): + coll = load_test_collection + item = load_test_data("test_item.json") + await app_client.post(f"/collections/{coll.id}/items", json=item) + resp = await app_client.post( + "/search", + json={ + "collections": [item["collection"]], + }, + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "4321", + }, + ) + features = resp.json()["features"] + assert len(features) > 0 + for feature in features: + for link in feature["links"]: + assert link["href"].startswith("https://test:1234/") diff --git a/stac_fastapi/pgstac/tests/resources/test_collection.py b/stac_fastapi/pgstac/tests/resources/test_collection.py index 937803db4..b3a4d5509 100644 --- a/stac_fastapi/pgstac/tests/resources/test_collection.py +++ b/stac_fastapi/pgstac/tests/resources/test_collection.py @@ -1,6 +1,7 @@ from typing import Callable import pystac +import pytest from stac_pydantic import Collection @@ -164,3 +165,67 @@ async def test_returns_license_link(app_client, load_test_collection): resp_json = resp.json() link_rel_types = [link["rel"] for link in resp_json["links"]] assert "license" in link_rel_types + + +@pytest.mark.asyncio +async def test_get_collection_forwarded_header(app_client, load_test_collection): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}", + headers={"Forwarded": "proto=https;host=test:1234"}, + ) + for link in [ + link + for link in resp.json()["links"] + if link["rel"] in ["items", "parent", "root", "self"] + ]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collection_x_forwarded_headers(app_client, load_test_collection): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}", + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for link in [ + link + for link in resp.json()["links"] + if link["rel"] in ["items", "parent", "root", "self"] + ]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collection_duplicate_forwarded_headers( + app_client, load_test_collection +): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}", + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for link in [ + link + for link in resp.json()["links"] + if link["rel"] in ["items", "parent", "root", "self"] + ]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collections_forwarded_header(app_client, load_test_collection): + resp = await app_client.get( + "/collections", + headers={"Forwarded": "proto=https;host=test:1234"}, + ) + for link in resp.json()["links"]: + assert link["href"].startswith("https://test:1234/") \ No newline at end of file diff --git a/stac_fastapi/pgstac/tests/resources/test_item.py b/stac_fastapi/pgstac/tests/resources/test_item.py index 40b4b514a..4e3c968f6 100644 --- a/stac_fastapi/pgstac/tests/resources/test_item.py +++ b/stac_fastapi/pgstac/tests/resources/test_item.py @@ -1,10 +1,12 @@ import json import uuid from datetime import timedelta +from http.client import HTTP_PORT from typing import Callable from urllib.parse import parse_qs, urljoin, urlparse import pystac +import pytest from httpx import AsyncClient from pystac.utils import datetime_to_str from shapely.geometry import Polygon @@ -1170,11 +1172,12 @@ async def test_relative_link_construction(): "type": "http", "scheme": "http", "method": "PUT", - "root_path": "http://test/stac", + "root_path": "/stac", # root_path should not have proto, domain, or port "path": "/", "raw_path": b"/tab/abc", "query_string": b"", "headers": {}, + "server": ("test", HTTP_PORT), } ) links = CollectionLinks(collection_id="naip", request=req) @@ -1394,3 +1397,49 @@ async def test_item_merge_raster_bands( assert len(red_bands[0].keys()) == 6 # The merged item should have kept the item value rather than the base value assert red_bands[0]["offset"] == 2.03976 + + +@pytest.mark.asyncio +async def test_get_collection_items_forwarded_header( + app_client, load_test_collection, load_test_item +): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}/items", + headers={"Forwarded": "proto=https;host=test:1234"}, + ) + for link in resp.json()["features"][0]["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collection_items_x_forwarded_headers( + app_client, load_test_collection, load_test_item +): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}/items", + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for link in resp.json()["features"][0]["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collection_items_duplicate_forwarded_headers( + app_client, load_test_collection, load_test_item +): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}/items", + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for link in resp.json()["features"][0]["links"]: + assert link["href"].startswith("https://test:1234/") \ No newline at end of file From 95c278149375a6b2a99640efe5a6cd25965fe3ff Mon Sep 17 00:00:00 2001 From: Tom Christian Date: Wed, 22 Jun 2022 12:26:48 -0700 Subject: [PATCH 2/7] Respect forwarding headers in sqlalchemy --- .../stac_fastapi/sqlalchemy/core.py | 19 ++++--- .../stac_fastapi/sqlalchemy/links.py | 57 +++++++++++++++++++ .../stac_fastapi/sqlalchemy/transactions.py | 13 +++-- stac_fastapi/sqlalchemy/tests/api/test_api.py | 55 ++++++++++++++++++ .../tests/resources/test_collection.py | 43 ++++++++++++++ .../sqlalchemy/tests/resources/test_item.py | 40 +++++++++++++ 6 files changed, 212 insertions(+), 15 deletions(-) create mode 100644 stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py index cd1ca9eea..92ab35474 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py @@ -22,6 +22,7 @@ from stac_fastapi.sqlalchemy import serializers from stac_fastapi.sqlalchemy.extensions.query import Operator +from stac_fastapi.sqlalchemy.links import get_base_url_from_request from stac_fastapi.sqlalchemy.models import database from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.sqlalchemy.tokens import PaginationTokenClient @@ -62,7 +63,7 @@ def _lookup_id( def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.reader.context_session() as session: collections = session.query(self.collection_table).all() serialized_collections = [ @@ -93,7 +94,7 @@ def all_collections(self, **kwargs) -> Collections: def get_collection(self, collection_id: str, **kwargs) -> Collection: """Get collection by id.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.reader.context_session() as session: collection = self._lookup_id(collection_id, self.collection_table, session) return self.collection_serializer.db_to_stac(collection, base_url) @@ -102,7 +103,7 @@ def item_collection( self, collection_id: str, limit: int = 10, token: str = None, **kwargs ) -> ItemCollection: """Read an item collection from the database.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.reader.context_session() as session: collection_children = ( session.query(self.item_table) @@ -136,7 +137,7 @@ def item_collection( { "rel": Relations.next.value, "type": "application/geo+json", - "href": f"{kwargs['request'].base_url}collections/{collection_id}/items?token={page.next}&limit={limit}", + "href": f"{base_url}collections/{collection_id}/items?token={page.next}&limit={limit}", "method": "GET", } ) @@ -145,7 +146,7 @@ def item_collection( { "rel": Relations.previous.value, "type": "application/geo+json", - "href": f"{kwargs['request'].base_url}collections/{collection_id}/items?token={page.previous}&limit={limit}", + "href": f"{base_url}collections/{collection_id}/items?token={page.previous}&limit={limit}", "method": "GET", } ) @@ -173,7 +174,7 @@ def item_collection( def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: """Get item by id.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.reader.context_session() as session: db_query = session.query(self.item_table) db_query = db_query.filter(self.item_table.collection_id == collection_id) @@ -262,7 +263,7 @@ def post_search( self, search_request: BaseSearchPostRequest, **kwargs ) -> ItemCollection: """POST search catalog.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.reader.context_session() as session: token = ( self.get_token(search_request.token) if search_request.token else False @@ -395,7 +396,7 @@ def post_search( { "rel": Relations.next.value, "type": "application/geo+json", - "href": f"{kwargs['request'].base_url}search", + "href": f"{base_url}search", "method": "POST", "body": {"token": page.next}, "merge": True, @@ -406,7 +407,7 @@ def post_search( { "rel": Relations.previous.value, "type": "application/geo+json", - "href": f"{kwargs['request'].base_url}search", + "href": f"{base_url}search", "method": "POST", "body": {"token": page.previous}, "merge": True, diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py new file mode 100644 index 000000000..862530776 --- /dev/null +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py @@ -0,0 +1,57 @@ +"""Functionality to assist link construction.""" + +import re +from http.client import HTTP_PORT, HTTPS_PORT +from urllib.parse import urljoin + +from starlette.requests import Request + + +def get_base_url_from_request(request: Request) -> str: + """ + Account for forwarding headers when deriving base URL. + + Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. + Default to what can be derived from the URL if no headers provided. + """ + if not isinstance(request, Request): + # Lots of tests execute setup logic with MockStarletteRequest + # and this type only has a single property: base_url. + return request.base_url + domain = request.url.hostname + proto = request.url.scheme + port_str = str(request.url.port) if request.url.port is not None else None + forwarded = request.headers.get("forwarded") + if forwarded is not None: + parts = forwarded.split(";") + for part in parts: + if len(part) > 0 and re.search("=", part): + key, value = part.split("=") + if key == "proto": + proto = value + elif key == "host": + host_parts = value.split(":") + domain = host_parts[0] + port_str = host_parts[1] if len(host_parts) == 2 else None + else: + proto = request.headers.get("x-forwarded-proto", proto) + port_str = request.headers.get("x-forwarded-port", port_str) + port_suffix = "" + if port_str is not None and port_str.isdigit(): + if (proto == "http" and port_str == str(HTTP_PORT)) or ( + proto == "https" and port_str == str(HTTPS_PORT) + ): + pass + else: + port_suffix = f":{port_str}" + # ensure url ends with slash + url = re.sub( + r"([^/])$", + r"\1/", + urljoin( + f"{proto}://{domain}{port_suffix}", + # ensure root path starts with slash + re.sub(r"^([^/])", r"/\1", request.scope.get("root_path")), + ), + ) + return url \ No newline at end of file diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py index 1ae1d6f2e..a9d06615c 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py @@ -11,6 +11,7 @@ Items, ) from stac_fastapi.sqlalchemy import serializers +from stac_fastapi.sqlalchemy.links import get_base_url_from_request from stac_fastapi.sqlalchemy.models import database from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.types import stac as stac_types @@ -38,7 +39,7 @@ def create_item( self, model: Union[stac_types.Item, stac_types.ItemCollection], **kwargs ) -> Optional[stac_types.Item]: """Create item.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) # If a feature collection is posted if model["type"] == "FeatureCollection": @@ -56,7 +57,7 @@ def create_collection( self, collection: stac_types.Collection, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Create collection.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) data = self.collection_serializer.stac_to_db(collection) with self.session.writer.context_session() as session: session.add(data) @@ -66,7 +67,7 @@ def update_item( self, item: stac_types.Item, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Update item.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.reader.context_session() as session: query = session.query(self.item_table).filter( self.item_table.id == item["id"] @@ -87,7 +88,7 @@ def update_collection( self, collection: stac_types.Collection, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Update collection.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.reader.context_session() as session: query = session.query(self.collection_table).filter( self.collection_table.id == collection["id"] @@ -105,7 +106,7 @@ def delete_item( self, item_id: str, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Delete item.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.writer.context_session() as session: query = session.query(self.item_table).filter( self.item_table.collection_id == collection_id @@ -123,7 +124,7 @@ def delete_collection( self, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Delete collection.""" - base_url = str(kwargs["request"].base_url) + base_url = get_base_url_from_request(kwargs["request"]) with self.session.writer.context_session() as session: query = session.query(self.collection_table).filter( self.collection_table.id == collection_id diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 0abd7cb00..6d36b001c 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -304,3 +304,58 @@ def test_app_fields_extension_return_all_properties( assert feature["properties"][expected_prop][0:19] == expected_value[0:19] else: assert feature["properties"][expected_prop] == expected_value + + +def test_app_search_response_forwarded_header( + load_test_data, app_client, postgres_transactions +): + item = load_test_data("test_item.json") + postgres_transactions.create_item(item, request=MockStarletteRequest) + + resp = app_client.get( + "/search", + params={"collections": ["test-collection"]}, + headers={"Forwarded": "proto=https;host=testserver:1234"}, + ) + for feature in resp.json()["features"]: + for link in feature["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_app_search_response_x_forwarded_headers( + load_test_data, app_client, postgres_transactions +): + item = load_test_data("test_item.json") + postgres_transactions.create_item(item, request=MockStarletteRequest) + + resp = app_client.get( + "/search", + params={"collections": ["test-collection"]}, + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for feature in resp.json()["features"]: + for link in feature["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_app_search_response_duplicate_forwarded_headers( + load_test_data, app_client, postgres_transactions +): + item = load_test_data("test_item.json") + postgres_transactions.create_item(item, request=MockStarletteRequest) + + resp = app_client.get( + "/search", + params={"collections": ["test-collection"]}, + headers={ + "Forwarded": "proto=https;host=testserver:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for feature in resp.json()["features"]: + for link in feature["links"]: + assert link["href"].startswith("https://testserver:1234/") \ No newline at end of file diff --git a/stac_fastapi/sqlalchemy/tests/resources/test_collection.py b/stac_fastapi/sqlalchemy/tests/resources/test_collection.py index b0d8b3d66..c3bba18ac 100644 --- a/stac_fastapi/sqlalchemy/tests/resources/test_collection.py +++ b/stac_fastapi/sqlalchemy/tests/resources/test_collection.py @@ -73,3 +73,46 @@ def test_returns_valid_collection(app_client, load_test_data): resp_json, root=mock_root, preserve_dict=False ) collection.validate() + + +def test_get_collection_forwarded_header(app_client, load_test_data): + test_collection = load_test_data("test_collection.json") + app_client.put("/collections", json=test_collection) + + resp = app_client.get( + f"/collections/{test_collection['id']}", + headers={"Forwarded": "proto=https;host=testserver:1234"}, + ) + for link in resp.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_get_collection_x_forwarded_headers(app_client, load_test_data): + test_collection = load_test_data("test_collection.json") + app_client.put("/collections", json=test_collection) + + resp = app_client.get( + f"/collections/{test_collection['id']}", + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for link in resp.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_get_collection_duplicate_forwarded_headers(app_client, load_test_data): + test_collection = load_test_data("test_collection.json") + app_client.put("/collections", json=test_collection) + + resp = app_client.get( + f"/collections/{test_collection['id']}", + headers={ + "Forwarded": "proto=https;host=testserver:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for link in resp.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") \ No newline at end of file diff --git a/stac_fastapi/sqlalchemy/tests/resources/test_item.py b/stac_fastapi/sqlalchemy/tests/resources/test_item.py index 2f671de68..9d9f73656 100644 --- a/stac_fastapi/sqlalchemy/tests/resources/test_item.py +++ b/stac_fastapi/sqlalchemy/tests/resources/test_item.py @@ -942,3 +942,43 @@ def test_search_datetime_validation_errors(app_client): resp = app_client.get("/search?datetime={}".format(dt)) assert resp.status_code == 400 + + +def test_get_item_forwarded_header(app_client, load_test_data): + test_item = load_test_data("test_item.json") + app_client.post(f"/collections/{test_item['collection']}/items", json=test_item) + get_item = app_client.get( + f"/collections/{test_item['collection']}/items/{test_item['id']}", + headers={"Forwarded": "proto=https;host=testserver:1234"}, + ) + for link in get_item.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_get_item_x_forwarded_headers(app_client, load_test_data): + test_item = load_test_data("test_item.json") + app_client.post(f"/collections/{test_item['collection']}/items", json=test_item) + get_item = app_client.get( + f"/collections/{test_item['collection']}/items/{test_item['id']}", + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for link in get_item.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_get_item_duplicate_forwarded_headers(app_client, load_test_data): + test_item = load_test_data("test_item.json") + app_client.post(f"/collections/{test_item['collection']}/items", json=test_item) + get_item = app_client.get( + f"/collections/{test_item['collection']}/items/{test_item['id']}", + headers={ + "Forwarded": "proto=https;host=testserver:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for link in get_item.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") \ No newline at end of file From d77370caf7ba03d4a31d85a4be2c5be51e253a28 Mon Sep 17 00:00:00 2001 From: Tom Christian Date: Tue, 5 Jul 2022 09:44:34 -0700 Subject: [PATCH 3/7] Linting fixes --- stac_fastapi/pgstac/stac_fastapi/pgstac/core.py | 7 ++++++- stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py | 1 + stac_fastapi/pgstac/tests/api/test_api.py | 1 + stac_fastapi/pgstac/tests/resources/test_collection.py | 2 +- stac_fastapi/pgstac/tests/resources/test_item.py | 4 ++-- stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py | 2 +- stac_fastapi/sqlalchemy/tests/api/test_api.py | 2 +- stac_fastapi/sqlalchemy/tests/resources/test_collection.py | 2 +- stac_fastapi/sqlalchemy/tests/resources/test_item.py | 2 +- 9 files changed, 15 insertions(+), 8 deletions(-) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index e110759d2..bca8e3598 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -18,7 +18,12 @@ from starlette.requests import Request from stac_fastapi.pgstac.config import Settings -from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks, PagingLinks, get_base_url_from_request +from stac_fastapi.pgstac.models.links import ( + CollectionLinks, + ItemLinks, + PagingLinks, + get_base_url_from_request, +) from stac_fastapi.pgstac.types.search import PgstacSearch from stac_fastapi.pgstac.utils import filter_fields from stac_fastapi.types.core import AsyncBaseCoreClient diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py index c0e57b51c..a7ab697d3 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py @@ -41,6 +41,7 @@ def merge_params(url: str, newparams: Dict) -> str: def get_base_url_from_request(request: Request) -> str: """ Account for forwarding headers when deriving base URL. + Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. Default to what can be derived from the URL if no headers provided. """ diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index b805a9a37..ed19a5ee2 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta + import pytest STAC_CORE_ROUTES = [ diff --git a/stac_fastapi/pgstac/tests/resources/test_collection.py b/stac_fastapi/pgstac/tests/resources/test_collection.py index b3a4d5509..bbb8c124a 100644 --- a/stac_fastapi/pgstac/tests/resources/test_collection.py +++ b/stac_fastapi/pgstac/tests/resources/test_collection.py @@ -228,4 +228,4 @@ async def test_get_collections_forwarded_header(app_client, load_test_collection headers={"Forwarded": "proto=https;host=test:1234"}, ) for link in resp.json()["links"]: - assert link["href"].startswith("https://test:1234/") \ No newline at end of file + assert link["href"].startswith("https://test:1234/") diff --git a/stac_fastapi/pgstac/tests/resources/test_item.py b/stac_fastapi/pgstac/tests/resources/test_item.py index 4e3c968f6..a56fd16dd 100644 --- a/stac_fastapi/pgstac/tests/resources/test_item.py +++ b/stac_fastapi/pgstac/tests/resources/test_item.py @@ -1172,7 +1172,7 @@ async def test_relative_link_construction(): "type": "http", "scheme": "http", "method": "PUT", - "root_path": "/stac", # root_path should not have proto, domain, or port + "root_path": "/stac", # root_path should not have proto, domain, or port "path": "/", "raw_path": b"/tab/abc", "query_string": b"", @@ -1442,4 +1442,4 @@ async def test_get_collection_items_duplicate_forwarded_headers( }, ) for link in resp.json()["features"][0]["links"]: - assert link["href"].startswith("https://test:1234/") \ No newline at end of file + assert link["href"].startswith("https://test:1234/") diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py index 862530776..dbd11e0e1 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py @@ -54,4 +54,4 @@ def get_base_url_from_request(request: Request) -> str: re.sub(r"^([^/])", r"/\1", request.scope.get("root_path")), ), ) - return url \ No newline at end of file + return url diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 6d36b001c..a9a78df66 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -358,4 +358,4 @@ def test_app_search_response_duplicate_forwarded_headers( ) for feature in resp.json()["features"]: for link in feature["links"]: - assert link["href"].startswith("https://testserver:1234/") \ No newline at end of file + assert link["href"].startswith("https://testserver:1234/") diff --git a/stac_fastapi/sqlalchemy/tests/resources/test_collection.py b/stac_fastapi/sqlalchemy/tests/resources/test_collection.py index c3bba18ac..275b2684f 100644 --- a/stac_fastapi/sqlalchemy/tests/resources/test_collection.py +++ b/stac_fastapi/sqlalchemy/tests/resources/test_collection.py @@ -115,4 +115,4 @@ def test_get_collection_duplicate_forwarded_headers(app_client, load_test_data): }, ) for link in resp.json()["links"]: - assert link["href"].startswith("https://testserver:1234/") \ No newline at end of file + assert link["href"].startswith("https://testserver:1234/") diff --git a/stac_fastapi/sqlalchemy/tests/resources/test_item.py b/stac_fastapi/sqlalchemy/tests/resources/test_item.py index 9d9f73656..d7618c2c1 100644 --- a/stac_fastapi/sqlalchemy/tests/resources/test_item.py +++ b/stac_fastapi/sqlalchemy/tests/resources/test_item.py @@ -981,4 +981,4 @@ def test_get_item_duplicate_forwarded_headers(app_client, load_test_data): }, ) for link in get_item.json()["links"]: - assert link["href"].startswith("https://testserver:1234/") \ No newline at end of file + assert link["href"].startswith("https://testserver:1234/") From 47bfc857b073a5c1e749219c47025572e3fcaf46 Mon Sep 17 00:00:00 2001 From: Tom Christian Date: Tue, 5 Jul 2022 16:52:50 -0700 Subject: [PATCH 4/7] Moved to proxy header middleware --- .gitignore | 5 +- stac_fastapi/api/stac_fastapi/api/app.py | 5 +- .../api/stac_fastapi/api/middleware.py | 96 ++++++++++++++++++- .../pgstac/stac_fastapi/pgstac/core.py | 9 +- .../stac_fastapi/pgstac/models/links.py | 52 +--------- stac_fastapi/pgstac/tests/api/test_api.py | 21 ++++ .../stac_fastapi/sqlalchemy/core.py | 19 ++-- .../stac_fastapi/sqlalchemy/links.py | 57 ----------- .../stac_fastapi/sqlalchemy/transactions.py | 13 ++- stac_fastapi/sqlalchemy/tests/api/test_api.py | 16 ++++ 10 files changed, 159 insertions(+), 134 deletions(-) delete mode 100644 stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py diff --git a/.gitignore b/.gitignore index 407d0fe9a..1cc0fbf51 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,7 @@ docs/api/* .envrc # Virtualenv -venv \ No newline at end of file +venv + +# IDE +.vscode \ No newline at end of file diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index a9f8a5542..9761deaa6 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -14,6 +14,7 @@ from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers +from stac_fastapi.api.middleware import ProxyHeaderMiddleware from stac_fastapi.api.models import ( APIRequest, CollectionUri, @@ -91,7 +92,9 @@ class StacApi: ) pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) - middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware])) + middlewares: List = attr.ib( + default=attr.Factory(lambda: [BrotliMiddleware, ProxyHeaderMiddleware]) + ) route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[]) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index acb00915b..48237aad6 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,11 +1,14 @@ """api middleware.""" -from typing import Callable +import re +from http.client import HTTP_PORT, HTTPS_PORT +from typing import Callable, List, Tuple from fastapi import APIRouter, FastAPI from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.routing import Match +from starlette.types import ASGIApp, Receive, Scope, Send def router_middleware(app: FastAPI, router: APIRouter): @@ -29,3 +32,94 @@ async def _middleware(request: Request, call_next): return func return deco + + +class ProxyHeaderMiddleware: + """ + Account for forwarding headers when deriving base URL. + + Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. + Default to what can be derived from the URL if no headers provided. + Middleware updates the host header that is interpreted by starlette when deriving Request.base_url. + """ + + def __init__(self, app: ASGIApp): + """Create proxy header middleware.""" + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Call from stac-fastapi framework.""" + if scope["type"] == "http": + proto, domain, port = self._get_forwarded_url_parts(scope) + scope["scheme"] = proto + if domain is not None: + port_suffix = "" + if port is not None: + if (proto == "http" and port != HTTP_PORT) or ( + proto == "https" and port != HTTPS_PORT + ): + port_suffix = f":{port}" + scope["headers"] = self._replace_header_value_by_name( + scope, + "host", + f"{domain}{port_suffix}", + ) + await self.app(scope, receive, send) + + def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]: + proto = scope.get("scheme", "http") + header_host = self._get_header_value_by_name(scope, "host") + if header_host is None: + domain, port = scope.get("server") + else: + header_host_parts = header_host.split(":") + domain, port = ( + header_host_parts + if len(header_host_parts) == 2 + else header_host_parts[0], + None, + ) + forwarded = self._get_header_value_by_name(scope, "forwarded") + if forwarded is not None: + parts = forwarded.split(";") + for part in parts: + if len(part) > 0 and re.search("=", part): + key, value = part.split("=") + if key == "proto": + proto = value + elif key == "host": + host_parts = value.split(":") + domain = host_parts[0] + try: + port = int(host_parts[1]) if len(host_parts) == 2 else None + except ValueError: + # ignore ports that are not valid integers + pass + else: + proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto) + port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port) + try: + port = int(port_str) if port_str is not None else None + except ValueError: + # ignore ports that are not valid integers + pass + + return (proto, domain, port) + + def _get_header_value_by_name( + self, scope: Scope, header_name: str, default_value: str = None + ) -> str: + headers = scope["headers"] + candidates = [ + value.decode() for key, value in headers if key.decode() == header_name + ] + return candidates[0] if len(candidates) == 1 else default_value + + def _replace_header_value_by_name( + self, scope: Scope, header_name: str, new_value: str + ) -> List[Tuple[str]]: + return [ + (name, value) + for name, value in scope["headers"] + if name.decode() != header_name + ] + [(str.encode(header_name), str.encode(new_value))] diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index bca8e3598..811afbba6 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -18,12 +18,7 @@ from starlette.requests import Request from stac_fastapi.pgstac.config import Settings -from stac_fastapi.pgstac.models.links import ( - CollectionLinks, - ItemLinks, - PagingLinks, - get_base_url_from_request, -) +from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks, PagingLinks from stac_fastapi.pgstac.types.search import PgstacSearch from stac_fastapi.pgstac.utils import filter_fields from stac_fastapi.types.core import AsyncBaseCoreClient @@ -40,7 +35,7 @@ class CoreCrudClient(AsyncBaseCoreClient): async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" request: Request = kwargs["request"] - base_url = get_base_url_from_request(request) + base_url = str(request.base_url) pool = request.app.state.readpool async with pool.acquire() as conn: diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py index a7ab697d3..4816c0969 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py @@ -1,7 +1,5 @@ """link helpers.""" -import re -from http.client import HTTP_PORT, HTTPS_PORT from typing import Any, Dict, List, Optional from urllib.parse import ParseResult, parse_qs, unquote, urlencode, urljoin, urlparse @@ -38,52 +36,6 @@ def merge_params(url: str, newparams: Dict) -> str: return href -def get_base_url_from_request(request: Request) -> str: - """ - Account for forwarding headers when deriving base URL. - - Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. - Default to what can be derived from the URL if no headers provided. - """ - domain = request.url.hostname - proto = request.url.scheme - port_str = str(request.url.port) if request.url.port is not None else None - forwarded = request.headers.get("forwarded") - if forwarded is not None: - parts = forwarded.split(";") - for part in parts: - if len(part) > 0 and re.search("=", part): - key, value = part.split("=") - if key == "proto": - proto = value - elif key == "host": - host_parts = value.split(":") - domain = host_parts[0] - port_str = host_parts[1] if len(host_parts) == 2 else None - else: - proto = request.headers.get("x-forwarded-proto", proto) - port_str = request.headers.get("x-forwarded-port", port_str) - port_suffix = "" - if port_str is not None and port_str.isdigit(): - if (proto == "http" and port_str == str(HTTP_PORT)) or ( - proto == "https" and port_str == str(HTTPS_PORT) - ): - pass - else: - port_suffix = f":{port_str}" - # ensure url ends with slash - url = re.sub( - r"([^/])$", - r"\1/", - urljoin( - f"{proto}://{domain}{port_suffix}", - # ensure root path starts with slash - re.sub(r"^([^/])", r"/\1", request.scope.get("root_path")), - ), - ) - return url - - @attr.s class BaseLinks: """Create inferred links common to collections and items.""" @@ -93,12 +45,12 @@ class BaseLinks: @property def base_url(self): """Get the base url.""" - return get_base_url_from_request(self.request) + return str(self.request.base_url) @property def url(self): """Get the current request url.""" - return str(self.request.url).replace(str(self.request.base_url), self.base_url) + return str(self.request.url) def resolve(self, url): """Resolve url to the current request url.""" diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index ed19a5ee2..f2f58e8ad 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -285,6 +285,27 @@ async def test_search_line_string_intersects( assert len(resp_json["features"]) == 1 +@pytest.mark.asyncio +async def test_landing_forwarded_header( + load_test_data, app_client, load_test_collection +): + coll = load_test_collection + item = load_test_data("test_item.json") + await app_client.post(f"/collections/{coll.id}/items", json=item) + response = ( + await app_client.get( + "/", + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "4321", + }, + ) + ).json() + for link in response["links"]: + assert link["href"].startswith("https://test:1234/") + + @pytest.mark.asyncio async def test_search_forwarded_header( load_test_data, app_client, load_test_collection diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py index 92ab35474..cd1ca9eea 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py @@ -22,7 +22,6 @@ from stac_fastapi.sqlalchemy import serializers from stac_fastapi.sqlalchemy.extensions.query import Operator -from stac_fastapi.sqlalchemy.links import get_base_url_from_request from stac_fastapi.sqlalchemy.models import database from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.sqlalchemy.tokens import PaginationTokenClient @@ -63,7 +62,7 @@ def _lookup_id( def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: collections = session.query(self.collection_table).all() serialized_collections = [ @@ -94,7 +93,7 @@ def all_collections(self, **kwargs) -> Collections: def get_collection(self, collection_id: str, **kwargs) -> Collection: """Get collection by id.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: collection = self._lookup_id(collection_id, self.collection_table, session) return self.collection_serializer.db_to_stac(collection, base_url) @@ -103,7 +102,7 @@ def item_collection( self, collection_id: str, limit: int = 10, token: str = None, **kwargs ) -> ItemCollection: """Read an item collection from the database.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: collection_children = ( session.query(self.item_table) @@ -137,7 +136,7 @@ def item_collection( { "rel": Relations.next.value, "type": "application/geo+json", - "href": f"{base_url}collections/{collection_id}/items?token={page.next}&limit={limit}", + "href": f"{kwargs['request'].base_url}collections/{collection_id}/items?token={page.next}&limit={limit}", "method": "GET", } ) @@ -146,7 +145,7 @@ def item_collection( { "rel": Relations.previous.value, "type": "application/geo+json", - "href": f"{base_url}collections/{collection_id}/items?token={page.previous}&limit={limit}", + "href": f"{kwargs['request'].base_url}collections/{collection_id}/items?token={page.previous}&limit={limit}", "method": "GET", } ) @@ -174,7 +173,7 @@ def item_collection( def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: """Get item by id.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: db_query = session.query(self.item_table) db_query = db_query.filter(self.item_table.collection_id == collection_id) @@ -263,7 +262,7 @@ def post_search( self, search_request: BaseSearchPostRequest, **kwargs ) -> ItemCollection: """POST search catalog.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: token = ( self.get_token(search_request.token) if search_request.token else False @@ -396,7 +395,7 @@ def post_search( { "rel": Relations.next.value, "type": "application/geo+json", - "href": f"{base_url}search", + "href": f"{kwargs['request'].base_url}search", "method": "POST", "body": {"token": page.next}, "merge": True, @@ -407,7 +406,7 @@ def post_search( { "rel": Relations.previous.value, "type": "application/geo+json", - "href": f"{base_url}search", + "href": f"{kwargs['request'].base_url}search", "method": "POST", "body": {"token": page.previous}, "merge": True, diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py deleted file mode 100644 index dbd11e0e1..000000000 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/links.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Functionality to assist link construction.""" - -import re -from http.client import HTTP_PORT, HTTPS_PORT -from urllib.parse import urljoin - -from starlette.requests import Request - - -def get_base_url_from_request(request: Request) -> str: - """ - Account for forwarding headers when deriving base URL. - - Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. - Default to what can be derived from the URL if no headers provided. - """ - if not isinstance(request, Request): - # Lots of tests execute setup logic with MockStarletteRequest - # and this type only has a single property: base_url. - return request.base_url - domain = request.url.hostname - proto = request.url.scheme - port_str = str(request.url.port) if request.url.port is not None else None - forwarded = request.headers.get("forwarded") - if forwarded is not None: - parts = forwarded.split(";") - for part in parts: - if len(part) > 0 and re.search("=", part): - key, value = part.split("=") - if key == "proto": - proto = value - elif key == "host": - host_parts = value.split(":") - domain = host_parts[0] - port_str = host_parts[1] if len(host_parts) == 2 else None - else: - proto = request.headers.get("x-forwarded-proto", proto) - port_str = request.headers.get("x-forwarded-port", port_str) - port_suffix = "" - if port_str is not None and port_str.isdigit(): - if (proto == "http" and port_str == str(HTTP_PORT)) or ( - proto == "https" and port_str == str(HTTPS_PORT) - ): - pass - else: - port_suffix = f":{port_str}" - # ensure url ends with slash - url = re.sub( - r"([^/])$", - r"\1/", - urljoin( - f"{proto}://{domain}{port_suffix}", - # ensure root path starts with slash - re.sub(r"^([^/])", r"/\1", request.scope.get("root_path")), - ), - ) - return url diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py index a9d06615c..1ae1d6f2e 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/transactions.py @@ -11,7 +11,6 @@ Items, ) from stac_fastapi.sqlalchemy import serializers -from stac_fastapi.sqlalchemy.links import get_base_url_from_request from stac_fastapi.sqlalchemy.models import database from stac_fastapi.sqlalchemy.session import Session from stac_fastapi.types import stac as stac_types @@ -39,7 +38,7 @@ def create_item( self, model: Union[stac_types.Item, stac_types.ItemCollection], **kwargs ) -> Optional[stac_types.Item]: """Create item.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) # If a feature collection is posted if model["type"] == "FeatureCollection": @@ -57,7 +56,7 @@ def create_collection( self, collection: stac_types.Collection, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Create collection.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) data = self.collection_serializer.stac_to_db(collection) with self.session.writer.context_session() as session: session.add(data) @@ -67,7 +66,7 @@ def update_item( self, item: stac_types.Item, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Update item.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: query = session.query(self.item_table).filter( self.item_table.id == item["id"] @@ -88,7 +87,7 @@ def update_collection( self, collection: stac_types.Collection, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Update collection.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.reader.context_session() as session: query = session.query(self.collection_table).filter( self.collection_table.id == collection["id"] @@ -106,7 +105,7 @@ def delete_item( self, item_id: str, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Delete item.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.writer.context_session() as session: query = session.query(self.item_table).filter( self.item_table.collection_id == collection_id @@ -124,7 +123,7 @@ def delete_collection( self, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Delete collection.""" - base_url = get_base_url_from_request(kwargs["request"]) + base_url = str(kwargs["request"].base_url) with self.session.writer.context_session() as session: query = session.query(self.collection_table).filter( self.collection_table.id == collection_id diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index a9a78df66..79029b2ab 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -306,6 +306,22 @@ def test_app_fields_extension_return_all_properties( assert feature["properties"][expected_prop] == expected_value +def test_landing_forwarded_header(load_test_data, app_client, postgres_transactions): + item = load_test_data("test_item.json") + postgres_transactions.create_item(item, request=MockStarletteRequest) + + response = app_client.get( + "/", + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "4321", + }, + ).json() + for link in response["links"]: + assert link["href"].startswith("https://test:1234/") + + def test_app_search_response_forwarded_header( load_test_data, app_client, postgres_transactions ): From 0e9c7ed92f184716c18c6da0b2d84c4611bddbd3 Mon Sep 17 00:00:00 2001 From: geospatial-jeff Date: Mon, 1 Aug 2022 10:55:30 -0600 Subject: [PATCH 5/7] add makefile command to run api test cases --- Makefile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Makefile b/Makefile index fe2b6fe32..36187c2e8 100644 --- a/Makefile +++ b/Makefile @@ -46,6 +46,10 @@ test-sqlalchemy: run-joplin-sqlalchemy test-pgstac: $(run_pgstac) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/pgstac/tests/ && pytest -vvv' +.PHONY: test-api +test-api: + $(run_sqlalchemy) /bin/bash -c 'cd /app/stac_fastapi/api && pytest -svvv' + .PHONY: run-database run-database: docker-compose run --rm database From 9dbf35533c6f38e7b7e2abfc0532729dc013b0a2 Mon Sep 17 00:00:00 2001 From: geospatial-jeff Date: Mon, 1 Aug 2022 12:07:30 -0600 Subject: [PATCH 6/7] bugfix, make method static --- stac_fastapi/api/stac_fastapi/api/middleware.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index 48237aad6..793b75794 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -67,18 +67,18 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]: + print(scope) proto = scope.get("scheme", "http") header_host = self._get_header_value_by_name(scope, "host") if header_host is None: domain, port = scope.get("server") else: header_host_parts = header_host.split(":") - domain, port = ( - header_host_parts - if len(header_host_parts) == 2 - else header_host_parts[0], - None, - ) + if len(header_host_parts) == 2: + domain, port = header_host_parts + else: + domain = header_host_parts[0] + port = None forwarded = self._get_header_value_by_name(scope, "forwarded") if forwarded is not None: parts = forwarded.split(";") @@ -115,8 +115,9 @@ def _get_header_value_by_name( ] return candidates[0] if len(candidates) == 1 else default_value + @staticmethod def _replace_header_value_by_name( - self, scope: Scope, header_name: str, new_value: str + scope: Scope, header_name: str, new_value: str ) -> List[Tuple[str]]: return [ (name, value) From 998a02e8ecbc04c3b9c90ef58490bf611e74e06f Mon Sep 17 00:00:00 2001 From: geospatial-jeff Date: Mon, 1 Aug 2022 12:09:48 -0600 Subject: [PATCH 7/7] add unittests for proxy header middleware --- stac_fastapi/api/tests/test_middleware.py | 140 ++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 stac_fastapi/api/tests/test_middleware.py diff --git a/stac_fastapi/api/tests/test_middleware.py b/stac_fastapi/api/tests/test_middleware.py new file mode 100644 index 000000000..cfe299328 --- /dev/null +++ b/stac_fastapi/api/tests/test_middleware.py @@ -0,0 +1,140 @@ +import pytest +from starlette.applications import Starlette + +from stac_fastapi.api.middleware import ProxyHeaderMiddleware + + +@pytest.fixture +def proxy_header_middleware() -> ProxyHeaderMiddleware: + app = Starlette() + return ProxyHeaderMiddleware(app) + + +@pytest.mark.parametrize( + "headers,key,expected", + [ + ([(b"host", b"testserver")], "host", "testserver"), + ([(b"host", b"testserver")], "user-agent", None), + ( + [(b"host", b"testserver"), (b"accept-encoding", b"gzip, deflate, br")], + "accept-encoding", + "gzip, deflate, br", + ), + ], +) +def test_get_header_value_by_name( + proxy_header_middleware: ProxyHeaderMiddleware, headers, key, expected +): + scope = {"headers": headers} + actual = proxy_header_middleware._get_header_value_by_name(scope, key) + assert actual == expected + + +@pytest.mark.parametrize( + "headers,key,value", + [ + ([(b"host", b"testserver")], "host", "another-server"), + ([(b"host", b"testserver")], "user-agent", "agent"), + ( + [(b"host", b"testserver"), (b"accept-encoding", b"gzip, deflate, br")], + "accept-encoding", + "deflate", + ), + ], +) +def test_replace_header_value_by_name( + proxy_header_middleware: ProxyHeaderMiddleware, headers, key, value +): + scope = {"headers": headers} + updated_headers = proxy_header_middleware._replace_header_value_by_name( + scope, key, value + ) + + header_value = proxy_header_middleware._get_header_value_by_name( + {"headers": updated_headers}, key + ) + assert header_value == value + + +@pytest.mark.parametrize( + "scope,expected", + [ + ( + {"scheme": "https", "server": ["testserver", 80], "headers": []}, + ("https", "testserver", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"host", b"testserver:81")], + }, + ("http", "testserver", 81), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"host", b"testserver")], + }, + ("http", "testserver", None), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"forwarded", b"proto=https;host=test:1234")], + }, + ("https", "test", 1234), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"forwarded", b"proto=https;host=test:not-an-integer")], + }, + ("https", "test", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"x-forwarded-proto", b"https")], + }, + ("https", "testserver", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"x-forwarded-port", b"1111")], + }, + ("http", "testserver", 1111), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"x-forwarded-port", b"not-an-integer")], + }, + ("http", "testserver", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [ + (b"forwarded", b"proto=https;host=test:1234"), + (b"x-forwarded-port", b"1111"), + (b"x-forwarded-proto", b"https"), + ], + }, + ("https", "test", 1234), + ), + ], +) +def test_get_forwarded_url_parts( + proxy_header_middleware: ProxyHeaderMiddleware, scope, expected +): + actual = proxy_header_middleware._get_forwarded_url_parts(scope) + assert actual == expected