diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 70c10596463..69e1c22c381 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1592,6 +1592,7 @@ def enable_swagger( middlewares: List[Callable[..., Response]], optional List of middlewares to be used for the swagger route. """ + from aws_lambda_powertools.event_handler.openapi.compat import model_json from aws_lambda_powertools.event_handler.openapi.models import Server if not swagger_base_url: @@ -1640,7 +1641,28 @@ def swagger_handler(): license_info=license_info, ) - body = generate_swagger_html(spec, swagger_js, swagger_css) + # The .replace(' or similar tags. Escaping the forward slash in str: +def generate_swagger_html(spec: str, path: str, js_url: str, css_url: str) -> str: """ Generate Swagger UI HTML page Parameters ---------- - spec: OpenAPI + spec: str The OpenAPI spec + path: str + The path to the Swagger documentation js_url: str The URL to the Swagger UI JavaScript file css_url: str The URL to the Swagger UI CSS file """ - from aws_lambda_powertools.event_handler.openapi.compat import model_json - - # The .replace(' or similar tags. Escaping the forward slash in @@ -60,7 +44,7 @@ def generate_swagger_html(spec: "OpenAPI", js_url: str, css_url: str) -> str: layout: "BaseLayout", showExtensions: true, showCommonExtensions: true, - spec: {escaped_spec}, + spec: {spec}, presets: [ SwaggerUIBundle.presets.apis, SwaggerUIBundle.SwaggerUIStandalonePreset @@ -71,6 +55,7 @@ def generate_swagger_html(spec: "OpenAPI", js_url: str, css_url: str) -> str: }} var ui = SwaggerUIBundle(swaggerUIOptions) + ui.specActions.updateUrl('{path}?format=json'); """.strip() diff --git a/tests/functional/event_handler/test_openapi_swagger.py b/tests/functional/event_handler/test_openapi_swagger.py index da2bfe199f2..18ed85ed676 100644 --- a/tests/functional/event_handler/test_openapi_swagger.py +++ b/tests/functional/event_handler/test_openapi_swagger.py @@ -1,3 +1,6 @@ +import json +from typing import Dict + from aws_lambda_powertools.event_handler import APIGatewayRestResolver from tests.functional.utils import load_event @@ -68,3 +71,31 @@ def test_openapi_swagger_with_custom_base_url_no_embedded_assets(): LOAD_GW_EVENT["path"] = "/swagger.js" result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 404 + + +def test_openapi_swagger_json_view_with_default_path(): + app = APIGatewayRestResolver(enable_validation=True) + app.enable_swagger(title="OpenAPI JSON View") + LOAD_GW_EVENT["path"] = "/swagger" + LOAD_GW_EVENT["queryStringParameters"] = {"format": "json"} + + result = app(LOAD_GW_EVENT, {}) + + assert result["statusCode"] == 200 + assert result["multiValueHeaders"]["Content-Type"] == ["application/json"] + assert isinstance(json.loads(result["body"]), Dict) + assert "OpenAPI JSON View" in result["body"] + + +def test_openapi_swagger_json_view_with_custom_path(): + app = APIGatewayRestResolver(enable_validation=True) + app.enable_swagger(path="/fizzbuzz/foobar", title="OpenAPI JSON View") + LOAD_GW_EVENT["path"] = "/fizzbuzz/foobar" + LOAD_GW_EVENT["queryStringParameters"] = {"format": "json"} + + result = app(LOAD_GW_EVENT, {}) + + assert result["statusCode"] == 200 + assert result["multiValueHeaders"]["Content-Type"] == ["application/json"] + assert isinstance(json.loads(result["body"]), Dict) + assert "OpenAPI JSON View" in result["body"]