Skip to content

Commit 33820d1

Browse files
feat(event_handler): support Header parameter validation in OpenAPI schema (#3687)
* Adding header - Initial commit * Adding header - Fix VPC Lattice Payload * Adding header - tests and final changes * Making sonarqube happy * Adding documentation * Rafactoring to be complaint with RFC * Adding tests * Adding test with Uppercase variables * Revert event changes * Adding HTTP RFC * Adding getter/setter to clean the code * Adding getter/setter to clean the code * Addressing Ruben's feedback
1 parent ced0a3d commit 33820d1

18 files changed

+873
-213
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,22 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
8181
query_string,
8282
)
8383

84+
# Normalize header values before validate this
85+
headers = _normalize_multi_header_values_with_param(
86+
app.current_event.resolved_headers_field,
87+
route.dependant.header_params,
88+
)
89+
90+
# Process header values
91+
header_values, header_errors = _request_params_to_args(
92+
route.dependant.header_params,
93+
headers,
94+
)
95+
8496
values.update(path_values)
8597
values.update(query_values)
86-
errors += path_errors + query_errors
98+
values.update(header_values)
99+
errors += path_errors + query_errors + header_errors
87100

88101
# Process the request body, if it exists
89102
if route.dependant.body_params:
@@ -243,12 +256,14 @@ def _request_params_to_args(
243256
errors = []
244257

245258
for field in required_params:
246-
value = received_params.get(field.alias)
247-
248259
field_info = field.field_info
260+
261+
# To ensure early failure, we check if it's not an instance of Param.
249262
if not isinstance(field_info, Param):
250263
raise AssertionError(f"Expected Param field_info, got {field_info}")
251264

265+
value = received_params.get(field.alias)
266+
252267
loc = (field_info.in_.value, field.alias)
253268

254269
# If we don't have a value, see if it's required or has a default
@@ -377,3 +392,30 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st
377392
except KeyError:
378393
pass
379394
return query_string
395+
396+
397+
def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]):
398+
"""
399+
Extract and normalize resolved_headers_field
400+
401+
Parameters
402+
----------
403+
headers: Dict
404+
A dictionary containing the initial header parameters.
405+
params: Sequence[ModelField]
406+
A sequence of ModelField objects representing parameters.
407+
408+
Returns
409+
-------
410+
A dictionary containing the processed headers.
411+
"""
412+
if headers:
413+
for param in filter(is_scalar_field, params):
414+
try:
415+
if len(headers[param.alias]) == 1:
416+
# if the target parameter is a scalar and the list contains only 1 element
417+
# we keep the first value of the headers regardless if there are more in the payload
418+
headers[param.alias] = headers[param.alias][0]
419+
except KeyError:
420+
pass
421+
return headers

