Skip to content

feat(event_source): add class APIGatewayAuthorizerResponseWebSocket #6058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import warnings
from typing import Any, overload

from typing_extensions import deprecated
from typing_extensions import deprecated, override

from aws_lambda_powertools.utilities.data_classes.common import (
BaseRequestContext,
@@ -28,9 +28,10 @@ def __init__(
aws_account_id: str,
api_id: str,
stage: str,
http_method: str,
http_method: str | None,
resource: str,
partition: str = "aws",
is_websocket_authorizer: bool = False,
):
self.partition = partition
self.region = region
@@ -40,39 +41,54 @@ def __init__(
self.http_method = http_method
# Remove matching "/" from `resource`.
self.resource = resource.lstrip("/")
self.is_websocket_authorizer = is_websocket_authorizer

@property
def arn(self) -> str:
"""Build an arn from its parts
eg: arn:aws:execute-api:us-east-1:123456789012:abcdef123/test/GET/request"""
return (
f"arn:{self.partition}:execute-api:{self.region}:{self.aws_account_id}:{self.api_id}/{self.stage}/"
f"{self.http_method}/{self.resource}"
)
base_arn = f"arn:{self.partition}:execute-api:{self.region}:{self.aws_account_id}:{self.api_id}/{self.stage}"

if not self.is_websocket_authorizer:
return f"{base_arn}/{self.http_method}/{self.resource}"
else:
return f"{base_arn}/{self.resource}"


def parse_api_gateway_arn(arn: str) -> APIGatewayRouteArn:
def parse_api_gateway_arn(arn: str, is_websocket_authorizer: bool = False) -> APIGatewayRouteArn:
"""Parses a gateway route arn as a APIGatewayRouteArn class

Parameters
----------
arn : str
ARN string for a methodArn or a routeArn
is_websocket_authorizer: bool
If it's a API Gateway Websocket

Returns
-------
APIGatewayRouteArn
"""
arn_parts = arn.split(":")
api_gateway_arn_parts = arn_parts[5].split("/")

if not is_websocket_authorizer:
http_method = api_gateway_arn_parts[2]
resource = "/".join(api_gateway_arn_parts[3:]) if len(api_gateway_arn_parts) >= 4 else ""
else:
http_method = None
resource = "/".join(api_gateway_arn_parts[2:])

return APIGatewayRouteArn(
partition=arn_parts[1],
region=arn_parts[3],
aws_account_id=arn_parts[4],
api_id=api_gateway_arn_parts[0],
stage=api_gateway_arn_parts[1],
http_method=api_gateway_arn_parts[2],
http_method=http_method,
# conditional allow us to handle /path/{proxy+} resources, as their length changes.
resource="/".join(api_gateway_arn_parts[3:]) if len(api_gateway_arn_parts) >= 4 else "",
resource=resource,
is_websocket_authorizer=is_websocket_authorizer,
)


@@ -512,13 +528,14 @@ def _add_route(self, effect: str, http_method: str, resource: str, conditions: l
raise ValueError(f"Invalid resource path: {resource}. Path should match {self.path_regex}")

resource_arn = APIGatewayRouteArn(
self.region,
self.aws_account_id,
self.api_id,
self.stage,
http_method,
resource,
self.partition,
region=self.region,
aws_account_id=self.aws_account_id,
api_id=self.api_id,
stage=self.stage,
http_method=http_method,
resource=resource,
partition=self.partition,
is_websocket_authorizer=False,
).arn

route = {"resourceArn": resource_arn, "conditions": conditions}
@@ -617,3 +634,127 @@ def asdict(self) -> dict[str, Any]:
response["context"] = self.context

return response


class APIGatewayAuthorizerResponseWebSocket(APIGatewayAuthorizerResponse):
"""The IAM Policy Response required for API Gateway WebSocket APIs

Based on: - https://github.com/awslabs/aws-apigateway-lambda-authorizer-blueprints/blob/\
master/blueprints/python/api-gateway-authorizer-python.py

Documentation:
-------------
- https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-lambda-authorizer.html
- https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-lambda-authorizer-output.html
"""

@staticmethod
def from_route_arn(
arn: str,
principal_id: str,
context: dict | None = None,
usage_identifier_key: str | None = None,
) -> APIGatewayAuthorizerResponseWebSocket:
parsed_arn = parse_api_gateway_arn(arn, is_websocket_authorizer=True)
return APIGatewayAuthorizerResponseWebSocket(
principal_id,
parsed_arn.region,
parsed_arn.aws_account_id,
parsed_arn.api_id,
parsed_arn.stage,
context,
usage_identifier_key,
)

# Note: we need ignore[override] because we are removing the http_method field
@override
def _add_route(self, effect: str, resource: str, conditions: list[dict] | None = None): # type: ignore[override]
"""Adds a route to the internal lists of allowed or denied routes. Each object in
the internal list contains a resource ARN and a condition statement. The condition
statement can be null."""
resource_arn = APIGatewayRouteArn(
region=self.region,
aws_account_id=self.aws_account_id,
api_id=self.api_id,
stage=self.stage,
http_method=None,
resource=resource,
partition=self.partition,
is_websocket_authorizer=True,
).arn

route = {"resourceArn": resource_arn, "conditions": conditions}

if effect.lower() == "allow":
self._allow_routes.append(route)
else: # deny
self._deny_routes.append(route)

@override
def allow_all_routes(self):
"""Adds a '*' allow to the policy to authorize access to all methods of an API"""
self._add_route(effect="Allow", resource="*")

@override
def deny_all_routes(self):
"""Adds a '*' allow to the policy to deny access to all methods of an API"""

self._add_route(effect="Deny", resource="*")

# Note: we need ignore[override] because we are removing the http_method field
@override
def allow_route(self, resource: str, conditions: list[dict] | None = None): # type: ignore[override]
"""
Add an API Gateway Websocket method to the list of allowed methods for the policy.

This method adds an API Gateway Websocket method Resource path) to the list of
allowed methods for the policy. It optionally includes conditions for the policy statement.

Parameters
----------
resource : str
The API Gateway resource path to allow.
conditions : list[dict] | None, optional
A list of condition dictionaries to apply to the policy statement.
Default is None.

Notes
-----
For more information on AWS policy conditions, see:
https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition

Example
--------
>>> policy = APIGatewayAuthorizerResponseWebSocket(...)
>>> policy.allow_route("/api/users", [{"StringEquals": {"aws:RequestTag/Environment": "Production"}}])
"""
self._add_route(effect="Allow", resource=resource, conditions=conditions)

# Note: we need ignore[override] because we are removing the http_method field
@override
def deny_route(self, resource: str, conditions: list[dict] | None = None): # type: ignore[override]
"""
Add an API Gateway Websocket method to the list of allowed methods for the policy.

This method adds an API Gateway Websocket method Resource path) to the list of
denied methods for the policy. It optionally includes conditions for the policy statement.

Parameters
----------
resource : str
The API Gateway resource path to allow.
conditions : list[dict] | None, optional
A list of condition dictionaries to apply to the policy statement.
Default is None.

Notes
-----
For more information on AWS policy conditions, see:
https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition

Example
--------
>>> policy = APIGatewayAuthorizerResponseWebSocket(...)
>>> policy.deny_route("/api/users", [{"StringEquals": {"aws:RequestTag/Environment": "Production"}}])
"""
self._add_route(effect="Deny", resource=resource, conditions=conditions)
10 changes: 8 additions & 2 deletions docs/utilities/data_classes.md
Original file line number Diff line number Diff line change
@@ -131,12 +131,18 @@ It is used for [API Gateway Rest API Lambda Authorizer payload](https://docs.aws

Use **`APIGatewayAuthorizerRequestEvent`** for type `REQUEST` and **`APIGatewayAuthorizerTokenEvent`** for type `TOKEN`.

=== "app.py"
=== "Rest APIs"

```python hl_lines="2-4 8"
```python hl_lines="2-4 8 18"
--8<-- "examples/event_sources/src/apigw_authorizer_request.py"
```

=== "WebSocket APIs"

```python hl_lines="2-4 8 18"
--8<-- "examples/event_sources/src/apigw_authorizer_request_websocket.py"
```

=== "API Gateway Authorizer Request Example Event"

```json hl_lines="3 11"
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from aws_lambda_powertools.utilities.data_classes import event_source
from aws_lambda_powertools.utilities.data_classes.api_gateway_authorizer_event import (
APIGatewayAuthorizerRequestEvent,
APIGatewayAuthorizerResponseWebSocket,
)


@event_source(data_class=APIGatewayAuthorizerRequestEvent)
def lambda_handler(event: APIGatewayAuthorizerRequestEvent, context):
# Simple auth check (replace with your actual auth logic)
is_authorized = event.headers.get("HeaderAuth1") == "headerValue1"

if not is_authorized:
return {"principalId": "", "policyDocument": {"Version": "2012-10-17", "Statement": []}}

arn = event.parsed_arn

policy = APIGatewayAuthorizerResponseWebSocket(
principal_id="user",
context={"user": "example"},
region=arn.region,
aws_account_id=arn.aws_account_id,
api_id=arn.api_id,
stage=arn.stage,
)

policy.allow_all_routes()

return policy.asdict()
81 changes: 81 additions & 0 deletions tests/events/apiGatewayAuthorizerWebSocketEvent.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
{
"type":"REQUEST",
"methodArn":"arn:aws:execute-api:us-east-1:533568316194:c5jwxq709g/production/$connect",
"headers":{
"Authorization":"Leo",
"Connection":"upgrade",
"content-length":"0",
"Host":"c5jwxq709g.execute-api.us-east-1.amazonaws.com",
"Sec-WebSocket-Extensions":"permessage-deflate; client_max_window_bits",
"Sec-WebSocket-Version":"13",
"Upgrade":"websocket",
"X-Amzn-Trace-Id":"Root=1-6797b6d3-64f9c928577f3ac56f5368ce",
"X-Forwarded-For":"93.108.161.96",
"X-Forwarded-Port":"443",
"X-Forwarded-Proto":"https"
},
"multiValueHeaders":{
"Authorization":[
"Leo"
],
"Connection":[
"upgrade"
],
"content-length":[
"0"
],
"Host":[
"c5jwxq709g.execute-api.us-east-1.amazonaws.com"
],
"Sec-WebSocket-Extensions":[
"permessage-deflate; client_max_window_bits"
],
"Sec-WebSocket-Key":[
"CYZZrfNgEcgzzzwL44qytQ=="
],
"Sec-WebSocket-Version":[
"13"
],
"Upgrade":[
"websocket"
],
"X-Amzn-Trace-Id":[
"Root=1-6797b6d3-64f9c928577f3ac56f5368ce"
],
"X-Forwarded-For":[
"93.108.161.96"
],
"X-Forwarded-Port":[
"443"
],
"X-Forwarded-Proto":[
"https"
]
},
"queryStringParameters":{

},
"multiValueQueryStringParameters":{

},
"stageVariables":{

},
"requestContext":{
"routeKey":"$connect",
"eventType":"CONNECT",
"extendedRequestId":"FDmBIG3EoAMEqYA=",
"requestTime":"27/Jan/2025:16:39:47 +0000",
"messageDirection":"IN",
"stage":"production",
"connectedAt":1737995987617,
"requestTimeEpoch":1737995987617,
"identity":{
"sourceIp":"93.108.161.96"
},
"requestId":"FDmBIG3EoAMEqYA=",
"domainName":"c5jwxq709g.execute-api.us-east-1.amazonaws.com",
"connectionId":"FDmBIeapIAMCIQg=",
"apiId":"c5jwxq709g"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import pytest

from aws_lambda_powertools.utilities.data_classes.api_gateway_authorizer_event import (
DENY_ALL_RESPONSE,
APIGatewayAuthorizerResponseWebSocket,
)


@pytest.fixture
def builder():
return APIGatewayAuthorizerResponseWebSocket("foo", "us-west-1", "123456789", "fantom", "dev")


def test_authorizer_response_no_statement(builder: APIGatewayAuthorizerResponseWebSocket):
# GIVEN a builder with no statements
with pytest.raises(ValueError) as ex:
# WHEN calling build
builder.asdict()

# THEN raise a name error for not statements
assert str(ex.value) == "No statements defined for the policy"


def test_authorizer_response_allow_all_routes_with_context():
arn = "arn:aws:execute-api:us-west-1:123456789:fantom/dev/$connect"
builder = APIGatewayAuthorizerResponseWebSocket.from_route_arn(arn, principal_id="foo", context={"name": "Foo"})
builder.allow_all_routes()
assert builder.asdict() == {
"principalId": "foo",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": "execute-api:Invoke",
"Effect": "Allow",
"Resource": ["arn:aws:execute-api:us-west-1:123456789:fantom/dev/*"],
},
],
},
"context": {"name": "Foo"},
}


def test_authorizer_response_allow_all_routes_with_usage_identifier_key():
arn = "arn:aws:execute-api:us-east-1:1111111111:api/dev/y"
builder = APIGatewayAuthorizerResponseWebSocket.from_route_arn(arn, principal_id="cow", usage_identifier_key="key")
builder.allow_all_routes()
assert builder.asdict() == {
"principalId": "cow",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": "execute-api:Invoke",
"Effect": "Allow",
"Resource": ["arn:aws:execute-api:us-east-1:1111111111:api/dev/*"],
},
],
},
"usageIdentifierKey": "key",
}


