diff --git a/examples/https_connect_tunnel.py b/examples/https_connect_tunnel.py index b138b7bdce..135a472b6b 100644 --- a/examples/https_connect_tunnel.py +++ b/examples/https_connect_tunnel.py @@ -15,7 +15,6 @@ from proxy import Proxy from proxy.common.utils import build_http_response from proxy.http import httpStatusCodes -from proxy.http.parser import httpParserStates from proxy.core.base import BaseTcpTunnelHandler @@ -58,7 +57,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]: # CONNECT requests are short and we need not worry about # receiving partial request bodies here. - assert self.request.state == httpParserStates.COMPLETE + assert self.request.is_complete # Establish connection with upstream self.connect_upstream() diff --git a/proxy/common/constants.py b/proxy/common/constants.py index b2e4f6d71f..9a41874574 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -45,6 +45,8 @@ def _env_threadless_compliant() -> bool: SLASH = b'/' HTTP_1_0 = b'HTTP/1.0' HTTP_1_1 = b'HTTP/1.1' +HTTP_URL_PREFIX = b'http://' +HTTPS_URL_PREFIX = b'https://' PROXY_AGENT_HEADER_KEY = b'Proxy-agent' PROXY_AGENT_HEADER_VALUE = b'proxy.py v' + \ diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 6c158e2d22..d8f28e035d 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -214,7 +214,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]: # TODO(abhinavsingh): Remove .tobytes after parser is # memoryview compliant self.request.parse(data.tobytes()) - if self.request.state == httpParserStates.COMPLETE: + if self.request.is_complete: # Invoke plugin.on_request_complete for plugin in self.plugins.values(): upgraded_sock = plugin.on_request_complete() diff --git a/proxy/http/parser/parser.py b/proxy/http/parser/parser.py index 3554cf2bfe..c446b34b15 100644 --- a/proxy/http/parser/parser.py +++ b/proxy/http/parser/parser.py @@ -17,7 +17,7 @@ from ...common.constants import DEFAULT_DISABLE_HEADERS, COLON, DEFAULT_ENABLE_PROXY_PROTOCOL from ...common.constants import HTTP_1_1, SLASH, CRLF from ...common.constants import WHITESPACE, DEFAULT_HTTP_PORT -from ...common.utils import build_http_request, build_http_response, find_http_line, text_ +from ...common.utils import build_http_request, build_http_response, text_ from ...common.flag import flags from ..url import Url @@ -63,10 +63,12 @@ def __init__( if enable_proxy_protocol: assert self.type == httpParserTypes.REQUEST_PARSER self.protocol = ProxyProtocol() + # Request attributes self.host: Optional[bytes] = None self.port: Optional[int] = None self.path: Optional[bytes] = None self.method: Optional[bytes] = None + # Response attributes self.code: Optional[bytes] = None self.reason: Optional[bytes] = None self.version: Optional[bytes] = None @@ -78,7 +80,7 @@ def __init__( # - Keys are lower case header names. # - Values are 2-tuple containing original # header and it's value as received. - self.headers: Dict[bytes, Tuple[bytes, bytes]] = {} + self.headers: Optional[Dict[bytes, Tuple[bytes, bytes]]] = None self.body: Optional[bytes] = None self.chunk: Optional[ChunkParser] = None # Internal request line as a url structure @@ -109,19 +111,24 @@ def response(cls: Type[T], raw: bytes) -> T: def header(self, key: bytes) -> bytes: """Convenient method to return original header value from internal data structure.""" - if key.lower() not in self.headers: + if self.headers is None or key.lower() not in self.headers: raise KeyError('%s not found in headers', text_(key)) return self.headers[key.lower()][1] def has_header(self, key: bytes) -> bool: """Returns true if header key was found in payload.""" + if self.headers is None: + return False return key.lower() in self.headers def add_header(self, key: bytes, value: bytes) -> bytes: """Add/Update a header to internal data structure. Returns key with which passed (key, value) tuple is available.""" + if self.headers is None: + self.headers = {} k = key.lower() + # k = key self.headers[k] = (key, value) return k @@ -132,7 +139,7 @@ def add_headers(self, headers: List[Tuple[bytes, bytes]]) -> None: def del_header(self, header: bytes) -> None: """Delete a header from internal data structure.""" - if header.lower() in self.headers: + if self.headers and header.lower() in self.headers: del self.headers[header.lower()] def del_headers(self, headers: List[bytes]) -> None: @@ -151,6 +158,10 @@ def has_host(self) -> bool: NOTE: Host field WILL be None for incoming local WebServer requests.""" return self.host is not None + @property + def is_complete(self) -> bool: + return self.state == httpParserStates.COMPLETE + @property def is_http_1_1_keep_alive(self) -> bool: """Returns true for HTTP/1.1 keep-alive connections.""" @@ -185,30 +196,34 @@ def content_expected(self) -> bool: @property def body_expected(self) -> bool: """Returns true if content or chunked response is expected.""" - return self.content_expected or self.is_chunked_encoded + return self._content_expected or self._is_chunked_encoded def parse(self, raw: bytes) -> None: """Parses HTTP request out of raw bytes. Check for `HttpParser.state` after `parse` has successfully returned.""" - self.total_size += len(raw) + size = len(raw) + self.total_size += size raw = self.buffer + raw - self.buffer, more = b'', len(raw) > 0 + self.buffer, more = b'', size > 0 while more and self.state != httpParserStates.COMPLETE: # gte with HEADERS_COMPLETE also encapsulated RCVING_BODY state - more, raw = self._process_body(raw) \ - if self.state >= httpParserStates.HEADERS_COMPLETE else \ - self._process_line_and_headers(raw) + if self.state >= httpParserStates.HEADERS_COMPLETE: + more, raw = self._process_body(raw) + elif self.state == httpParserStates.INITIALIZED: + more, raw = self._process_line(raw) + else: + more, raw = self._process_headers(raw) # When server sends a response line without any header or body e.g. # HTTP/1.1 200 Connection established\r\n\r\n - if self.state == httpParserStates.LINE_RCVD and \ - raw == CRLF and \ - self.type == httpParserTypes.RESPONSE_PARSER: + if self.type == httpParserTypes.RESPONSE_PARSER and \ + self.state == httpParserStates.LINE_RCVD and \ + raw == CRLF: self.state = httpParserStates.COMPLETE # Mark request as complete if headers received and no incoming # body indication received. elif self.state == httpParserStates.HEADERS_COMPLETE and \ - not self.body_expected and \ + not (self._content_expected or self._is_chunked_encoded) and \ raw == b'': self.state = httpParserStates.COMPLETE self.buffer = raw @@ -229,7 +244,7 @@ def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = COLON + str(self.port).encode() + path - ) if not self.is_https_tunnel else (self.host + COLON + str(self.port).encode()) + ) if not self._is_https_tunnel else (self.host + COLON + str(self.port).encode()) return build_http_request( self.method, path, self.version, headers={} if not self.headers else { @@ -263,7 +278,7 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]: # the latter MUST be ignored. # # TL;DR -- Give transfer-encoding header preference over content-length. - if self.is_chunked_encoded: + if self._is_chunked_encoded: if not self.chunk: self.chunk = ChunkParser() raw = self.chunk.parse(raw) @@ -271,7 +286,7 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]: self.body = self.chunk.body self.state = httpParserStates.COMPLETE more = False - elif self.content_expected: + elif self._content_expected: self.state = httpParserStates.RCVING_BODY if self.body is None: self.body = b'' @@ -297,7 +312,7 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]: more, raw = False, b'' return more, raw - def _process_line_and_headers(self, raw: bytes) -> Tuple[bool, bytes]: + def _process_headers(self, raw: bytes) -> Tuple[bool, bytes]: """Returns False when no CRLF could be found in received bytes. TODO: We should not return until parser reaches headers complete @@ -308,60 +323,59 @@ def _process_line_and_headers(self, raw: bytes) -> Tuple[bool, bytes]: This will also help make the parser even more stateless. """ while True: - line, raw = find_http_line(raw) - if line is None: + parts = raw.split(CRLF, 1) + if len(parts) == 1: return False, raw - - if self.state == httpParserStates.INITIALIZED: - self._process_line(line) - if self.state == httpParserStates.INITIALIZED: - # return len(raw) > 0, raw - continue - elif self.state in (httpParserStates.LINE_RCVD, httpParserStates.RCVING_HEADERS): - if self.state == httpParserStates.LINE_RCVD: - self.state = httpParserStates.RCVING_HEADERS + line, raw = parts[0], parts[1] + if self.state in (httpParserStates.LINE_RCVD, httpParserStates.RCVING_HEADERS): if line == b'' or line.strip() == b'': # Blank line received. self.state = httpParserStates.HEADERS_COMPLETE else: + self.state = httpParserStates.RCVING_HEADERS self._process_header(line) - # If raw length is now zero, bail out # If we have received all headers, bail out if raw == b'' or self.state == httpParserStates.HEADERS_COMPLETE: break return len(raw) > 0, raw - def _process_line(self, raw: bytes) -> None: - if self.type == httpParserTypes.REQUEST_PARSER: - if self.protocol is not None and self.protocol.version is None: - # We expect to receive entire proxy protocol v1 line - # in one network read and don't expect partial packets - self.protocol.parse(raw) - else: + def _process_line(self, raw: bytes) -> Tuple[bool, bytes]: + while True: + parts = raw.split(CRLF, 1) + if len(parts) == 1: + return False, raw + line, raw = parts[0], parts[1] + if self.type == httpParserTypes.REQUEST_PARSER: + if self.protocol is not None and self.protocol.version is None: + # We expect to receive entire proxy protocol v1 line + # in one network read and don't expect partial packets + self.protocol.parse(line) + continue # Ref: https://datatracker.ietf.org/doc/html/rfc2616#section-5.1 - line = raw.split(WHITESPACE, 2) - if len(line) == 3: - self.method = line[0].upper() + parts = line.split(WHITESPACE, 2) + if len(parts) == 3: + self.method = parts[0] if self.method == httpMethods.CONNECT: self._is_https_tunnel = True - self.set_url(line[1]) - self.version = line[2] + self.set_url(parts[1]) + self.version = parts[2] self.state = httpParserStates.LINE_RCVD - else: - # To avoid a possible attack vector, we raise exception - # if parser receives an invalid request line. - # - # TODO: Better to use raise HttpProtocolException, - # but we should solve circular import problem first. - raise ValueError('Invalid request line') - else: - line = raw.split(WHITESPACE, 2) - self.version = line[0] - self.code = line[1] + break + # To avoid a possible attack vector, we raise exception + # if parser receives an invalid request line. + # + # TODO: Better to use raise HttpProtocolException, + # but we should solve circular import problem first. + raise ValueError('Invalid request line') + parts = line.split(WHITESPACE, 2) + self.version = parts[0] + self.code = parts[1] # Our own WebServerPlugin example currently doesn't send any reason - if len(line) == 3: - self.reason = line[2] + if len(parts) == 3: + self.reason = parts[2] self.state = httpParserStates.LINE_RCVD + break + return len(raw) > 0, raw def _process_header(self, raw: bytes) -> None: parts = raw.split(COLON, 1) @@ -380,20 +394,16 @@ def _process_header(self, raw: bytes) -> None: def _get_body_or_chunks(self) -> Optional[bytes]: return ChunkParser.to_chunks(self.body) \ - if self.body and self.is_chunked_encoded else \ + if self.body and self._is_chunked_encoded else \ self.body def _set_line_attributes(self) -> None: if self.type == httpParserTypes.REQUEST_PARSER: - if self.is_https_tunnel and self._url: + assert self._url + if self._is_https_tunnel: self.host = self._url.hostname self.port = 443 if self._url.port is None else self._url.port - elif self._url: + else: self.host, self.port = self._url.hostname, self._url.port \ if self._url.port else DEFAULT_HTTP_PORT - else: - raise KeyError( - 'Invalid request. Method: %r, Url: %r' % - (self.method, self._url), - ) self.path = self._url.remainder diff --git a/proxy/http/proxy/auth.py b/proxy/http/proxy/auth.py index 5653a9528c..be0bccd79b 100644 --- a/proxy/http/proxy/auth.py +++ b/proxy/http/proxy/auth.py @@ -38,7 +38,7 @@ class AuthPlugin(HttpProxyBasePlugin): def before_upstream_connection( self, request: HttpParser, ) -> Optional[HttpParser]: - if self.flags.auth_code: + if self.flags.auth_code and request.headers: if b'proxy-authorization' not in request.headers: raise ProxyAuthenticationFailed() parts = request.headers[b'proxy-authorization'][1].split() diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 7a7622bd12..1c13458986 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -312,7 +312,7 @@ async def read_from_descriptors(self, r: Readables) -> bool: # currently response parsing is disabled when TLS interception is enabled. # # or self.tls_interception_enabled(): - if self.response.state == httpParserStates.COMPLETE: + if self.response.is_complete: self.handle_pipeline_response(raw) else: # TODO(abhinavsingh): Remove .tobytes after parser is @@ -436,7 +436,7 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: # and response objects. # # if not self.request.is_https_tunnel and \ - # self.response.state == httpParserStates.COMPLETE: + # self.response.is_complete: # self.access_log() return chunk @@ -465,7 +465,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # For http proxy requests, handle pipeline case. # We also handle pipeline scenario for https proxy # requests is TLS interception is enabled. - if self.request.state == httpParserStates.COMPLETE and ( + if self.request.is_complete and ( not self.request.is_https_tunnel or self.tls_interception_enabled() ): @@ -488,7 +488,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # TODO(abhinavsingh): Remove .tobytes after parser is # memoryview compliant self.pipeline_request.parse(raw.tobytes()) - if self.pipeline_request.state == httpParserStates.COMPLETE: + if self.pipeline_request.is_complete: for plugin in self.plugins.values(): assert self.pipeline_request is not None r = plugin.handle_client_request(self.pipeline_request) @@ -592,7 +592,7 @@ def handle_pipeline_response(self, raw: memoryview) -> None: # TODO(abhinavsingh): Remove .tobytes after parser is memoryview # compliant self.pipeline_response.parse(raw.tobytes()) - if self.pipeline_response.state == httpParserStates.COMPLETE: + if self.pipeline_response.is_complete: self.pipeline_response = None def connect_upstream(self) -> None: @@ -912,7 +912,12 @@ def emit_request_complete(self) -> None: if self.request.is_https_tunnel else 'http://%s:%d%s' % (text_(self.request.host), self.request.port, text_(self.request.path)), 'method': text_(self.request.method), - 'headers': {text_(k): text_(v[1]) for k, v in self.request.headers.items()}, + 'headers': {} + if not self.request.headers else + { + text_(k): text_(v[1]) + for k, v in self.request.headers.items() + }, 'body': text_(self.request.body) if self.request.method == httpMethods.POST else None, @@ -923,7 +928,7 @@ def emit_request_complete(self) -> None: def emit_response_events(self, chunk_size: int) -> None: if not self.flags.enable_events: return - if self.response.state == httpParserStates.COMPLETE: + if self.response.is_complete: self.emit_response_complete() elif self.response.state == httpParserStates.RCVING_BODY: self.emit_response_chunk_received(chunk_size) @@ -937,7 +942,12 @@ def emit_response_headers_complete(self) -> None: request_id=self.uid.hex, event_name=eventNames.RESPONSE_HEADERS_COMPLETE, event_payload={ - 'headers': {text_(k): text_(v[1]) for k, v in self.response.headers.items()}, + 'headers': {} + if not self.response.headers else + { + text_(k): text_(v[1]) + for k, v in self.response.headers.items() + }, }, publisher_id=self.__class__.__name__, ) diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index eacee97024..5c9c3630df 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -28,7 +28,7 @@ from ..exception import HttpProtocolException from ..plugin import HttpProtocolHandlerPlugin from ..websocket import WebsocketFrame, websocketOpcodes -from ..parser import HttpParser, httpParserStates, httpParserTypes +from ..parser import HttpParser, httpParserTypes from .plugin import HttpWebServerBasePlugin from .protocols import httpProtocolTypes @@ -274,7 +274,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: return None # If 1st valid request was completed and it's a HTTP/1.1 keep-alive # And only if we have a route, parse pipeline requests - if self.request.state == httpParserStates.COMPLETE and \ + if self.request.is_complete and \ self.request.is_http_1_1_keep_alive and \ self.route is not None: if self.pipeline_request is None: @@ -284,7 +284,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # TODO(abhinavsingh): Remove .tobytes after parser is memoryview # compliant self.pipeline_request.parse(raw.tobytes()) - if self.pipeline_request.state == httpParserStates.COMPLETE: + if self.pipeline_request.is_complete: self.route.handle_request(self.pipeline_request) if not self.pipeline_request.is_http_1_1_keep_alive: logger.error( diff --git a/proxy/http/url.py b/proxy/http/url.py index 177e3823ca..a7fc4390cb 100644 --- a/proxy/http/url.py +++ b/proxy/http/url.py @@ -15,7 +15,7 @@ """ from typing import Optional, Tuple -from ..common.constants import COLON, SLASH +from ..common.constants import COLON, SLASH, HTTP_URL_PREFIX, HTTPS_URL_PREFIX from ..common.utils import text_ @@ -65,11 +65,11 @@ def from_bytes(cls, raw: bytes) -> 'Url': We use heuristics based approach for our URL parser. """ - sraw = raw.decode('utf-8') - if sraw[0] == SLASH.decode('utf-8'): + if raw[0] == 47: # SLASH == 47 return cls(remainder=raw) - if sraw.startswith('https://') or sraw.startswith('http://'): - is_https = sraw.startswith('https://') + is_http = raw.startswith(HTTP_URL_PREFIX) + is_https = raw.startswith(HTTPS_URL_PREFIX) + if is_http or is_https: rest = raw[len(b'https://'):] \ if is_https \ else raw[len(b'http://'):] @@ -88,21 +88,26 @@ def from_bytes(cls, raw: bytes) -> 'Url': @staticmethod def parse_host_and_port(raw: bytes) -> Tuple[bytes, Optional[int]]: - parts = raw.split(COLON) + parts = raw.split(COLON, 2) + num_parts = len(parts) port: Optional[int] = None - if len(parts) == 1: + # No port found + if num_parts == 1: return parts[0], None - if len(parts) == 2: - host, port = COLON.join(parts[:-1]), int(parts[-1]) - if len(parts) > 2: - try: - port = int(parts[-1]) - host = COLON.join(parts[:-1]) - except ValueError: - # If unable to convert last part into port, - # this is the IPv6 scenario. Treat entire - # data as host - host, port = raw, None + # Host and port found + if num_parts == 2: + return COLON.join(parts[:-1]), int(parts[-1]) + # More than a single COLON i.e. IPv6 scenario + try: + # Try to resolve last part as an int port + last_token = parts[-1].split(COLON) + port = int(last_token[-1]) + host = COLON.join(parts[:-1]) + COLON + \ + COLON.join(last_token[:-1]) + except ValueError: + # If unable to convert last part into port, + # treat entire data as host + host, port = raw, None # patch up invalid ipv6 scenario rhost = host.decode('utf-8') if COLON.decode('utf-8') in rhost and \ diff --git a/proxy/plugin/filter_by_url_regex.py b/proxy/plugin/filter_by_url_regex.py index ef5dd80530..557a8cff3c 100644 --- a/proxy/plugin/filter_by_url_regex.py +++ b/proxy/plugin/filter_by_url_regex.py @@ -59,9 +59,8 @@ def handle_client_request( request_host = None if request.host: request_host = request.host - else: - if b'host' in request.headers: - request_host = request.header(b'host') + elif request.headers and b'host' in request.headers: + request_host = request.header(b'host') if not request_host: logger.error("Cannot determine host") diff --git a/proxy/plugin/modify_chunk_response.py b/proxy/plugin/modify_chunk_response.py index f1ace20a04..05e6c6f3eb 100644 --- a/proxy/plugin/modify_chunk_response.py +++ b/proxy/plugin/modify_chunk_response.py @@ -10,7 +10,7 @@ """ from typing import Any -from ..http.parser import HttpParser, httpParserTypes, httpParserStates +from ..http.parser import HttpParser, httpParserTypes from ..http.proxy import HttpProxyBasePlugin @@ -34,7 +34,7 @@ def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: # Note that these chunks also include headers self.response.parse(chunk.tobytes()) # If response is complete, modify and dispatch to client - if self.response.state == httpParserStates.COMPLETE: + if self.response.is_complete: # Avoid setting a body for responses where a body is not expected. # Otherwise, example curl will report warnings. if self.response.body_expected: diff --git a/tests/http/test_http_parser.py b/tests/http/test_http_parser.py index 2d1cc883c9..7e3c974377 100644 --- a/tests/http/test_http_parser.py +++ b/tests/http/test_http_parser.py @@ -193,7 +193,8 @@ def test_has_header(self) -> None: self.assertTrue(self.parser.has_header(b'key')) def test_set_host_port_raises(self) -> None: - with self.assertRaises(KeyError): + # Assertion for url will fail + with self.assertRaises(AssertionError): self.parser._set_line_attributes() def test_find_line(self) -> None: @@ -243,6 +244,7 @@ def test_get_full_parse(self) -> None: self.assertEqual(self.parser._url.port, None) self.assertEqual(self.parser.version, b'HTTP/1.1') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) + assert self.parser.headers self.assertEqual( self.parser.headers[b'host'], (b'Host', b'example.com'), ) @@ -296,7 +298,7 @@ def test_get_partial_parse1(self) -> None: self.parser.total_size, len(pkt) + len(CRLF) + len(host_hdr), ) - self.assertDictEqual(self.parser.headers, {}) + assert self.parser.headers is None self.assertEqual(self.parser.buffer, b'Host: localhost:8080') self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) @@ -305,6 +307,7 @@ def test_get_partial_parse1(self) -> None: self.parser.total_size, len(pkt) + (3 * len(CRLF)) + len(host_hdr), ) + assert self.parser.headers is not None self.assertEqual( self.parser.headers[b'host'], ( @@ -330,6 +333,7 @@ def test_get_partial_parse2(self) -> None: self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) self.parser.parse(b'localhost:8080' + CRLF) + assert self.parser.headers self.assertEqual( self.parser.headers[b'host'], ( @@ -345,6 +349,7 @@ def test_get_partial_parse2(self) -> None: self.parser.parse(b'Content-Type: text/plain' + CRLF) self.assertEqual(self.parser.buffer, b'') + assert self.parser.headers self.assertEqual( self.parser.headers[b'content-type'], ( b'Content-Type', @@ -373,6 +378,7 @@ def test_post_full_parse(self) -> None: self.assertEqual(self.parser._url.hostname, b'localhost') self.assertEqual(self.parser._url.port, None) self.assertEqual(self.parser.version, b'HTTP/1.1') + assert self.parser.headers self.assertEqual( self.parser.headers[b'content-type'], (b'Content-Type', b'application/x-www-form-urlencoded'), @@ -528,6 +534,7 @@ def test_response_parse(self) -> None: b'