aws_lambda_powertools/event_handler/openapi/dependant.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from aws_lambda_powertools.event_handler.openapi.params import (
1515
Body,
1616
Dependant,
17+
Header,
1718
Param,
1819
ParamTypes,
1920
Query,
2021
_File,
2122
_Form,
22-
_Header,
2323
analyze_param,
2424
create_response_field,
2525
get_flat_dependant,
@@ -59,16 +59,21 @@ def add_param_to_fields(
5959
6060
"""
6161
field_info = cast(Param, field.field_info)
62-
if field_info.in_ == ParamTypes.path:
63-
dependant.path_params.append(field)
64-
elif field_info.in_ == ParamTypes.query:
65-
dependant.query_params.append(field)
66-
elif field_info.in_ == ParamTypes.header:
67-
dependant.header_params.append(field)
62+
63+
# Dictionary to map ParamTypes to their corresponding lists in dependant
64+
param_type_map = {
65+
ParamTypes.path: dependant.path_params,
66+
ParamTypes.query: dependant.query_params,
67+
ParamTypes.header: dependant.header_params,
68+
ParamTypes.cookie: dependant.cookie_params,
69+
}
70+
71+
# Check if field_info.in_ is a valid key in param_type_map and append the field to the corresponding list
72+
# or raise an exception if it's not a valid key.
73+
if field_info.in_ in param_type_map:
74+
param_type_map[field_info.in_].append(field)
6875
else:
69-
if field_info.in_ != ParamTypes.cookie:
70-
raise AssertionError(f"Unsupported param type: {field_info.in_}")
71-
dependant.cookie_params.append(field)
76+
raise AssertionError(f"Unsupported param type: {field_info.in_}")
7277

7378

7479
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
@@ -265,7 +270,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
265270
return False
266271
elif is_scalar_field(field=param_field):
267272
return False
268-
elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field):
273+
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
269274
return False
270275
else:
271276
if not isinstance(param_field.field_info, Body):

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def __init__(
486486
)
487487

488488

489-
class _Header(Param):
489+
class Header(Param):
490490
"""
491491
A class used internally to represent a header parameter in a path operation.
492492
"""
@@ -527,12 +527,75 @@ def __init__(
527527
json_schema_extra: Union[Dict[str, Any], None] = None,
528528
**extra: Any,
529529
):
530+
"""
531+
Constructs a new Query param.
532+
533+
Parameters
534+
----------
535+
default: Any
536+
The default value of the parameter
537+
default_factory: Callable[[], Any], optional
538+
Callable that will be called when a default value is needed for this field
539+
annotation: Any, optional
540+
The type annotation of the parameter
541+
alias: str, optional
542+
The public name of the field
543+
alias_priority: int, optional
544+
Priority of the alias. This affects whether an alias generator is used
545+
validation_alias: str | AliasPath | AliasChoices | None, optional
546+
Alias to be used for validation only
547+
serialization_alias: str | AliasPath | AliasChoices | None, optional
548+
Alias to be used for serialization only
549+
convert_underscores: bool
550+
If true convert "_" to "-"
551+
See RFC: https://www.rfc-editor.org/rfc/rfc9110.html#name-field-name-registry
552+
title: str, optional
553+
The title of the parameter
554+
description: str, optional
555+
The description of the parameter
556+
gt: float, optional
557+
Only applies to numbers, required the field to be "greater than"
558+
ge: float, optional
559+
Only applies to numbers, required the field to be "greater than or equal"
560+
lt: float, optional
561+
Only applies to numbers, required the field to be "less than"
562+
le: float, optional
563+
Only applies to numbers, required the field to be "less than or equal"
564+
min_length: int, optional
565+
Only applies to strings, required the field to have a minimum length
566+
max_length: int, optional
567+
Only applies to strings, required the field to have a maximum length
568+
pattern: str, optional
569+
Only applies to strings, requires the field match against a regular expression pattern string
570+
discriminator: str, optional
571+
Parameter field name for discriminating the type in a tagged union
572+
strict: bool, optional
573+
Enables Pydantic's strict mode for the field
574+
multiple_of: float, optional
575+
Only applies to numbers, requires the field to be a multiple of the given value
576+
allow_inf_nan: bool, optional
577+
Only applies to numbers, requires the field to allow infinity and NaN values
578+
max_digits: int, optional
579+
Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal.
580+
decimal_places: int, optional
581+
Only applies to Decimals, requires the field to have at most a number of decimal places
582+
examples: List[Any], optional
583+
A list of examples for the parameter
584+
deprecated: bool, optional
585+
If `True`, the parameter will be marked as deprecated
586+
include_in_schema: bool, optional
587+
If `False`, the parameter will be excluded from the generated OpenAPI schema
588+
json_schema_extra: Dict[str, Any], optional
589+
Extra values to include in the generated OpenAPI schema
590+
"""
530591
self.convert_underscores = convert_underscores
592+
self._alias = alias
593+
531594
super().__init__(
532595
default=default,
533596
default_factory=default_factory,
534597
annotation=annotation,
535-
alias=alias,
598+
alias=self._alias,
536599
alias_priority=alias_priority,
537600
validation_alias=validation_alias,
538601
serialization_alias=serialization_alias,
@@ -558,6 +621,18 @@ def __init__(
558621
**extra,
559622
)
560623

624+
@property
625+
def alias(self):
626+
return self._alias
627+
628+
@alias.setter
629+
def alias(self, value: Optional[str] = None):
630+
if value is not None:
631+
# Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the parameter name
632+
# This ensures that customers can access headers with any casing, as per the RFC guidelines.
633+
# Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2
634+
self._alias = value.lower()
635+
561636

562637
class Body(FieldInfo):
563638
"""

aws_lambda_powertools/utilities/data_classes/alb_event.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
4242

4343
return self.query_string_parameters
4444

45+
@property
46+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
47+
headers: Dict[str, Any] = {}
48+
49+
if self.multi_value_headers:
50+
headers = self.multi_value_headers
51+
else:
52+
headers = self.headers
53+
54+
return {key.lower(): value for key, value in headers.items()}
55+
4556
@property
4657
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
4758
return self.get("multiValueHeaders")

aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
125125

126126
return self.query_string_parameters
127127

128+
@property
129+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
130+
headers: Dict[str, Any] = {}
131+
132+
if self.multi_value_headers:
133+
headers = self.multi_value_headers
134+
else:
135+
headers = self.headers
136+
137+
return {key.lower(): value for key, value in headers.items()}
138+
128139
@property
129140
def request_context(self) -> APIGatewayEventRequestContext:
130141
return APIGatewayEventRequestContext(self._data)
@@ -316,3 +327,11 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
316327
return query_string
317328

318329
return {}
330+
331+
@property
332+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
333+
if self.headers is not None:
334+
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
335+
return headers
336+
337+
return {}

aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional
1+
from typing import Any, Dict, List, Optional
22

33
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper
44

@@ -112,3 +112,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
112112
@property
113113
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
114114
return self.query_string_parameters
115+
116+
@property
117+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
118+
return {}

aws_lambda_powertools/utilities/data_classes/common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,21 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
114114
"""
115115
return self.query_string_parameters
116116

117+
@property
118+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
119+
"""
120+
This property determines the appropriate header to be used
121+
as a trusted source for validating OpenAPI.
122+
123+
This is necessary because different resolvers use different formats to encode
124+
headers parameters.
125+
126+
Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the header name
127+
This ensures that customers can access headers with any casing, as per the RFC guidelines.
128+
Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2
129+
"""
130+
return self.headers
131+
117132
@property
118133
def is_base64_encoded(self) -> Optional[bool]:
119134
return self.get("isBase64Encoded")

aws_lambda_powertools/utilities/data_classes/vpc_lattice.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ def query_string_parameters(self) -> Dict[str, str]:
145145
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
146146
return self.query_string_parameters
147147

148+
@property
149+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
150+
if self.headers is not None:
151+
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
152+
return headers
153+
154+
return {}
155+
148156

149157
class vpcLatticeEventV2Identity(DictWrapper):
150158
@property
@@ -259,3 +267,10 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
259267
@property
260268
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
261269
return self.query_string_parameters
270+
271+
@property
272+
def resolved_headers_field(self) -> Optional[Dict[str, str]]:
273+
if self.headers is not None:
274+
return {key.lower(): value for key, value in self.headers.items()}
275+
276+
return {}

0 commit comments

Comments
 (0)