Skip to content

[WebServer] Refactor routing to allow same path for websocket and web requests #962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 11, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 60 additions & 56 deletions proxy/http/server/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import mimetypes

from typing import List, Optional, Dict, Union, Any, Pattern
from typing import List, Optional, Dict, Tuple, Union, Any, Pattern

from ...common.constants import DEFAULT_STATIC_SERVER_DIR
from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER
Expand All @@ -28,7 +28,7 @@
from ..websocket import WebsocketFrame, websocketOpcodes
from ..parser import HttpParser, httpParserTypes
from ..protocols import httpProtocols
from ..responses import NOT_FOUND_RESPONSE_PKT, NOT_IMPLEMENTED_RESPONSE_PKT, okResponse
from ..responses import NOT_FOUND_RESPONSE_PKT, okResponse

from .plugin import HttpWebServerBasePlugin
from .protocols import httpProtocolTypes
Expand Down Expand Up @@ -138,65 +138,28 @@ def read_and_build_static_file_response(path: str) -> memoryview:
except FileNotFoundError:
return NOT_FOUND_RESPONSE_PKT

def try_upgrade(self) -> bool:
if self.request.has_header(b'connection') and \
self.request.header(b'connection').lower() == b'upgrade':
if self.request.has_header(b'upgrade') and \
self.request.header(b'upgrade').lower() == b'websocket':
self.client.queue(
memoryview(
build_websocket_handshake_response(
WebsocketFrame.key_to_accept(
self.request.header(b'Sec-WebSocket-Key'),
),
),
def switch_to_websocket(self) -> None:
self.client.queue(
memoryview(
build_websocket_handshake_response(
WebsocketFrame.key_to_accept(
self.request.header(b'Sec-WebSocket-Key'),
),
)
self.switched_protocol = httpProtocolTypes.WEBSOCKET
else:
self.client.queue(NOT_IMPLEMENTED_RESPONSE_PKT)
return True
return False
),
),
)
self.switched_protocol = httpProtocolTypes.WEBSOCKET

def on_request_complete(self) -> Union[socket.socket, bool]:
path = self.request.path or b'/'
# Routing for Http(s) requests
protocol = httpProtocolTypes.HTTPS \
if self.encryption_enabled() else \
httpProtocolTypes.HTTP
for route in self.routes[protocol]:
if route.match(text_(path)):
self.route = self.routes[protocol][route]
assert self.route
self.route.handle_request(self.request)
if self.request.has_header(b'connection') and \
self.request.header(b'connection').lower() == b'close':
return True
return False
# If a websocket route exists for the path, try upgrade
for route in self.routes[httpProtocolTypes.WEBSOCKET]:
if route.match(text_(path)):
self.route = self.routes[httpProtocolTypes.WEBSOCKET][route]
# Connection upgrade
teardown = self.try_upgrade()
if teardown:
return True
# For upgraded connections, nothing more to do
if self.switched_protocol:
# Invoke plugin.on_websocket_open
assert self.route
self.route.on_websocket_open()
return False
break
# Try route
teardown = self._try_route(path)
if teardown:
return teardown
# No-route found, try static serving if enabled
if self.flags.enable_static_server:
path = text_(path).split('?', 1)[0]
self.client.queue(
self.read_and_build_static_file_response(
self.flags.static_server_dir + path,
),
)
return True
teardown = self._try_static_file(path)
if teardown:
return teardown
# Catch all unhandled web server requests, return 404
self.client.queue(NOT_FOUND_RESPONSE_PKT)
return True
Expand Down Expand Up @@ -305,3 +268,44 @@ def on_client_connection_close(self) -> None:

def access_log(self, context: Dict[str, Any]) -> None:
logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context))

@property
def _protocol(self) -> Tuple[bool, int]:
do_ws_upgrade = self.request.is_connection_upgrade and \
self.request.header(b'upgrade').lower() == b'websocket'
return do_ws_upgrade, httpProtocolTypes.WEBSOCKET \
if do_ws_upgrade \
else httpProtocolTypes.HTTPS \
if self.encryption_enabled() \
else httpProtocolTypes.HTTP

def _try_route(self, path: bytes) -> bool:
do_ws_upgrade, protocol = self._protocol
for route in self.routes[protocol]:
if route.match(text_(path)):
self.route = self.routes[protocol][route]
assert self.route
# Optionally, upgrade protocol
if do_ws_upgrade:
self.switch_to_websocket()
assert self.route
# Invoke plugin.on_websocket_open
self.route.on_websocket_open()
else:
# Invoke plugin.handle_request
self.route.handle_request(self.request)
if self.request.has_header(b'connection') and \
self.request.header(b'connection').lower() == b'close':
return True
return False

def _try_static_file(self, path: bytes) -> bool:
if self.flags.enable_static_server:
path = text_(path).split('?', 1)[0]
self.client.queue(
self.read_and_build_static_file_response(
self.flags.static_server_dir + path,
),
)
return True
return False