Skip to content

Commit fb49c43

Browse files
committed
Refactor
1 parent 00758ea commit fb49c43

File tree

2 files changed

+49
-41
lines changed

2 files changed

+49
-41
lines changed

proxy/http/parser/parser.py

-7
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,6 @@ def is_http_1_1_keep_alive(self) -> bool:
172172
self.header(b'Connection').lower() == b'keep-alive'
173173
)
174174

175-
@property
176-
def is_websocket_upgrade(self) -> bool:
177-
return self.has_header(b'connection') and \
178-
self.header(b'connection').lower() == b'upgrade' and \
179-
self.has_header(b'upgrade') and \
180-
self.header(b'upgrade').lower() == b'websocket'
181-
182175
@property
183176
def is_connection_upgrade(self) -> bool:
184177
"""Returns true for websocket upgrade requests."""

proxy/http/server/web.py

+49-34
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
import mimetypes
1616

17-
from typing import List, Optional, Dict, Union, Any, Pattern
17+
from typing import List, Optional, Dict, Tuple, Union, Any, Pattern
1818

1919
from ...common.constants import DEFAULT_STATIC_SERVER_DIR
2020
from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER
@@ -152,40 +152,14 @@ def switch_to_websocket(self) -> None:
152152

153153
def on_request_complete(self) -> Union[socket.socket, bool]:
154154
path = self.request.path or b'/'
155-
# Routing for Http(s) requests
156-
protocol = httpProtocolTypes.HTTPS \
157-
if self.encryption_enabled() else \
158-
httpProtocolTypes.HTTP
159-
for route in self.routes[protocol]:
160-
if route.match(text_(path)):
161-
self.route = self.routes[protocol][route]
162-
assert self.route
163-
self.route.handle_request(self.request)
164-
if self.request.has_header(b'connection') and \
165-
self.request.header(b'connection').lower() == b'close':
166-
return True
167-
return False
168-
# If a websocket route exists for the path, try upgrade
169-
for route in self.routes[httpProtocolTypes.WEBSOCKET]:
170-
if route.match(text_(path)):
171-
self.route = self.routes[httpProtocolTypes.WEBSOCKET][route]
172-
# Connection upgrade
173-
if self.request.is_websocket_upgrade:
174-
self.switch_to_websocket()
175-
# Invoke plugin.on_websocket_open
176-
assert self.route
177-
self.route.on_websocket_open()
178-
return False
179-
break
155+
# Try route
156+
teardown = self._try_route(path)
157+
if teardown:
158+
return teardown
180159
# No-route found, try static serving if enabled
181-
if self.flags.enable_static_server:
182-
path = text_(path).split('?', 1)[0]
183-
self.client.queue(
184-
self.read_and_build_static_file_response(
185-
self.flags.static_server_dir + path,
186-
),
187-
)
188-
return True
160+
teardown = self._try_static_file(path)
161+
if teardown:
162+
return teardown
189163
# Catch all unhandled web server requests, return 404
190164
self.client.queue(NOT_FOUND_RESPONSE_PKT)
191165
return True
@@ -294,3 +268,44 @@ def on_client_connection_close(self) -> None:
294268

295269
def access_log(self, context: Dict[str, Any]) -> None:
296270
logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context))
271+
272+
@property
273+
def _protocol(self) -> Tuple[bool, int]:
274+
do_ws_upgrade = self.request.is_connection_upgrade and \
275+
self.request.header(b'upgrade').lower() == b'websocket'
276+
return do_ws_upgrade, httpProtocolTypes.WEBSOCKET \
277+
if do_ws_upgrade \
278+
else httpProtocolTypes.HTTPS \
279+
if self.encryption_enabled() \
280+
else httpProtocolTypes.HTTP
281+
282+
def _try_route(self, path: bytes) -> bool:
283+
do_ws_upgrade, protocol = self._protocol
284+
for route in self.routes[protocol]:
285+
if route.match(text_(path)):
286+
self.route = self.routes[protocol][route]
287+
assert self.route
288+
# Optionally, upgrade protocol
289+
if do_ws_upgrade:
290+
self.switch_to_websocket()
291+
assert self.route
292+
# Invoke plugin.on_websocket_open
293+
self.route.on_websocket_open()
294+
else:
295+
# Invoke plugin.handle_request
296+
self.route.handle_request(self.request)
297+
if self.request.has_header(b'connection') and \
298+
self.request.header(b'connection').lower() == b'close':
299+
return True
300+
return False
301+
302+
def _try_static_file(self, path: bytes) -> bool:
303+
if self.flags.enable_static_server:
304+
path = text_(path).split('?', 1)[0]
305+
self.client.queue(
306+
self.read_and_build_static_file_response(
307+
self.flags.static_server_dir + path,
308+
),
309+
)
310+
return True
311+
return False

0 commit comments

Comments
 (0)