def test_authorizer_response_deny_all_routes(builder: APIGatewayAuthorizerResponseWebSocket):
builder.deny_all_routes()
assert builder.asdict() == {
"principalId": "foo",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": "execute-api:Invoke",
"Effect": "Deny",
"Resource": ["arn:aws:execute-api:us-west-1:123456789:fantom/dev/*"],
},
],
},
}


def test_authorizer_response_allow_route(builder: APIGatewayAuthorizerResponseWebSocket):
builder.allow_route(resource="/foo")
assert builder.asdict() == {
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": "execute-api:Invoke",
"Effect": "Allow",
"Resource": ["arn:aws:execute-api:us-west-1:123456789:fantom/dev/foo"],
},
],
},
"principalId": "foo",
}


def test_authorizer_response_deny_route(builder: APIGatewayAuthorizerResponseWebSocket):
builder.deny_route(resource="foo")
assert builder.asdict() == {
"principalId": "foo",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": "execute-api:Invoke",
"Effect": "Deny",
"Resource": ["arn:aws:execute-api:us-west-1:123456789:fantom/dev/foo"],
},
],
},
}


def test_authorizer_response_allow_route_with_conditions(builder: APIGatewayAuthorizerResponseWebSocket):
condition = {"StringEquals": {"method.request.header.Content-Type": "text/html"}}
builder.allow_route(
resource="/foo",
conditions=[condition],
)
assert builder.asdict() == {
"principalId": "foo",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": "execute-api:Invoke",
"Effect": "Allow",
"Resource": ["arn:aws:execute-api:us-west-1:123456789:fantom/dev/foo"],
"Condition": [{"StringEquals": {"method.request.header.Content-Type": "text/html"}}],
},
],
},
}


