diff --git a/.gitignore b/.gitignore index 407d0fe9a..1cc0fbf51 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,7 @@ docs/api/* .envrc # Virtualenv -venv \ No newline at end of file +venv + +# IDE +.vscode \ No newline at end of file diff --git a/CHANGES.md b/CHANGES.md index 3152e9b2e..bf14d02ea 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,7 @@ * Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367)) * Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383)) * Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411)) +* Respect `Forwarded` or `X-Forwarded-*` request headers when building links to better accommodate load balancers and proxies. ### Changed diff --git a/Makefile b/Makefile index fe2b6fe32..36187c2e8 100644 --- a/Makefile +++ b/Makefile @@ -46,6 +46,10 @@ test-sqlalchemy: run-joplin-sqlalchemy test-pgstac: $(run_pgstac) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/pgstac/tests/ && pytest -vvv' +.PHONY: test-api +test-api: + $(run_sqlalchemy) /bin/bash -c 'cd /app/stac_fastapi/api && pytest -svvv' + .PHONY: run-database run-database: docker-compose run --rm database diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index a9f8a5542..9761deaa6 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -14,6 +14,7 @@ from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers +from stac_fastapi.api.middleware import ProxyHeaderMiddleware from stac_fastapi.api.models import ( APIRequest, CollectionUri, @@ -91,7 +92,9 @@ class StacApi: ) pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) - middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware])) + middlewares: List = attr.ib( + default=attr.Factory(lambda: [BrotliMiddleware, ProxyHeaderMiddleware]) + ) route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[]) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index acb00915b..793b75794 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,11 +1,14 @@ """api middleware.""" -from typing import Callable +import re +from http.client import HTTP_PORT, HTTPS_PORT +from typing import Callable, List, Tuple from fastapi import APIRouter, FastAPI from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.routing import Match +from starlette.types import ASGIApp, Receive, Scope, Send def router_middleware(app: FastAPI, router: APIRouter): @@ -29,3 +32,95 @@ async def _middleware(request: Request, call_next): return func return deco + + +class ProxyHeaderMiddleware: + """ + Account for forwarding headers when deriving base URL. + + Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. + Default to what can be derived from the URL if no headers provided. + Middleware updates the host header that is interpreted by starlette when deriving Request.base_url. + """ + + def __init__(self, app: ASGIApp): + """Create proxy header middleware.""" + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Call from stac-fastapi framework.""" + if scope["type"] == "http": + proto, domain, port = self._get_forwarded_url_parts(scope) + scope["scheme"] = proto + if domain is not None: + port_suffix = "" + if port is not None: + if (proto == "http" and port != HTTP_PORT) or ( + proto == "https" and port != HTTPS_PORT + ): + port_suffix = f":{port}" + scope["headers"] = self._replace_header_value_by_name( + scope, + "host", + f"{domain}{port_suffix}", + ) + await self.app(scope, receive, send) + + def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]: + print(scope) + proto = scope.get("scheme", "http") + header_host = self._get_header_value_by_name(scope, "host") + if header_host is None: + domain, port = scope.get("server") + else: + header_host_parts = header_host.split(":") + if len(header_host_parts) == 2: + domain, port = header_host_parts + else: + domain = header_host_parts[0] + port = None + forwarded = self._get_header_value_by_name(scope, "forwarded") + if forwarded is not None: + parts = forwarded.split(";") + for part in parts: + if len(part) > 0 and re.search("=", part): + key, value = part.split("=") + if key == "proto": + proto = value + elif key == "host": + host_parts = value.split(":") + domain = host_parts[0] + try: + port = int(host_parts[1]) if len(host_parts) == 2 else None + except ValueError: + # ignore ports that are not valid integers + pass + else: + proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto) + port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port) + try: + port = int(port_str) if port_str is not None else None + except ValueError: + # ignore ports that are not valid integers + pass + + return (proto, domain, port) + + def _get_header_value_by_name( + self, scope: Scope, header_name: str, default_value: str = None + ) -> str: + headers = scope["headers"] + candidates = [ + value.decode() for key, value in headers if key.decode() == header_name + ] + return candidates[0] if len(candidates) == 1 else default_value + + @staticmethod + def _replace_header_value_by_name( + scope: Scope, header_name: str, new_value: str + ) -> List[Tuple[str]]: + return [ + (name, value) + for name, value in scope["headers"] + if name.decode() != header_name + ] + [(str.encode(header_name), str.encode(new_value))] diff --git a/stac_fastapi/api/tests/test_middleware.py b/stac_fastapi/api/tests/test_middleware.py new file mode 100644 index 000000000..cfe299328 --- /dev/null +++ b/stac_fastapi/api/tests/test_middleware.py @@ -0,0 +1,140 @@ +import pytest +from starlette.applications import Starlette + +from stac_fastapi.api.middleware import ProxyHeaderMiddleware + + +@pytest.fixture +def proxy_header_middleware() -> ProxyHeaderMiddleware: + app = Starlette() + return ProxyHeaderMiddleware(app) + + +@pytest.mark.parametrize( + "headers,key,expected", + [ + ([(b"host", b"testserver")], "host", "testserver"), + ([(b"host", b"testserver")], "user-agent", None), + ( + [(b"host", b"testserver"), (b"accept-encoding", b"gzip, deflate, br")], + "accept-encoding", + "gzip, deflate, br", + ), + ], +) +def test_get_header_value_by_name( + proxy_header_middleware: ProxyHeaderMiddleware, headers, key, expected +): + scope = {"headers": headers} + actual = proxy_header_middleware._get_header_value_by_name(scope, key) + assert actual == expected + + +@pytest.mark.parametrize( + "headers,key,value", + [ + ([(b"host", b"testserver")], "host", "another-server"), + ([(b"host", b"testserver")], "user-agent", "agent"), + ( + [(b"host", b"testserver"), (b"accept-encoding", b"gzip, deflate, br")], + "accept-encoding", + "deflate", + ), + ], +) +def test_replace_header_value_by_name( + proxy_header_middleware: ProxyHeaderMiddleware, headers, key, value +): + scope = {"headers": headers} + updated_headers = proxy_header_middleware._replace_header_value_by_name( + scope, key, value + ) + + header_value = proxy_header_middleware._get_header_value_by_name( + {"headers": updated_headers}, key + ) + assert header_value == value + + +@pytest.mark.parametrize( + "scope,expected", + [ + ( + {"scheme": "https", "server": ["testserver", 80], "headers": []}, + ("https", "testserver", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"host", b"testserver:81")], + }, + ("http", "testserver", 81), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"host", b"testserver")], + }, + ("http", "testserver", None), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"forwarded", b"proto=https;host=test:1234")], + }, + ("https", "test", 1234), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"forwarded", b"proto=https;host=test:not-an-integer")], + }, + ("https", "test", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"x-forwarded-proto", b"https")], + }, + ("https", "testserver", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"x-forwarded-port", b"1111")], + }, + ("http", "testserver", 1111), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [(b"x-forwarded-port", b"not-an-integer")], + }, + ("http", "testserver", 80), + ), + ( + { + "scheme": "http", + "server": ["testserver", 80], + "headers": [ + (b"forwarded", b"proto=https;host=test:1234"), + (b"x-forwarded-port", b"1111"), + (b"x-forwarded-proto", b"https"), + ], + }, + ("https", "test", 1234), + ), + ], +) +def test_get_forwarded_url_parts( + proxy_header_middleware: ProxyHeaderMiddleware, scope, expected +): + actual = proxy_header_middleware._get_forwarded_url_parts(scope) + assert actual == expected diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index b109ad946..3619055ed 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -1,5 +1,7 @@ from datetime import datetime, timedelta +import pytest + STAC_CORE_ROUTES = [ "GET /", "GET /collections", @@ -290,3 +292,94 @@ async def test_search_line_string_intersects( assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 + + +@pytest.mark.asyncio +async def test_landing_forwarded_header( + load_test_data, app_client, load_test_collection +): + coll = load_test_collection + item = load_test_data("test_item.json") + await app_client.post(f"/collections/{coll.id}/items", json=item) + response = ( + await app_client.get( + "/", + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "4321", + }, + ) + ).json() + for link in response["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_search_forwarded_header( + load_test_data, app_client, load_test_collection +): + coll = load_test_collection + item = load_test_data("test_item.json") + await app_client.post(f"/collections/{coll.id}/items", json=item) + resp = await app_client.post( + "/search", + json={ + "collections": [item["collection"]], + }, + headers={"Forwarded": "proto=https;host=test:1234"}, + ) + features = resp.json()["features"] + assert len(features) > 0 + for feature in features: + for link in feature["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_search_x_forwarded_headers( + load_test_data, app_client, load_test_collection +): + coll = load_test_collection + item = load_test_data("test_item.json") + await app_client.post(f"/collections/{coll.id}/items", json=item) + resp = await app_client.post( + "/search", + json={ + "collections": [item["collection"]], + }, + headers={ + "X-Forwarded-Proto": "https", + "X-Forwarded-Port": "1234", + }, + ) + features = resp.json()["features"] + assert len(features) > 0 + for feature in features: + for link in feature["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_search_duplicate_forward_headers( + load_test_data, app_client, load_test_collection +): + coll = load_test_collection + item = load_test_data("test_item.json") + await app_client.post(f"/collections/{coll.id}/items", json=item) + resp = await app_client.post( + "/search", + json={ + "collections": [item["collection"]], + }, + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "4321", + }, + ) + features = resp.json()["features"] + assert len(features) > 0 + for feature in features: + for link in feature["links"]: + assert link["href"].startswith("https://test:1234/") diff --git a/stac_fastapi/pgstac/tests/resources/test_collection.py b/stac_fastapi/pgstac/tests/resources/test_collection.py index 937803db4..bbb8c124a 100644 --- a/stac_fastapi/pgstac/tests/resources/test_collection.py +++ b/stac_fastapi/pgstac/tests/resources/test_collection.py @@ -1,6 +1,7 @@ from typing import Callable import pystac +import pytest from stac_pydantic import Collection @@ -164,3 +165,67 @@ async def test_returns_license_link(app_client, load_test_collection): resp_json = resp.json() link_rel_types = [link["rel"] for link in resp_json["links"]] assert "license" in link_rel_types + + +@pytest.mark.asyncio +async def test_get_collection_forwarded_header(app_client, load_test_collection): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}", + headers={"Forwarded": "proto=https;host=test:1234"}, + ) + for link in [ + link + for link in resp.json()["links"] + if link["rel"] in ["items", "parent", "root", "self"] + ]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collection_x_forwarded_headers(app_client, load_test_collection): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}", + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for link in [ + link + for link in resp.json()["links"] + if link["rel"] in ["items", "parent", "root", "self"] + ]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collection_duplicate_forwarded_headers( + app_client, load_test_collection +): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}", + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for link in [ + link + for link in resp.json()["links"] + if link["rel"] in ["items", "parent", "root", "self"] + ]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collections_forwarded_header(app_client, load_test_collection): + resp = await app_client.get( + "/collections", + headers={"Forwarded": "proto=https;host=test:1234"}, + ) + for link in resp.json()["links"]: + assert link["href"].startswith("https://test:1234/") diff --git a/stac_fastapi/pgstac/tests/resources/test_item.py b/stac_fastapi/pgstac/tests/resources/test_item.py index 40b4b514a..a56fd16dd 100644 --- a/stac_fastapi/pgstac/tests/resources/test_item.py +++ b/stac_fastapi/pgstac/tests/resources/test_item.py @@ -1,10 +1,12 @@ import json import uuid from datetime import timedelta +from http.client import HTTP_PORT from typing import Callable from urllib.parse import parse_qs, urljoin, urlparse import pystac +import pytest from httpx import AsyncClient from pystac.utils import datetime_to_str from shapely.geometry import Polygon @@ -1170,11 +1172,12 @@ async def test_relative_link_construction(): "type": "http", "scheme": "http", "method": "PUT", - "root_path": "http://test/stac", + "root_path": "/stac", # root_path should not have proto, domain, or port "path": "/", "raw_path": b"/tab/abc", "query_string": b"", "headers": {}, + "server": ("test", HTTP_PORT), } ) links = CollectionLinks(collection_id="naip", request=req) @@ -1394,3 +1397,49 @@ async def test_item_merge_raster_bands( assert len(red_bands[0].keys()) == 6 # The merged item should have kept the item value rather than the base value assert red_bands[0]["offset"] == 2.03976 + + +@pytest.mark.asyncio +async def test_get_collection_items_forwarded_header( + app_client, load_test_collection, load_test_item +): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}/items", + headers={"Forwarded": "proto=https;host=test:1234"}, + ) + for link in resp.json()["features"][0]["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collection_items_x_forwarded_headers( + app_client, load_test_collection, load_test_item +): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}/items", + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for link in resp.json()["features"][0]["links"]: + assert link["href"].startswith("https://test:1234/") + + +@pytest.mark.asyncio +async def test_get_collection_items_duplicate_forwarded_headers( + app_client, load_test_collection, load_test_item +): + coll = load_test_collection + resp = await app_client.get( + f"/collections/{coll.id}/items", + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for link in resp.json()["features"][0]["links"]: + assert link["href"].startswith("https://test:1234/") diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index 1ee196923..ba7f5a655 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -319,3 +319,74 @@ def test_app_fields_extension_return_all_properties( assert feature["properties"][expected_prop][0:19] == expected_value[0:19] else: assert feature["properties"][expected_prop] == expected_value + + +def test_landing_forwarded_header(load_test_data, app_client, postgres_transactions): + item = load_test_data("test_item.json") + postgres_transactions.create_item(item, request=MockStarletteRequest) + + response = app_client.get( + "/", + headers={ + "Forwarded": "proto=https;host=test:1234", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "4321", + }, + ).json() + for link in response["links"]: + assert link["href"].startswith("https://test:1234/") + + +def test_app_search_response_forwarded_header( + load_test_data, app_client, postgres_transactions +): + item = load_test_data("test_item.json") + postgres_transactions.create_item(item, request=MockStarletteRequest) + + resp = app_client.get( + "/search", + params={"collections": ["test-collection"]}, + headers={"Forwarded": "proto=https;host=testserver:1234"}, + ) + for feature in resp.json()["features"]: + for link in feature["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_app_search_response_x_forwarded_headers( + load_test_data, app_client, postgres_transactions +): + item = load_test_data("test_item.json") + postgres_transactions.create_item(item, request=MockStarletteRequest) + + resp = app_client.get( + "/search", + params={"collections": ["test-collection"]}, + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for feature in resp.json()["features"]: + for link in feature["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_app_search_response_duplicate_forwarded_headers( + load_test_data, app_client, postgres_transactions +): + item = load_test_data("test_item.json") + postgres_transactions.create_item(item, request=MockStarletteRequest) + + resp = app_client.get( + "/search", + params={"collections": ["test-collection"]}, + headers={ + "Forwarded": "proto=https;host=testserver:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for feature in resp.json()["features"]: + for link in feature["links"]: + assert link["href"].startswith("https://testserver:1234/") diff --git a/stac_fastapi/sqlalchemy/tests/resources/test_collection.py b/stac_fastapi/sqlalchemy/tests/resources/test_collection.py index b0d8b3d66..275b2684f 100644 --- a/stac_fastapi/sqlalchemy/tests/resources/test_collection.py +++ b/stac_fastapi/sqlalchemy/tests/resources/test_collection.py @@ -73,3 +73,46 @@ def test_returns_valid_collection(app_client, load_test_data): resp_json, root=mock_root, preserve_dict=False ) collection.validate() + + +def test_get_collection_forwarded_header(app_client, load_test_data): + test_collection = load_test_data("test_collection.json") + app_client.put("/collections", json=test_collection) + + resp = app_client.get( + f"/collections/{test_collection['id']}", + headers={"Forwarded": "proto=https;host=testserver:1234"}, + ) + for link in resp.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_get_collection_x_forwarded_headers(app_client, load_test_data): + test_collection = load_test_data("test_collection.json") + app_client.put("/collections", json=test_collection) + + resp = app_client.get( + f"/collections/{test_collection['id']}", + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for link in resp.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_get_collection_duplicate_forwarded_headers(app_client, load_test_data): + test_collection = load_test_data("test_collection.json") + app_client.put("/collections", json=test_collection) + + resp = app_client.get( + f"/collections/{test_collection['id']}", + headers={ + "Forwarded": "proto=https;host=testserver:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for link in resp.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") diff --git a/stac_fastapi/sqlalchemy/tests/resources/test_item.py b/stac_fastapi/sqlalchemy/tests/resources/test_item.py index 2f671de68..d7618c2c1 100644 --- a/stac_fastapi/sqlalchemy/tests/resources/test_item.py +++ b/stac_fastapi/sqlalchemy/tests/resources/test_item.py @@ -942,3 +942,43 @@ def test_search_datetime_validation_errors(app_client): resp = app_client.get("/search?datetime={}".format(dt)) assert resp.status_code == 400 + + +def test_get_item_forwarded_header(app_client, load_test_data): + test_item = load_test_data("test_item.json") + app_client.post(f"/collections/{test_item['collection']}/items", json=test_item) + get_item = app_client.get( + f"/collections/{test_item['collection']}/items/{test_item['id']}", + headers={"Forwarded": "proto=https;host=testserver:1234"}, + ) + for link in get_item.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_get_item_x_forwarded_headers(app_client, load_test_data): + test_item = load_test_data("test_item.json") + app_client.post(f"/collections/{test_item['collection']}/items", json=test_item) + get_item = app_client.get( + f"/collections/{test_item['collection']}/items/{test_item['id']}", + headers={ + "X-Forwarded-Port": "1234", + "X-Forwarded-Proto": "https", + }, + ) + for link in get_item.json()["links"]: + assert link["href"].startswith("https://testserver:1234/") + + +def test_get_item_duplicate_forwarded_headers(app_client, load_test_data): + test_item = load_test_data("test_item.json") + app_client.post(f"/collections/{test_item['collection']}/items", json=test_item) + get_item = app_client.get( + f"/collections/{test_item['collection']}/items/{test_item['id']}", + headers={ + "Forwarded": "proto=https;host=testserver:1234", + "X-Forwarded-Port": "4321", + "X-Forwarded-Proto": "http", + }, + ) + for link in get_item.json()["links"]: + assert link["href"].startswith("https://testserver:1234/")