diff --git a/azure-functions-extension-fastapi/azure/functions/extension/fastapi/__init__.py b/azure-functions-extension-fastapi/azure/functions/extension/fastapi/__init__.py index 3e45af9..5a2228b 100644 --- a/azure-functions-extension-fastapi/azure/functions/extension/fastapi/__init__.py +++ b/azure-functions-extension-fastapi/azure/functions/extension/fastapi/__init__.py @@ -10,13 +10,14 @@ UJSONResponse, ) -from .web import WebApp, WebServer +from .web import RequestSynchronizer, WebApp, WebServer __all__ = [ "WebServer", "WebApp", "Request", "Response", + "RequestSynchronizer", "StreamingResponse", "HTMLResponse", "PlainTextResponse", diff --git a/azure-functions-extension-fastapi/azure/functions/extension/fastapi/web.py b/azure-functions-extension-fastapi/azure/functions/extension/fastapi/web.py index e799504..2b55e1f 100644 --- a/azure-functions-extension-fastapi/azure/functions/extension/fastapi/web.py +++ b/azure-functions-extension-fastapi/azure/functions/extension/fastapi/web.py @@ -5,6 +5,7 @@ import uvicorn from azure.functions.extension.base import ( + RequestSynchronizer, RequestTrackerMeta, ResponseLabels, ResponseTrackerMeta, @@ -22,10 +23,24 @@ from fastapi.responses import RedirectResponse as FastApiRedirectResponse from fastapi.responses import StreamingResponse as FastApiStreamingResponse from fastapi.responses import UJSONResponse as FastApiUJSONResponse +from pydantic import BaseModel + + +class RequestSynchronizer(RequestSynchronizer): + def sync_route_params(self, request, path_params): + # add null checks for request and path_params + if request is None: + raise TypeError("Request object is None") + if path_params is None: + raise TypeError("Path parameters are None") + + request.path_params.clear() + request.path_params.update(path_params) class Request(metaclass=RequestTrackerMeta): request_type = FastApiRequest + synchronizer = RequestSynchronizer() class Response(metaclass=ResponseTrackerMeta): @@ -73,6 +88,41 @@ class FileResponse(metaclass=ResponseTrackerMeta): response_type = FastApiFileResponse +class StrResponse(metaclass=ResponseTrackerMeta): + label = ResponseLabels.STR + response_type = str + + +class DictResponse(metaclass=ResponseTrackerMeta): + label = ResponseLabels.DICT + response_type = dict + + +class BoolResponse(metaclass=ResponseTrackerMeta): + label = ResponseLabels.BOOL + response_type = bool + + +class PydanticResponse(metaclass=ResponseTrackerMeta): + label = ResponseLabels.PYDANTIC + response_type = BaseModel + + +class IntResponse(metaclass=ResponseTrackerMeta): + label = ResponseLabels.INT + response_type = int + + +class FloatResponse(metaclass=ResponseTrackerMeta): + label = ResponseLabels.FLOAT + response_type = float + + +class ListResponse(metaclass=ResponseTrackerMeta): + label = ResponseLabels.LIST + response_type = list + + class WebApp(WebApp): def __init__(self): self.web_app = FastAPI() diff --git a/azure-functions-extension-fastapi/pyproject.toml b/azure-functions-extension-fastapi/pyproject.toml index 2be4135..f12f1b8 100644 --- a/azure-functions-extension-fastapi/pyproject.toml +++ b/azure-functions-extension-fastapi/pyproject.toml @@ -27,7 +27,8 @@ classifiers= [ dependencies = [ 'azure-functions-extension-base', 'fastapi==0.110.0', - 'uvicorn==0.28.0' + 'uvicorn==0.28.0', + 'pydantic==2.6.4', ] [project.optional-dependencies] diff --git a/azure-functions-extension-fastapi/tests/test_web.py b/azure-functions-extension-fastapi/tests/test_web.py index 32c6be2..ee2bf64 100644 --- a/azure-functions-extension-fastapi/tests/test_web.py +++ b/azure-functions-extension-fastapi/tests/test_web.py @@ -7,17 +7,25 @@ ResponseLabels, ResponseTrackerMeta, ) -from azure.functions.extension.fastapi import JSONResponse -from azure.functions.extension.fastapi import Request as FastApiRequest -from azure.functions.extension.fastapi import Response as FastApiResponse -from azure.functions.extension.fastapi import WebApp, WebServer +from azure.functions.extension.fastapi import RequestSynchronizer, WebApp, WebServer from fastapi import FastAPI +from fastapi import Request as FastApiRequest +from fastapi import Response as FastApiResponse +from fastapi.responses import FileResponse as FastApiFileResponse +from fastapi.responses import HTMLResponse as FastApiHTMLResponse +from fastapi.responses import JSONResponse as FastApiJSONResponse +from fastapi.responses import ORJSONResponse as FastApiORJSONResponse +from fastapi.responses import PlainTextResponse as FastApiPlainTextResponse +from fastapi.responses import RedirectResponse as FastApiRedirectResponse +from fastapi.responses import StreamingResponse as FastApiStreamingResponse +from fastapi.responses import UJSONResponse as FastApiUJSONResponse class TestRequestTrackerMeta(unittest.TestCase): def test_request_type_defined(self): class Request(metaclass=RequestTrackerMeta): request_type = FastApiRequest + synchronizer = RequestSynchronizer() self.assertTrue(hasattr(Request, "request_type")) self.assertEqual(Request.request_type, FastApiRequest) @@ -59,7 +67,7 @@ class Response1(metaclass=ResponseTrackerMeta): class Response2(metaclass=ResponseTrackerMeta): label = ResponseLabels.STANDARD - response_type = JSONResponse + response_type = FastApiJSONResponse self.assertTrue( "Only one response type shall be recorded" in str(context.exception) @@ -126,3 +134,108 @@ async def serve(): async def run_serve(self): await self.web_server.serve() + + +class TestRequestSynchronizer(unittest.TestCase): + def test_sync_route_params(self): + # Create a mock request object + mock_request = MagicMock() + + # Define some path parameters + path_params = {"param1": "value1", "param2": "value2"} + + # Create an instance of the ConcreteRequestSynchronizer + synchronizer = RequestSynchronizer() + + # Call the sync_route_params method with the mock request and path parameters + synchronizer.sync_route_params(mock_request, path_params) + + # Assert that the request's path_params have been updated with the provided path parameters + mock_request.path_params.clear.assert_called_once() + mock_request.path_params.update.assert_called_once_with(path_params) + + def test_sync_route_params_missing_request(self): + # Create an instance of the ConcreteRequestSynchronizer + synchronizer = RequestSynchronizer() + + # Define some path parameters + path_params = {"param1": "value1", "param2": "value2"} + + # Call the sync_route_params method with a None request and path parameters + with self.assertRaises(TypeError): + synchronizer.sync_route_params(None, path_params) + + def test_sync_route_params_missing_path_params(self): + # Create a mock request object + mock_request = MagicMock() + + # Create an instance of the ConcreteRequestSynchronizer + synchronizer = RequestSynchronizer() + + # Call the sync_route_params method with the mock request and None path parameters + with self.assertRaises(TypeError): + synchronizer.sync_route_params(mock_request, None) + + +class TestExtensionClasses(unittest.TestCase): + def test_request(self): + from azure.functions.extension.fastapi.web import Request + + self.assertEqual(RequestTrackerMeta.get_request_type(), FastApiRequest) + self.assertTrue( + isinstance(RequestTrackerMeta.get_synchronizer(), RequestSynchronizer) + ) + + def test_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.STANDARD), + FastApiResponse, + ) + + def test_streaming_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.STREAMING), + FastApiStreamingResponse, + ) + + def test_html_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.HTML), + FastApiHTMLResponse, + ) + + def test_plain_text_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.PLAIN_TEXT), + FastApiPlainTextResponse, + ) + + def test_redirect_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.REDIRECT), + FastApiRedirectResponse, + ) + + def test_json_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.JSON), + FastApiJSONResponse, + ) + + def test_ujson_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.UJSON), + FastApiUJSONResponse, + ) + + def test_orjson_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.ORJSON), + FastApiORJSONResponse, + ) + + def test_file_response(self): + self.assertEqual( + ResponseTrackerMeta.get_response_type(ResponseLabels.FILE), + FastApiFileResponse, + )