|
14 | 14 | import logging
|
15 | 15 | import mimetypes
|
16 | 16 |
|
17 |
| -from typing import List, Optional, Dict, Union, Any, Pattern |
| 17 | +from typing import List, Optional, Dict, Tuple, Union, Any, Pattern |
18 | 18 |
|
19 | 19 | from ...common.constants import DEFAULT_STATIC_SERVER_DIR
|
20 | 20 | from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER
|
|
28 | 28 | from ..websocket import WebsocketFrame, websocketOpcodes
|
29 | 29 | from ..parser import HttpParser, httpParserTypes
|
30 | 30 | from ..protocols import httpProtocols
|
31 |
| -from ..responses import NOT_FOUND_RESPONSE_PKT, NOT_IMPLEMENTED_RESPONSE_PKT, okResponse |
| 31 | +from ..responses import NOT_FOUND_RESPONSE_PKT, okResponse |
32 | 32 |
|
33 | 33 | from .plugin import HttpWebServerBasePlugin
|
34 | 34 | from .protocols import httpProtocolTypes
|
@@ -138,65 +138,28 @@ def read_and_build_static_file_response(path: str) -> memoryview:
|
138 | 138 | except FileNotFoundError:
|
139 | 139 | return NOT_FOUND_RESPONSE_PKT
|
140 | 140 |
|
141 |
| - def try_upgrade(self) -> bool: |
142 |
| - if self.request.has_header(b'connection') and \ |
143 |
| - self.request.header(b'connection').lower() == b'upgrade': |
144 |
| - if self.request.has_header(b'upgrade') and \ |
145 |
| - self.request.header(b'upgrade').lower() == b'websocket': |
146 |
| - self.client.queue( |
147 |
| - memoryview( |
148 |
| - build_websocket_handshake_response( |
149 |
| - WebsocketFrame.key_to_accept( |
150 |
| - self.request.header(b'Sec-WebSocket-Key'), |
151 |
| - ), |
152 |
| - ), |
| 141 | + def switch_to_websocket(self) -> None: |
| 142 | + self.client.queue( |
| 143 | + memoryview( |
| 144 | + build_websocket_handshake_response( |
| 145 | + WebsocketFrame.key_to_accept( |
| 146 | + self.request.header(b'Sec-WebSocket-Key'), |
153 | 147 | ),
|
154 |
| - ) |
155 |
| - self.switched_protocol = httpProtocolTypes.WEBSOCKET |
156 |
| - else: |
157 |
| - self.client.queue(NOT_IMPLEMENTED_RESPONSE_PKT) |
158 |
| - return True |
159 |
| - return False |
| 148 | + ), |
| 149 | + ), |
| 150 | + ) |
| 151 | + self.switched_protocol = httpProtocolTypes.WEBSOCKET |
160 | 152 |
|
161 | 153 | def on_request_complete(self) -> Union[socket.socket, bool]:
|
162 | 154 | path = self.request.path or b'/'
|
163 |
| - # Routing for Http(s) requests |
164 |
| - protocol = httpProtocolTypes.HTTPS \ |
165 |
| - if self.encryption_enabled() else \ |
166 |
| - httpProtocolTypes.HTTP |
167 |
| - for route in self.routes[protocol]: |
168 |
| - if route.match(text_(path)): |
169 |
| - self.route = self.routes[protocol][route] |
170 |
| - assert self.route |
171 |
| - self.route.handle_request(self.request) |
172 |
| - if self.request.has_header(b'connection') and \ |
173 |
| - self.request.header(b'connection').lower() == b'close': |
174 |
| - return True |
175 |
| - return False |
176 |
| - # If a websocket route exists for the path, try upgrade |
177 |
| - for route in self.routes[httpProtocolTypes.WEBSOCKET]: |
178 |
| - if route.match(text_(path)): |
179 |
| - self.route = self.routes[httpProtocolTypes.WEBSOCKET][route] |
180 |
| - # Connection upgrade |
181 |
| - teardown = self.try_upgrade() |
182 |
| - if teardown: |
183 |
| - return True |
184 |
| - # For upgraded connections, nothing more to do |
185 |
| - if self.switched_protocol: |
186 |
| - # Invoke plugin.on_websocket_open |
187 |
| - assert self.route |
188 |
| - self.route.on_websocket_open() |
189 |
| - return False |
190 |
| - break |
| 155 | + # Try route |
| 156 | + teardown = self._try_route(path) |
| 157 | + if teardown: |
| 158 | + return teardown |
191 | 159 | # No-route found, try static serving if enabled
|
192 |
| - if self.flags.enable_static_server: |
193 |
| - path = text_(path).split('?', 1)[0] |
194 |
| - self.client.queue( |
195 |
| - self.read_and_build_static_file_response( |
196 |
| - self.flags.static_server_dir + path, |
197 |
| - ), |
198 |
| - ) |
199 |
| - return True |
| 160 | + teardown = self._try_static_file(path) |
| 161 | + if teardown: |
| 162 | + return teardown |
200 | 163 | # Catch all unhandled web server requests, return 404
|
201 | 164 | self.client.queue(NOT_FOUND_RESPONSE_PKT)
|
202 | 165 | return True
|
@@ -305,3 +268,44 @@ def on_client_connection_close(self) -> None:
|
305 | 268 |
|
306 | 269 | def access_log(self, context: Dict[str, Any]) -> None:
|
307 | 270 | logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context))
|
| 271 | + |
| 272 | + @property |
| 273 | + def _protocol(self) -> Tuple[bool, int]: |
| 274 | + do_ws_upgrade = self.request.is_connection_upgrade and \ |
| 275 | + self.request.header(b'upgrade').lower() == b'websocket' |
| 276 | + return do_ws_upgrade, httpProtocolTypes.WEBSOCKET \ |
| 277 | + if do_ws_upgrade \ |
| 278 | + else httpProtocolTypes.HTTPS \ |
| 279 | + if self.encryption_enabled() \ |
| 280 | + else httpProtocolTypes.HTTP |
| 281 | + |
| 282 | + def _try_route(self, path: bytes) -> bool: |
| 283 | + do_ws_upgrade, protocol = self._protocol |
| 284 | + for route in self.routes[protocol]: |
| 285 | + if route.match(text_(path)): |
| 286 | + self.route = self.routes[protocol][route] |
| 287 | + assert self.route |
| 288 | + # Optionally, upgrade protocol |
| 289 | + if do_ws_upgrade: |
| 290 | + self.switch_to_websocket() |
| 291 | + assert self.route |
| 292 | + # Invoke plugin.on_websocket_open |
| 293 | + self.route.on_websocket_open() |
| 294 | + else: |
| 295 | + # Invoke plugin.handle_request |
| 296 | + self.route.handle_request(self.request) |
| 297 | + if self.request.has_header(b'connection') and \ |
| 298 | + self.request.header(b'connection').lower() == b'close': |
| 299 | + return True |
| 300 | + return False |
| 301 | + |
| 302 | + def _try_static_file(self, path: bytes) -> bool: |
| 303 | + if self.flags.enable_static_server: |
| 304 | + path = text_(path).split('?', 1)[0] |
| 305 | + self.client.queue( |
| 306 | + self.read_and_build_static_file_response( |
| 307 | + self.flags.static_server_dir + path, |
| 308 | + ), |
| 309 | + ) |
| 310 | + return True |
| 311 | + return False |
0 commit comments