Skip to content

Commit b7580fe

Browse files
Fix/forwarding headers (#415)
* Respect forwarding headers in pgstac * Respect forwarding headers in sqlalchemy * Linting fixes * Moved to proxy header middleware * add makefile command to run api test cases * bugfix, make method static * add unittests for proxy header middleware Co-authored-by: Jeff Albrecht <[email protected]>
1 parent 286fd98 commit b7580fe

File tree

12 files changed

+611
-4
lines changed

12 files changed

+611
-4
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,7 @@ docs/api/*
129129
.envrc
130130

131131
# Virtualenv
132-
venv
132+
venv
133+
134+
# IDE
135+
.vscode

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367))
99
* Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383))
1010
* Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411))
11+
* Respect `Forwarded` or `X-Forwarded-*` request headers when building links to better accommodate load balancers and proxies.
1112

1213
### Changed
1314

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ test-sqlalchemy: run-joplin-sqlalchemy
4646
test-pgstac:
4747
$(run_pgstac) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/pgstac/tests/ && pytest -vvv'
4848

49+
.PHONY: test-api
50+
test-api:
51+
$(run_sqlalchemy) /bin/bash -c 'cd /app/stac_fastapi/api && pytest -svvv'
52+
4953
.PHONY: run-database
5054
run-database:
5155
docker-compose run --rm database

stac_fastapi/api/stac_fastapi/api/app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from starlette.responses import JSONResponse, Response
1515

1616
from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers
17+
from stac_fastapi.api.middleware import ProxyHeaderMiddleware
1718
from stac_fastapi.api.models import (
1819
APIRequest,
1920
CollectionUri,
@@ -91,7 +92,9 @@ class StacApi:
9192
)
9293
pagination_extension = attr.ib(default=TokenPaginationExtension)
9394
response_class: Type[Response] = attr.ib(default=JSONResponse)
94-
middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware]))
95+
middlewares: List = attr.ib(
96+
default=attr.Factory(lambda: [BrotliMiddleware, ProxyHeaderMiddleware])
97+
)
9598
route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[])
9699

97100
def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]:

stac_fastapi/api/stac_fastapi/api/middleware.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""api middleware."""
22

3-
from typing import Callable
3+
import re
4+
from http.client import HTTP_PORT, HTTPS_PORT
5+
from typing import Callable, List, Tuple
46

57
from fastapi import APIRouter, FastAPI
68
from starlette.middleware.base import BaseHTTPMiddleware
79
from starlette.requests import Request
810
from starlette.routing import Match
11+
from starlette.types import ASGIApp, Receive, Scope, Send
912

1013

1114
def router_middleware(app: FastAPI, router: APIRouter):
@@ -29,3 +32,95 @@ async def _middleware(request: Request, call_next):
2932
return func
3033

3134
return deco
35+
36+
37+
class ProxyHeaderMiddleware:
38+
"""
39+
Account for forwarding headers when deriving base URL.
40+
41+
Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing.
42+
Default to what can be derived from the URL if no headers provided.
43+
Middleware updates the host header that is interpreted by starlette when deriving Request.base_url.
44+
"""
45+
46+
def __init__(self, app: ASGIApp):
47+
"""Create proxy header middleware."""
48+
self.app = app
49+
50+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
51+
"""Call from stac-fastapi framework."""
52+
if scope["type"] == "http":
53+
proto, domain, port = self._get_forwarded_url_parts(scope)
54+
scope["scheme"] = proto
55+
if domain is not None:
56+
port_suffix = ""
57+
if port is not None:
58+
if (proto == "http" and port != HTTP_PORT) or (
59+
proto == "https" and port != HTTPS_PORT
60+
):
61+
port_suffix = f":{port}"
62+
scope["headers"] = self._replace_header_value_by_name(
63+
scope,
64+
"host",
65+
f"{domain}{port_suffix}",
66+
)
67+
await self.app(scope, receive, send)
68+
69+
def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:
70+
print(scope)
71+
proto = scope.get("scheme", "http")
72+
header_host = self._get_header_value_by_name(scope, "host")
73+
if header_host is None:
74+
domain, port = scope.get("server")
75+
else:
76+
header_host_parts = header_host.split(":")
77+
if len(header_host_parts) == 2:
78+
domain, port = header_host_parts
79+
else:
80+
domain = header_host_parts[0]
81+
port = None
82+
forwarded = self._get_header_value_by_name(scope, "forwarded")
83+
if forwarded is not None:
84+
parts = forwarded.split(";")
85+
for part in parts:
86+
if len(part) > 0 and re.search("=", part):
87+
key, value = part.split("=")
88+
if key == "proto":
89+
proto = value
90+
elif key == "host":
91+
host_parts = value.split(":")
92+
domain = host_parts[0]
93+
try:
94+
port = int(host_parts[1]) if len(host_parts) == 2 else None
95+
except ValueError:
96+
# ignore ports that are not valid integers
97+
pass
98+
else:
99+
proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto)
100+
port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port)
101+
try:
102+
port = int(port_str) if port_str is not None else None
103+
except ValueError:
104+
# ignore ports that are not valid integers
105+
pass
106+
107+
return (proto, domain, port)
108+
109+
def _get_header_value_by_name(
110+
self, scope: Scope, header_name: str, default_value: str = None
111+
) -> str:
112+
headers = scope["headers"]
113+
candidates = [
114+
value.decode() for key, value in headers if key.decode() == header_name
115+
]
116+
return candidates[0] if len(candidates) == 1 else default_value
117+
118+
@staticmethod
119+
def _replace_header_value_by_name(
120+
scope: Scope, header_name: str, new_value: str
121+
) -> List[Tuple[str]]:
122+
return [
123+
(name, value)
124+
for name, value in scope["headers"]
125+
if name.decode() != header_name
126+
] + [(str.encode(header_name), str.encode(new_value))]
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import pytest
2+
from starlette.applications import Starlette
3+
4+
from stac_fastapi.api.middleware import ProxyHeaderMiddleware
5+
6+
7+
@pytest.fixture
8+
def proxy_header_middleware() -> ProxyHeaderMiddleware:
9+
app = Starlette()
10+
return ProxyHeaderMiddleware(app)
11+
12+
13+
@pytest.mark.parametrize(
14+
"headers,key,expected",
15+
[
16+
([(b"host", b"testserver")], "host", "testserver"),
17+
([(b"host", b"testserver")], "user-agent", None),
18+
(
19+
[(b"host", b"testserver"), (b"accept-encoding", b"gzip, deflate, br")],
20+
"accept-encoding",
21+
"gzip, deflate, br",
22+
),
23+
],
24+
)
25+
def test_get_header_value_by_name(
26+
proxy_header_middleware: ProxyHeaderMiddleware, headers, key, expected
27+
):
28+
scope = {"headers": headers}
29+
actual = proxy_header_middleware._get_header_value_by_name(scope, key)
30+
assert actual == expected
31+
32+
33+
@pytest.mark.parametrize(
34+
"headers,key,value",
35+
[
36+
([(b"host", b"testserver")], "host", "another-server"),
37+
([(b"host", b"testserver")], "user-agent", "agent"),
38+
(
39+
[(b"host", b"testserver"), (b"accept-encoding", b"gzip, deflate, br")],
40+
"accept-encoding",
41+
"deflate",
42+
),
43+
],
44+
)
45+
def test_replace_header_value_by_name(
46+
proxy_header_middleware: ProxyHeaderMiddleware, headers, key, value
47+
):
48+
scope = {"headers": headers}
49+
updated_headers = proxy_header_middleware._replace_header_value_by_name(
50+
scope, key, value
51+
)
52+
53+
header_value = proxy_header_middleware._get_header_value_by_name(
54+
{"headers": updated_headers}, key
55+
)
56+
assert header_value == value
57+
58+
59+
@pytest.mark.parametrize(
60+
"scope,expected",
61+
[
62+
(
63+
{"scheme": "https", "server": ["testserver", 80], "headers": []},
64+
("https", "testserver", 80),
65+
),
66+
(
67+
{
68+
"scheme": "http",
69+
"server": ["testserver", 80],
70+
"headers": [(b"host", b"testserver:81")],
71+
},
72+
("http", "testserver", 81),
73+
),
74+
(
75+
{
76+
"scheme": "http",
77+
"server": ["testserver", 80],
78+
"headers": [(b"host", b"testserver")],
79+
},
80+
("http", "testserver", None),
81+
),
82+
(
83+
{
84+
"scheme": "http",
85+
"server": ["testserver", 80],
86+
"headers": [(b"forwarded", b"proto=https;host=test:1234")],
87+
},
88+
("https", "test", 1234),
89+
),
90+
(
91+
{
92+
"scheme": "http",
93+
"server": ["testserver", 80],
94+
"headers": [(b"forwarded", b"proto=https;host=test:not-an-integer")],
95+
},
96+
("https", "test", 80),
97+
),
98+
(
99+
{
100+
"scheme": "http",
101+
"server": ["testserver", 80],
102+
"headers": [(b"x-forwarded-proto", b"https")],
103+
},
104+
("https", "testserver", 80),
105+
),
106+
(
107+
{
108+
"scheme": "http",
109+
"server": ["testserver", 80],
110+
"headers": [(b"x-forwarded-port", b"1111")],
111+
},
112+
("http", "testserver", 1111),
113+
),
114+
(
115+
{
116+
"scheme": "http",
117+
"server": ["testserver", 80],
118+
"headers": [(b"x-forwarded-port", b"not-an-integer")],
119+
},
120+
("http", "testserver", 80),
121+
),
122+
(
123+
{
124+
"scheme": "http",
125+
"server": ["testserver", 80],
126+
"headers": [
127+
(b"forwarded", b"proto=https;host=test:1234"),
128+
(b"x-forwarded-port", b"1111"),
129+
(b"x-forwarded-proto", b"https"),
130+
],
131+
},
132+
("https", "test", 1234),
133+
),
134+
],
135+
)
136+
def test_get_forwarded_url_parts(
137+
proxy_header_middleware: ProxyHeaderMiddleware, scope, expected
138+
):
139+
actual = proxy_header_middleware._get_forwarded_url_parts(scope)
140+
assert actual == expected

0 commit comments

Comments
 (0)