Skip to content

Http V2 (Part 1): Add req sychroznier and built in resp types #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .web import (
HttpV2FeatureChecker,
ModuleTrackerMeta,
RequestSynchronizer,
RequestTrackerMeta,
ResponseLabels,
ResponseTrackerMeta,
Expand All @@ -35,6 +36,7 @@
"ResponseLabels",
"WebServer",
"WebApp",
"RequestSynchronizer",
]

__version__ = "1.0.0a2"
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import inspect
from abc import abstractmethod
from enum import Enum
Expand Down Expand Up @@ -34,28 +35,37 @@ def module_imported(cls):

class RequestTrackerMeta(type):
_request_type = None
_synchronizer: None

def __new__(cls, name, bases, dct, **kwargs):
new_class = super().__new__(cls, name, bases, dct)

request_type = dct.get("request_type")

if request_type is None:
raise Exception(f"Request type not provided for class {name}")
raise TypeError(f"Request type not provided for class {name}")

if cls._request_type is not None and cls._request_type != request_type:
raise Exception(
raise TypeError(
f"Only one request type shall be recorded for class {name} "
f"but found {cls._request_type} and {request_type}"
)
cls._request_type = request_type
cls._synchronizer = dct.get("synchronizer")

if cls._synchronizer is None:
raise TypeError(f"Request synchronizer not provided for class {name}")

return new_class

@classmethod
def get_request_type(cls):
return cls._request_type

@classmethod
def get_synchronizer(cls):
return cls._synchronizer

@classmethod
def check_type(cls, pytype: type) -> bool:
if pytype is not None and inspect.isclass(pytype):
Expand All @@ -65,6 +75,12 @@ def check_type(cls, pytype: type) -> bool:
return False


class RequestSynchronizer(abc.ABC):
@abstractmethod
def sync_route_params(self, request, path_params):
raise NotImplementedError()


class ResponseTrackerMeta(type):
_response_types = {}

Expand All @@ -75,14 +91,14 @@ def __new__(cls, name, bases, dct, **kwargs):
response_type = dct.get("response_type")

if label is None:
raise Exception(f"Response label not provided for class {name}")
raise TypeError(f"Response label not provided for class {name}")
if response_type is None:
raise Exception(f"Response type not provided for class {name}")
raise TypeError(f"Response type not provided for class {name}")
if (
cls._response_types.get(label) is not None
and cls._response_types.get(label) != response_type
):
raise Exception(
raise TypeError(
f"Only one response type shall be recorded for class {name} "
f"but found {cls._response_types.get(label)} and {response_type}"
)
Expand All @@ -109,25 +125,29 @@ def check_type(cls, pytype: type) -> bool:
return False


class WebApp(metaclass=ModuleTrackerMeta):
class ABCModuleTrackerMeta(abc.ABCMeta, ModuleTrackerMeta):
pass


class WebApp(metaclass=ABCModuleTrackerMeta):
@abstractmethod
def route(self, func: Callable):
pass
raise NotImplementedError()

@abstractmethod
def get_app(self):
pass
raise NotImplementedError()


class WebServer(metaclass=ModuleTrackerMeta):
class WebServer(metaclass=ABCModuleTrackerMeta):
def __init__(self, hostname, port, web_app: WebApp):
self.hostname = hostname
self.port = port
self.web_app = web_app.get_app()

@abstractmethod
async def serve(self):
pass
raise NotImplementedError() # pragma: no cover


class HttpV2FeatureChecker:
Expand All @@ -146,3 +166,10 @@ class ResponseLabels(Enum):
PLAIN_TEXT = "plain_text"
REDIRECT = "redirect"
UJSON = "ujson"
INT = "int"
FLOAT = "float"
STR = "str"
LIST = "list"
DICT = "dict"
BOOL = "bool"
PYDANTIC = "pydantic"
154 changes: 152 additions & 2 deletions azure-functions-extension-base/tests/test_web.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from azure.functions.extension.base import (
HttpV2FeatureChecker,
ModuleTrackerMeta,
RequestSynchronizer,
RequestTrackerMeta,
ResponseLabels,
ResponseTrackerMeta,
Expand Down Expand Up @@ -66,6 +67,10 @@ class TestRequest2:
class TestRequest3:
pass

class Syncronizer(RequestSynchronizer):
def sync_route_params(self, request, path_params):
pass

def setUp(self):
# Reset _request_type before each test
RequestTrackerMeta._request_type = None
Expand All @@ -81,42 +86,71 @@ class TestClass(metaclass=RequestTrackerMeta):
str(context.exception), "Request type not provided for class TestClass"
)

def test_request_synchronizer_not_provided(self):
# Define a class without providing the synchronizer attribute
with self.assertRaises(Exception) as context:

class TestClass(metaclass=RequestTrackerMeta):
request_type = self.TestRequest1

self.assertEqual(
str(context.exception),
"Request synchronizer not provided for class TestClass",
)

def test_single_request_type(self):
# Define a class providing a request_type attribute
class TestClass(metaclass=RequestTrackerMeta):
request_type = self.TestRequest1
synchronizer = self.Syncronizer()

# Ensure the request_type is correctly recorded
self.assertEqual(RequestTrackerMeta.get_request_type(), self.TestRequest1)
self.assertTrue(
isinstance(RequestTrackerMeta.get_synchronizer(), RequestSynchronizer)
)
# Ensure check_type returns True for the provided request_type
self.assertTrue(RequestTrackerMeta.check_type(self.TestRequest1))
self.assertFalse(RequestTrackerMeta.check_type(self.TestRequest2))

def test_multiple_request_types_same(self):
# Define a class providing the same request_type attribute
class TestClass1(metaclass=RequestTrackerMeta):
request_type = self.TestRequest1
synchronizer = self.Syncronizer()

# Ensure the request_type is correctly recorded
self.assertEqual(RequestTrackerMeta.get_request_type(), self.TestRequest1)
self.assertTrue(
isinstance(RequestTrackerMeta.get_synchronizer(), RequestSynchronizer)
)
# Ensure check_type returns True for the provided request_type
self.assertTrue(RequestTrackerMeta.check_type(self.TestRequest1))

# Define another class providing the same request_type attribute
class TestClass2(metaclass=RequestTrackerMeta):
request_type = self.TestRequest1
synchronizer = self.Syncronizer()

# Ensure the request_type remains the same
self.assertEqual(RequestTrackerMeta.get_request_type(), self.TestRequest1)
self.assertTrue(
isinstance(RequestTrackerMeta.get_synchronizer(), RequestSynchronizer)
)
# Ensure check_type still returns True for the original request_type
self.assertTrue(RequestTrackerMeta.check_type(self.TestRequest1))

def test_multiple_request_types_different(self):
# Define a class providing a different request_type attribute
class TestClass1(metaclass=RequestTrackerMeta):
request_type = self.TestRequest1
synchronizer = self.Syncronizer()

# Ensure the request_type is correctly recorded
self.assertEqual(RequestTrackerMeta.get_request_type(), self.TestRequest1)
self.assertTrue(
isinstance(RequestTrackerMeta.get_synchronizer(), RequestSynchronizer)
)
# Ensure check_type returns True for the provided request_type
self.assertTrue(RequestTrackerMeta.check_type(self.TestRequest1))

Expand All @@ -134,9 +168,30 @@ class TestClass2(metaclass=RequestTrackerMeta):

# Ensure the request_type remains the same after the exception
self.assertEqual(RequestTrackerMeta.get_request_type(), self.TestRequest1)
self.assertTrue(
isinstance(RequestTrackerMeta.get_synchronizer(), RequestSynchronizer)
)
# Ensure check_type still returns True for the original request_type
self.assertTrue(RequestTrackerMeta.check_type(self.TestRequest1))

def test_pytype_is_none(self):
self.assertFalse(RequestTrackerMeta.check_type(None))

def test_pytype_is_not_class(self):
self.assertFalse(RequestTrackerMeta.check_type("string"))

def test_sync_route_params_raises_not_implemented_error(self):
class MockSyncronizer(RequestSynchronizer):
def sync_route_params(self, request, path_params):
super().sync_route_params(request, path_params)

# Create an instance of RequestSynchronizer
synchronizer = MockSyncronizer()

# Ensure that calling sync_route_params raises NotImplementedError
with self.assertRaises(NotImplementedError):
synchronizer.sync_route_params(None, None)


class TestResponseTrackerMeta(unittest.TestCase):
class MockResponse1:
Expand Down Expand Up @@ -208,13 +263,36 @@ class TestResponse2(metaclass=ResponseTrackerMeta):
ResponseTrackerMeta.get_response_type(ResponseLabels.STANDARD),
self.MockResponse1,
)
self.assertEqual(
ResponseTrackerMeta.get_standard_response_type(), self.MockResponse1
)
self.assertEqual(
ResponseTrackerMeta.get_response_type(ResponseLabels.STREAMING),
self.MockResponse2,
)
self.assertTrue(ResponseTrackerMeta.check_type(self.MockResponse1))
self.assertTrue(ResponseTrackerMeta.check_type(self.MockResponse2))

def test_response_label_not_provided(self):
with self.assertRaises(Exception) as context:

class TestResponse(metaclass=ResponseTrackerMeta):
response_type = self.MockResponse1

self.assertEqual(
str(context.exception), "Response label not provided for class TestResponse"
)

def test_response_type_not_provided(self):
with self.assertRaises(Exception) as context:

class TestResponse(metaclass=ResponseTrackerMeta):
label = "test_label_1"

self.assertEqual(
str(context.exception), "Response type not provided for class TestResponse"
)


class TestWebApp(unittest.TestCase):
def test_route_and_get_app(self):
Expand All @@ -228,6 +306,34 @@ def get_app(self):
app = MockWebApp()
self.assertEqual(app.get_app(), "MockApp")

def test_route_method_raises_not_implemented_error(self):
class MockWebApp(WebApp):
def get_app(self):
pass

def route(self, func):
super().route(func)

with self.assertRaises(NotImplementedError):
# Create a mock WebApp instance
mock_web_app = MockWebApp()
# Call the route method
mock_web_app.route(None)

def test_get_app_method_raises_not_implemented_error(self):
class MockWebApp(WebApp):
def route(self, func):
pass

def get_app(self):
super().get_app()

with self.assertRaises(NotImplementedError):
# Create a mock WebApp instance
mock_web_app = MockWebApp()
# Call the get_app method
mock_web_app.get_app()


class TestWebServer(unittest.TestCase):
def test_web_server_initialization(self):
Expand All @@ -238,12 +344,36 @@ def route(self, func):
def get_app(self):
return "MockApp"

class MockWebServer(WebServer):
async def serve(self):
pass

mock_web_app = MockWebApp()
server = WebServer("localhost", 8080, mock_web_app)
server = MockWebServer("localhost", 8080, mock_web_app)
self.assertEqual(server.hostname, "localhost")
self.assertEqual(server.port, 8080)
self.assertEqual(server.web_app, "MockApp")

async def test_serve_method_raises_not_implemented_error(self):
# Create a mock WebApp instance
class MockWebApp(WebApp):
def route(self, func):
pass

def get_app(self):
pass

class MockWebServer(WebServer):
async def serve(self):
super().serve()

# Create a WebServer instance with the mock WebApp
server = MockWebServer("localhost", 8080, MockWebApp())

# Ensure that calling the serve method raises NotImplementedError
with self.assertRaises(NotImplementedError):
await server.serve()


class TestHttpV2Enabled(unittest.TestCase):
@patch("azure.functions.extension.base.ModuleTrackerMeta.module_imported")
Expand All @@ -253,3 +383,23 @@ def test_http_v2_enabled(self, mock_module_imported):

mock_module_imported.return_value = False
self.assertFalse(HttpV2FeatureChecker.http_v2_enabled())


class TestResponseLabels(unittest.TestCase):
def test_enum_values(self):
self.assertEqual(ResponseLabels.STANDARD.value, "standard")
self.assertEqual(ResponseLabels.STREAMING.value, "streaming")
self.assertEqual(ResponseLabels.FILE.value, "file")
self.assertEqual(ResponseLabels.HTML.value, "html")
self.assertEqual(ResponseLabels.JSON.value, "json")
self.assertEqual(ResponseLabels.ORJSON.value, "orjson")
self.assertEqual(ResponseLabels.PLAIN_TEXT.value, "plain_text")
self.assertEqual(ResponseLabels.REDIRECT.value, "redirect")
self.assertEqual(ResponseLabels.UJSON.value, "ujson")
self.assertEqual(ResponseLabels.INT.value, "int")
self.assertEqual(ResponseLabels.FLOAT.value, "float")
self.assertEqual(ResponseLabels.STR.value, "str")
self.assertEqual(ResponseLabels.LIST.value, "list")
self.assertEqual(ResponseLabels.DICT.value, "dict")
self.assertEqual(ResponseLabels.BOOL.value, "bool")
self.assertEqual(ResponseLabels.PYDANTIC.value, "pydantic")