From d578ae7e78a5bda2202aa9fceaa5af37c9f618a8 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Sat, 13 Apr 2024 17:07:42 +0530 Subject: [PATCH 1/2] Enable TLS interception support even when proxy pool plugin is enabled --- proxy/core/connection/server.py | 11 ++++ proxy/http/proxy/plugin.py | 8 ++- proxy/http/proxy/server.py | 91 +++++++++++++++++++++++++-------- proxy/plugin/proxy_pool.py | 6 +++ 4 files changed, 95 insertions(+), 21 deletions(-) diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index 31233049f7..c7bd4faa05 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -25,6 +25,17 @@ def __init__(self, host: str, port: int) -> None: self._conn: Optional[TcpOrTlsSocket] = None self.addr: HostPort = (host, port) self.closed = True + self._proxy = False + + def is_secure(self) -> bool: + return isinstance(self._conn, ssl.SSLSocket) + + def mark_as_proxy(self) -> None: + self._proxy = True + + @property + def is_proxy(self) -> bool: + return self._proxy @property def connection(self) -> TcpOrTlsSocket: diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index a9c10e88f3..b44b1a872f 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -17,7 +17,7 @@ from ...core.event import EventQueue from ..descriptors import DescriptorsHandlerMixin from ...common.utils import tls_interception_enabled - +from ...core.connection import TcpServerConnection if TYPE_CHECKING: # pragma: no cover from ...common.types import HostPort @@ -69,6 +69,12 @@ def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['Ho """ return None, None + def upstream_connection( + self, + request: HttpParser, + ) -> Optional[TcpServerConnection]: + return None + # No longer abstract since 2.4.0 # # @abstractmethod diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 70d3369ec4..2361a2bf1d 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -21,7 +21,7 @@ import logging import threading import subprocess -from typing import Any, Dict, List, Union, Optional, cast +from typing import Any, Dict, List, Union, Optional from .plugin import HttpProxyBasePlugin from ..parser import HttpParser, httpParserTypes, httpParserStates @@ -487,6 +487,14 @@ def on_request_complete(self) -> Union[socket.socket, bool]: # Connect to upstream if do_connect: self.connect_upstream() + else: + # If a plugin asked us not to connect to upstream + # check if any plugin is managing an upstream connection. + for plugin in self.plugins.values(): + up = plugin.upstream_connection(self.request) + if up is not None: + self.upstream = up + break # Invoke plugin.handle_client_request for plugin in self.plugins.values(): @@ -756,13 +764,28 @@ def intercept(self) -> Union[socket.socket, bool]: return self.client.connection def wrap_server(self) -> bool: - assert self.upstream is not None - assert isinstance(self.upstream.connection, socket.socket) + assert self.upstream is not None and self.request.host + return self._wrap_server( + self.upstream, + host=self.request.host, + ca_file=self.flags.ca_file, + ) + + @staticmethod + def _wrap_server( + upstream: TcpServerConnection, + host: bytes, + ca_file: Optional[str] = None, + ) -> bool: + assert isinstance(upstream.connection, socket.socket) do_close = False + if upstream.is_proxy: + # Don't wrap upstream if its part of proxy chain + return do_close try: - self.upstream.wrap( - text_(self.request.host), - self.flags.ca_file, + upstream.wrap( + text_(host), + ca_file, as_non_blocking=True, ) except ssl.SSLCertVerificationError: # Server raised certificate verification error @@ -770,40 +793,68 @@ def wrap_server(self) -> bool: # we will cache such upstream hosts and avoid intercepting them for future # requests. logger.warning( - 'ssl.SSLCertVerificationError: ' + - 'Server raised cert verification error for upstream: {0}'.format( - self.upstream.addr[0], + "ssl.SSLCertVerificationError: " + + "Server raised cert verification error for upstream: {0}".format( + upstream.addr[0], ), ) do_close = True except ssl.SSLError as e: if e.reason == 'SSLV3_ALERT_HANDSHAKE_FAILURE': logger.warning( - '{0}: '.format(e.reason) + - 'Server raised handshake alert failure for upstream: {0}'.format( - self.upstream.addr[0], + "{0}: ".format(e.reason) + + "Server raised handshake alert failure for upstream: {0}".format( + upstream.addr[0], ), ) else: logger.exception( - 'SSLError when wrapping client for upstream: {0}'.format( - self.upstream.addr[0], - ), exc_info=e, + "SSLError when wrapping client for upstream: {0}".format( + upstream.addr[0], + ), + exc_info=e, ) do_close = True if not do_close: - assert isinstance(self.upstream.connection, ssl.SSLSocket) + assert isinstance(upstream.connection, ssl.SSLSocket) return do_close def wrap_client(self) -> bool: assert self.upstream is not None and self.flags.ca_signing_key_file is not None - assert isinstance(self.upstream.connection, ssl.SSLSocket) + certificate: Optional[Dict[str, Any]] = None + if isinstance(self.upstream.connection, ssl.SSLSocket): + certificate = self.upstream.connection.getpeercert() + else: + assert self.upstream.is_proxy and self.request.host and self.request.port + if self.flags.enable_conn_pool: + assert self.upstream_conn_pool + with self.lock: + _, upstream = self.upstream_conn_pool.acquire( + (text_(self.request.host), self.request.port), + ) + else: + _, upstream = True, TcpServerConnection( + text_(self.request.host), + self.request.port, + ) + # Connect with overridden upstream IP and source address + # if any of the plugin returned a non-null value. + upstream.connect() + upstream.connection.setblocking(False) + do_close = self._wrap_server( + upstream, + host=self.request.host, + ca_file=self.flags.ca_file, + ) + if do_close: + return do_close + assert isinstance(upstream.connection, ssl.SSLSocket) + certificate = upstream.connection.getpeercert() + assert certificate do_close = False try: # TODO: Perform async certificate generation - generated_cert = self.generate_upstream_certificate( - cast(Dict[str, Any], self.upstream.connection.getpeercert()), - ) + generated_cert = self.generate_upstream_certificate(certificate) self.client.wrap(self.flags.ca_signing_key_file, generated_cert) except subprocess.TimeoutExpired as e: # Popen communicate timeout logger.exception( diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index c244adeda1..a70e3dc18f 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -14,6 +14,8 @@ import ipaddress from typing import Any, Dict, List, Optional +from proxy.core.connection import TcpServerConnection + from ..http import Url, httpHeaders, httpMethods from ..core.base import TcpUpstreamConnectionHandler from ..http.proxy import HttpProxyBasePlugin @@ -78,6 +80,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def handle_upstream_data(self, raw: memoryview) -> None: self.client.queue(raw) + def upstream_connection(self, request: HttpParser) -> Optional[TcpServerConnection]: + return self.upstream + def before_upstream_connection( self, request: HttpParser, ) -> Optional[HttpParser]: @@ -107,6 +112,7 @@ def before_upstream_connection( logger.debug('Using endpoint: {0}:{1}'.format(*endpoint_tuple)) self.initialize_upstream(*endpoint_tuple) assert self.upstream + self.upstream.mark_as_proxy() try: self.upstream.connect() except TimeoutError: From 0ee1203958d549da08b0aa0aa9444c8dfb84f5d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Apr 2024 11:40:17 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- proxy/http/proxy/plugin.py | 1 + proxy/http/proxy/server.py | 10 +++++----- proxy/plugin/proxy_pool.py | 1 - 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index b44b1a872f..c1dfd34fa4 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -19,6 +19,7 @@ from ...common.utils import tls_interception_enabled from ...core.connection import TcpServerConnection + if TYPE_CHECKING: # pragma: no cover from ...common.types import HostPort from ...core.connection import UpstreamConnectionPool diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 2361a2bf1d..fcaf3e7306 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -793,8 +793,8 @@ def _wrap_server( # we will cache such upstream hosts and avoid intercepting them for future # requests. logger.warning( - "ssl.SSLCertVerificationError: " - + "Server raised cert verification error for upstream: {0}".format( + 'ssl.SSLCertVerificationError: ' + + 'Server raised cert verification error for upstream: {0}'.format( upstream.addr[0], ), ) @@ -802,14 +802,14 @@ def _wrap_server( except ssl.SSLError as e: if e.reason == 'SSLV3_ALERT_HANDSHAKE_FAILURE': logger.warning( - "{0}: ".format(e.reason) - + "Server raised handshake alert failure for upstream: {0}".format( + '{0}: '.format(e.reason) + + 'Server raised handshake alert failure for upstream: {0}'.format( upstream.addr[0], ), ) else: logger.exception( - "SSLError when wrapping client for upstream: {0}".format( + 'SSLError when wrapping client for upstream: {0}'.format( upstream.addr[0], ), exc_info=e, diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index a70e3dc18f..eecac398a2 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -15,7 +15,6 @@ from typing import Any, Dict, List, Optional from proxy.core.connection import TcpServerConnection - from ..http import Url, httpHeaders, httpMethods from ..core.base import TcpUpstreamConnectionHandler from ..http.proxy import HttpProxyBasePlugin