Skip to content

Commit a84abab

Browse files
authored
[WebServer] Refactor routing to allow same path for websocket and web requests (#962)
* Switch to WS * Refactor
1 parent 474cce1 commit a84abab

File tree

1 file changed

+60
-56
lines changed

1 file changed

+60
-56
lines changed

proxy/http/server/web.py

+60-56
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
import mimetypes
1616

17-
from typing import List, Optional, Dict, Union, Any, Pattern
17+
from typing import List, Optional, Dict, Tuple, Union, Any, Pattern
1818

1919
from ...common.constants import DEFAULT_STATIC_SERVER_DIR
2020
from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER
@@ -28,7 +28,7 @@
2828
from ..websocket import WebsocketFrame, websocketOpcodes
2929
from ..parser import HttpParser, httpParserTypes
3030
from ..protocols import httpProtocols
31-
from ..responses import NOT_FOUND_RESPONSE_PKT, NOT_IMPLEMENTED_RESPONSE_PKT, okResponse
31+
from ..responses import NOT_FOUND_RESPONSE_PKT, okResponse
3232

3333
from .plugin import HttpWebServerBasePlugin
3434
from .protocols import httpProtocolTypes
@@ -138,65 +138,28 @@ def read_and_build_static_file_response(path: str) -> memoryview:
138138
except FileNotFoundError:
139139
return NOT_FOUND_RESPONSE_PKT
140140

141-
def try_upgrade(self) -> bool:
142-
if self.request.has_header(b'connection') and \
143-
self.request.header(b'connection').lower() == b'upgrade':
144-
if self.request.has_header(b'upgrade') and \
145-
self.request.header(b'upgrade').lower() == b'websocket':
146-
self.client.queue(
147-
memoryview(
148-
build_websocket_handshake_response(
149-
WebsocketFrame.key_to_accept(
150-
self.request.header(b'Sec-WebSocket-Key'),
151-
),
152-
),
141+
def switch_to_websocket(self) -> None:
142+
self.client.queue(
143+
memoryview(
144+
build_websocket_handshake_response(
145+
WebsocketFrame.key_to_accept(
146+
self.request.header(b'Sec-WebSocket-Key'),
153147
),
154-
)
155-
self.switched_protocol = httpProtocolTypes.WEBSOCKET
156-
else:
157-
self.client.queue(NOT_IMPLEMENTED_RESPONSE_PKT)
158-
return True
159-
return False
148+
),
149+
),
150+
)
151+
self.switched_protocol = httpProtocolTypes.WEBSOCKET
160152

161153
def on_request_complete(self) -> Union[socket.socket, bool]:
162154
path = self.request.path or b'/'
163-
# Routing for Http(s) requests
164-
protocol = httpProtocolTypes.HTTPS \
165-
if self.encryption_enabled() else \
166-
httpProtocolTypes.HTTP
167-
for route in self.routes[protocol]:
168-
if route.match(text_(path)):
169-
self.route = self.routes[protocol][route]
170-
assert self.route
171-
self.route.handle_request(self.request)
172-
if self.request.has_header(b'connection') and \
173-
self.request.header(b'connection').lower() == b'close':
174-
return True
175-
return False
176-
# If a websocket route exists for the path, try upgrade
177-
for route in self.routes[httpProtocolTypes.WEBSOCKET]:
178-
if route.match(text_(path)):
179-
self.route = self.routes[httpProtocolTypes.WEBSOCKET][route]
180-
# Connection upgrade
181-
teardown = self.try_upgrade()
182-
if teardown:
183-
return True
184-
# For upgraded connections, nothing more to do
185-
if self.switched_protocol:
186-
# Invoke plugin.on_websocket_open
187-
assert self.route
188-
self.route.on_websocket_open()
189-
return False
190-
break
155+
# Try route
156+
teardown = self._try_route(path)
157+
if teardown:
158+
return teardown
191159
# No-route found, try static serving if enabled
192-
if self.flags.enable_static_server:
193-
path = text_(path).split('?', 1)[0]
194-
self.client.queue(
195-
self.read_and_build_static_file_response(
196-
self.flags.static_server_dir + path,
197-
),
198-
)
199-
return True
160+
teardown = self._try_static_file(path)
161+
if teardown:
162+
return teardown
200163
# Catch all unhandled web server requests, return 404
201164
self.client.queue(NOT_FOUND_RESPONSE_PKT)
202165
return True
@@ -305,3 +268,44 @@ def on_client_connection_close(self) -> None:
305268

306269
def access_log(self, context: Dict[str, Any]) -> None:
307270
logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context))
271+
272+
@property
273+
def _protocol(self) -> Tuple[bool, int]:
274+
do_ws_upgrade = self.request.is_connection_upgrade and \
275+
self.request.header(b'upgrade').lower() == b'websocket'
276+
return do_ws_upgrade, httpProtocolTypes.WEBSOCKET \
277+
if do_ws_upgrade \
278+
else httpProtocolTypes.HTTPS \
279+
if self.encryption_enabled() \
280+
else httpProtocolTypes.HTTP
281+
282+
def _try_route(self, path: bytes) -> bool:
283+
do_ws_upgrade, protocol = self._protocol
284+
for route in self.routes[protocol]:
285+
if route.match(text_(path)):
286+
self.route = self.routes[protocol][route]
287+
assert self.route
288+
# Optionally, upgrade protocol
289+
if do_ws_upgrade:
290+
self.switch_to_websocket()
291+
assert self.route
292+
# Invoke plugin.on_websocket_open
293+
self.route.on_websocket_open()
294+
else:
295+
# Invoke plugin.handle_request
296+
self.route.handle_request(self.request)
297+
if self.request.has_header(b'connection') and \
298+
self.request.header(b'connection').lower() == b'close':
299+
return True
300+
return False
301+
302+
def _try_static_file(self, path: bytes) -> bool:
303+
if self.flags.enable_static_server:
304+
path = text_(path).split('?', 1)[0]
305+
self.client.queue(
306+
self.read_and_build_static_file_response(
307+
self.flags.static_server_dir + path,
308+
),
309+
)
310+
return True
311+
return False

0 commit comments

Comments
 (0)