1
1
import unittest
2
- from unittest .mock import patch
2
+ from unittest .mock import MagicMock , patch
3
3
4
4
from azure .functions .extension .base import (
5
5
HttpV2FeatureChecker ,
6
6
ModuleTrackerMeta ,
7
+ RequestSynchronizer ,
7
8
RequestTrackerMeta ,
8
9
ResponseLabels ,
9
10
ResponseTrackerMeta ,
@@ -66,6 +67,10 @@ class TestRequest2:
66
67
class TestRequest3 :
67
68
pass
68
69
70
+ class Syncronizer (RequestSynchronizer ):
71
+ def sync_route_params (self , request , path_params ):
72
+ pass
73
+
69
74
def setUp (self ):
70
75
# Reset _request_type before each test
71
76
RequestTrackerMeta ._request_type = None
@@ -81,42 +86,71 @@ class TestClass(metaclass=RequestTrackerMeta):
81
86
str (context .exception ), "Request type not provided for class TestClass"
82
87
)
83
88
89
+ def test_request_synchronizer_not_provided (self ):
90
+ # Define a class without providing the synchronizer attribute
91
+ with self .assertRaises (Exception ) as context :
92
+
93
+ class TestClass (metaclass = RequestTrackerMeta ):
94
+ request_type = self .TestRequest1
95
+
96
+ self .assertEqual (
97
+ str (context .exception ),
98
+ "Request synchronizer not provided for class TestClass" ,
99
+ )
100
+
84
101
def test_single_request_type (self ):
85
102
# Define a class providing a request_type attribute
86
103
class TestClass (metaclass = RequestTrackerMeta ):
87
104
request_type = self .TestRequest1
105
+ synchronizer = self .Syncronizer ()
88
106
89
107
# Ensure the request_type is correctly recorded
90
108
self .assertEqual (RequestTrackerMeta .get_request_type (), self .TestRequest1 )
109
+ self .assertTrue (
110
+ isinstance (RequestTrackerMeta .get_synchronizer (), RequestSynchronizer )
111
+ )
91
112
# Ensure check_type returns True for the provided request_type
92
113
self .assertTrue (RequestTrackerMeta .check_type (self .TestRequest1 ))
114
+ self .assertFalse (RequestTrackerMeta .check_type (self .TestRequest2 ))
93
115
94
116
def test_multiple_request_types_same (self ):
95
117
# Define a class providing the same request_type attribute
96
118
class TestClass1 (metaclass = RequestTrackerMeta ):
97
119
request_type = self .TestRequest1
120
+ synchronizer = self .Syncronizer ()
98
121
99
122
# Ensure the request_type is correctly recorded
100
123
self .assertEqual (RequestTrackerMeta .get_request_type (), self .TestRequest1 )
124
+ self .assertTrue (
125
+ isinstance (RequestTrackerMeta .get_synchronizer (), RequestSynchronizer )
126
+ )
101
127
# Ensure check_type returns True for the provided request_type
102
128
self .assertTrue (RequestTrackerMeta .check_type (self .TestRequest1 ))
103
129
104
130
# Define another class providing the same request_type attribute
105
131
class TestClass2 (metaclass = RequestTrackerMeta ):
106
132
request_type = self .TestRequest1
133
+ synchronizer = self .Syncronizer ()
107
134
108
135
# Ensure the request_type remains the same
109
136
self .assertEqual (RequestTrackerMeta .get_request_type (), self .TestRequest1 )
137
+ self .assertTrue (
138
+ isinstance (RequestTrackerMeta .get_synchronizer (), RequestSynchronizer )
139
+ )
110
140
# Ensure check_type still returns True for the original request_type
111
141
self .assertTrue (RequestTrackerMeta .check_type (self .TestRequest1 ))
112
142
113
143
def test_multiple_request_types_different (self ):
114
144
# Define a class providing a different request_type attribute
115
145
class TestClass1 (metaclass = RequestTrackerMeta ):
116
146
request_type = self .TestRequest1
147
+ synchronizer = self .Syncronizer ()
117
148
118
149
# Ensure the request_type is correctly recorded
119
150
self .assertEqual (RequestTrackerMeta .get_request_type (), self .TestRequest1 )
151
+ self .assertTrue (
152
+ isinstance (RequestTrackerMeta .get_synchronizer (), RequestSynchronizer )
153
+ )
120
154
# Ensure check_type returns True for the provided request_type
121
155
self .assertTrue (RequestTrackerMeta .check_type (self .TestRequest1 ))
122
156
@@ -134,9 +168,30 @@ class TestClass2(metaclass=RequestTrackerMeta):
134
168
135
169
# Ensure the request_type remains the same after the exception
136
170
self .assertEqual (RequestTrackerMeta .get_request_type (), self .TestRequest1 )
171
+ self .assertTrue (
172
+ isinstance (RequestTrackerMeta .get_synchronizer (), RequestSynchronizer )
173
+ )
137
174
# Ensure check_type still returns True for the original request_type
138
175
self .assertTrue (RequestTrackerMeta .check_type (self .TestRequest1 ))
139
176
177
+ def test_pytype_is_none (self ):
178
+ self .assertFalse (RequestTrackerMeta .check_type (None ))
179
+
180
+ def test_pytype_is_not_class (self ):
181
+ self .assertFalse (RequestTrackerMeta .check_type ("string" ))
182
+
183
+ def test_sync_route_params_raises_not_implemented_error (self ):
184
+ class MockSyncronizer (RequestSynchronizer ):
185
+ def sync_route_params (self , request , path_params ):
186
+ super ().sync_route_params (request , path_params )
187
+
188
+ # Create an instance of RequestSynchronizer
189
+ synchronizer = MockSyncronizer ()
190
+
191
+ # Ensure that calling sync_route_params raises NotImplementedError
192
+ with self .assertRaises (NotImplementedError ):
193
+ synchronizer .sync_route_params (None , None )
194
+
140
195
141
196
class TestResponseTrackerMeta (unittest .TestCase ):
142
197
class MockResponse1 :
@@ -208,13 +263,36 @@ class TestResponse2(metaclass=ResponseTrackerMeta):
208
263
ResponseTrackerMeta .get_response_type (ResponseLabels .STANDARD ),
209
264
self .MockResponse1 ,
210
265
)
266
+ self .assertEqual (
267
+ ResponseTrackerMeta .get_standard_response_type (), self .MockResponse1
268
+ )
211
269
self .assertEqual (
212
270
ResponseTrackerMeta .get_response_type (ResponseLabels .STREAMING ),
213
271
self .MockResponse2 ,
214
272
)
215
273
self .assertTrue (ResponseTrackerMeta .check_type (self .MockResponse1 ))
216
274
self .assertTrue (ResponseTrackerMeta .check_type (self .MockResponse2 ))
217
275
276
+ def test_response_label_not_provided (self ):
277
+ with self .assertRaises (Exception ) as context :
278
+
279
+ class TestResponse (metaclass = ResponseTrackerMeta ):
280
+ response_type = self .MockResponse1
281
+
282
+ self .assertEqual (
283
+ str (context .exception ), "Response label not provided for class TestResponse"
284
+ )
285
+
286
+ def test_response_type_not_provided (self ):
287
+ with self .assertRaises (Exception ) as context :
288
+
289
+ class TestResponse (metaclass = ResponseTrackerMeta ):
290
+ label = "test_label_1"
291
+
292
+ self .assertEqual (
293
+ str (context .exception ), "Response type not provided for class TestResponse"
294
+ )
295
+
218
296
219
297
class TestWebApp (unittest .TestCase ):
220
298
def test_route_and_get_app (self ):
@@ -228,6 +306,34 @@ def get_app(self):
228
306
app = MockWebApp ()
229
307
self .assertEqual (app .get_app (), "MockApp" )
230
308
309
+ def test_route_method_raises_not_implemented_error (self ):
310
+ class MockWebApp (WebApp ):
311
+ def get_app (self ):
312
+ pass
313
+
314
+ def route (self , func ):
315
+ super ().route (func )
316
+
317
+ with self .assertRaises (NotImplementedError ):
318
+ # Create a mock WebApp instance
319
+ mock_web_app = MockWebApp ()
320
+ # Call the route method
321
+ mock_web_app .route (None )
322
+
323
+ def test_get_app_method_raises_not_implemented_error (self ):
324
+ class MockWebApp (WebApp ):
325
+ def route (self , func ):
326
+ pass
327
+
328
+ def get_app (self ):
329
+ super ().get_app ()
330
+
331
+ with self .assertRaises (NotImplementedError ):
332
+ # Create a mock WebApp instance
333
+ mock_web_app = MockWebApp ()
334
+ # Call the get_app method
335
+ mock_web_app .get_app ()
336
+
231
337
232
338
class TestWebServer (unittest .TestCase ):
233
339
def test_web_server_initialization (self ):
@@ -238,12 +344,36 @@ def route(self, func):
238
344
def get_app (self ):
239
345
return "MockApp"
240
346
347
+ class MockWebServer (WebServer ):
348
+ async def serve (self ):
349
+ pass
350
+
241
351
mock_web_app = MockWebApp ()
242
- server = WebServer ("localhost" , 8080 , mock_web_app )
352
+ server = MockWebServer ("localhost" , 8080 , mock_web_app )
243
353
self .assertEqual (server .hostname , "localhost" )
244
354
self .assertEqual (server .port , 8080 )
245
355
self .assertEqual (server .web_app , "MockApp" )
246
356
357
+ async def test_serve_method_raises_not_implemented_error (self ):
358
+ # Create a mock WebApp instance
359
+ class MockWebApp (WebApp ):
360
+ def route (self , func ):
361
+ pass
362
+
363
+ def get_app (self ):
364
+ pass
365
+
366
+ class MockWebServer (WebServer ):
367
+ async def serve (self ):
368
+ super ().serve ()
369
+
370
+ # Create a WebServer instance with the mock WebApp
371
+ server = MockWebServer ("localhost" , 8080 , MockWebApp ())
372
+
373
+ # Ensure that calling the serve method raises NotImplementedError
374
+ with self .assertRaises (NotImplementedError ):
375
+ await server .serve ()
376
+
247
377
248
378
class TestHttpV2Enabled (unittest .TestCase ):
249
379
@patch ("azure.functions.extension.base.ModuleTrackerMeta.module_imported" )
@@ -253,3 +383,23 @@ def test_http_v2_enabled(self, mock_module_imported):
253
383
254
384
mock_module_imported .return_value = False
255
385
self .assertFalse (HttpV2FeatureChecker .http_v2_enabled ())
386
+
387
+
388
+ class TestResponseLabels (unittest .TestCase ):
389
+ def test_enum_values (self ):
390
+ self .assertEqual (ResponseLabels .STANDARD .value , "standard" )
391
+ self .assertEqual (ResponseLabels .STREAMING .value , "streaming" )
392
+ self .assertEqual (ResponseLabels .FILE .value , "file" )
393
+ self .assertEqual (ResponseLabels .HTML .value , "html" )
394
+ self .assertEqual (ResponseLabels .JSON .value , "json" )
395
+ self .assertEqual (ResponseLabels .ORJSON .value , "orjson" )
396
+ self .assertEqual (ResponseLabels .PLAIN_TEXT .value , "plain_text" )
397
+ self .assertEqual (ResponseLabels .REDIRECT .value , "redirect" )
398
+ self .assertEqual (ResponseLabels .UJSON .value , "ujson" )
399
+ self .assertEqual (ResponseLabels .INT .value , "int" )
400
+ self .assertEqual (ResponseLabels .FLOAT .value , "float" )
401
+ self .assertEqual (ResponseLabels .STR .value , "str" )
402
+ self .assertEqual (ResponseLabels .LIST .value , "list" )
403
+ self .assertEqual (ResponseLabels .DICT .value , "dict" )
404
+ self .assertEqual (ResponseLabels .BOOL .value , "bool" )
405
+ self .assertEqual (ResponseLabels .PYDANTIC .value , "pydantic" )
0 commit comments