def test_authorizer_response_deny_route_with_conditions(builder: APIGatewayAuthorizerResponseWebSocket):
condition = {"StringEquals": {"method.request.header.Content-Type": "application/json"}}
builder.deny_route(resource="/foo", conditions=[condition])
assert builder.asdict() == {
"principalId": "foo",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": "execute-api:Invoke",
"Effect": "Deny",
"Resource": ["arn:aws:execute-api:us-west-1:123456789:fantom/dev/foo"],
"Condition": [{"StringEquals": {"method.request.header.Content-Type": "application/json"}}],
},
],
},
}


def test_deny_all():
# CHECK we always explicitly deny all
statements = DENY_ALL_RESPONSE["policyDocument"]["Statement"]
assert len(statements) == 1
assert statements[0] == {
"Action": "execute-api:Invoke",
"Effect": "Deny",
"Resource": ["*"],
}


def test_authorizer_response_allow_route_with_underscore(builder: APIGatewayAuthorizerResponseWebSocket):
builder.allow_route(resource="/has_underscore")
assert builder.asdict() == {
"principalId": "foo",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": "execute-api:Invoke",
"Effect": "Allow",
"Resource": ["arn:aws:execute-api:us-west-1:123456789:fantom/dev/has_underscore"],
},
],
},
}