diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 258a80b0fd..a3b48e4095 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,11 @@ --- repos: -# - repo: https://github.com/asottile/add-trailing-comma.git -# rev: v2.0.1 -# hooks: -# - id: add-trailing-comma +- repo: https://github.com/asottile/add-trailing-comma.git + rev: v2.0.1 + hooks: + - id: add-trailing-comma + args: + - --py36-plus # - repo: https://github.com/timothycrosley/isort.git # rev: 5.4.2 diff --git a/examples/https_connect_tunnel.py b/examples/https_connect_tunnel.py index 950186468f..6bf504da12 100644 --- a/examples/https_connect_tunnel.py +++ b/examples/https_connect_tunnel.py @@ -23,16 +23,20 @@ class HttpsConnectTunnelHandler(BaseTcpTunnelHandler): """A https CONNECT tunnel.""" - PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'Connection established' - )) - - PROXY_TUNNEL_UNSUPPORTED_SCHEME = memoryview(build_http_response( - httpStatusCodes.BAD_REQUEST, - headers={b'Connection': b'close'}, - reason=b'Unsupported protocol scheme' - )) + PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.OK, + reason=b'Connection established', + ), + ) + + PROXY_TUNNEL_UNSUPPORTED_SCHEME = memoryview( + build_http_response( + httpStatusCodes.BAD_REQUEST, + headers={b'Connection': b'close'}, + reason=b'Unsupported protocol scheme', + ), + ) def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -49,7 +53,8 @@ def handle_data(self, data: memoryview) -> Optional[bool]: # Drop the request if not a CONNECT request if self.request.method != httpMethods.CONNECT: self.client.queue( - HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME) + HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME, + ) return True # CONNECT requests are short and we need not worry about @@ -61,7 +66,8 @@ def handle_data(self, data: memoryview) -> Optional[bool]: # Queue tunnel established response to client self.client.queue( - HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, + ) return None @@ -70,7 +76,8 @@ def main() -> None: # This example requires `threadless=True` pool = AcceptorPool( flags=Proxy.initialize(port=12345, num_workers=1, threadless=True), - work_klass=HttpsConnectTunnelHandler) + work_klass=HttpsConnectTunnelHandler, + ) try: pool.setup() while True: diff --git a/examples/pubsub_eventing.py b/examples/pubsub_eventing.py index d5bcd12f67..d8650e5098 100644 --- a/examples/pubsub_eventing.py +++ b/examples/pubsub_eventing.py @@ -25,8 +25,10 @@ num_events_received = [0, 0] -def publisher_process(shutdown_event: multiprocessing.synchronize.Event, - dispatcher_queue: EventQueue) -> None: +def publisher_process( + shutdown_event: multiprocessing.synchronize.Event, + dispatcher_queue: EventQueue, +) -> None: print('publisher starting') try: while not shutdown_event.is_set(): @@ -34,7 +36,7 @@ def publisher_process(shutdown_event: multiprocessing.synchronize.Event, request_id=process_publisher_request_id, event_name=eventNames.WORK_STARTED, event_payload={'time': time.time()}, - publisher_id='eventing_pubsub_process' + publisher_id='eventing_pubsub_process', ) except KeyboardInterrupt: pass @@ -70,7 +72,8 @@ def on_event(payload: Dict[str, Any]) -> None: publisher_shutdown_event = multiprocessing.Event() publisher = multiprocessing.Process( target=publisher_process, args=( - publisher_shutdown_event, event_manager.event_queue, )) + publisher_shutdown_event, event_manager.event_queue, ), + ) publisher.start() try: @@ -80,7 +83,7 @@ def on_event(payload: Dict[str, Any]) -> None: request_id=main_publisher_request_id, event_name=eventNames.WORK_STARTED, event_payload={'time': time.time()}, - publisher_id='eventing_pubsub_main' + publisher_id='eventing_pubsub_main', ) except KeyboardInterrupt: print('bye!!!') @@ -92,5 +95,8 @@ def on_event(payload: Dict[str, Any]) -> None: subscriber.unsubscribe() # Signal dispatcher to shutdown event_manager.stop_event_dispatcher() - print('Received {0} events from main thread, {1} events from another process, in {2} seconds'.format( - num_events_received[0], num_events_received[1], time.time() - start_time)) + print( + 'Received {0} events from main thread, {1} events from another process, in {2} seconds'.format( + num_events_received[0], num_events_received[1], time.time() - start_time, + ), + ) diff --git a/examples/ssl_echo_server.py b/examples/ssl_echo_server.py index 013bc3a5f0..b609abd199 100644 --- a/examples/ssl_echo_server.py +++ b/examples/ssl_echo_server.py @@ -29,11 +29,13 @@ def initialize(self) -> None: conn = wrap_socket( self.client.connection, self.flags.keyfile, - self.flags.certfile) + self.flags.certfile, + ) conn.setblocking(False) # Upgrade plain TcpClientConnection to SSL connection object self.client = TcpClientConnection( - conn=conn, addr=self.client.addr) + conn=conn, addr=self.client.addr, + ) def handle_data(self, data: memoryview) -> Optional[bool]: # echo back to client @@ -49,8 +51,10 @@ def main() -> None: num_workers=1, threadless=True, keyfile='https-key.pem', - certfile='https-signed-cert.pem'), - work_klass=EchoSSLServerHandler) + certfile='https-signed-cert.pem', + ), + work_klass=EchoSSLServerHandler, + ) try: pool.setup() while True: diff --git a/examples/tcp_echo_server.py b/examples/tcp_echo_server.py index c468b7eaf8..16cf3fcb95 100644 --- a/examples/tcp_echo_server.py +++ b/examples/tcp_echo_server.py @@ -32,7 +32,8 @@ def main() -> None: # This example requires `threadless=True` pool = AcceptorPool( flags=Proxy.initialize(port=12345, num_workers=1, threadless=True), - work_klass=EchoServerHandler) + work_klass=EchoServerHandler, + ) try: pool.setup() while True: diff --git a/examples/websocket_client.py b/examples/websocket_client.py index c87ed3e4fb..a382304401 100644 --- a/examples/websocket_client.py +++ b/examples/websocket_client.py @@ -22,10 +22,14 @@ def on_message(frame: WebsocketFrame) -> None: """WebsocketClient on_message callback.""" global client, num_echos, last_dispatch_time - print('Received %r after %d millisec' % - (frame.data, (time.time() - last_dispatch_time) * 1000)) - assert(frame.data == b'hello' and frame.opcode == - websocketOpcodes.TEXT_FRAME) + print( + 'Received %r after %d millisec' % + (frame.data, (time.time() - last_dispatch_time) * 1000), + ) + assert( + frame.data == b'hello' and frame.opcode == + websocketOpcodes.TEXT_FRAME + ) if num_echos > 0: client.queue(static_frame) last_dispatch_time = time.time() @@ -40,7 +44,8 @@ def on_message(frame: WebsocketFrame) -> None: b'echo.websocket.org', 80, b'/', - on_message=on_message) + on_message=on_message, + ) # Perform handshake client.handshake() # Queue some data for client diff --git a/proxy/common/flag.py b/proxy/common/flag.py index 13ae8e6f37..e23789d363 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -36,7 +36,7 @@ def __init__(self) -> None: self.actions: List[str] = [] self.parser = argparse.ArgumentParser( description='proxy.py v%s' % __version__, - epilog='Proxy.py not working? Report at: %s/issues/new' % __homepage__ + epilog='Proxy.py not working? Report at: %s/issues/new' % __homepage__, ) def add_argument(self, *args: Any, **kwargs: Any) -> argparse.Action: @@ -46,7 +46,8 @@ def add_argument(self, *args: Any, **kwargs: Any) -> argparse.Action: return action def parse_args( - self, input_args: Optional[List[str]]) -> argparse.Namespace: + self, input_args: Optional[List[str]], + ) -> argparse.Namespace: """Parse flags from input arguments.""" self.args = self.parser.parse_args(input_args) return self.args diff --git a/proxy/common/pki.py b/proxy/common/pki.py index 80711fd4c6..1c764a9d71 100644 --- a/proxy/common/pki.py +++ b/proxy/common/pki.py @@ -57,13 +57,14 @@ def remove_passphrase( key_in_path: str, password: str, key_out_path: str, - timeout: int = 10) -> bool: + timeout: int = 10, +) -> bool: """Remove passphrase from a private key.""" command = [ 'openssl', 'rsa', '-passin', 'pass:%s' % password, '-in', key_in_path, - '-out', key_out_path + '-out', key_out_path, ] return run_openssl_command(command, timeout) @@ -72,12 +73,13 @@ def gen_private_key( key_path: str, password: str, bits: int = 2048, - timeout: int = 10) -> bool: + timeout: int = 10, +) -> bool: """Generates a private key.""" command = [ 'openssl', 'genrsa', '-aes256', '-passout', 'pass:%s' % password, - '-out', key_path, str(bits) + '-out', key_path, str(bits), ] return run_openssl_command(command, timeout) @@ -90,7 +92,8 @@ def gen_public_key( alt_subj_names: Optional[List[str]] = None, extended_key_usage: Optional[str] = None, validity_in_days: int = 365, - timeout: int = 10) -> bool: + timeout: int = 10, +) -> bool: """For a given private key, generates a corresponding public key.""" with ssl_config(alt_subj_names, extended_key_usage) as (config_path, has_extension): command = [ @@ -98,7 +101,7 @@ def gen_public_key( '-days', str(validity_in_days), '-subj', subject, '-passin', 'pass:%s' % private_key_password, '-config', config_path, - '-key', private_key_path, '-out', public_key_path + '-key', private_key_path, '-out', public_key_path, ] if has_extension: command.extend([ @@ -112,13 +115,14 @@ def gen_csr( key_path: str, password: str, crt_path: str, - timeout: int = 10) -> bool: + timeout: int = 10, +) -> bool: """Generates a CSR based upon existing certificate and key file.""" command = [ 'openssl', 'x509', '-x509toreq', '-passin', 'pass:%s' % password, '-in', crt_path, '-signkey', key_path, - '-out', csr_path + '-out', csr_path, ] return run_openssl_command(command, timeout) @@ -133,7 +137,8 @@ def sign_csr( alt_subj_names: Optional[List[str]] = None, extended_key_usage: Optional[str] = None, validity_in_days: int = 365, - timeout: int = 10) -> bool: + timeout: int = 10, +) -> bool: """Sign a CSR using CA key and certificate.""" with ext_file(alt_subj_names, extended_key_usage) as extension_path: command = [ @@ -152,7 +157,8 @@ def sign_csr( def get_ext_config( alt_subj_names: Optional[List[str]] = None, - extended_key_usage: Optional[str] = None) -> bytes: + extended_key_usage: Optional[str] = None, +) -> bytes: config = b'' # Add SAN extension if alt_subj_names is not None and len(alt_subj_names) > 0: @@ -169,12 +175,14 @@ def get_ext_config( @contextlib.contextmanager def ext_file( alt_subj_names: Optional[List[str]] = None, - extended_key_usage: Optional[str] = None) -> Generator[str, None, None]: + extended_key_usage: Optional[str] = None, +) -> Generator[str, None, None]: # Write config to temp file config_path = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex) with open(config_path, 'wb') as cnf: cnf.write( - get_ext_config(alt_subj_names, extended_key_usage)) + get_ext_config(alt_subj_names, extended_key_usage), + ) yield config_path @@ -185,7 +193,8 @@ def ext_file( @contextlib.contextmanager def ssl_config( alt_subj_names: Optional[List[str]] = None, - extended_key_usage: Optional[str] = None) -> Generator[Tuple[str, bool], None, None]: + extended_key_usage: Optional[str] = None, +) -> Generator[Tuple[str, bool], None, None]: config = DEFAULT_CONFIG has_extension = False @@ -212,7 +221,7 @@ def run_openssl_command(command: List[str], timeout: int) -> bool: cmd = subprocess.Popen( command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE + stderr=subprocess.PIPE, ) cmd.communicate(timeout=timeout) return cmd.returncode == 0 @@ -221,7 +230,7 @@ def run_openssl_command(command: List[str], timeout: int) -> bool: if __name__ == '__main__': available_actions = ( 'remove_passphrase', 'gen_private_key', 'gen_public_key', - 'gen_csr', 'sign_csr' + 'gen_csr', 'sign_csr', ) parser = argparse.ArgumentParser( @@ -231,7 +240,7 @@ def run_openssl_command(command: List[str], timeout: int) -> bool: 'action', type=str, default=None, - help='Valid actions: ' + ', '.join(available_actions) + help='Valid actions: ' + ', '.join(available_actions), ) parser.add_argument( '--password', @@ -294,17 +303,24 @@ def run_openssl_command(command: List[str], timeout: int) -> bool: if args.action == 'gen_private_key': gen_private_key(args.private_key_path, args.password) elif args.action == 'gen_public_key': - gen_public_key(args.public_key_path, args.private_key_path, - args.password, args.subject) + gen_public_key( + args.public_key_path, args.private_key_path, + args.password, args.subject, + ) elif args.action == 'remove_passphrase': - remove_passphrase(args.private_key_path, args.password, - args.private_key_path) + remove_passphrase( + args.private_key_path, args.password, + args.private_key_path, + ) elif args.action == 'gen_csr': gen_csr( args.csr_path, args.private_key_path, args.password, - args.public_key_path) + args.public_key_path, + ) elif args.action == 'sign_csr': - sign_csr(args.csr_path, args.crt_path, args.private_key_path, args.password, - args.public_key_path, str(int(time.time())), alt_subj_names=[args.hostname, ]) + sign_csr( + args.csr_path, args.crt_path, args.private_key_path, args.password, + args.public_key_path, str(int(time.time())), alt_subj_names=[args.hostname], + ) diff --git a/proxy/common/utils.py b/proxy/common/utils.py index b894da5d08..a79cd8a870 100644 --- a/proxy/common/utils.py +++ b/proxy/common/utils.py @@ -46,22 +46,27 @@ def bytes_(s: Any, encoding: str = 'utf-8', errors: str = 'strict') -> Any: return s -def build_http_request(method: bytes, url: bytes, - protocol_version: bytes = HTTP_1_1, - headers: Optional[Dict[bytes, bytes]] = None, - body: Optional[bytes] = None) -> bytes: +def build_http_request( + method: bytes, url: bytes, + protocol_version: bytes = HTTP_1_1, + headers: Optional[Dict[bytes, bytes]] = None, + body: Optional[bytes] = None, +) -> bytes: """Build and returns a HTTP request packet.""" if headers is None: headers = {} return build_http_pkt( - [method, url, protocol_version], headers, body) + [method, url, protocol_version], headers, body, + ) -def build_http_response(status_code: int, - protocol_version: bytes = HTTP_1_1, - reason: Optional[bytes] = None, - headers: Optional[Dict[bytes, bytes]] = None, - body: Optional[bytes] = None) -> bytes: +def build_http_response( + status_code: int, + protocol_version: bytes = HTTP_1_1, + reason: Optional[bytes] = None, + headers: Optional[Dict[bytes, bytes]] = None, + body: Optional[bytes] = None, +) -> bytes: """Build and returns a HTTP response packet.""" line = [protocol_version, bytes_(status_code)] if reason: @@ -87,9 +92,11 @@ def build_http_header(k: bytes, v: bytes) -> bytes: return k + COLON + WHITESPACE + v -def build_http_pkt(line: List[bytes], - headers: Optional[Dict[bytes, bytes]] = None, - body: Optional[bytes] = None) -> bytes: +def build_http_pkt( + line: List[bytes], + headers: Optional[Dict[bytes, bytes]] = None, + body: Optional[bytes] = None, +) -> bytes: """Build and returns a HTTP request or response packet.""" pkt = WHITESPACE.join(line) + CRLF if headers is not None: @@ -105,7 +112,8 @@ def build_websocket_handshake_request( key: bytes, method: bytes = b'GET', url: bytes = b'/', - host: bytes = b'localhost') -> bytes: + host: bytes = b'localhost', +) -> bytes: """ Build and returns a Websocket handshake request packet. @@ -121,7 +129,7 @@ def build_websocket_handshake_request( b'Upgrade': b'websocket', b'Sec-WebSocket-Key': key, b'Sec-WebSocket-Version': b'13', - } + }, ) @@ -136,8 +144,8 @@ def build_websocket_handshake_response(accept: bytes) -> bytes: headers={ b'Upgrade': b'websocket', b'Connection': b'Upgrade', - b'Sec-WebSocket-Accept': accept - } + b'Sec-WebSocket-Accept': accept, + }, ) @@ -153,15 +161,19 @@ def find_http_line(raw: bytes) -> Tuple[Optional[bytes], bytes]: return line, rest -def wrap_socket(conn: socket.socket, keyfile: str, - certfile: str) -> ssl.SSLSocket: +def wrap_socket( + conn: socket.socket, keyfile: str, + certfile: str, +) -> ssl.SSLSocket: ctx = ssl.create_default_context( - ssl.Purpose.CLIENT_AUTH) + ssl.Purpose.CLIENT_AUTH, + ) ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 ctx.verify_mode = ssl.CERT_NONE ctx.load_cert_chain( certfile=certfile, - keyfile=keyfile) + keyfile=keyfile, + ) return ctx.wrap_socket( conn, server_side=True, @@ -169,18 +181,21 @@ def wrap_socket(conn: socket.socket, keyfile: str, def new_socket_connection( - addr: Tuple[str, int], timeout: int = DEFAULT_TIMEOUT) -> socket.socket: + addr: Tuple[str, int], timeout: int = DEFAULT_TIMEOUT, +) -> socket.socket: conn = None try: ip = ipaddress.ip_address(addr[0]) if ip.version == 4: conn = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, 0) + socket.AF_INET, socket.SOCK_STREAM, 0, + ) conn.settimeout(timeout) conn.connect(addr) else: conn = socket.socket( - socket.AF_INET6, socket.SOCK_STREAM, 0) + socket.AF_INET6, socket.SOCK_STREAM, 0, + ) conn.settimeout(timeout) conn.connect((addr[0], addr[1], 0, 0)) except ValueError: @@ -209,12 +224,14 @@ def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + exc_tb: Optional[TracebackType], + ) -> None: if self.conn: self.conn.close() def __call__( # type: ignore - self, func: Callable[..., Any]) -> Callable[[Tuple[Any, ...], Dict[str, Any]], Any]: + self, func: Callable[..., Any], + ) -> Callable[[Tuple[Any, ...], Dict[str, Any]], Any]: @functools.wraps(func) def decorated(*args: Any, **kwargs: Any) -> Any: with self as conn: @@ -233,19 +250,24 @@ def get_available_port() -> int: def setup_logger( log_file: Optional[str] = DEFAULT_LOG_FILE, log_level: str = DEFAULT_LOG_LEVEL, - log_format: str = DEFAULT_LOG_FORMAT) -> None: + log_format: str = DEFAULT_LOG_FORMAT, +) -> None: ll = getattr( logging, - {'D': 'DEBUG', + { + 'D': 'DEBUG', 'I': 'INFO', 'W': 'WARNING', 'E': 'ERROR', - 'C': 'CRITICAL'}[log_level.upper()[0]]) + 'C': 'CRITICAL', + }[log_level.upper()[0]], + ) if log_file: logging.basicConfig( filename=log_file, filemode='a', level=ll, - format=log_format) + format=log_format, + ) else: logging.basicConfig(level=ll, format=log_format) diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 411bc2ad2d..78e0f02bd9 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -37,7 +37,7 @@ action='store_true', default=DEFAULT_THREADLESS, help='Default: False. When disabled a new thread is spawned ' - 'to handle each client connection.' + 'to handle each client connection.', ) @@ -64,7 +64,8 @@ def __init__( flags: argparse.Namespace, work_klass: Type[Work], lock: multiprocessing.synchronize.Lock, - event_queue: Optional[EventQueue] = None) -> None: + event_queue: Optional[EventQueue] = None, + ) -> None: super().__init__() self.idd = idd self.work_queue: connection.Connection = work_queue @@ -86,7 +87,7 @@ def start_threadless_process(self) -> None: client_queue=pipe[1], flags=self.flags, work_klass=self.work_klass, - event_queue=self.event_queue + event_queue=self.event_queue, ) self.threadless_process.start() logger.debug('Started process %d', self.threadless_process.pid) @@ -106,21 +107,21 @@ def start_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None: send_handle( self.threadless_client_queue, conn.fileno(), - self.threadless_process.pid + self.threadless_process.pid, ) conn.close() else: work = self.work_klass( TcpClientConnection(conn, addr), flags=self.flags, - event_queue=self.event_queue + event_queue=self.event_queue, ) work_thread = threading.Thread(target=work.run) work_thread.daemon = True work.publish_event( event_name=eventNames.WORK_STARTED, event_payload={'fileno': conn.fileno(), 'addr': addr}, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) work_thread.start() @@ -134,15 +135,17 @@ def run_once(self) -> None: self.start_work(conn, addr) def run(self) -> None: - setup_logger(self.flags.log_file, self.flags.log_level, - self.flags.log_format) + setup_logger( + self.flags.log_file, self.flags.log_level, + self.flags.log_format, + ) self.selector = selectors.DefaultSelector() fileno = recv_handle(self.work_queue) self.work_queue.close() self.sock = socket.fromfd( fileno, family=self.flags.family, - type=socket.SOCK_STREAM + type=socket.SOCK_STREAM, ) try: self.selector.register(self.sock, selectors.EVENT_READ) diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index 2152ab1f28..0f1a28663e 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -35,23 +35,27 @@ '--backlog', type=int, default=DEFAULT_BACKLOG, - help='Default: 100. Maximum number of pending connections to proxy server') + help='Default: 100. Maximum number of pending connections to proxy server', +) flags.add_argument( '--hostname', type=str, default=str(DEFAULT_IPV6_HOSTNAME), - help='Default: ::1. Server IP address.') + help='Default: ::1. Server IP address.', +) flags.add_argument( '--port', type=int, default=DEFAULT_PORT, - help='Default: 8899. Server port.') + help='Default: 8899. Server port.', +) flags.add_argument( '--num-workers', type=int, default=DEFAULT_NUM_WORKERS, - help='Defaults to number of CPU cores.') + help='Defaults to number of CPU cores.', +) class AcceptorPool: @@ -73,8 +77,10 @@ class AcceptorPool: `work_klass` must implement `work.Work` class. """ - def __init__(self, flags: argparse.Namespace, - work_klass: Type[Work], event_queue: Optional[EventQueue] = None) -> None: + def __init__( + self, flags: argparse.Namespace, + work_klass: Type[Work], event_queue: Optional[EventQueue] = None, + ) -> None: self.flags = flags self.socket: Optional[socket.socket] = None self.acceptors: List[Acceptor] = [] @@ -109,7 +115,8 @@ def start_workers(self) -> None: logger.debug( 'Started acceptor#%d process %d', acceptor_id, - acceptor.pid) + acceptor.pid, + ) self.acceptors.append(acceptor) self.work_queues.append(work_queue[0]) logger.info('Started %d workers' % self.flags.num_workers) @@ -132,7 +139,7 @@ def setup(self) -> None: send_handle( self.work_queues[index], self.socket.fileno(), - self.acceptors[index].pid + self.acceptors[index].pid, ) self.work_queues[index].close() self.socket.close() diff --git a/proxy/core/acceptor/threadless.py b/proxy/core/acceptor/threadless.py index 0b5d40b3d8..8caf21739b 100644 --- a/proxy/core/acceptor/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -56,7 +56,8 @@ def __init__( client_queue: connection.Connection, flags: argparse.Namespace, work_klass: Type[Work], - event_queue: Optional[EventQueue] = None) -> None: + event_queue: Optional[EventQueue] = None, + ) -> None: super().__init__() self.client_queue = client_queue self.flags = flags @@ -69,8 +70,10 @@ def __init__( self.loop: Optional[asyncio.AbstractEventLoop] = None @contextlib.contextmanager - def selected_events(self) -> Generator[Tuple[Readables, Writables], - None, None]: + def selected_events(self) -> Generator[ + Tuple[Readables, Writables], + None, None, + ]: events: Dict[socket.socket, int] = {} for work in self.works.values(): events.update(work.get_events()) @@ -92,12 +95,14 @@ def selected_events(self) -> Generator[Tuple[Readables, Writables], async def handle_events( self, fileno: int, readables: Readables, - writables: Writables) -> bool: + writables: Writables + ) -> bool: return self.works[fileno].handle_events(readables, writables) # TODO: Use correct future typing annotations async def wait_for_tasks( - self, tasks: Dict[int, Any]) -> None: + self, tasks: Dict[int, Any] + ) -> None: for work_id in tasks: # TODO: Resolving one handle_events here can block # resolution of other tasks. This can happen when handle_events @@ -116,7 +121,8 @@ async def wait_for_tasks( def fromfd(self, fileno: int) -> socket.socket: return socket.fromfd( fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, - type=socket.SOCK_STREAM) + type=socket.SOCK_STREAM, + ) def accept_client(self) -> None: addr = self.client_queue.recv() @@ -124,19 +130,20 @@ def accept_client(self) -> None: self.works[fileno] = self.work_klass( TcpClientConnection(conn=self.fromfd(fileno), addr=addr), flags=self.flags, - event_queue=self.event_queue + event_queue=self.event_queue, ) self.works[fileno].publish_event( event_name=eventNames.WORK_STARTED, event_payload={'fileno': fileno, 'addr': addr}, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) try: self.works[fileno].initialize() except Exception as e: logger.exception( 'Exception occurred during initialization', - exc_info=e) + exc_info=e, + ) self.cleanup(fileno) def cleanup_inactive(self) -> None: @@ -170,7 +177,8 @@ def run_once(self) -> None: tasks = {} for fileno in self.works: tasks[fileno] = self.loop.create_task( - self.handle_events(fileno, readables, writables)) + self.handle_events(fileno, readables, writables), + ) # Accepted client connection from Acceptor if self.client_queue in readables: self.accept_client() @@ -180,8 +188,10 @@ def run_once(self) -> None: self.cleanup_inactive() def run(self) -> None: - setup_logger(self.flags.log_file, self.flags.log_level, - self.flags.log_format) + setup_logger( + self.flags.log_file, self.flags.log_level, + self.flags.log_format, + ) try: self.selector = selectors.DefaultSelector() self.selector.register(self.client_queue, selectors.EVENT_READ) diff --git a/proxy/core/acceptor/work.py b/proxy/core/acceptor/work.py index 6bf3880ecf..dcfabc2831 100644 --- a/proxy/core/acceptor/work.py +++ b/proxy/core/acceptor/work.py @@ -28,7 +28,8 @@ def __init__( client: TcpClientConnection, flags: argparse.Namespace, event_queue: Optional[EventQueue] = None, - uid: Optional[UUID] = None) -> None: + uid: Optional[UUID] = None, + ) -> None: self.client = client self.flags = flags self.event_queue = event_queue @@ -43,7 +44,8 @@ def get_events(self) -> Dict[socket.socket, int]: def handle_events( self, readables: Readables, - writables: Writables) -> bool: + writables: Writables, + ) -> bool: """Handle readable and writable sockets. Return True to shutdown work.""" @@ -63,7 +65,7 @@ def shutdown(self) -> None: self.publish_event( event_name=eventNames.WORK_FINISHED, event_payload={}, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) def run(self) -> None: @@ -77,7 +79,8 @@ def publish_event( self, event_name: int, event_payload: Dict[str, Any], - publisher_id: Optional[str] = None) -> None: + publisher_id: Optional[str] = None, + ) -> None: """Convenience method provided to publish events into the global event queue.""" if not self.flags.enable_events: return @@ -86,5 +89,5 @@ def publish_event( self.uid.hex, event_name, event_payload, - publisher_id + publisher_id, ) diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py index d4f516836e..0eda82ba21 100644 --- a/proxy/core/base/tcp_server.py +++ b/proxy/core/base/tcp_server.py @@ -61,7 +61,8 @@ def get_events(self) -> Dict[socket.socket, int]: def handle_events( self, readables: Readables, - writables: Writables) -> bool: + writables: Writables, + ) -> bool: """Return True to shutdown work.""" do_shutdown = False if self.client.connection in readables: @@ -71,25 +72,33 @@ def handle_events( # Client closed connection, signal shutdown print( 'Connection closed by client {0}'.format( - self.client.addr)) + self.client.addr, + ), + ) do_shutdown = True else: r = self.handle_data(data) if isinstance(r, bool) and r is True: print( 'Implementation signaled shutdown for client {0}'.format( - self.client.addr)) + self.client.addr, + ), + ) if self.client.has_buffer(): print( 'Client {0} has pending buffer, will be flushed before shutting down'.format( - self.client.addr)) + self.client.addr, + ), + ) self.must_flush_before_shutdown = True else: do_shutdown = True except ConnectionResetError: print( 'Connection reset by client {0}'.format( - self.client.addr)) + self.client.addr, + ), + ) do_shutdown = True if self.client.connection in writables: @@ -102,5 +111,7 @@ def handle_events( if do_shutdown: print( 'Shutting down client {0} connection'.format( - self.client.addr)) + self.client.addr, + ), + ) return do_shutdown diff --git a/proxy/core/base/tcp_tunnel.py b/proxy/core/base/tcp_tunnel.py index d83053ae11..8703247cf2 100644 --- a/proxy/core/base/tcp_tunnel.py +++ b/proxy/core/base/tcp_tunnel.py @@ -38,8 +38,11 @@ def initialize(self) -> None: def shutdown(self) -> None: if self.upstream: - print('Connection closed with upstream {0}:{1}'.format( - text_(self.request.host), self.request.port)) + print( + 'Connection closed with upstream {0}:{1}'.format( + text_(self.request.host), self.request.port, + ), + ) self.upstream.close() super().shutdown() @@ -61,7 +64,8 @@ def get_events(self) -> Dict[socket.socket, int]: def handle_events( self, readables: Readables, - writables: Writables) -> bool: + writables: Writables, + ) -> bool: # Handle client events do_shutdown: bool = super().handle_events(readables, writables) if do_shutdown: @@ -82,7 +86,11 @@ def handle_events( def connect_upstream(self) -> None: assert self.request.host and self.request.port self.upstream = TcpServerConnection( - text_(self.request.host), self.request.port) + text_(self.request.host), self.request.port, + ) self.upstream.connect() - print('Connection established with upstream {0}:{1}'.format( - text_(self.request.host), self.request.port)) + print( + 'Connection established with upstream {0}:{1}'.format( + text_(self.request.host), self.request.port, + ), + ) diff --git a/proxy/core/connection/client.py b/proxy/core/connection/client.py index 62597a10d4..bbc72f207e 100644 --- a/proxy/core/connection/client.py +++ b/proxy/core/connection/client.py @@ -18,9 +18,11 @@ class TcpClientConnection(TcpConnection): """An accepted client connection request.""" - def __init__(self, - conn: Union[ssl.SSLSocket, socket.socket], - addr: Tuple[str, int]): + def __init__( + self, + conn: Union[ssl.SSLSocket, socket.socket], + addr: Tuple[str, int], + ): super().__init__(tcpConnectionTypes.CLIENT) self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn self.addr: Tuple[str, int] = addr @@ -40,5 +42,6 @@ def wrap(self, keyfile: str, certfile: str) -> None: # ca_certs=self.flags.ca_cert_file, certfile=certfile, keyfile=keyfile, - ssl_version=ssl.PROTOCOL_TLS) + ssl_version=ssl.PROTOCOL_TLS, + ) self.connection.setblocking(False) diff --git a/proxy/core/connection/connection.py b/proxy/core/connection/connection.py index 73eabe208b..daf464337d 100644 --- a/proxy/core/connection/connection.py +++ b/proxy/core/connection/connection.py @@ -19,10 +19,12 @@ logger = logging.getLogger(__name__) -TcpConnectionTypes = NamedTuple('TcpConnectionTypes', [ - ('SERVER', int), - ('CLIENT', int), -]) +TcpConnectionTypes = NamedTuple( + 'TcpConnectionTypes', [ + ('SERVER', int), + ('CLIENT', int), + ], +) tcpConnectionTypes = TcpConnectionTypes(1, 2) @@ -55,14 +57,16 @@ def send(self, data: bytes) -> int: return self.connection.send(data) def recv( - self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> Optional[memoryview]: + self, buffer_size: int = DEFAULT_BUFFER_SIZE, + ) -> Optional[memoryview]: """Users must handle socket.error exceptions""" data: bytes = self.connection.recv(buffer_size) if len(data) == 0: return None logger.debug( 'received %d bytes from %s' % - (len(data), self.tag)) + (len(data), self.tag), + ) # logger.info(data) return memoryview(data) diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index b19d8e4734..e88b1a8d1f 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -40,11 +40,13 @@ def connect(self) -> None: def wrap(self, hostname: str, ca_file: Optional[str]) -> None: ctx = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH, cafile=ca_file) + ssl.Purpose.SERVER_AUTH, cafile=ca_file, + ) ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 ctx.check_hostname = True self.connection.setblocking(True) self._conn = ctx.wrap_socket( self.connection, - server_hostname=hostname) + server_hostname=hostname, + ) self.connection.setblocking(False) diff --git a/proxy/core/event/dispatcher.py b/proxy/core/event/dispatcher.py index 80cb13043d..f65d5bf892 100644 --- a/proxy/core/event/dispatcher.py +++ b/proxy/core/event/dispatcher.py @@ -50,7 +50,8 @@ class EventDispatcher: def __init__( self, shutdown: threading.Event, - event_queue: EventQueue) -> None: + event_queue: EventQueue, + ) -> None: self.shutdown: threading.Event = shutdown self.event_queue: EventQueue = event_queue self.subscribers: Dict[str, DictQueueType] = {} diff --git a/proxy/core/event/manager.py b/proxy/core/event/manager.py index 74e10b033a..9986f82ed6 100644 --- a/proxy/core/event/manager.py +++ b/proxy/core/event/manager.py @@ -40,10 +40,10 @@ def start_event_dispatcher(self) -> None: assert self.event_queue self.event_dispatcher = EventDispatcher( shutdown=self.event_dispatcher_shutdown, - event_queue=self.event_queue + event_queue=self.event_queue, ) self.event_dispatcher_thread = threading.Thread( - target=self.event_dispatcher.run + target=self.event_dispatcher.run, ) self.event_dispatcher_thread.start() logger.debug('Thread ID: %d', self.event_dispatcher_thread.ident) @@ -55,4 +55,5 @@ def stop_event_dispatcher(self) -> None: self.event_dispatcher_thread.join() logger.debug( 'Shutdown of global event dispatcher thread %d successful', - self.event_dispatcher_thread.ident) + self.event_dispatcher_thread.ident, + ) diff --git a/proxy/core/event/names.py b/proxy/core/event/names.py index e56f5626b9..980c1ea89b 100644 --- a/proxy/core/event/names.py +++ b/proxy/core/event/names.py @@ -13,14 +13,16 @@ # Name of the events that eventing framework will support # Ideally this must be configurable via command line or # at-least extendable via plugins. -EventNames = NamedTuple('EventNames', [ - ('SUBSCRIBE', int), - ('UNSUBSCRIBE', int), - ('WORK_STARTED', int), - ('WORK_FINISHED', int), - ('REQUEST_COMPLETE', int), - ('RESPONSE_HEADERS_COMPLETE', int), - ('RESPONSE_CHUNK_RECEIVED', int), - ('RESPONSE_COMPLETE', int), -]) +EventNames = NamedTuple( + 'EventNames', [ + ('SUBSCRIBE', int), + ('UNSUBSCRIBE', int), + ('WORK_STARTED', int), + ('WORK_FINISHED', int), + ('REQUEST_COMPLETE', int), + ('RESPONSE_HEADERS_COMPLETE', int), + ('RESPONSE_CHUNK_RECEIVED', int), + ('RESPONSE_COMPLETE', int), + ], +) eventNames = EventNames(1, 2, 3, 4, 5, 6, 7, 8) diff --git a/proxy/core/event/queue.py b/proxy/core/event/queue.py index b4e6ab9615..23aaf4e689 100644 --- a/proxy/core/event/queue.py +++ b/proxy/core/event/queue.py @@ -45,7 +45,7 @@ def publish( request_id: str, event_name: int, event_payload: Dict[str, Any], - publisher_id: Optional[str] = None + publisher_id: Optional[str] = None, ) -> None: self.queue.put({ 'request_id': request_id, @@ -60,7 +60,8 @@ def publish( def subscribe( self, sub_id: str, - channel: DictQueueType) -> None: + channel: DictQueueType, + ) -> None: """Subscribe to global events.""" self.queue.put({ 'event_name': eventNames.SUBSCRIBE, @@ -69,7 +70,8 @@ def subscribe( def unsubscribe( self, - sub_id: str) -> None: + sub_id: str, + ) -> None: """Unsubscribe by subscriber id.""" self.queue.put({ 'event_name': eventNames.UNSUBSCRIBE, diff --git a/proxy/core/event/subscriber.py b/proxy/core/event/subscriber.py index 90648e0d87..841fe12f5b 100644 --- a/proxy/core/event/subscriber.py +++ b/proxy/core/event/subscriber.py @@ -39,13 +39,15 @@ def subscribe(self, callback: Callable[[Dict[str, Any]], None]) -> None: self.relay_channel = self.manager.Queue() self.relay_thread = threading.Thread( target=self.relay, - args=(self.relay_shutdown, self.relay_channel, callback)) + args=(self.relay_shutdown, self.relay_channel, callback), + ) self.relay_thread.start() self.relay_sub_id = uuid.uuid4().hex self.event_queue.subscribe(self.relay_sub_id, self.relay_channel) logger.debug( 'Subscribed relay sub id %s from core events', - self.relay_sub_id) + self.relay_sub_id, + ) def unsubscribe(self) -> None: if self.relay_sub_id is None: @@ -68,7 +70,8 @@ def unsubscribe(self) -> None: self.relay_thread.join() logger.debug( 'Un-subscribed relay sub id %s from core events', - self.relay_sub_id) + self.relay_sub_id, + ) self.relay_thread = None self.relay_shutdown = None @@ -79,7 +82,8 @@ def unsubscribe(self) -> None: def relay( shutdown: threading.Event, channel: DictQueueType, - callback: Callable[[Dict[str, Any]], None]) -> None: + callback: Callable[[Dict[str, Any]], None], + ) -> None: while not shutdown.is_set(): try: ev = channel.get(timeout=1) diff --git a/proxy/core/ssh/tunnel.py b/proxy/core/ssh/tunnel.py index 07169ec019..bb494010f1 100644 --- a/proxy/core/ssh/tunnel.py +++ b/proxy/core/ssh/tunnel.py @@ -25,7 +25,8 @@ def __init__( remote_addr: Tuple[str, int], private_pem_key: str, remote_proxy_port: int, - conn_handler: Callable[[paramiko.channel.Channel], None]) -> None: + conn_handler: Callable[[paramiko.channel.Channel], None], + ) -> None: self.remote_addr = remote_addr self.ssh_username = ssh_username self.private_pem_key = private_pem_key @@ -41,7 +42,7 @@ def run(self) -> None: hostname=self.remote_addr[0], port=self.remote_addr[1], username=self.ssh_username, - key_filename=self.private_pem_key + key_filename=self.private_pem_key, ) print('SSH connection established...') transport: Optional[paramiko.transport.Transport] = ssh.get_transport( @@ -51,7 +52,8 @@ def run(self) -> None: print('Tunnel port forward setup successful...') while True: conn: Optional[paramiko.channel.Channel] = transport.accept( - timeout=1) + timeout=1, + ) assert conn is not None e = transport.get_exception() if e: diff --git a/proxy/dashboard/dashboard.py b/proxy/dashboard/dashboard.py index dda3ae0ed3..25fb067160 100644 --- a/proxy/dashboard/dashboard.py +++ b/proxy/dashboard/dashboard.py @@ -63,18 +63,25 @@ def handle_request(self, request: HttpParser) -> None: if request.path == b'/dashboard/': self.client.queue( HttpWebServerPlugin.read_and_build_static_file_response( - os.path.join(self.flags.static_server_dir, 'dashboard', 'proxy.html'))) + os.path.join(self.flags.static_server_dir, 'dashboard', 'proxy.html'), + ), + ) elif request.path in ( b'/dashboard', - b'/dashboard/proxy.html'): - self.client.queue(memoryview(build_http_response( - httpStatusCodes.PERMANENT_REDIRECT, reason=b'Permanent Redirect', - headers={ - b'Location': b'/dashboard/', - b'Content-Length': b'0', - b'Connection': b'close', - } - ))) + b'/dashboard/proxy.html', + ): + self.client.queue( + memoryview( + build_http_response( + httpStatusCodes.PERMANENT_REDIRECT, reason=b'Permanent Redirect', + headers={ + b'Location': b'/dashboard/', + b'Content-Length': b'0', + b'Connection': b'close', + }, + ), + ), + ) def on_websocket_open(self) -> None: logger.info('app ws opened') @@ -104,6 +111,11 @@ def on_websocket_close(self) -> None: def reply(self, data: Dict[str, Any]) -> None: self.client.queue( - memoryview(WebsocketFrame.text( - bytes_( - json.dumps(data))))) + memoryview( + WebsocketFrame.text( + bytes_( + json.dumps(data), + ), + ), + ), + ) diff --git a/proxy/dashboard/inspect_traffic.py b/proxy/dashboard/inspect_traffic.py index df82a37f90..4689d293bb 100644 --- a/proxy/dashboard/inspect_traffic.py +++ b/proxy/dashboard/inspect_traffic.py @@ -37,23 +37,31 @@ def handle_message(self, message: Dict[str, Any]) -> None: # inspection can only be enabled if --enable-events is used if not self.flags.enable_events: self.client.queue( - memoryview(WebsocketFrame.text( - bytes_( - json.dumps( - {'id': message['id'], 'response': 'not enabled'}) - ) - )) + memoryview( + WebsocketFrame.text( + bytes_( + json.dumps( + {'id': message['id'], 'response': 'not enabled'}, + ), + ), + ), + ), ) else: self.subscriber.subscribe( lambda event: InspectTrafficPlugin.callback( - self.client, event)) + self.client, event, + ), + ) self.reply( - {'id': message['id'], 'response': 'inspection_enabled'}) + {'id': message['id'], 'response': 'inspection_enabled'}, + ) elif message['method'] == 'disable_inspection': self.subscriber.unsubscribe() - self.reply({'id': message['id'], - 'response': 'inspection_disabled'}) + self.reply({ + 'id': message['id'], + 'response': 'inspection_disabled', + }) else: raise NotImplementedError() @@ -61,6 +69,11 @@ def handle_message(self, message: Dict[str, Any]) -> None: def callback(client: TcpClientConnection, event: Dict[str, Any]) -> None: event['push'] = 'inspect_traffic' client.queue( - memoryview(WebsocketFrame.text( - bytes_( - json.dumps(event))))) + memoryview( + WebsocketFrame.text( + bytes_( + json.dumps(event), + ), + ), + ), + ) diff --git a/proxy/dashboard/plugin.py b/proxy/dashboard/plugin.py index b3787ac2df..8e36a600fc 100644 --- a/proxy/dashboard/plugin.py +++ b/proxy/dashboard/plugin.py @@ -26,7 +26,8 @@ def __init__( self, flags: argparse.Namespace, client: TcpClientConnection, - event_queue: EventQueue) -> None: + event_queue: EventQueue, + ) -> None: self.flags = flags self.client = client self.event_queue = event_queue @@ -51,6 +52,11 @@ def disconnected(self) -> None: def reply(self, data: Dict[str, Any]) -> None: self.client.queue( - memoryview(WebsocketFrame.text( - bytes_( - json.dumps(data))))) + memoryview( + WebsocketFrame.text( + bytes_( + json.dumps(data), + ), + ), + ), + ) diff --git a/proxy/http/chunk_parser.py b/proxy/http/chunk_parser.py index 2a33294e71..6a5258097a 100644 --- a/proxy/http/chunk_parser.py +++ b/proxy/http/chunk_parser.py @@ -14,11 +14,13 @@ from ..common.constants import CRLF, DEFAULT_BUFFER_SIZE -ChunkParserStates = NamedTuple('ChunkParserStates', [ - ('WAITING_FOR_SIZE', int), - ('WAITING_FOR_DATA', int), - ('COMPLETE', int), -]) +ChunkParserStates = NamedTuple( + 'ChunkParserStates', [ + ('WAITING_FOR_SIZE', int), + ('WAITING_FOR_DATA', int), + ('COMPLETE', int), + ], +) chunkParserStates = ChunkParserStates(1, 2, 3) diff --git a/proxy/http/codes.py b/proxy/http/codes.py index 042d27e4f9..21eab204a4 100644 --- a/proxy/http/codes.py +++ b/proxy/http/codes.py @@ -11,37 +11,39 @@ from typing import NamedTuple -HttpStatusCodes = NamedTuple('HttpStatusCodes', [ - # 1xx - ('CONTINUE', int), - ('SWITCHING_PROTOCOLS', int), - # 2xx - ('OK', int), - # 3xx - ('MOVED_PERMANENTLY', int), - ('SEE_OTHER', int), - ('TEMPORARY_REDIRECT', int), - ('PERMANENT_REDIRECT', int), - # 4xx - ('BAD_REQUEST', int), - ('UNAUTHORIZED', int), - ('FORBIDDEN', int), - ('NOT_FOUND', int), - ('PROXY_AUTH_REQUIRED', int), - ('REQUEST_TIMEOUT', int), - ('I_AM_A_TEAPOT', int), - # 5xx - ('INTERNAL_SERVER_ERROR', int), - ('NOT_IMPLEMENTED', int), - ('BAD_GATEWAY', int), - ('GATEWAY_TIMEOUT', int), - ('NETWORK_READ_TIMEOUT_ERROR', int), - ('NETWORK_CONNECT_TIMEOUT_ERROR', int), -]) +HttpStatusCodes = NamedTuple( + 'HttpStatusCodes', [ + # 1xx + ('CONTINUE', int), + ('SWITCHING_PROTOCOLS', int), + # 2xx + ('OK', int), + # 3xx + ('MOVED_PERMANENTLY', int), + ('SEE_OTHER', int), + ('TEMPORARY_REDIRECT', int), + ('PERMANENT_REDIRECT', int), + # 4xx + ('BAD_REQUEST', int), + ('UNAUTHORIZED', int), + ('FORBIDDEN', int), + ('NOT_FOUND', int), + ('PROXY_AUTH_REQUIRED', int), + ('REQUEST_TIMEOUT', int), + ('I_AM_A_TEAPOT', int), + # 5xx + ('INTERNAL_SERVER_ERROR', int), + ('NOT_IMPLEMENTED', int), + ('BAD_GATEWAY', int), + ('GATEWAY_TIMEOUT', int), + ('NETWORK_READ_TIMEOUT_ERROR', int), + ('NETWORK_CONNECT_TIMEOUT_ERROR', int), + ], +) httpStatusCodes = HttpStatusCodes( 100, 101, 200, 301, 303, 307, 308, 400, 401, 403, 404, 407, 408, 418, - 500, 501, 502, 504, 598, 599 + 500, 501, 502, 504, 598, 599, ) diff --git a/proxy/http/exception/http_request_rejected.py b/proxy/http/exception/http_request_rejected.py index 46fd9b04a0..a0fa810fc1 100644 --- a/proxy/http/exception/http_request_rejected.py +++ b/proxy/http/exception/http_request_rejected.py @@ -21,11 +21,13 @@ class HttpRequestRejected(HttpProtocolException): Connections can either be dropped/closed or optionally an HTTP status code can be returned.""" - def __init__(self, - status_code: Optional[int] = None, - reason: Optional[bytes] = None, - headers: Optional[Dict[bytes, bytes]] = None, - body: Optional[bytes] = None): + def __init__( + self, + status_code: Optional[int] = None, + reason: Optional[bytes] = None, + headers: Optional[Dict[bytes, bytes]] = None, + body: Optional[bytes] = None, + ): self.status_code: Optional[int] = status_code self.reason: Optional[bytes] = reason self.headers: Optional[Dict[bytes, bytes]] = headers @@ -33,10 +35,12 @@ def __init__(self, def response(self, _request: HttpParser) -> Optional[memoryview]: if self.status_code: - return memoryview(build_http_response( - status_code=self.status_code, - reason=self.reason, - headers=self.headers, - body=self.body - )) + return memoryview( + build_http_response( + status_code=self.status_code, + reason=self.reason, + headers=self.headers, + body=self.body, + ), + ) return None diff --git a/proxy/http/exception/proxy_auth_failed.py b/proxy/http/exception/proxy_auth_failed.py index ae1c6a4443..053837e8a8 100644 --- a/proxy/http/exception/proxy_auth_failed.py +++ b/proxy/http/exception/proxy_auth_failed.py @@ -20,15 +20,18 @@ class ProxyAuthenticationFailed(HttpProtocolException): """Exception raised when Http Proxy auth is enabled and incoming request doesn't present necessary credentials.""" - RESPONSE_PKT = memoryview(build_http_response( - httpStatusCodes.PROXY_AUTH_REQUIRED, - reason=b'Proxy Authentication Required', - headers={ - PROXY_AGENT_HEADER_KEY: PROXY_AGENT_HEADER_VALUE, - b'Proxy-Authenticate': b'Basic', - b'Connection': b'close', - }, - body=b'Proxy Authentication Required')) + RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.PROXY_AUTH_REQUIRED, + reason=b'Proxy Authentication Required', + headers={ + PROXY_AGENT_HEADER_KEY: PROXY_AGENT_HEADER_VALUE, + b'Proxy-Authenticate': b'Basic', + b'Connection': b'close', + }, + body=b'Proxy Authentication Required', + ), + ) def response(self, _request: HttpParser) -> memoryview: return self.RESPONSE_PKT diff --git a/proxy/http/exception/proxy_conn_failed.py b/proxy/http/exception/proxy_conn_failed.py index 0cec224277..8cbd73d093 100644 --- a/proxy/http/exception/proxy_conn_failed.py +++ b/proxy/http/exception/proxy_conn_failed.py @@ -19,15 +19,17 @@ class ProxyConnectionFailed(HttpProtocolException): """Exception raised when HttpProxyPlugin is unable to establish connection to upstream server.""" - RESPONSE_PKT = memoryview(build_http_response( - httpStatusCodes.BAD_GATEWAY, - reason=b'Bad Gateway', - headers={ - PROXY_AGENT_HEADER_KEY: PROXY_AGENT_HEADER_VALUE, - b'Connection': b'close' - }, - body=b'Bad Gateway' - )) + RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.BAD_GATEWAY, + reason=b'Bad Gateway', + headers={ + PROXY_AGENT_HEADER_KEY: PROXY_AGENT_HEADER_VALUE, + b'Connection': b'close', + }, + body=b'Bad Gateway', + ), + ) def __init__(self, host: str, port: int, reason: str): self.host: str = host diff --git a/proxy/http/handler.py b/proxy/http/handler.py index d4f82b9a5b..41ec608f86 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -43,13 +43,14 @@ help='Default: 1 MB. Maximum amount of data received from the ' 'client in a single recv() operation. Bump this ' 'value for faster uploads at the expense of ' - 'increased RAM.') + 'increased RAM.', +) flags.add_argument( '--key-file', type=str, default=DEFAULT_KEY_FILE, help='Default: None. Server key file to enable end-to-end TLS encryption with clients. ' - 'If used, must also pass --cert-file.' + 'If used, must also pass --cert-file.', ) flags.add_argument( '--timeout', @@ -58,7 +59,7 @@ help='Default: ' + str(DEFAULT_TIMEOUT) + '. Number of seconds after which ' 'an inactive connection must be dropped. Inactivity is defined by no ' - 'data sent or received by the client.' + 'data sent or received by the client.', ) @@ -68,10 +69,12 @@ class HttpProtocolHandler(Work): Accepts `Client` connection and delegates to HttpProtocolHandlerPlugin. """ - def __init__(self, client: TcpClientConnection, - flags: argparse.Namespace, - event_queue: Optional[EventQueue] = None, - uid: Optional[UUID] = None): + def __init__( + self, client: TcpClientConnection, + flags: argparse.Namespace, + event_queue: Optional[EventQueue] = None, + uid: Optional[UUID] = None, + ): super().__init__(client, flags, event_queue, uid) self.start_time: float = time.time() @@ -87,7 +90,8 @@ def encryption_enabled(self) -> bool: self.flags.certfile is not None def optionally_wrap_socket( - self, conn: socket.socket) -> Union[ssl.SSLSocket, socket.socket]: + self, conn: socket.socket, + ) -> Union[ssl.SSLSocket, socket.socket]: """Attempts to wrap accepted client connection using provided certificates. Shutdown and closes client connection upon error. @@ -112,7 +116,8 @@ def initialize(self) -> None: self.flags, self.client, self.request, - self.event_queue) + self.event_queue, + ) self.plugins[instance.name()] = instance logger.debug('Handling connection %r' % self.client.connection) @@ -124,7 +129,7 @@ def is_inactive(self) -> bool: def get_events(self) -> Dict[socket.socket, int]: events: Dict[socket.socket, int] = { - self.client.connection: selectors.EVENT_READ + self.client.connection: selectors.EVENT_READ, } if self.client.has_buffer(): events[self.client.connection] |= selectors.EVENT_WRITE @@ -146,7 +151,8 @@ def get_events(self) -> Dict[socket.socket, int]: def handle_events( self, readables: Readables, - writables: Writables) -> bool: + writables: Writables, + ) -> bool: """Returns True if proxy must teardown.""" # Flush buffer for ready to write sockets teardown = self.handle_writables(writables) @@ -184,7 +190,8 @@ def shutdown(self) -> None: logger.debug( 'Closing client connection %r ' 'at address %r has buffer %s' % - (self.client.connection, self.client.addr, self.client.has_buffer())) + (self.client.connection, self.client.addr, self.client.has_buffer()), + ) conn = self.client.connection # Unwrap if wrapped before shutdown. @@ -209,10 +216,12 @@ def flush(self) -> None: try: self.selector.register( self.client.connection, - selectors.EVENT_WRITE) + selectors.EVENT_WRITE, + ) while self.client.has_buffer(): - ev: List[Tuple[selectors.SelectorKey, int] - ] = self.selector.select(timeout=1) + ev: List[ + Tuple[selectors.SelectorKey, int] + ] = self.selector.select(timeout=1) if len(ev) == 0: continue self.client.flush() @@ -239,7 +248,8 @@ def handle_writables(self, writables: Writables) -> bool: self.client.flush() except BrokenPipeError: logger.error( - 'BrokenPipeError when flushing buffer for client') + 'BrokenPipeError when flushing buffer for client', + ) return True except OSError: logger.error('OSError when flushing buffer to client') @@ -254,7 +264,8 @@ def handle_readables(self, readables: Readables) -> bool: client_data = self.client.recv(self.flags.client_recvbuf_size) except ssl.SSLWantReadError: # Try again later logger.warning( - 'SSLWantReadError encountered while reading from client, will retry ...') + 'SSLWantReadError encountered while reading from client, will retry ...', + ) return False except socket.error as e: if e.errno == errno.ECONNRESET: @@ -262,7 +273,8 @@ def handle_readables(self, readables: Readables) -> bool: else: logger.exception( 'Exception while receiving from %s connection %r with reason %r' % - (self.client.tag, self.client.connection, e)) + (self.client.tag, self.client.connection, e), + ) return True if client_data is None: @@ -296,7 +308,8 @@ def handle_readables(self, readables: Readables) -> bool: upgraded_sock = plugin.on_request_complete() if isinstance(upgraded_sock, ssl.SSLSocket): logger.debug( - 'Updated client conn to %s', upgraded_sock) + 'Updated client conn to %s', upgraded_sock, + ) self.client._conn = upgraded_sock for plugin_ in self.plugins.values(): if plugin_ != plugin: @@ -305,7 +318,8 @@ def handle_readables(self, readables: Readables) -> bool: return True except HttpProtocolException as e: logger.debug( - 'HttpProtocolException type raised') + 'HttpProtocolException type raised', + ) response: Optional[memoryview] = e.response(self.request) if response: self.client.queue(response) @@ -314,8 +328,10 @@ def handle_readables(self, readables: Readables) -> bool: @contextlib.contextmanager def selected_events(self) -> \ - Generator[Tuple[Readables, Writables], - None, None]: + Generator[ + Tuple[Readables, Writables], + None, None, + ]: events = self.get_events() for fd in events: self.selector.register(fd, events[fd]) @@ -346,7 +362,8 @@ def run(self) -> None: if self.is_inactive(): logger.debug( 'Client buffer is empty and maximum inactivity has reached ' - 'between client and server connection, tearing down...') + 'between client and server connection, tearing down...', + ) break teardown = self.run_once() if teardown: @@ -358,6 +375,7 @@ def run(self) -> None: except Exception as e: logger.exception( 'Exception while handling connection %r' % - self.client.connection, exc_info=e) + self.client.connection, exc_info=e, + ) finally: self.shutdown() diff --git a/proxy/http/inspector/devtools.py b/proxy/http/inspector/devtools.py index 50b34192a4..07b264fcd8 100644 --- a/proxy/http/inspector/devtools.py +++ b/proxy/http/inspector/devtools.py @@ -30,7 +30,7 @@ type=str, default=DEFAULT_DEVTOOLS_WS_PATH, help='Default: /devtools. Only applicable ' - 'if --enable-devtools is used.' + 'if --enable-devtools is used.', ) @@ -52,7 +52,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def routes(self) -> List[Tuple[int, str]]: return [ - (httpProtocolTypes.WEBSOCKET, text_(self.flags.devtools_ws_path)) + (httpProtocolTypes.WEBSOCKET, text_(self.flags.devtools_ws_path)), ] def handle_request(self, request: HttpParser) -> None: @@ -60,7 +60,8 @@ def handle_request(self, request: HttpParser) -> None: def on_websocket_open(self) -> None: self.subscriber.subscribe( - lambda event: CoreEventsToDevtoolsProtocol.transformer(self.client, event)) + lambda event: CoreEventsToDevtoolsProtocol.transformer(self.client, event), + ) def on_websocket_message(self, frame: WebsocketFrame) -> None: try: @@ -88,7 +89,7 @@ def handle_devtools_message(self, message: Dict[str, Any]) -> None: 'Emulation.canEmulate', ): data: Dict[str, Any] = { - 'result': False + 'result': False, } elif method == 'Page.getResourceTree': data = { @@ -100,9 +101,9 @@ def handle_devtools_message(self, message: Dict[str, Any]) -> None: 'mimeType': 'other', }, 'childFrames': [], - 'resources': [] - } - } + 'resources': [], + }, + }, } elif method == 'Network.getResponseBody': connection_id = message['params']['requestId'] @@ -110,7 +111,7 @@ def handle_devtools_message(self, message: Dict[str, Any]) -> None: 'result': { 'body': text_(CoreEventsToDevtoolsProtocol.RESPONSES[connection_id]), 'base64Encoded': False, - } + }, } else: logging.warning('Unhandled devtools method %s', method) diff --git a/proxy/http/inspector/transformer.py b/proxy/http/inspector/transformer.py index 450ce7bba1..ebc5ddcede 100644 --- a/proxy/http/inspector/transformer.py +++ b/proxy/http/inspector/transformer.py @@ -29,14 +29,17 @@ class CoreEventsToDevtoolsProtocol: RESPONSES: Dict[str, bytes] = {} @staticmethod - def transformer(client: TcpClientConnection, - event: Dict[str, Any]) -> None: + def transformer( + client: TcpClientConnection, + event: Dict[str, Any], + ) -> None: event_name = event['event_name'] if event_name == eventNames.REQUEST_COMPLETE: data = CoreEventsToDevtoolsProtocol.request_complete(event) elif event_name == eventNames.RESPONSE_HEADERS_COMPLETE: data = CoreEventsToDevtoolsProtocol.response_headers_complete( - event) + event, + ) elif event_name == eventNames.RESPONSE_CHUNK_RECEIVED: data = CoreEventsToDevtoolsProtocol.response_chunk_received(event) elif event_name == eventNames.RESPONSE_COMPLETE: @@ -45,9 +48,14 @@ def transformer(client: TcpClientConnection, # drop core events unrelated to Devtools return client.queue( - memoryview(WebsocketFrame.text( - bytes_( - json.dumps(data))))) + memoryview( + WebsocketFrame.text( + bytes_( + json.dumps(data), + ), + ), + ), + ) @staticmethod def request_complete(event: Dict[str, Any]) -> Dict[str, Any]: @@ -75,7 +83,7 @@ def request_complete(event: Dict[str, Any]) -> Dict[str, Any]: 'mixedContentType': 'none', }, 'initiator': { - 'type': 'other' + 'type': 'other', }, } @@ -120,7 +128,7 @@ def response_headers_complete(event: Dict[str, Any]) -> Dict[str, Any]: 'requestHeaders': '', 'remoteIPAddress': '', 'remotePort': '', - } + }, } @staticmethod diff --git a/proxy/http/methods.py b/proxy/http/methods.py index 63b8cd1e98..823c9e8e97 100644 --- a/proxy/http/methods.py +++ b/proxy/http/methods.py @@ -11,17 +11,19 @@ from typing import NamedTuple -HttpMethods = NamedTuple('HttpMethods', [ - ('GET', bytes), - ('HEAD', bytes), - ('POST', bytes), - ('PUT', bytes), - ('DELETE', bytes), - ('CONNECT', bytes), - ('OPTIONS', bytes), - ('TRACE', bytes), - ('PATCH', bytes), -]) +HttpMethods = NamedTuple( + 'HttpMethods', [ + ('GET', bytes), + ('HEAD', bytes), + ('POST', bytes), + ('PUT', bytes), + ('DELETE', bytes), + ('CONNECT', bytes), + ('OPTIONS', bytes), + ('TRACE', bytes), + ('PATCH', bytes), + ], +) httpMethods = HttpMethods( b'GET', b'HEAD', diff --git a/proxy/http/parser.py b/proxy/http/parser.py index 368e11ab87..597e641de1 100644 --- a/proxy/http/parser.py +++ b/proxy/http/parser.py @@ -18,20 +18,24 @@ from ..common.utils import build_http_request, build_http_response, find_http_line, text_ -HttpParserStates = NamedTuple('HttpParserStates', [ - ('INITIALIZED', int), - ('LINE_RCVD', int), - ('RCVING_HEADERS', int), - ('HEADERS_COMPLETE', int), - ('RCVING_BODY', int), - ('COMPLETE', int), -]) +HttpParserStates = NamedTuple( + 'HttpParserStates', [ + ('INITIALIZED', int), + ('LINE_RCVD', int), + ('RCVING_HEADERS', int), + ('HEADERS_COMPLETE', int), + ('RCVING_BODY', int), + ('COMPLETE', int), + ], +) httpParserStates = HttpParserStates(1, 2, 3, 4, 5, 6) -HttpParserTypes = NamedTuple('HttpParserTypes', [ - ('REQUEST_PARSER', int), - ('RESPONSE_PARSER', int), -]) +HttpParserTypes = NamedTuple( + 'HttpParserTypes', [ + ('REQUEST_PARSER', int), + ('RESPONSE_PARSER', int), + ], +) httpParserTypes = HttpParserTypes(1, 2) @@ -126,7 +130,8 @@ def set_line_attributes(self) -> None: else: raise KeyError( 'Invalid request. Method: %r, Url: %r' % - (self.method, self.url)) + (self.method, self.url), + ) self.path = self.build_path() def is_chunked_encoded(self) -> bool: @@ -134,8 +139,10 @@ def is_chunked_encoded(self) -> bool: self.headers[b'transfer-encoding'][1].lower() == b'chunked' def body_expected(self) -> bool: - return (b'content-length' in self.headers and - int(self.header(b'content-length')) > 0) or \ + return ( + b'content-length' in self.headers and + int(self.header(b'content-length')) > 0 + ) or \ self.is_chunked_encoded() def parse(self, raw: bytes) -> None: @@ -150,7 +157,8 @@ def parse(self, raw: bytes) -> None: while more and self.state != httpParserStates.COMPLETE: if self.state in ( httpParserStates.HEADERS_COMPLETE, - httpParserStates.RCVING_BODY): + httpParserStates.RCVING_BODY, + ): if b'content-length' in self.headers: self.state = httpParserStates.RCVING_BODY if self.body is None: @@ -172,7 +180,8 @@ def parse(self, raw: bytes) -> None: more = False else: raise NotImplementedError( - 'Parser shouldn\'t have reached here') + 'Parser shouldn\'t have reached here', + ) else: more, raw = self.process(raw) self.buffer = raw @@ -259,9 +268,11 @@ def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = return build_http_request( self.method, path, self.version, - headers={} if not self.headers else {self.headers[k][0]: self.headers[k][1] for k in self.headers if - k.lower() not in disable_headers}, - body=body + headers={} if not self.headers else { + self.headers[k][0]: self.headers[k][1] for k in self.headers if + k.lower() not in disable_headers + }, + body=body, ) def build_response(self) -> bytes: @@ -272,8 +283,10 @@ def build_response(self) -> bytes: protocol_version=self.version, reason=self.reason, headers={} if not self.headers else { - self.headers[k][0]: self.headers[k][1] for k in self.headers}, - body=self.body if not self.is_chunked_encoded() else ChunkParser.to_chunks(self.body)) + self.headers[k][0]: self.headers[k][1] for k in self.headers + }, + body=self.body if not self.is_chunked_encoded() else ChunkParser.to_chunks(self.body), + ) def has_host(self) -> bool: """Host field SHOULD be None for incoming local WebServer requests.""" @@ -281,8 +294,10 @@ def has_host(self) -> bool: def is_http_1_1_keep_alive(self) -> bool: return self.version == HTTP_1_1 and \ - (not self.has_header(b'Connection') or - self.header(b'Connection').lower() == b'keep-alive') + ( + not self.has_header(b'Connection') or + self.header(b'Connection').lower() == b'keep-alive' + ) def is_connection_upgrade(self) -> bool: return self.version == HTTP_1_1 and \ diff --git a/proxy/http/plugin.py b/proxy/http/plugin.py index 16c046ab8b..37566f7266 100644 --- a/proxy/http/plugin.py +++ b/proxy/http/plugin.py @@ -50,7 +50,8 @@ def __init__( flags: argparse.Namespace, client: TcpClientConnection, request: HttpParser, - event_queue: EventQueue): + event_queue: EventQueue, + ): self.uid: UUID = uid self.flags: argparse.Namespace = flags self.client: TcpClientConnection = client @@ -67,7 +68,8 @@ def name(self) -> str: @abstractmethod def get_descriptors( - self) -> Tuple[List[socket.socket], List[socket.socket]]: + self, + ) -> Tuple[List[socket.socket], List[socket.socket]]: """Implementations must return a list of descriptions that they wish to read from and write into.""" return [], [] # pragma: no cover diff --git a/proxy/http/proxy/auth.py b/proxy/http/proxy/auth.py index 24d00f2aa9..10e2997b1a 100644 --- a/proxy/http/proxy/auth.py +++ b/proxy/http/proxy/auth.py @@ -22,14 +22,16 @@ type=str, default=DEFAULT_BASIC_AUTH, help='Default: No authentication. Specify colon separated user:password ' - 'to enable basic authentication.') + 'to enable basic authentication.', +) class AuthPlugin(HttpProxyBasePlugin): """Performs proxy authentication.""" def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: if self.flags.auth_code: if b'proxy-authorization' not in request.headers: raise ProxyAuthenticationFailed() @@ -41,7 +43,8 @@ def before_upstream_connection( return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index aa679c2314..710f50d407 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -32,7 +32,8 @@ def __init__( uid: UUID, flags: argparse.Namespace, client: TcpClientConnection, - event_queue: EventQueue) -> None: + event_queue: EventQueue, + ) -> None: self.uid = uid # pragma: no cover self.flags = flags # pragma: no cover self.client = client # pragma: no cover @@ -57,7 +58,8 @@ def name(self) -> str: # # @abstractmethod def get_descriptors( - self) -> Tuple[List[socket.socket], List[socket.socket]]: + self, + ) -> Tuple[List[socket.socket], List[socket.socket]]: return [], [] # pragma: no cover # @abstractmethod @@ -77,7 +79,8 @@ def read_from_descriptors(self, r: Readables) -> bool: @abstractmethod def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: """Handler called just before Proxy upstream connection is established. Return optionally modified request object. @@ -90,7 +93,8 @@ def before_upstream_connection( # # @abstractmethod def handle_client_data( - self, raw: memoryview) -> Optional[memoryview]: + self, raw: memoryview, + ) -> Optional[memoryview]: """Handler called in special scenarios when an upstream server connection is never established. @@ -104,7 +108,8 @@ def handle_client_data( @abstractmethod def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: """Handler called before dispatching client request to upstream. Note: For pipelined (keep-alive) connections, this handler can be diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 9bd3465290..0792763cc6 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -47,49 +47,50 @@ type=str, default=DEFAULT_CA_KEY_FILE, help='Default: None. CA key to use for signing dynamically generated ' - 'HTTPS certificates. If used, must also pass --ca-cert-file and --ca-signing-key-file' + 'HTTPS certificates. If used, must also pass --ca-cert-file and --ca-signing-key-file', ) flags.add_argument( '--ca-cert-dir', type=str, default=DEFAULT_CA_CERT_DIR, help='Default: ~/.proxy.py. Directory to store dynamically generated certificates. ' - 'Also see --ca-key-file, --ca-cert-file and --ca-signing-key-file' + 'Also see --ca-key-file, --ca-cert-file and --ca-signing-key-file', ) flags.add_argument( '--ca-cert-file', type=str, default=DEFAULT_CA_CERT_FILE, help='Default: None. Signing certificate to use for signing dynamically generated ' - 'HTTPS certificates. If used, must also pass --ca-key-file and --ca-signing-key-file' + 'HTTPS certificates. If used, must also pass --ca-key-file and --ca-signing-key-file', ) flags.add_argument( '--ca-file', type=str, default=DEFAULT_CA_FILE, help='Default: None. Provide path to custom CA file for peer certificate validation. ' - 'Specially useful on MacOS.' + 'Specially useful on MacOS.', ) flags.add_argument( '--ca-signing-key-file', type=str, default=DEFAULT_CA_SIGNING_KEY_FILE, help='Default: None. CA signing key to use for dynamic generation of ' - 'HTTPS certificates. If used, must also pass --ca-key-file and --ca-cert-file' + 'HTTPS certificates. If used, must also pass --ca-key-file and --ca-cert-file', ) flags.add_argument( '--cert-file', type=str, default=DEFAULT_CERT_FILE, help='Default: None. Server certificate to enable end-to-end TLS encryption with clients. ' - 'If used, must also pass --key-file.' + 'If used, must also pass --key-file.', ) flags.add_argument( '--disable-headers', type=str, default=COMMA.join(DEFAULT_DISABLE_HEADERS), help='Default: None. Comma separated list of headers to remove before ' - 'dispatching client request to upstream server.') + 'dispatching client request to upstream server.', +) flags.add_argument( '--server-recvbuf-size', type=int, @@ -97,23 +98,27 @@ help='Default: 1 MB. Maximum amount of data received from the ' 'server in a single recv() operation. Bump this ' 'value for faster downloads at the expense of ' - 'increased RAM.') + 'increased RAM.', +) class HttpProxyPlugin(HttpProtocolHandlerPlugin): """HttpProtocolHandler plugin which implements HttpProxy specifications.""" - PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'Connection established' - )) + PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.OK, + reason=b'Connection established', + ), + ) # Used to synchronization during certificate generation. lock = threading.Lock() def __init__( self, - *args: Any, **kwargs: Any) -> None: + *args: Any, **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.start_time: float = time.time() self.server: Optional[TcpServerConnection] = None @@ -128,7 +133,8 @@ def __init__( self.uid, self.flags, self.client, - self.event_queue) + self.event_queue, + ) self.plugins[instance.name()] = instance def tls_interception_enabled(self) -> bool: @@ -138,7 +144,8 @@ def tls_interception_enabled(self) -> bool: self.flags.ca_cert_file is not None def get_descriptors( - self) -> Tuple[List[socket.socket], List[socket.socket]]: + self, + ) -> Tuple[List[socket.socket], List[socket.socket]]: if not self.request.has_host(): return [], [] @@ -178,15 +185,18 @@ def write_to_descriptors(self, w: Writables) -> bool: self.server.flush() except ssl.SSLWantWriteError: logger.warning( - 'SSLWantWriteError while trying to flush to server, will retry') + 'SSLWantWriteError while trying to flush to server, will retry', + ) return False except BrokenPipeError: logger.error( - 'BrokenPipeError when flushing buffer for server') + 'BrokenPipeError when flushing buffer for server', + ) return True except OSError as e: logger.exception( - 'OSError when flushing buffer to server', exc_info=e) + 'OSError when flushing buffer to server', exc_info=e, + ) return True return False @@ -210,7 +220,8 @@ def read_from_descriptors(self, r: Readables) -> bool: if e.errno == errno.ETIMEDOUT: logger.warning( '%s:%d timed out on recv' % - self.server.addr) + self.server.addr, + ) return True raise e except ssl.SSLWantReadError: # Try again later @@ -220,14 +231,16 @@ def read_from_descriptors(self, r: Readables) -> bool: if e.errno == errno.EHOSTUNREACH: logger.warning( '%s:%d unreachable on recv' % - self.server.addr) + self.server.addr, + ) return True if e.errno == errno.ECONNRESET: logger.warning('Connection reset by upstream: %r' % e) else: logger.exception( 'Exception while receiving from %s connection %r with reason %r' % - (self.server.tag, self.server.connection, e)) + (self.server.tag, self.server.connection, e), + ) return True if raw is None: @@ -315,7 +328,8 @@ def on_client_connection_close(self) -> None: finally: logger.debug( 'Closed server connection, has buffer %s' % - self.server.has_buffer()) + self.server.has_buffer(), + ) def access_log(self, log_attrs: Dict[str, Any]) -> None: access_log_format = DEFAULT_HTTPS_ACCESS_LOG_FORMAT @@ -361,7 +375,8 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # requests is TLS interception is enabled. if self.request.state == httpParserStates.COMPLETE and ( self.request.method != httpMethods.CONNECT or - self.tls_interception_enabled()): + self.tls_interception_enabled() + ): if self.pipeline_request is not None and \ self.pipeline_request.is_connection_upgrade(): # Previous pipelined request was a WebSocket @@ -372,7 +387,8 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: if self.pipeline_request is None: self.pipeline_request = HttpParser( - httpParserTypes.REQUEST_PARSER) + httpParserTypes.REQUEST_PARSER, + ) # TODO(abhinavsingh): Remove .tobytes after parser is # memoryview compliant @@ -389,7 +405,9 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # parser is fully memoryview compliant self.server.queue( memoryview( - self.pipeline_request.build())) + self.pipeline_request.build(), + ), + ) if not self.pipeline_request.is_connection_upgrade(): self.pipeline_request = None # For scenarios where we cannot peek into the data, @@ -438,7 +456,8 @@ def on_request_complete(self) -> Union[socket.socket, bool]: if self.server: if self.request.method == httpMethods.CONNECT: self.client.queue( - HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, + ) if self.tls_interception_enabled(): return self.intercept() # If an upstream server connection was established for http request, @@ -448,7 +467,8 @@ def on_request_complete(self) -> Union[socket.socket, bool]: # officially documented in any specification, drop it. # - proxy-authorization is of no use for upstream, remove it. self.request.del_headers( - [b'proxy-authorization', b'proxy-connection']) + [b'proxy-authorization', b'proxy-connection'], + ) # - For HTTP/1.0, connection header defaults to close # - For HTTP/1.1, connection header defaults to keep-alive # Respect headers sent by client instead of manipulating @@ -457,17 +477,23 @@ def on_request_complete(self) -> Union[socket.socket, bool]: # connection headers are meant for communication between client and # first intercepting proxy. self.request.add_headers( - [(b'Via', b'1.1 %s' % PROXY_AGENT_HEADER_VALUE)]) + [(b'Via', b'1.1 %s' % PROXY_AGENT_HEADER_VALUE)], + ) # Disable args.disable_headers before dispatching to upstream self.server.queue( - memoryview(self.request.build( - disable_headers=self.flags.disable_headers))) + memoryview( + self.request.build( + disable_headers=self.flags.disable_headers, + ), + ), + ) return False def handle_pipeline_response(self, raw: memoryview) -> None: if self.pipeline_response is None: self.pipeline_response = HttpParser( - httpParserTypes.RESPONSE_PARSER) + httpParserTypes.RESPONSE_PARSER, + ) # TODO(abhinavsingh): Remove .tobytes after parser is memoryview # compliant self.pipeline_response.parse(raw.tobytes()) @@ -481,12 +507,14 @@ def connect_upstream(self) -> None: try: logger.debug( 'Connecting to upstream %s:%s' % - (text_(host), port)) + (text_(host), port), + ) self.server.connect() self.server.connection.setblocking(False) logger.debug( 'Connected to upstream %s:%s' % - (text_(host), port)) + (text_(host), port), + ) except Exception as e: # TimeoutError, socket.gaierror self.server.closed = True raise ProxyConnectionFailed(text_(host), port, repr(e)) from e @@ -499,7 +527,8 @@ def connect_upstream(self) -> None: # def gen_ca_signed_certificate( - self, cert_file_path: str, certificate: Dict[str, Any]) -> None: + self, cert_file_path: str, certificate: Dict[str, Any], + ) -> None: '''CA signing key (default) is used for generating a public key for common_name, if one already doesn't exist. Using generated public key a CSR request is generated, which is then signed by @@ -507,12 +536,16 @@ def gen_ca_signed_certificate( certificate doesn't already exist. returns signed certificate path.''' - assert(self.request.host and self.flags.ca_cert_dir and self.flags.ca_signing_key_file and - self.flags.ca_key_file and self.flags.ca_cert_file) + assert( + self.request.host and self.flags.ca_cert_dir and self.flags.ca_signing_key_file and + self.flags.ca_key_file and self.flags.ca_cert_file + ) upstream_subject = {s[0][0]: s[0][1] for s in certificate['subject']} - public_key_path = os.path.join(self.flags.ca_cert_dir, - '{0}.{1}'.format(text_(self.request.host), 'pub')) + public_key_path = os.path.join( + self.flags.ca_cert_dir, + '{0}.{1}'.format(text_(self.request.host), 'pub'), + ) private_key_path = self.flags.ca_signing_key_file private_key_password = '' @@ -528,28 +561,36 @@ def gen_ca_signed_certificate( subject = '' for key in keys: if upstream_subject.get(keys[key], None): - subject += '/{0}={1}'.format(key, - upstream_subject.get(keys[key])) - alt_subj_names = [text_(self.request.host), ] + subject += '/{0}={1}'.format( + key, + upstream_subject.get(keys[key]), + ) + alt_subj_names = [text_(self.request.host)] validity_in_days = 365 * 2 timeout = 10 # Generate a public key for the common name if not os.path.isfile(public_key_path): logger.debug('Generating public key %s', public_key_path) - resp = gen_public_key(public_key_path=public_key_path, private_key_path=private_key_path, - private_key_password=private_key_password, subject=subject, alt_subj_names=alt_subj_names, - validity_in_days=validity_in_days, timeout=timeout) + resp = gen_public_key( + public_key_path=public_key_path, private_key_path=private_key_path, + private_key_password=private_key_password, subject=subject, alt_subj_names=alt_subj_names, + validity_in_days=validity_in_days, timeout=timeout, + ) assert(resp is True) - csr_path = os.path.join(self.flags.ca_cert_dir, - '{0}.{1}'.format(text_(self.request.host), 'csr')) + csr_path = os.path.join( + self.flags.ca_cert_dir, + '{0}.{1}'.format(text_(self.request.host), 'csr'), + ) # Generate a CSR request for this common name if not os.path.isfile(csr_path): logger.debug('Generating CSR %s', csr_path) - resp = gen_csr(csr_path=csr_path, key_path=private_key_path, password=private_key_password, - crt_path=public_key_path, timeout=timeout) + resp = gen_csr( + csr_path=csr_path, key_path=private_key_path, password=private_key_password, + crt_path=public_key_path, timeout=timeout, + ) assert(resp is True) ca_key_path = self.flags.ca_key_file @@ -560,10 +601,12 @@ def gen_ca_signed_certificate( # Sign generated CSR if not os.path.isfile(cert_file_path): logger.debug('Signing CSR %s', cert_file_path) - resp = sign_csr(csr_path=csr_path, crt_path=cert_file_path, ca_key_path=ca_key_path, - ca_key_password=ca_key_password, ca_crt_path=ca_crt_path, - serial=str(serial), alt_subj_names=alt_subj_names, - validity_in_days=validity_in_days, timeout=timeout) + resp = sign_csr( + csr_path=csr_path, crt_path=cert_file_path, ca_key_path=ca_key_path, + ca_key_password=ca_key_password, ca_crt_path=ca_crt_path, + serial=str(serial), alt_subj_names=alt_subj_names, + validity_in_days=validity_in_days, timeout=timeout, + ) assert(resp is True) @staticmethod @@ -571,16 +614,21 @@ def generated_cert_file_path(ca_cert_dir: str, host: str) -> str: return os.path.join(ca_cert_dir, '%s.pem' % host) def generate_upstream_certificate( - self, certificate: Dict[str, Any]) -> str: - if not (self.flags.ca_cert_dir and self.flags.ca_signing_key_file and - self.flags.ca_cert_file and self.flags.ca_key_file): + self, certificate: Dict[str, Any], + ) -> str: + if not ( + self.flags.ca_cert_dir and self.flags.ca_signing_key_file and + self.flags.ca_cert_file and self.flags.ca_key_file + ): raise HttpProtocolException( f'For certificate generation all the following flags are mandatory: ' f'--ca-cert-file:{ self.flags.ca_cert_file }, ' f'--ca-key-file:{ self.flags.ca_key_file }, ' - f'--ca-signing-key-file:{ self.flags.ca_signing_key_file }') + f'--ca-signing-key-file:{ self.flags.ca_signing_key_file }', + ) cert_file_path = HttpProxyPlugin.generated_cert_file_path( - self.flags.ca_cert_dir, text_(self.request.host)) + self.flags.ca_cert_dir, text_(self.request.host), + ) with self.lock: if not os.path.isfile(cert_file_path): self.gen_ca_signed_certificate(cert_file_path, certificate) @@ -596,15 +644,18 @@ def intercept(self) -> Union[socket.socket, bool]: self.wrap_client() except subprocess.TimeoutExpired as e: # Popen communicate timeout logger.exception( - 'TimeoutExpired during certificate generation', exc_info=e) + 'TimeoutExpired during certificate generation', exc_info=e, + ) return True except BrokenPipeError: logger.error( - 'BrokenPipeError when wrapping client') + 'BrokenPipeError when wrapping client', + ) return True except OSError as e: logger.exception( - 'OSError when wrapping client', exc_info=e) + 'OSError when wrapping client', exc_info=e, + ) return True # Update all plugin connection reference # TODO(abhinavsingh): Is this required? @@ -622,10 +673,12 @@ def wrap_client(self) -> None: assert self.server is not None and self.flags.ca_signing_key_file is not None assert isinstance(self.server.connection, ssl.SSLSocket) generated_cert = self.generate_upstream_certificate( - cast(Dict[str, Any], self.server.connection.getpeercert())) + cast(Dict[str, Any], self.server.connection.getpeercert()), + ) self.client.wrap(self.flags.ca_signing_key_file, generated_cert) logger.debug( - 'TLS interception using %s', generated_cert) + 'TLS interception using %s', generated_cert, + ) # # Event emitter callbacks @@ -648,9 +701,9 @@ def emit_request_complete(self) -> None: 'headers': {text_(k): text_(v[1]) for k, v in self.request.headers.items()}, 'body': text_(self.request.body) if self.request.method == httpMethods.POST - else None + else None, }, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) def emit_response_events(self) -> None: diff --git a/proxy/http/server/pac_plugin.py b/proxy/http/server/pac_plugin.py index 20f131ddb9..afcde44f2b 100644 --- a/proxy/http/server/pac_plugin.py +++ b/proxy/http/server/pac_plugin.py @@ -26,13 +26,15 @@ default=DEFAULT_PAC_FILE, help='A file (Proxy Auto Configuration) or string to serve when ' 'the server receives a direct file request. ' - 'Using this option enables proxy.HttpWebServerPlugin.') + 'Using this option enables proxy.HttpWebServerPlugin.', +) flags.add_argument( '--pac-file-url-path', type=str, default=text_(DEFAULT_PAC_FILE_URL_PATH), help='Default: %s. Web server path to serve the PAC file.' % - text_(DEFAULT_PAC_FILE_URL_PATH)) + text_(DEFAULT_PAC_FILE_URL_PATH), +) class HttpWebServerPacFilePlugin(HttpWebServerBasePlugin): @@ -70,9 +72,11 @@ def cache_pac_file_response(self) -> None: content = f.read() except IOError: content = bytes_(self.flags.pac_file) - self.pac_file_response = memoryview(build_http_response( - 200, reason=b'OK', headers={ - b'Content-Type': b'application/x-ns-proxy-autoconfig', - b'Content-Encoding': b'gzip', - }, body=gzip.compress(content) - )) + self.pac_file_response = memoryview( + build_http_response( + 200, reason=b'OK', headers={ + b'Content-Type': b'application/x-ns-proxy-autoconfig', + b'Content-Encoding': b'gzip', + }, body=gzip.compress(content), + ), + ) diff --git a/proxy/http/server/plugin.py b/proxy/http/server/plugin.py index 491219caa0..ea6d540b32 100644 --- a/proxy/http/server/plugin.py +++ b/proxy/http/server/plugin.py @@ -31,7 +31,8 @@ def __init__( uid: UUID, flags: argparse.Namespace, client: TcpClientConnection, - event_queue: EventQueue): + event_queue: EventQueue, + ): self.uid = uid self.flags = flags self.client = client @@ -49,7 +50,8 @@ def __init__( # # @abstractmethod def get_descriptors( - self) -> Tuple[List[socket.socket], List[socket.socket]]: + self, + ) -> Tuple[List[socket.socket], List[socket.socket]]: return [], [] # pragma: no cover # @abstractmethod diff --git a/proxy/http/server/protocols.py b/proxy/http/server/protocols.py index e2a99ae9e2..b0f6202c06 100644 --- a/proxy/http/server/protocols.py +++ b/proxy/http/server/protocols.py @@ -10,9 +10,11 @@ """ from typing import NamedTuple -HttpProtocolTypes = NamedTuple('HttpProtocolTypes', [ - ('HTTP', int), - ('HTTPS', int), - ('WEBSOCKET', int), -]) +HttpProtocolTypes = NamedTuple( + 'HttpProtocolTypes', [ + ('HTTP', int), + ('HTTPS', int), + ('WEBSOCKET', int), + ], +) httpProtocolTypes = HttpProtocolTypes(1, 2, 3) diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index 04abd298ab..f430e426ad 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -39,30 +39,39 @@ default=DEFAULT_STATIC_SERVER_DIR, help='Default: "public" folder in directory where proxy.py is placed. ' 'This option is only applicable when static server is also enabled. ' - 'See --enable-static-server.' + 'See --enable-static-server.', ) class HttpWebServerPlugin(HttpProtocolHandlerPlugin): """HttpProtocolHandler plugin which handles incoming requests to local web server.""" - DEFAULT_404_RESPONSE = memoryview(build_http_response( - httpStatusCodes.NOT_FOUND, - reason=b'NOT FOUND', - headers={b'Server': PROXY_AGENT_HEADER_VALUE, - b'Connection': b'close'} - )) + DEFAULT_404_RESPONSE = memoryview( + build_http_response( + httpStatusCodes.NOT_FOUND, + reason=b'NOT FOUND', + headers={ + b'Server': PROXY_AGENT_HEADER_VALUE, + b'Connection': b'close', + }, + ), + ) - DEFAULT_501_RESPONSE = memoryview(build_http_response( - httpStatusCodes.NOT_IMPLEMENTED, - reason=b'NOT IMPLEMENTED', - headers={b'Server': PROXY_AGENT_HEADER_VALUE, - b'Connection': b'close'} - )) + DEFAULT_501_RESPONSE = memoryview( + build_http_response( + httpStatusCodes.NOT_IMPLEMENTED, + reason=b'NOT IMPLEMENTED', + headers={ + b'Server': PROXY_AGENT_HEADER_VALUE, + b'Connection': b'close', + }, + ), + ) def __init__( self, - *args: Any, **kwargs: Any) -> None: + *args: Any, **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.start_time: float = time.time() self.pipeline_request: Optional[HttpParser] = None @@ -80,7 +89,8 @@ def __init__( self.uid, self.flags, self.client, - self.event_queue) + self.event_queue, + ) for (protocol, route) in instance.routes(): self.routes[protocol][re.compile(route)] = instance @@ -95,16 +105,19 @@ def read_and_build_static_file_response(path: str) -> memoryview: content_type = mimetypes.guess_type(path)[0] if content_type is None: content_type = 'text/plain' - return memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'OK', - headers={ - b'Content-Type': bytes_(content_type), - b'Cache-Control': b'max-age=86400', - b'Content-Encoding': b'gzip', - b'Connection': b'close', - }, - body=gzip.compress(content))) + return memoryview( + build_http_response( + httpStatusCodes.OK, + reason=b'OK', + headers={ + b'Content-Type': bytes_(content_type), + b'Cache-Control': b'max-age=86400', + b'Content-Encoding': b'gzip', + b'Connection': b'close', + }, + body=gzip.compress(content), + ), + ) def serve_file_or_404(self, path: str) -> bool: """Read and serves a file from disk. @@ -114,7 +127,8 @@ def serve_file_or_404(self, path: str) -> bool: """ try: self.client.queue( - self.read_and_build_static_file_response(path)) + self.read_and_build_static_file_response(path), + ) except IOError: self.client.queue(self.DEFAULT_404_RESPONSE) return True @@ -125,9 +139,14 @@ def try_upgrade(self) -> bool: if self.request.has_header(b'upgrade') and \ self.request.header(b'upgrade').lower() == b'websocket': self.client.queue( - memoryview(build_websocket_handshake_response( - WebsocketFrame.key_to_accept( - self.request.header(b'Sec-WebSocket-Key'))))) + memoryview( + build_websocket_handshake_response( + WebsocketFrame.key_to_accept( + self.request.header(b'Sec-WebSocket-Key'), + ), + ), + ), + ) self.switched_protocol = httpProtocolTypes.WEBSOCKET else: self.client.queue(self.DEFAULT_501_RESPONSE) @@ -175,7 +194,8 @@ def on_request_complete(self) -> Union[socket.socket, bool]: path = text_(self.request.path).split('?')[0] if os.path.isfile(self.flags.static_server_dir + path): return self.serve_file_or_404( - self.flags.static_server_dir + path) + self.flags.static_server_dir + path, + ) # Catch all unhandled web server requests, return 404 self.client.queue(self.DEFAULT_404_RESPONSE) @@ -183,7 +203,8 @@ def on_request_complete(self) -> Union[socket.socket, bool]: # TODO(abhinavsingh): Call plugin get/read/write descriptor callbacks def get_descriptors( - self) -> Tuple[List[socket.socket], List[socket.socket]]: + self, + ) -> Tuple[List[socket.socket], List[socket.socket]]: return [], [] def write_to_descriptors(self, w: Writables) -> bool: @@ -203,7 +224,8 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: remaining = frame.parse(remaining) if frame.opcode == websocketOpcodes.CONNECTION_CLOSE: logger.warning( - 'Client sent connection close packet') + 'Client sent connection close packet', + ) raise HttpProtocolException() else: assert self.route @@ -217,7 +239,8 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: self.route is not None: if self.pipeline_request is None: self.pipeline_request = HttpParser( - httpParserTypes.REQUEST_PARSER) + httpParserTypes.REQUEST_PARSER, + ) # TODO(abhinavsingh): Remove .tobytes after parser is memoryview # compliant self.pipeline_request.parse(raw.tobytes()) @@ -225,7 +248,8 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: self.route.handle_request(self.pipeline_request) if not self.pipeline_request.is_http_1_1_keep_alive(): logger.error( - 'Pipelined request is not keep-alive, will teardown request...') + 'Pipelined request is not keep-alive, will teardown request...', + ) raise HttpProtocolException() self.pipeline_request = None return raw @@ -245,8 +269,11 @@ def on_client_connection_close(self) -> None: def access_log(self) -> None: logger.info( '%s:%s - %s %s - %.2f ms' % - (self.client.addr[0], - self.client.addr[1], - text_(self.request.method), - text_(self.request.path), - (time.time() - self.start_time) * 1000)) + ( + self.client.addr[0], + self.client.addr[1], + text_(self.request.method), + text_(self.request.path), + (time.time() - self.start_time) * 1000, + ), + ) diff --git a/proxy/http/websocket/client.py b/proxy/http/websocket/client.py index 716d0faea3..902b891914 100644 --- a/proxy/http/websocket/client.py +++ b/proxy/http/websocket/client.py @@ -27,19 +27,28 @@ class WebsocketClient(TcpConnection): - def __init__(self, - hostname: bytes, - port: int, - path: bytes = b'/', - on_message: Optional[Callable[[WebsocketFrame], None]] = None) -> None: + def __init__( + self, + hostname: bytes, + port: int, + path: bytes = b'/', + on_message: Optional[Callable[[WebsocketFrame], None]] = None, + ) -> None: super().__init__(tcpConnectionTypes.CLIENT) self.hostname: bytes = hostname self.port: int = port self.path: bytes = path self.sock: socket.socket = new_socket_connection( - (socket.gethostbyname(text_(self.hostname)), self.port)) - self.on_message: Optional[Callable[[ - WebsocketFrame], None]] = on_message + (socket.gethostbyname(text_(self.hostname)), self.port), + ) + self.on_message: Optional[ + Callable[ + [ + WebsocketFrame, + ], + None, + ] + ] = on_message self.selector: selectors.DefaultSelector = selectors.DefaultSelector() @property @@ -56,7 +65,9 @@ def upgrade(self) -> None: build_websocket_handshake_request( key, url=self.path, - host=self.hostname)) + host=self.hostname, + ), + ) response = HttpParser(httpParserTypes.RESPONSE_PARSER) response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE)) accept = response.header(b'Sec-Websocket-Accept') diff --git a/proxy/http/websocket/frame.py b/proxy/http/websocket/frame.py index 55f9d91b16..6814de02bf 100644 --- a/proxy/http/websocket/frame.py +++ b/proxy/http/websocket/frame.py @@ -18,14 +18,16 @@ from typing import TypeVar, Type, Optional, NamedTuple -WebsocketOpcodes = NamedTuple('WebsocketOpcodes', [ - ('CONTINUATION_FRAME', int), - ('TEXT_FRAME', int), - ('BINARY_FRAME', int), - ('CONNECTION_CLOSE', int), - ('PING', int), - ('PONG', int), -]) +WebsocketOpcodes = NamedTuple( + 'WebsocketOpcodes', [ + ('CONTINUATION_FRAME', int), + ('TEXT_FRAME', int), + ('BINARY_FRAME', int), + ('CONNECTION_CLOSE', int), + ('PING', int), + ('PONG', int), + ], +) websocketOpcodes = WebsocketOpcodes(0x0, 0x1, 0x2, 0x8, 0x9, 0xA) @@ -91,35 +93,38 @@ def build(self) -> bytes: (1 << 6 if self.rsv1 else 0) | (1 << 5 if self.rsv2 else 0) | (1 << 4 if self.rsv3 else 0) | - self.opcode - )) + self.opcode, + ), + ) assert self.payload_length is not None if self.payload_length < 126: raw.write( struct.pack( '!B', - (1 << 7 if self.masked else 0) | self.payload_length - ) + (1 << 7 if self.masked else 0) | self.payload_length, + ), ) elif self.payload_length < 1 << 16: raw.write( struct.pack( '!BH', (1 << 7 if self.masked else 0) | 126, - self.payload_length - ) + self.payload_length, + ), ) elif self.payload_length < 1 << 64: raw.write( struct.pack( '!BHQ', (1 << 7 if self.masked else 0) | 127, - self.payload_length - ) + self.payload_length, + ), ) else: - raise ValueError(f'Invalid payload_length { self.payload_length },' - f'maximum allowed { 1 << 64 }') + raise ValueError( + f'Invalid payload_length { self.payload_length },' + f'maximum allowed { 1 << 64 }', + ) if self.masked and self.data: mask = secrets.token_bytes(4) if self.mask is None else self.mask raw.write(mask) diff --git a/proxy/plugin/cache/base.py b/proxy/plugin/cache/base.py index 81a2ef65f9..e74d7dd960 100644 --- a/proxy/plugin/cache/base.py +++ b/proxy/plugin/cache/base.py @@ -30,7 +30,8 @@ class BaseCacheResponsesPlugin(HttpProxyBasePlugin): def __init__( self, *args: Any, - **kwargs: Any) -> None: + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.store: Optional[CacheStore] = None @@ -38,7 +39,8 @@ def set_store(self, store: CacheStore) -> None: self.store = store def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: assert self.store try: self.store.open(request) @@ -47,7 +49,8 @@ def before_upstream_connection( return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: assert self.store return self.store.cache_request(request) diff --git a/proxy/plugin/cache/cache_responses.py b/proxy/plugin/cache/cache_responses.py index f6da087e5d..200a60cb39 100644 --- a/proxy/plugin/cache/cache_responses.py +++ b/proxy/plugin/cache/cache_responses.py @@ -24,5 +24,6 @@ class CacheResponsesPlugin(BaseCacheResponsesPlugin): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.disk_store = OnDiskCacheStore( - uid=self.uid, cache_dir=self.flags.cache_dir) + uid=self.uid, cache_dir=self.flags.cache_dir, + ) self.set_store(self.disk_store) diff --git a/proxy/plugin/cache/store/disk.py b/proxy/plugin/cache/store/disk.py index 1f472a12c5..f50565f866 100644 --- a/proxy/plugin/cache/store/disk.py +++ b/proxy/plugin/cache/store/disk.py @@ -27,7 +27,7 @@ '--cache-dir', type=str, default=tempfile.gettempdir(), - help='Default: A temporary directory. Flag only applicable when cache plugin is used with on-disk storage.' + help='Default: A temporary directory. Flag only applicable when cache plugin is used with on-disk storage.', ) @@ -42,7 +42,8 @@ def __init__(self, uid: UUID, cache_dir: str) -> None: def open(self, request: HttpParser) -> None: self.cache_file_path = os.path.join( self.cache_dir, - '%s-%s.txt' % (text_(request.host), self.uid.hex)) + '%s-%s.txt' % (text_(request.host), self.uid.hex), + ) self.cache_file = open(self.cache_file_path, "wb") def cache_request(self, request: HttpParser) -> Optional[HttpParser]: diff --git a/proxy/plugin/filter_by_client_ip.py b/proxy/plugin/filter_by_client_ip.py index 95169b4884..679679ab70 100644 --- a/proxy/plugin/filter_by_client_ip.py +++ b/proxy/plugin/filter_by_client_ip.py @@ -21,7 +21,7 @@ '--filtered-client-ips', type=str, default='127.0.0.1,::1', - help='Default: 127.0.0.1,::1. Comma separated list of IPv4 and IPv6 addresses.' + help='Default: 127.0.0.1,::1. Comma separated list of IPv4 and IPv6 addresses.', ) @@ -29,18 +29,20 @@ class FilterByClientIpPlugin(HttpProxyBasePlugin): """Drop traffic by inspecting incoming client IP address.""" def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: if self.client.addr[0] in self.flags.filtered_client_ips.split(','): raise HttpRequestRejected( status_code=httpStatusCodes.I_AM_A_TEAPOT, reason=b'I\'m a tea pot', headers={ b'Connection': b'close', - } + }, ) return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: diff --git a/proxy/plugin/filter_by_upstream.py b/proxy/plugin/filter_by_upstream.py index 7feabccc3f..18c4cb4623 100644 --- a/proxy/plugin/filter_by_upstream.py +++ b/proxy/plugin/filter_by_upstream.py @@ -22,7 +22,7 @@ '--filtered-upstream-hosts', type=str, default='facebook.com,www.facebook.com', - help='Default: Blocks Facebook. Comma separated list of IPv4 and IPv6 addresses.' + help='Default: Blocks Facebook. Comma separated list of IPv4 and IPv6 addresses.', ) @@ -30,19 +30,21 @@ class FilterByUpstreamHostPlugin(HttpProxyBasePlugin): """Drop traffic by inspecting upstream host.""" def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: print(self.flags.filtered_upstream_hosts) if text_(request.host) in self.flags.filtered_upstream_hosts.split(','): raise HttpRequestRejected( status_code=httpStatusCodes.I_AM_A_TEAPOT, reason=b'I\'m a tea pot', headers={ b'Connection': b'close', - } + }, ) return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: diff --git a/proxy/plugin/filter_by_url_regex.py b/proxy/plugin/filter_by_url_regex.py index 43a4074799..c9e9407111 100644 --- a/proxy/plugin/filter_by_url_regex.py +++ b/proxy/plugin/filter_by_url_regex.py @@ -29,7 +29,7 @@ '--filtered-url-regex-config', type=str, default='', - help='Default: No config. Comma separated list of IPv4 and IPv6 addresses.' + help='Default: No config. Comma separated list of IPv4 and IPv6 addresses.', ) @@ -48,11 +48,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.filters = json.load(f) def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: # determine host request_host = None if request.host: @@ -68,7 +70,7 @@ def handle_client_request( # build URL url = b'%s%s' % ( request_host, - request.path + request.path, ) # check URL against list rule_number = 1 @@ -76,11 +78,13 @@ def handle_client_request( # if regex matches on URL if re.search(text_(blocked_entry['regex']), text_(url)): # log that the request has been filtered - logger.info("Blocked: %r with status_code '%r' by rule number '%r'" % ( - text_(url), - httpStatusCodes.NOT_FOUND, - rule_number, - )) + logger.info( + "Blocked: %r with status_code '%r' by rule number '%r'" % ( + text_(url), + httpStatusCodes.NOT_FOUND, + rule_number, + ), + ) # close the connection with the status code from the filter # list raise HttpRequestRejected( diff --git a/proxy/plugin/man_in_the_middle.py b/proxy/plugin/man_in_the_middle.py index cc3ab63e7e..d7723fae94 100644 --- a/proxy/plugin/man_in_the_middle.py +++ b/proxy/plugin/man_in_the_middle.py @@ -20,17 +20,23 @@ class ManInTheMiddlePlugin(HttpProxyBasePlugin): """Modifies upstream server responses.""" def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - return memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'OK', body=b'Hello from man in the middle')) + return memoryview( + build_http_response( + httpStatusCodes.OK, + reason=b'OK', + body=b'Hello from man in the middle', + ), + ) def on_upstream_connection_close(self) -> None: pass diff --git a/proxy/plugin/mock_rest_api.py b/proxy/plugin/mock_rest_api.py index 270c864408..9768422133 100644 --- a/proxy/plugin/mock_rest_api.py +++ b/proxy/plugin/mock_rest_api.py @@ -50,35 +50,48 @@ class ProposedRestApiPlugin(HttpProxyBasePlugin): 'url': text_(API_SERVER) + '/v1/users/2/', 'username': 'someone', }, - ] + ], }, } def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: # Return None to disable establishing connection to upstream # Most likely our api.example.com won't even exist under development # scenario return None def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: if request.host != self.API_SERVER: return request assert request.path if request.path in self.REST_API_SPEC: - self.client.queue(memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'OK', - headers={b'Content-Type': b'application/json'}, - body=bytes_(json.dumps( - self.REST_API_SPEC[request.path])) - ))) + self.client.queue( + memoryview( + build_http_response( + httpStatusCodes.OK, + reason=b'OK', + headers={b'Content-Type': b'application/json'}, + body=bytes_( + json.dumps( + self.REST_API_SPEC[request.path], + ), + ), + ), + ), + ) else: - self.client.queue(memoryview(build_http_response( - httpStatusCodes.NOT_FOUND, - reason=b'NOT FOUND', body=b'Not Found' - ))) + self.client.queue( + memoryview( + build_http_response( + httpStatusCodes.NOT_FOUND, + reason=b'NOT FOUND', body=b'Not Found', + ), + ), + ) return None def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: diff --git a/proxy/plugin/modify_chunk_response.py b/proxy/plugin/modify_chunk_response.py index 707da5de5f..b3533784d9 100644 --- a/proxy/plugin/modify_chunk_response.py +++ b/proxy/plugin/modify_chunk_response.py @@ -30,11 +30,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.response = HttpParser(httpParserTypes.RESPONSE_PARSER) def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: diff --git a/proxy/plugin/modify_post_data.py b/proxy/plugin/modify_post_data.py index 98b89daf5c..5a8db81a0d 100644 --- a/proxy/plugin/modify_post_data.py +++ b/proxy/plugin/modify_post_data.py @@ -22,18 +22,22 @@ class ModifyPostDataPlugin(HttpProxyBasePlugin): MODIFIED_BODY = b'{"key": "modified"}' def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: if request.method == httpMethods.POST: request.body = ModifyPostDataPlugin.MODIFIED_BODY # Update Content-Length header only when request is NOT chunked # encoded if not request.is_chunked_encoded(): - request.add_header(b'Content-Length', - bytes_(len(request.body))) + request.add_header( + b'Content-Length', + bytes_(len(request.body)), + ) # Enforce content-type json if request.has_header(b'Content-Type'): request.del_header(b'Content-Type') diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index c0a969e018..17877487a5 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -54,7 +54,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.upstream: Optional[TcpServerConnection] = None # Cached attributes to be used during access log override self.request_host_port_path_method: List[Any] = [ - None, None, None, None] + None, None, None, None, + ] self.total_size = 0 def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]: @@ -88,7 +89,8 @@ def write_to_descriptors(self, w: Writables) -> bool: return False def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: """Avoids establishing the default connection to upstream server by returning None. """ @@ -100,7 +102,8 @@ def before_upstream_connection( endpoint = random.choice(self.UPSTREAM_PROXY_POOL) logger.debug('Using endpoint: {0}:{1}'.format(*endpoint)) self.upstream = TcpServerConnection( - endpoint[0], endpoint[1]) + endpoint[0], endpoint[1], + ) try: self.upstream.connect() except ConnectionRefusedError: @@ -112,14 +115,17 @@ def before_upstream_connection( # using a datastructure without having to spawn separate thread/process for health # check. logger.info( - 'Connection refused by upstream proxy {0}:{1}'.format(*endpoint)) + 'Connection refused by upstream proxy {0}:{1}'.format(*endpoint), + ) raise HttpProtocolException() logger.debug( - 'Established connection to upstream proxy {0}:{1}'.format(*endpoint)) + 'Established connection to upstream proxy {0}:{1}'.format(*endpoint), + ) return None def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: """Only invoked once after client original proxy request has been received completely.""" assert self.upstream # For log sanity (i.e. to avoid None:None), expose upstream host:port from headers @@ -137,7 +143,8 @@ def handle_client_request( port = '443' if request.method == httpMethods.CONNECT else '80' path = None if not request.path else request.path.decode() self.request_host_port_path_method = [ - host, port, path, request.method] + host, port, path, request.method, + ] # Queue original request to upstream proxy self.upstream.queue(memoryview(request.build(for_proxy=True))) return request @@ -158,7 +165,8 @@ def on_upstream_connection_close(self) -> None: def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: addr, port = ( - self.upstream.addr[0], self.upstream.addr[1]) if self.upstream else (None, None) + self.upstream.addr[0], self.upstream.addr[1], + ) if self.upstream else (None, None) context.update({ 'upstream_proxy_host': addr, 'upstream_proxy_port': port, diff --git a/proxy/plugin/redirect_to_custom_server.py b/proxy/plugin/redirect_to_custom_server.py index d2118e5fea..75d38359eb 100644 --- a/proxy/plugin/redirect_to_custom_server.py +++ b/proxy/plugin/redirect_to_custom_server.py @@ -22,7 +22,8 @@ class RedirectToCustomServerPlugin(HttpProxyBasePlugin): UPSTREAM_SERVER = b'http://localhost:8899/' def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: # Redirect all non-https requests to inbuilt WebServer. if request.method != httpMethods.CONNECT: request.set_url(self.UPSTREAM_SERVER) @@ -31,11 +32,14 @@ def before_upstream_connection( request.del_header(b'Host') request.add_header( b'Host', urlparse.urlsplit( - self.UPSTREAM_SERVER).netloc) + self.UPSTREAM_SERVER, + ).netloc, + ) return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: return request def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: diff --git a/proxy/plugin/reverse_proxy.py b/proxy/plugin/reverse_proxy.py index 1a80d1a474..3095282d9c 100644 --- a/proxy/plugin/reverse_proxy.py +++ b/proxy/plugin/reverse_proxy.py @@ -45,13 +45,13 @@ class ReverseProxyPlugin(HttpWebServerBasePlugin): REVERSE_PROXY_LOCATION: str = r'/get$' REVERSE_PROXY_PASS = [ - b'http://httpbin.org/get' + b'http://httpbin.org/get', ] def routes(self) -> List[Tuple[int, str]]: return [ (httpProtocolTypes.HTTP, ReverseProxyPlugin.REVERSE_PROXY_LOCATION), - (httpProtocolTypes.HTTPS, ReverseProxyPlugin.REVERSE_PROXY_LOCATION) + (httpProtocolTypes.HTTPS, ReverseProxyPlugin.REVERSE_PROXY_LOCATION), ] # TODO(abhinavsingh): Upgrade to use non-blocking get/read/write API. diff --git a/proxy/plugin/shortlink.py b/proxy/plugin/shortlink.py index 309fc1fbc2..43554dcaaa 100644 --- a/proxy/plugin/shortlink.py +++ b/proxy/plugin/shortlink.py @@ -47,33 +47,43 @@ class ShortLinkPlugin(HttpProxyBasePlugin): } def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: if request.host and request.host != b'localhost' and DOT not in request.host: # Avoid connecting to upstream return None return request def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: if request.host and request.host != b'localhost' and DOT not in request.host: if request.host in self.SHORT_LINKS: path = SLASH if not request.path else request.path - self.client.queue(memoryview(build_http_response( - httpStatusCodes.SEE_OTHER, reason=b'See Other', - headers={ - b'Location': b'http://' + self.SHORT_LINKS[request.host] + path, - b'Content-Length': b'0', - b'Connection': b'close', - } - ))) + self.client.queue( + memoryview( + build_http_response( + httpStatusCodes.SEE_OTHER, reason=b'See Other', + headers={ + b'Location': b'http://' + self.SHORT_LINKS[request.host] + path, + b'Content-Length': b'0', + b'Connection': b'close', + }, + ), + ), + ) else: - self.client.queue(memoryview(build_http_response( - httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND', - headers={ - b'Content-Length': b'0', - b'Connection': b'close', - } - ))) + self.client.queue( + memoryview( + build_http_response( + httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND', + headers={ + b'Content-Length': b'0', + b'Connection': b'close', + }, + ), + ), + ) return None return request diff --git a/proxy/plugin/web_server_route.py b/proxy/plugin/web_server_route.py index c8b4731a44..00163c253f 100644 --- a/proxy/plugin/web_server_route.py +++ b/proxy/plugin/web_server_route.py @@ -32,11 +32,21 @@ def routes(self) -> List[Tuple[int, str]]: def handle_request(self, request: HttpParser) -> None: if request.path == b'/http-route-example': - self.client.queue(memoryview(build_http_response( - httpStatusCodes.OK, body=b'HTTP route response'))) + self.client.queue( + memoryview( + build_http_response( + httpStatusCodes.OK, body=b'HTTP route response', + ), + ), + ) elif request.path == b'/https-route-example': - self.client.queue(memoryview(build_http_response( - httpStatusCodes.OK, body=b'HTTPS route response'))) + self.client.queue( + memoryview( + build_http_response( + httpStatusCodes.OK, body=b'HTTPS route response', + ), + ), + ) def on_websocket_open(self) -> None: logger.info('Websocket open') diff --git a/proxy/proxy.py b/proxy/proxy.py index 950bdf793f..5b18d34b03 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -56,29 +56,32 @@ '--pid-file', type=str, default=DEFAULT_PID_FILE, - help='Default: None. Save parent process ID to a file.') + help='Default: None. Save parent process ID to a file.', +) flags.add_argument( '--version', '-v', action='store_true', default=DEFAULT_VERSION, - help='Prints proxy.py version.') + help='Prints proxy.py version.', +) flags.add_argument( '--disable-http-proxy', action='store_true', default=DEFAULT_DISABLE_HTTP_PROXY, - help='Default: False. Whether to disable proxy.HttpProxyPlugin.') + help='Default: False. Whether to disable proxy.HttpProxyPlugin.', +) flags.add_argument( '--enable-dashboard', action='store_true', default=DEFAULT_ENABLE_DASHBOARD, - help='Default: False. Enables proxy.py dashboard.' + help='Default: False. Enables proxy.py dashboard.', ) flags.add_argument( '--enable-devtools', action='store_true', default=DEFAULT_ENABLE_DEVTOOLS, - help='Default: False. Enables integration with Chrome Devtool Frontend. Also see --devtools-ws-path.' + help='Default: False. Enables integration with Chrome Devtool Frontend. Also see --devtools-ws-path.', ) flags.add_argument( '--enable-static-server', @@ -87,19 +90,20 @@ help='Default: False. Enable inbuilt static file server. ' 'Optionally, also use --static-server-dir to serve static content ' 'from custom directory. By default, static file server serves ' - 'out of installed proxy.py python module folder.' + 'out of installed proxy.py python module folder.', ) flags.add_argument( '--enable-web-server', action='store_true', default=DEFAULT_ENABLE_WEB_SERVER, - help='Default: False. Whether to enable proxy.HttpWebServerPlugin.') + help='Default: False. Whether to enable proxy.HttpWebServerPlugin.', +) flags.add_argument( '--enable-events', action='store_true', default=DEFAULT_ENABLE_EVENTS, help='Default: False. Enables core to dispatch lifecycle events. ' - 'Plugins can be used to subscribe for core events.' + 'Plugins can be used to subscribe for core events.', ) flags.add_argument( '--log-level', @@ -107,28 +111,33 @@ default=DEFAULT_LOG_LEVEL, help='Valid options: DEBUG, INFO (default), WARNING, ERROR, CRITICAL. ' 'Both upper and lowercase values are allowed. ' - 'You may also simply use the leading character e.g. --log-level d') + 'You may also simply use the leading character e.g. --log-level d', +) flags.add_argument( '--log-file', type=str, default=DEFAULT_LOG_FILE, - help='Default: sys.stdout. Log file destination.') + help='Default: sys.stdout. Log file destination.', +) flags.add_argument( '--log-format', type=str, default=DEFAULT_LOG_FORMAT, - help='Log format for Python logger.') + help='Log format for Python logger.', +) flags.add_argument( '--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT, help='Default: 1024. Maximum number of files (TCP connections) ' - 'that proxy.py can open concurrently.') + 'that proxy.py can open concurrently.', +) flags.add_argument( '--plugins', type=str, default=DEFAULT_PLUGINS, - help='Comma separated plugins') + help='Comma separated plugins', +) class Proxy: @@ -170,7 +179,7 @@ def __enter__(self) -> 'Proxy': self.pool = AcceptorPool( flags=self.flags, work_klass=self.work_klass, - event_queue=self.event_manager.event_queue if self.event_manager is not None else None + event_queue=self.event_manager.event_queue if self.event_manager is not None else None, ) self.pool.setup() self.write_pid_file() @@ -180,7 +189,8 @@ def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + exc_tb: Optional[TracebackType], + ) -> None: assert self.pool self.pool.shutdown() if self.flags.enable_events: @@ -189,8 +199,10 @@ def __exit__( self.delete_pid_file() @staticmethod - def initialize(input_args: Optional[List[str]] - = None, **opts: Any) -> argparse.Namespace: + def initialize( + input_args: Optional[List[str]] + = None, **opts: Any, + ) -> argparse.Namespace: if input_args is None: input_args = [] @@ -224,8 +236,10 @@ def initialize(input_args: Optional[List[str]] # Load default plugins along with user provided --plugins plugins = Proxy.load_plugins( [bytes_(p) for p in collections.OrderedDict(default_plugins).keys()] + - [p if isinstance(p, type) else bytes_(p) for p in opts.get( - 'plugins', args.plugins.split(text_(COMMA)))] + [ + p if isinstance(p, type) else bytes_(p) + for p in opts.get('plugins', args.plugins.split(text_(COMMA))) + ], ) # proxy.py currently cannot serve over HTTPS and also perform TLS interception @@ -233,8 +247,10 @@ def initialize(input_args: Optional[List[str]] # at the same time. if (args.cert_file and args.key_file) and \ (args.ca_key_file and args.ca_cert_file and args.ca_signing_key_file): - print('You can either enable end-to-end encryption OR TLS interception,' - 'not both together.') + print( + 'You can either enable end-to-end encryption OR TLS interception,' + 'not both together.', + ) sys.exit(1) # Generate auth_code required for basic authentication if enabled @@ -252,83 +268,124 @@ def initialize(input_args: Optional[List[str]] Optional[bytes], opts.get( 'auth_code', - auth_code)) + auth_code, + ), + ) args.server_recvbuf_size = cast( int, opts.get( 'server_recvbuf_size', - args.server_recvbuf_size)) + args.server_recvbuf_size, + ), + ) args.client_recvbuf_size = cast( int, opts.get( 'client_recvbuf_size', - args.client_recvbuf_size)) + args.client_recvbuf_size, + ), + ) args.pac_file = cast( Optional[str], opts.get( 'pac_file', bytes_( - args.pac_file))) + args.pac_file, + ), + ), + ) args.pac_file_url_path = cast( Optional[bytes], opts.get( 'pac_file_url_path', bytes_( - args.pac_file_url_path))) - disabled_headers = cast(Optional[List[bytes]], opts.get('disable_headers', [ - header.lower() for header in bytes_( - args.disable_headers).split(COMMA) if header.strip() != b''])) + args.pac_file_url_path, + ), + ), + ) + disabled_headers = cast( + Optional[List[bytes]], opts.get( + 'disable_headers', [ + header.lower() + for header in bytes_(args.disable_headers).split(COMMA) + if header.strip() != b'' + ], + ), + ) args.disable_headers = disabled_headers if disabled_headers is not None else DEFAULT_DISABLE_HEADERS args.certfile = cast( Optional[str], opts.get( - 'cert_file', args.cert_file)) + 'cert_file', args.cert_file, + ), + ) args.keyfile = cast(Optional[str], opts.get('key_file', args.key_file)) args.ca_key_file = cast( Optional[str], opts.get( - 'ca_key_file', args.ca_key_file)) + 'ca_key_file', args.ca_key_file, + ), + ) args.ca_cert_file = cast( Optional[str], opts.get( - 'ca_cert_file', args.ca_cert_file)) + 'ca_cert_file', args.ca_cert_file, + ), + ) args.ca_signing_key_file = cast( Optional[str], opts.get( 'ca_signing_key_file', - args.ca_signing_key_file)) + args.ca_signing_key_file, + ), + ) args.ca_file = cast( Optional[str], opts.get( 'ca_file', - args.ca_file)) - args.hostname = cast(IpAddress, - opts.get('hostname', ipaddress.ip_address(args.hostname))) + args.ca_file, + ), + ) + args.hostname = cast( + IpAddress, + opts.get('hostname', ipaddress.ip_address(args.hostname)), + ) args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET args.port = cast(int, opts.get('port', args.port)) args.backlog = cast(int, opts.get('backlog', args.backlog)) num_workers = opts.get('num_workers', args.num_workers) num_workers = num_workers if num_workers is not None else DEFAULT_NUM_WORKERS args.num_workers = cast( - int, num_workers if num_workers > 0 else multiprocessing.cpu_count()) + int, num_workers if num_workers > 0 else multiprocessing.cpu_count(), + ) args.static_server_dir = cast( str, opts.get( 'static_server_dir', - args.static_server_dir)) + args.static_server_dir, + ), + ) args.enable_static_server = cast( bool, opts.get( 'enable_static_server', - args.enable_static_server)) + args.enable_static_server, + ), + ) args.devtools_ws_path = cast( bytes, opts.get( 'devtools_ws_path', - getattr(args, 'devtools_ws_path', DEFAULT_DEVTOOLS_WS_PATH))) + getattr(args, 'devtools_ws_path', DEFAULT_DEVTOOLS_WS_PATH), + ), + ) args.timeout = cast(int, opts.get('timeout', args.timeout)) args.threadless = cast(bool, opts.get('threadless', args.threadless)) args.enable_events = cast( bool, opts.get( 'enable_events', - args.enable_events)) + args.enable_events, + ), + ) args.pid_file = cast( Optional[str], opts.get( - 'pid_file', args.pid_file)) + 'pid_file', args.pid_file, + ), + ) args.proxy_py_data_dir = DEFAULT_DATA_DIRECTORY_PATH os.makedirs(args.proxy_py_data_dir, exist_ok=True) @@ -337,21 +394,23 @@ def initialize(input_args: Optional[List[str]] args.ca_cert_dir = cast(Optional[str], ca_cert_dir) if args.ca_cert_dir is None: args.ca_cert_dir = os.path.join( - args.proxy_py_data_dir, 'certificates') + args.proxy_py_data_dir, 'certificates', + ) os.makedirs(args.ca_cert_dir, exist_ok=True) return args @staticmethod - def load_plugins(plugins: List[Union[bytes, type]] - ) -> Dict[bytes, List[type]]: + def load_plugins( + plugins: List[Union[bytes, type]], + ) -> Dict[bytes, List[type]]: """Accepts a comma separated list of Python modules and returns a list of respective Python classes.""" p: Dict[bytes, List[type]] = { b'HttpProtocolHandlerPlugin': [], b'HttpProxyBasePlugin': [], b'HttpWebServerBasePlugin': [], - b'ProxyDashboardWebsocketPlugin': [] + b'ProxyDashboardWebsocketPlugin': [], } for plugin_ in plugins: klass, module_name = Proxy.import_plugin(plugin_) @@ -381,13 +440,17 @@ def import_plugin(plugin: Union[bytes, type]) -> Any: klass = getattr( importlib.import_module( module_name.replace( - os.path.sep, text_(DOT))), - klass_name) + os.path.sep, text_(DOT), + ), + ), + klass_name, + ) return (klass, module_name) @staticmethod def get_default_plugins( - args: argparse.Namespace) -> List[Tuple[str, bool]]: + args: argparse.Namespace, + ) -> List[Tuple[str, bool]]: # Prepare list of plugins to load based upon # --enable-*, --disable-* and --basic-auth flags. default_plugins: List[Tuple[str, bool]] = [] @@ -423,18 +486,22 @@ def set_open_file_limit(soft_limit: int) -> None: """Configure open file description soft limit on supported OS.""" if os.name != 'nt': # resource module not available on Windows OS curr_soft_limit, curr_hard_limit = resource.getrlimit( - resource.RLIMIT_NOFILE) + resource.RLIMIT_NOFILE, + ) if curr_soft_limit < soft_limit < curr_hard_limit: resource.setrlimit( - resource.RLIMIT_NOFILE, (soft_limit, curr_hard_limit)) + resource.RLIMIT_NOFILE, (soft_limit, curr_hard_limit), + ) logger.debug( - 'Open file soft limit set to %d', soft_limit) + 'Open file soft limit set to %d', soft_limit, + ) @contextlib.contextmanager def start( input_args: Optional[List[str]] = None, - **opts: Any) -> Generator[Proxy, None, None]: + **opts: Any, +) -> Generator[Proxy, None, None]: """Deprecated. Kept for backward compatibility. New users must directly use proxy.Proxy context manager class.""" @@ -447,12 +514,15 @@ def start( def main( input_args: Optional[List[str]] = None, - **opts: Any) -> None: + **opts: Any, +) -> None: try: with Proxy(input_args=input_args, **opts) as proxy: assert proxy.pool is not None - logger.info('Listening on %s:%d' % - (proxy.pool.flags.hostname, proxy.pool.flags.port)) + logger.info( + 'Listening on %s:%d' % + (proxy.pool.flags.hostname, proxy.pool.flags.port), + ) # TODO: Introduce cron feature # https://github.com/abhinavsingh/proxy.py/issues/392 # diff --git a/proxy/testing/test_case.py b/proxy/testing/test_case.py index 55f4b952fd..573ffa4697 100644 --- a/proxy/testing/test_case.py +++ b/proxy/testing/test_case.py @@ -45,20 +45,24 @@ def setUpClass(cls) -> None: cls.PROXY = Proxy(input_args=cls.INPUT_ARGS) cls.PROXY.flags.plugins[b'HttpProxyBasePlugin'].append( - CacheResponsesPlugin) + CacheResponsesPlugin, + ) cls.PROXY.__enter__() cls.wait_for_server(cls.PROXY_PORT) @staticmethod - def wait_for_server(proxy_port: int, - wait_for_seconds: int = DEFAULT_TIMEOUT) -> None: + def wait_for_server( + proxy_port: int, + wait_for_seconds: int = DEFAULT_TIMEOUT, + ) -> None: """Wait for proxy.py server to come up.""" start_time = time.time() while True: try: conn = new_socket_connection( - ('localhost', proxy_port)) + ('localhost', proxy_port), + ) conn.close() break except ConnectionRefusedError: @@ -66,7 +70,8 @@ def wait_for_server(proxy_port: int, if time.time() - start_time > wait_for_seconds: raise TimeoutError( - 'Timed out while waiting for proxy.py to start...') + 'Timed out while waiting for proxy.py to start...', + ) @classmethod def tearDownClass(cls) -> None: diff --git a/tests/common/test_flags.py b/tests/common/test_flags.py index c36f8e6355..a11c9ecee3 100644 --- a/tests/common/test_flags.py +++ b/tests/common/test_flags.py @@ -27,23 +27,30 @@ def assert_plugins(self, expected: Dict[str, List[type]]) -> None: for p in expected[k]: self.assertIn(p, self.flags.plugins[k.encode()]) self.assertEqual( - len([o for o in self.flags.plugins[k.encode()] if o == p]), 1) + len([o for o in self.flags.plugins[k.encode()] if o == p]), 1, + ) def test_load_plugin_from_bytes(self) -> None: - self.flags = Proxy.initialize([], plugins=[ - b'proxy.plugin.CacheResponsesPlugin', - ]) + self.flags = Proxy.initialize( + [], plugins=[ + b'proxy.plugin.CacheResponsesPlugin', + ], + ) self.assert_plugins({'HttpProxyBasePlugin': [CacheResponsesPlugin]}) def test_load_plugins_from_bytes(self) -> None: - self.flags = Proxy.initialize([], plugins=[ - b'proxy.plugin.CacheResponsesPlugin', - b'proxy.plugin.FilterByUpstreamHostPlugin', - ]) - self.assert_plugins({'HttpProxyBasePlugin': [ - CacheResponsesPlugin, - FilterByUpstreamHostPlugin, - ]}) + self.flags = Proxy.initialize( + [], plugins=[ + b'proxy.plugin.CacheResponsesPlugin', + b'proxy.plugin.FilterByUpstreamHostPlugin', + ], + ) + self.assert_plugins({ + 'HttpProxyBasePlugin': [ + CacheResponsesPlugin, + FilterByUpstreamHostPlugin, + ], + }) def test_load_plugin_from_args(self) -> None: self.flags = Proxy.initialize([ @@ -55,60 +62,82 @@ def test_load_plugins_from_args(self) -> None: self.flags = Proxy.initialize([ '--plugins', 'proxy.plugin.CacheResponsesPlugin,proxy.plugin.FilterByUpstreamHostPlugin', ]) - self.assert_plugins({'HttpProxyBasePlugin': [ - CacheResponsesPlugin, - FilterByUpstreamHostPlugin, - ]}) + self.assert_plugins({ + 'HttpProxyBasePlugin': [ + CacheResponsesPlugin, + FilterByUpstreamHostPlugin, + ], + }) def test_load_plugin_from_class(self) -> None: - self.flags = Proxy.initialize([], plugins=[ - CacheResponsesPlugin, - ]) + self.flags = Proxy.initialize( + [], plugins=[ + CacheResponsesPlugin, + ], + ) self.assert_plugins({'HttpProxyBasePlugin': [CacheResponsesPlugin]}) def test_load_plugins_from_class(self) -> None: - self.flags = Proxy.initialize([], plugins=[ - CacheResponsesPlugin, - FilterByUpstreamHostPlugin, - ]) - self.assert_plugins({'HttpProxyBasePlugin': [ - CacheResponsesPlugin, - FilterByUpstreamHostPlugin, - ]}) + self.flags = Proxy.initialize( + [], plugins=[ + CacheResponsesPlugin, + FilterByUpstreamHostPlugin, + ], + ) + self.assert_plugins({ + 'HttpProxyBasePlugin': [ + CacheResponsesPlugin, + FilterByUpstreamHostPlugin, + ], + }) def test_load_plugins_from_bytes_and_class(self) -> None: - self.flags = Proxy.initialize([], plugins=[ - CacheResponsesPlugin, - b'proxy.plugin.FilterByUpstreamHostPlugin', - ]) - self.assert_plugins({'HttpProxyBasePlugin': [ - CacheResponsesPlugin, - FilterByUpstreamHostPlugin, - ]}) + self.flags = Proxy.initialize( + [], plugins=[ + CacheResponsesPlugin, + b'proxy.plugin.FilterByUpstreamHostPlugin', + ], + ) + self.assert_plugins({ + 'HttpProxyBasePlugin': [ + CacheResponsesPlugin, + FilterByUpstreamHostPlugin, + ], + }) def test_unique_plugin_from_bytes(self) -> None: - self.flags = Proxy.initialize([], plugins=[ - bytes_(PLUGIN_HTTP_PROXY), - ]) - self.assert_plugins({'HttpProtocolHandlerPlugin': [ - HttpProxyPlugin, - ]}) + self.flags = Proxy.initialize( + [], plugins=[ + bytes_(PLUGIN_HTTP_PROXY), + ], + ) + self.assert_plugins({ + 'HttpProtocolHandlerPlugin': [ + HttpProxyPlugin, + ], + }) def test_unique_plugin_from_args(self) -> None: self.flags = Proxy.initialize([ '--plugins', PLUGIN_HTTP_PROXY, ]) - self.assert_plugins({'HttpProtocolHandlerPlugin': [ - HttpProxyPlugin, - ]}) + self.assert_plugins({ + 'HttpProtocolHandlerPlugin': [ + HttpProxyPlugin, + ], + }) def test_unique_plugin_from_class(self) -> None: - self.flags = Proxy.initialize([], plugins=[ - HttpProxyPlugin, - ]) - self.assert_plugins({'HttpProtocolHandlerPlugin': [ - HttpProxyPlugin, - ]}) + self.flags = Proxy.initialize( + [], plugins=[ + HttpProxyPlugin, + ], + ) + self.assert_plugins({ + 'HttpProtocolHandlerPlugin': [ + HttpProxyPlugin, + ], + }) if __name__ == '__main__': diff --git a/tests/common/test_pki.py b/tests/common/test_pki.py index e55c063795..76dd723837 100644 --- a/tests/common/test_pki.py +++ b/tests/common/test_pki.py @@ -28,7 +28,8 @@ def test_run_openssl_command(self, mock_popen: mock.Mock) -> None: mock_popen.assert_called_with( command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + stderr=subprocess.PIPE, + ) def test_get_ext_config(self) -> None: self.assertEqual(pki.get_ext_config(None, None), b'') @@ -36,17 +37,25 @@ def test_get_ext_config(self) -> None: self.assertEqual( pki.get_ext_config( ['proxy.py'], - None), - b'\nsubjectAltName=DNS:proxy.py') + None, + ), + b'\nsubjectAltName=DNS:proxy.py', + ) self.assertEqual( pki.get_ext_config( None, - 'serverAuth'), - b'\nextendedKeyUsage=serverAuth') - self.assertEqual(pki.get_ext_config(['proxy.py'], 'serverAuth'), - b'\nsubjectAltName=DNS:proxy.py\nextendedKeyUsage=serverAuth') - self.assertEqual(pki.get_ext_config(['proxy.py', 'www.proxy.py'], 'serverAuth'), - b'\nsubjectAltName=DNS:proxy.py,DNS:www.proxy.py\nextendedKeyUsage=serverAuth') + 'serverAuth', + ), + b'\nextendedKeyUsage=serverAuth', + ) + self.assertEqual( + pki.get_ext_config(['proxy.py'], 'serverAuth'), + b'\nsubjectAltName=DNS:proxy.py\nextendedKeyUsage=serverAuth', + ) + self.assertEqual( + pki.get_ext_config(['proxy.py', 'www.proxy.py'], 'serverAuth'), + b'\nsubjectAltName=DNS:proxy.py,DNS:www.proxy.py\nextendedKeyUsage=serverAuth', + ) def test_ssl_config_no_ext(self) -> None: with pki.ssl_config() as (config_path, has_extension): @@ -61,7 +70,8 @@ def test_ssl_config(self) -> None: self.assertEqual( config.read(), pki.DEFAULT_CONFIG + - b'\n[PROXY]\nsubjectAltName=DNS:proxy.py') + b'\n[PROXY]\nsubjectAltName=DNS:proxy.py', + ) def test_extfile_no_ext(self) -> None: with pki.ext_file() as config_path: @@ -73,7 +83,8 @@ def test_extfile(self) -> None: with open(config_path, 'rb') as config: self.assertEqual( config.read(), - b'\nsubjectAltName=DNS:proxy.py') + b'\nsubjectAltName=DNS:proxy.py', + ) def test_gen_private_key(self) -> None: key_path, nopass_key_path = self._gen_private_key() @@ -114,7 +125,8 @@ def _gen_private_key(self) -> Tuple[str, str]: key_path = os.path.join(tempfile.gettempdir(), 'test_gen_private.key') nopass_key_path = os.path.join( tempfile.gettempdir(), - 'test_gen_private_nopass.key') + 'test_gen_private_nopass.key', + ) pki.gen_private_key(key_path, 'password') pki.remove_passphrase(key_path, 'password', nopass_key_path) return (key_path, nopass_key_path) diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index b75b2b569c..655ec74ef5 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -37,7 +37,8 @@ def test_new_socket_connection_ipv6(self, mock_socket: mock.Mock) -> None: mock_socket.assert_called_with(socket.AF_INET6, socket.SOCK_STREAM, 0) self.assertEqual(conn, mock_socket.return_value) mock_socket.return_value.connect.assert_called_with( - (self.addr_ipv6[0], self.addr_ipv6[1], 0, 0)) + (self.addr_ipv6[0], self.addr_ipv6[1], 0, 0), + ) @mock.patch('socket.create_connection') def test_new_socket_connection_dual(self, mock_socket: mock.Mock) -> None: @@ -54,6 +55,7 @@ def dummy(conn: socket.socket) -> None: @mock.patch('proxy.common.utils.new_socket_connection') def test_context_manager( - self, mock_new_socket_connection: mock.Mock) -> None: + self, mock_new_socket_connection: mock.Mock, + ) -> None: with socket_connection(self.addr_ipv4) as conn: self.assertEqual(conn, mock_new_socket_connection.return_value) diff --git a/tests/core/test_acceptor.py b/tests/core/test_acceptor.py index aae3de5caf..ceb863d2ba 100644 --- a/tests/core/test_acceptor.py +++ b/tests/core/test_acceptor.py @@ -30,7 +30,8 @@ def setUp(self) -> None: work_queue=self.pipe[1], flags=self.flags, lock=multiprocessing.Lock(), - work_klass=self.mock_protocol_handler) + work_klass=self.mock_protocol_handler, + ) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @@ -39,7 +40,8 @@ def test_continues_when_no_events( self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + ) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -66,7 +68,8 @@ def test_accepts_client_from_server_socket( mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_thread: mock.Mock, - mock_client: mock.Mock) -> None: + mock_client: mock.Mock, + ) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -87,7 +90,7 @@ def test_accepts_client_from_server_socket( mock_fromfd.assert_called_with( fileno, family=socket.AF_INET6, - type=socket.SOCK_STREAM + type=socket.SOCK_STREAM, ) self.mock_protocol_handler.assert_called_with( mock_client.return_value, @@ -95,6 +98,7 @@ def test_accepts_client_from_server_socket( event_queue=None, ) mock_thread.assert_called_with( - target=self.mock_protocol_handler.return_value.run) + target=self.mock_protocol_handler.return_value.run, + ) mock_thread.return_value.start.assert_called() sock.close.assert_called() diff --git a/tests/core/test_acceptor_pool.py b/tests/core/test_acceptor_pool.py index 6b007ea9fc..c7f46bf4ac 100644 --- a/tests/core/test_acceptor_pool.py +++ b/tests/core/test_acceptor_pool.py @@ -27,7 +27,8 @@ def test_setup_and_shutdown( mock_acceptor: mock.Mock, mock_socket: mock.Mock, mock_pipe: mock.Mock, - mock_send_handle: mock.Mock) -> None: + mock_send_handle: mock.Mock, + ) -> None: acceptor1 = mock.MagicMock() acceptor2 = mock.MagicMock() mock_acceptor.side_effect = [acceptor1, acceptor2] @@ -44,12 +45,14 @@ def test_setup_and_shutdown( work_klass.assert_not_called() mock_socket.assert_called_with( socket.AF_INET6 if pool.flags.hostname.version == 6 else socket.AF_INET, - socket.SOCK_STREAM + socket.SOCK_STREAM, ) sock.setsockopt.assert_called_with( - socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1, + ) sock.bind.assert_called_with( - (str(pool.flags.hostname), 8899)) + (str(pool.flags.hostname), 8899), + ) sock.listen.assert_called_with(pool.flags.backlog) sock.setblocking.assert_called_with(False) diff --git a/tests/core/test_connection.py b/tests/core/test_connection.py index 3cd63ad129..905ab56d2b 100644 --- a/tests/core/test_connection.py +++ b/tests/core/test_connection.py @@ -22,8 +22,10 @@ class TestTcpConnection(unittest.TestCase): class TcpConnectionToTest(TcpConnection): - def __init__(self, conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None, - tag: int = tcpConnectionTypes.CLIENT) -> None: + def __init__( + self, conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None, + tag: int = tcpConnectionTypes.CLIENT, + ) -> None: super().__init__(tag) self._conn = conn @@ -65,48 +67,60 @@ def testFlushReturnsIfNoBuffer(self) -> None: @mock.patch('socket.socket') def testTcpServerEstablishesIPv6Connection( - self, mock_socket: mock.Mock) -> None: + self, mock_socket: mock.Mock, + ) -> None: conn = TcpServerConnection( - str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT) + str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT, + ) conn.connect() mock_socket.assert_called() mock_socket.return_value.connect.assert_called_with( - (str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT, 0, 0)) + (str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT, 0, 0), + ) @mock.patch('proxy.core.connection.server.new_socket_connection') def testTcpServerIgnoresDoubleConnectSilently( self, - mock_new_socket_connection: mock.Mock) -> None: + mock_new_socket_connection: mock.Mock, + ) -> None: conn = TcpServerConnection( - str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT) + str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT, + ) conn.connect() conn.connect() mock_new_socket_connection.assert_called_once() @mock.patch('socket.socket') def testTcpServerEstablishesIPv4Connection( - self, mock_socket: mock.Mock) -> None: + self, mock_socket: mock.Mock, + ) -> None: conn = TcpServerConnection( - str(DEFAULT_IPV4_HOSTNAME), DEFAULT_PORT) + str(DEFAULT_IPV4_HOSTNAME), DEFAULT_PORT, + ) conn.connect() mock_socket.assert_called() mock_socket.return_value.connect.assert_called_with( - (str(DEFAULT_IPV4_HOSTNAME), DEFAULT_PORT)) + (str(DEFAULT_IPV4_HOSTNAME), DEFAULT_PORT), + ) @mock.patch('proxy.core.connection.server.new_socket_connection') def testTcpServerConnectionProperty( self, - mock_new_socket_connection: mock.Mock) -> None: + mock_new_socket_connection: mock.Mock, + ) -> None: conn = TcpServerConnection( - str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT) + str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT, + ) conn.connect() self.assertEqual( conn.connection, - mock_new_socket_connection.return_value) + mock_new_socket_connection.return_value, + ) def testTcpServerRaisesTcpConnectionUninitializedException(self) -> None: conn = TcpServerConnection( - str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT) + str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT, + ) with self.assertRaises(TcpConnectionUninitializedException): _ = conn.connection diff --git a/tests/core/test_event_dispatcher.py b/tests/core/test_event_dispatcher.py index bb17a709ba..eb890b3b0c 100644 --- a/tests/core/test_event_dispatcher.py +++ b/tests/core/test_event_dispatcher.py @@ -27,7 +27,8 @@ def setUp(self) -> None: self.event_queue = EventQueue(multiprocessing.Manager().Queue()) self.dispatcher = EventDispatcher( shutdown=self.dispatcher_shutdown, - event_queue=self.event_queue) + event_queue=self.event_queue, + ) def tearDown(self) -> None: self.dispatcher_shutdown.set() @@ -37,7 +38,7 @@ def test_empties_queue(self) -> None: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) self.dispatcher.run_once() with self.assertRaises(queue.Empty): @@ -53,18 +54,20 @@ def subscribe(self, mock_time: mock.Mock) -> DictQueueType: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) self.dispatcher.run_once() - self.assertEqual(q.get(), { - 'request_id': '1234', - 'process_id': os.getpid(), - 'thread_id': threading.get_ident(), - 'event_timestamp': 1234567, - 'event_name': eventNames.WORK_STARTED, - 'event_payload': {'hello': 'events'}, - 'publisher_id': self.__class__.__name__, - }) + self.assertEqual( + q.get(), { + 'request_id': '1234', + 'process_id': os.getpid(), + 'thread_id': threading.get_ident(), + 'event_timestamp': 1234567, + 'event_name': eventNames.WORK_STARTED, + 'event_payload': {'hello': 'events'}, + 'publisher_id': self.__class__.__name__, + }, + ) return q def test_subscribe(self) -> None: @@ -78,7 +81,7 @@ def test_unsubscribe(self) -> None: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) self.dispatcher.run_once() with self.assertRaises(queue.Empty): diff --git a/tests/core/test_event_queue.py b/tests/core/test_event_queue.py index 18f7527163..1955f918e7 100644 --- a/tests/core/test_event_queue.py +++ b/tests/core/test_event_queue.py @@ -30,17 +30,19 @@ def test_publish(self, mock_time: mock.Mock) -> None: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, + ) + self.assertEqual( + evq.queue.get(), { + 'request_id': '1234', + 'process_id': os.getpid(), + 'thread_id': threading.get_ident(), + 'event_timestamp': 1234567, + 'event_name': eventNames.WORK_STARTED, + 'event_payload': {'hello': 'events'}, + 'publisher_id': self.__class__.__name__, + }, ) - self.assertEqual(evq.queue.get(), { - 'request_id': '1234', - 'process_id': os.getpid(), - 'thread_id': threading.get_ident(), - 'event_timestamp': 1234567, - 'event_name': eventNames.WORK_STARTED, - 'event_payload': {'hello': 'events'}, - 'publisher_id': self.__class__.__name__, - }) def test_subscribe(self) -> None: evq = EventQueue(MANAGER.Queue()) diff --git a/tests/core/test_event_subscriber.py b/tests/core/test_event_subscriber.py index 30e67b39df..8c440a8866 100644 --- a/tests/core/test_event_subscriber.py +++ b/tests/core/test_event_subscriber.py @@ -30,7 +30,8 @@ def test_event_subscriber(self, mock_time: mock.Mock) -> None: self.event_queue = EventQueue(multiprocessing.Manager().Queue()) self.dispatcher = EventDispatcher( shutdown=self.dispatcher_shutdown, - event_queue=self.event_queue) + event_queue=self.event_queue, + ) self.subscriber = EventSubscriber(self.event_queue) self.subscriber.subscribe(self.callback) @@ -40,7 +41,7 @@ def test_event_subscriber(self, mock_time: mock.Mock) -> None: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) self.dispatcher.run_once() @@ -49,12 +50,14 @@ def test_event_subscriber(self, mock_time: mock.Mock) -> None: self.dispatcher_shutdown.set() def callback(self, ev: Dict[str, Any]) -> None: - self.assertEqual(ev, { - 'request_id': '1234', - 'process_id': os.getpid(), - 'thread_id': PUBLISHER_ID, - 'event_timestamp': 1234567, - 'event_name': eventNames.WORK_STARTED, - 'event_payload': {'hello': 'events'}, - 'publisher_id': self.__class__.__name__, - }) + self.assertEqual( + ev, { + 'request_id': '1234', + 'process_id': os.getpid(), + 'thread_id': PUBLISHER_ID, + 'event_timestamp': 1234567, + 'event_name': eventNames.WORK_STARTED, + 'event_payload': {'hello': 'events'}, + 'publisher_id': self.__class__.__name__, + }, + ) diff --git a/tests/http/exceptions/test_http_proxy_auth_failed.py b/tests/http/exceptions/test_http_proxy_auth_failed.py index dc37b3b613..3faff4d504 100644 --- a/tests/http/exceptions/test_http_proxy_auth_failed.py +++ b/tests/http/exceptions/test_http_proxy_auth_failed.py @@ -23,9 +23,11 @@ class TestHttpProxyAuthFailed(unittest.TestCase): @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + def setUp( + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + ) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector @@ -35,7 +37,8 @@ def setUp(self, self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=self.flags) + flags=self.flags, + ) self.protocol_handler.initialize() @mock.patch('proxy.http.proxy.server.TcpServerConnection') @@ -43,20 +46,27 @@ def test_proxy_auth_fails_without_cred(self, mock_server_conn: mock.Mock) -> Non self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ - b'Host': b'upstream.host' - }) + b'Host': b'upstream.host', + }, + ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_not_called() self.assertEqual(self.protocol_handler.client.has_buffer(), True) self.assertEqual( - self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT) + self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, + ) self._conn.send.assert_not_called() @mock.patch('proxy.http.proxy.server.TcpServerConnection') @@ -66,19 +76,26 @@ def test_proxy_auth_fails_with_invalid_cred(self, mock_server_conn: mock.Mock) - headers={ b'Host': b'upstream.host', b'Proxy-Authorization': b'Basic hello', - }) + }, + ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_not_called() self.assertEqual(self.protocol_handler.client.has_buffer(), True) self.assertEqual( - self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT) + self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, + ) self._conn.send.assert_not_called() @mock.patch('proxy.http.proxy.server.TcpServerConnection') @@ -88,13 +105,19 @@ def test_proxy_auth_works_with_valid_cred(self, mock_server_conn: mock.Mock) -> headers={ b'Host': b'upstream.host', b'Proxy-Authorization': b'Basic dXNlcjpwYXNz', - }) + }, + ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_called_once() @@ -107,13 +130,19 @@ def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: m headers={ b'Host': b'upstream.host', b'Proxy-Authorization': b'bAsIc dXNlcjpwYXNz', - }) + }, + ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_called_once() diff --git a/tests/http/exceptions/test_http_request_rejected.py b/tests/http/exceptions/test_http_request_rejected.py index 59eac81c3b..457b8dc84d 100644 --- a/tests/http/exceptions/test_http_request_rejected.py +++ b/tests/http/exceptions/test_http_request_rejected.py @@ -28,15 +28,19 @@ def test_empty_response(self) -> None: def test_status_code_response(self) -> None: e = HttpRequestRejected(status_code=200, reason=b'OK') - self.assertEqual(e.response(self.request), CRLF.join([ - b'HTTP/1.1 200 OK', - CRLF - ])) + self.assertEqual( + e.response(self.request), CRLF.join([ + b'HTTP/1.1 200 OK', + CRLF, + ]), + ) def test_body_response(self) -> None: e = HttpRequestRejected( status_code=httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND', - body=b'Nothing here') + body=b'Nothing here', + ) self.assertEqual( e.response(self.request), - build_http_response(httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND', body=b'Nothing here')) + build_http_response(httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND', body=b'Nothing here'), + ) diff --git a/tests/http/test_chunk_parser.py b/tests/http/test_chunk_parser.py index 94b71afb6d..eb47c264d1 100644 --- a/tests/http/test_chunk_parser.py +++ b/tests/http/test_chunk_parser.py @@ -19,16 +19,18 @@ def setUp(self) -> None: self.parser = ChunkParser() def test_chunk_parse_basic(self) -> None: - self.parser.parse(b''.join([ - b'4\r\n', - b'Wiki\r\n', - b'5\r\n', - b'pedia\r\n', - b'E\r\n', - b' in\r\n\r\nchunks.\r\n', - b'0\r\n', - b'\r\n' - ])) + self.parser.parse( + b''.join([ + b'4\r\n', + b'Wiki\r\n', + b'5\r\n', + b'pedia\r\n', + b'E\r\n', + b' in\r\n\r\nchunks.\r\n', + b'0\r\n', + b'\r\n', + ]), + ) self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'Wikipedia in\r\n\r\nchunks.') @@ -42,42 +44,48 @@ def test_chunk_parse_issue_27(self) -> None: self.assertEqual(self.parser.body, b'') self.assertEqual( self.parser.state, - chunkParserStates.WAITING_FOR_SIZE) + chunkParserStates.WAITING_FOR_SIZE, + ) self.parser.parse(b'\r\n') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, 3) self.assertEqual(self.parser.body, b'') self.assertEqual( self.parser.state, - chunkParserStates.WAITING_FOR_DATA) + chunkParserStates.WAITING_FOR_DATA, + ) self.parser.parse(b'abc') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abc') self.assertEqual( self.parser.state, - chunkParserStates.WAITING_FOR_SIZE) + chunkParserStates.WAITING_FOR_SIZE, + ) self.parser.parse(b'\r\n') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abc') self.assertEqual( self.parser.state, - chunkParserStates.WAITING_FOR_SIZE) + chunkParserStates.WAITING_FOR_SIZE, + ) self.parser.parse(b'4\r\n') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, 4) self.assertEqual(self.parser.body, b'abc') self.assertEqual( self.parser.state, - chunkParserStates.WAITING_FOR_DATA) + chunkParserStates.WAITING_FOR_DATA, + ) self.parser.parse(b'defg\r\n0') self.assertEqual(self.parser.chunk, b'0') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abcdefg') self.assertEqual( self.parser.state, - chunkParserStates.WAITING_FOR_SIZE) + chunkParserStates.WAITING_FOR_SIZE, + ) self.parser.parse(b'\r\n\r\n') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) @@ -87,4 +95,5 @@ def test_chunk_parse_issue_27(self) -> None: def test_to_chunks(self) -> None: self.assertEqual( b'f\r\n{"key":"value"}\r\n0\r\n\r\n', - ChunkParser.to_chunks(b'{"key":"value"}')) + ChunkParser.to_chunks(b'{"key":"value"}'), + ) diff --git a/tests/http/test_http_parser.py b/tests/http/test_http_parser.py index f037966b75..ded49c8283 100644 --- a/tests/http/test_http_parser.py +++ b/tests/http/test_http_parser.py @@ -30,63 +30,81 @@ def test_urlparse(self) -> None: def test_build_request(self) -> None: self.assertEqual( build_http_request( - b'GET', b'http://localhost:12345', b'HTTP/1.1'), + b'GET', b'http://localhost:12345', b'HTTP/1.1', + ), CRLF.join([ b'GET http://localhost:12345 HTTP/1.1', - CRLF - ])) + CRLF, + ]), + ) self.assertEqual( - build_http_request(b'GET', b'http://localhost:12345', b'HTTP/1.1', - headers={b'key': b'value'}), + build_http_request( + b'GET', b'http://localhost:12345', b'HTTP/1.1', + headers={b'key': b'value'}, + ), CRLF.join([ b'GET http://localhost:12345 HTTP/1.1', b'key: value', - CRLF - ])) + CRLF, + ]), + ) self.assertEqual( - build_http_request(b'GET', b'http://localhost:12345', b'HTTP/1.1', - headers={b'key': b'value'}, - body=b'Hello from py'), + build_http_request( + b'GET', b'http://localhost:12345', b'HTTP/1.1', + headers={b'key': b'value'}, + body=b'Hello from py', + ), CRLF.join([ b'GET http://localhost:12345 HTTP/1.1', b'key: value', - CRLF - ]) + b'Hello from py') + CRLF, + ]) + b'Hello from py', + ) def test_build_response(self) -> None: self.assertEqual( build_http_response( - 200, reason=b'OK', protocol_version=b'HTTP/1.1'), + 200, reason=b'OK', protocol_version=b'HTTP/1.1', + ), CRLF.join([ b'HTTP/1.1 200 OK', - CRLF - ])) + CRLF, + ]), + ) self.assertEqual( - build_http_response(200, reason=b'OK', protocol_version=b'HTTP/1.1', - headers={b'key': b'value'}), + build_http_response( + 200, reason=b'OK', protocol_version=b'HTTP/1.1', + headers={b'key': b'value'}, + ), CRLF.join([ b'HTTP/1.1 200 OK', b'key: value', - CRLF - ])) + CRLF, + ]), + ) def test_build_response_adds_content_length_header(self) -> None: body = b'Hello world!!!' self.assertEqual( - build_http_response(200, reason=b'OK', protocol_version=b'HTTP/1.1', - headers={b'key': b'value'}, - body=body), + build_http_response( + 200, reason=b'OK', protocol_version=b'HTTP/1.1', + headers={b'key': b'value'}, + body=body, + ), CRLF.join([ b'HTTP/1.1 200 OK', b'key: value', b'Content-Length: ' + bytes_(len(body)), - CRLF - ]) + body) + CRLF, + ]) + body, + ) def test_build_header(self) -> None: self.assertEqual( build_http_header( - b'key', b'value'), b'key: value') + b'key', b'value', + ), b'key: value', + ) def test_header_raises(self) -> None: with self.assertRaises(KeyError): @@ -104,15 +122,22 @@ def test_set_host_port_raises(self) -> None: def test_find_line(self) -> None: self.assertEqual( find_http_line( - b'CONNECT python.org:443 HTTP/1.0\r\n\r\n'), - (b'CONNECT python.org:443 HTTP/1.0', - CRLF)) + b'CONNECT python.org:443 HTTP/1.0\r\n\r\n', + ), + ( + b'CONNECT python.org:443 HTTP/1.0', + CRLF, + ), + ) def test_find_line_returns_None(self) -> None: self.assertEqual( find_http_line(b'CONNECT python.org:443 HTTP/1.0'), - (None, - b'CONNECT python.org:443 HTTP/1.0')) + ( + None, + b'CONNECT python.org:443 HTTP/1.0', + ), + ) def test_connect_request_with_crlf_as_separate_chunk(self) -> None: """See https://github.com/abhinavsingh/py/issues/70 for background.""" @@ -126,10 +151,12 @@ def test_get_full_parse(self) -> None: raw = CRLF.join([ b'GET %s HTTP/1.1', b'Host: %s', - CRLF + CRLF, ]) - pkt = raw % (b'https://example.com/path/dir/?a=b&c=d#p=q', - b'example.com') + pkt = raw % ( + b'https://example.com/path/dir/?a=b&c=d#p=q', + b'example.com', + ) self.parser.parse(pkt) self.assertEqual(self.parser.total_size, len(pkt)) self.assertEqual(self.parser.build_path(), b'/path/dir/?a=b&c=d#p=q') @@ -140,14 +167,18 @@ def test_get_full_parse(self) -> None: self.assertEqual(self.parser.version, b'HTTP/1.1') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) self.assertEqual( - self.parser.headers[b'host'], (b'Host', b'example.com')) + self.parser.headers[b'host'], (b'Host', b'example.com'), + ) self.parser.del_headers([b'host']) self.parser.add_headers([(b'Host', b'example.com')]) self.assertEqual( raw % - (b'/path/dir/?a=b&c=d#p=q', - b'example.com'), - self.parser.build()) + ( + b'/path/dir/?a=b&c=d#p=q', + b'example.com', + ), + self.parser.build(), + ) def test_build_url_none(self) -> None: self.assertEqual(self.parser.build_path(), b'/None') @@ -159,11 +190,12 @@ def test_line_rcvd_to_rcving_headers_state_change(self) -> None: self.assert_state_change_with_crlf( httpParserStates.INITIALIZED, httpParserStates.LINE_RCVD, - httpParserStates.COMPLETE) + httpParserStates.COMPLETE, + ) def test_get_partial_parse1(self) -> None: pkt = CRLF.join([ - b'GET http://localhost:8080 HTTP/1.1' + b'GET http://localhost:8080 HTTP/1.1', ]) self.parser.parse(pkt) self.assertEqual(self.parser.total_size, len(pkt)) @@ -172,7 +204,8 @@ def test_get_partial_parse1(self) -> None: self.assertEqual(self.parser.version, None) self.assertEqual( self.parser.state, - httpParserStates.INITIALIZED) + httpParserStates.INITIALIZED, + ) self.parser.parse(CRLF) self.assertEqual(self.parser.total_size, len(pkt) + len(CRLF)) @@ -185,26 +218,35 @@ def test_get_partial_parse1(self) -> None: host_hdr = b'Host: localhost:8080' self.parser.parse(host_hdr) - self.assertEqual(self.parser.total_size, - len(pkt) + len(CRLF) + len(host_hdr)) + self.assertEqual( + self.parser.total_size, + len(pkt) + len(CRLF) + len(host_hdr), + ) self.assertDictEqual(self.parser.headers, {}) self.assertEqual(self.parser.buffer, b'Host: localhost:8080') self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) self.parser.parse(CRLF * 2) - self.assertEqual(self.parser.total_size, len(pkt) + - (3 * len(CRLF)) + len(host_hdr)) + self.assertEqual( + self.parser.total_size, len(pkt) + + (3 * len(CRLF)) + len(host_hdr), + ) self.assertEqual( self.parser.headers[b'host'], - (b'Host', - b'localhost:8080')) + ( + b'Host', + b'localhost:8080', + ), + ) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_get_partial_parse2(self) -> None: - self.parser.parse(CRLF.join([ - b'GET http://localhost:8080 HTTP/1.1', - b'Host: ' - ])) + self.parser.parse( + CRLF.join([ + b'GET http://localhost:8080 HTTP/1.1', + b'Host: ', + ]), + ) self.assertEqual(self.parser.method, b'GET') assert self.parser.url self.assertEqual(self.parser.url.hostname, b'localhost') @@ -216,20 +258,26 @@ def test_get_partial_parse2(self) -> None: self.parser.parse(b'localhost:8080' + CRLF) self.assertEqual( self.parser.headers[b'host'], - (b'Host', - b'localhost:8080')) + ( + b'Host', + b'localhost:8080', + ), + ) self.assertEqual(self.parser.buffer, b'') self.assertEqual( self.parser.state, - httpParserStates.RCVING_HEADERS) + httpParserStates.RCVING_HEADERS, + ) self.parser.parse(b'Content-Type: text/plain' + CRLF) self.assertEqual(self.parser.buffer, b'') self.assertEqual( - self.parser.headers[b'content-type'], (b'Content-Type', b'text/plain')) + self.parser.headers[b'content-type'], (b'Content-Type', b'text/plain'), + ) self.assertEqual( self.parser.state, - httpParserStates.RCVING_HEADERS) + httpParserStates.RCVING_HEADERS, + ) self.parser.parse(CRLF) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) @@ -240,7 +288,7 @@ def test_post_full_parse(self) -> None: b'Host: localhost', b'Content-Length: 7', b'Content-Type: application/x-www-form-urlencoded' + CRLF, - b'a=b&c=d' + b'a=b&c=d', ]) self.parser.parse(raw % b'http://localhost') self.assertEqual(self.parser.method, b'POST') @@ -248,19 +296,25 @@ def test_post_full_parse(self) -> None: self.assertEqual(self.parser.url.hostname, b'localhost') self.assertEqual(self.parser.url.port, None) self.assertEqual(self.parser.version, b'HTTP/1.1') - self.assertEqual(self.parser.headers[b'content-type'], - (b'Content-Type', b'application/x-www-form-urlencoded')) - self.assertEqual(self.parser.headers[b'content-length'], - (b'Content-Length', b'7')) + self.assertEqual( + self.parser.headers[b'content-type'], + (b'Content-Type', b'application/x-www-form-urlencoded'), + ) + self.assertEqual( + self.parser.headers[b'content-length'], + (b'Content-Length', b'7'), + ) self.assertEqual(self.parser.body, b'a=b&c=d') self.assertEqual(self.parser.buffer, b'') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) self.assertEqual(len(self.parser.build()), len(raw % b'/')) - def assert_state_change_with_crlf(self, - initial_state: int, - next_state: int, - final_state: int) -> None: + def assert_state_change_with_crlf( + self, + initial_state: int, + next_state: int, + final_state: int, + ) -> None: self.assertEqual(self.parser.state, initial_state) self.parser.parse(CRLF) self.assertEqual(self.parser.state, next_state) @@ -268,12 +322,14 @@ def assert_state_change_with_crlf(self, self.assertEqual(self.parser.state, final_state) def test_post_partial_parse(self) -> None: - self.parser.parse(CRLF.join([ - b'POST http://localhost HTTP/1.1', - b'Host: localhost', - b'Content-Length: 7', - b'Content-Type: application/x-www-form-urlencoded' - ])) + self.parser.parse( + CRLF.join([ + b'POST http://localhost HTTP/1.1', + b'Host: localhost', + b'Content-Length: 7', + b'Content-Type: application/x-www-form-urlencoded', + ]), + ) self.assertEqual(self.parser.method, b'POST') assert self.parser.url self.assertEqual(self.parser.url.hostname, b'localhost') @@ -282,12 +338,14 @@ def test_post_partial_parse(self) -> None: self.assert_state_change_with_crlf( httpParserStates.RCVING_HEADERS, httpParserStates.RCVING_HEADERS, - httpParserStates.HEADERS_COMPLETE) + httpParserStates.HEADERS_COMPLETE, + ) self.parser.parse(b'a=b') self.assertEqual( self.parser.state, - httpParserStates.RCVING_BODY) + httpParserStates.RCVING_BODY, + ) self.assertEqual(self.parser.body, b'a=b') self.assertEqual(self.parser.buffer, b'') @@ -321,12 +379,14 @@ def test_request_parse_without_content_length(self) -> None: See https://github.com/abhinavsingh/py/issues/20 for details. """ - self.parser.parse(CRLF.join([ - b'POST http://localhost HTTP/1.1', - b'Host: localhost', - b'Content-Type: application/x-www-form-urlencoded', - CRLF - ])) + self.parser.parse( + CRLF.join([ + b'POST http://localhost HTTP/1.1', + b'Host: localhost', + b'Content-Type: application/x-www-form-urlencoded', + CRLF, + ]), + ) self.assertEqual(self.parser.method, b'POST') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) @@ -350,33 +410,38 @@ def test_response_parse_without_content_length(self) -> None: self.assertEqual(self.parser.code, b'200') self.assertEqual(self.parser.version, b'HTTP/1.0') self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) - self.parser.parse(CRLF.join([ - b'Server: BaseHTTP/0.3 Python/2.7.10', - b'Date: Thu, 13 Dec 2018 16:24:09 GMT', - CRLF - ])) + self.parser.parse( + CRLF.join([ + b'Server: BaseHTTP/0.3 Python/2.7.10', + b'Date: Thu, 13 Dec 2018 16:24:09 GMT', + CRLF, + ]), + ) self.assertEqual( self.parser.state, - httpParserStates.COMPLETE) + httpParserStates.COMPLETE, + ) def test_response_parse(self) -> None: self.parser.type = httpParserTypes.RESPONSE_PARSER - self.parser.parse(b''.join([ - b'HTTP/1.1 301 Moved Permanently\r\n', - b'Location: http://www.google.com/\r\n', - b'Content-Type: text/html; charset=UTF-8\r\n', - b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', - b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', - b'Cache-Control: public, max-age=2592000\r\n', - b'Server: gws\r\n', - b'Content-Length: 219\r\n', - b'X-XSS-Protection: 1; mode=block\r\n', - b'X-Frame-Options: SAMEORIGIN\r\n\r\n', - b'\n' + - b'301 Moved', - b'\n

301 Moved

\nThe document has moved\n' + - b'here.\r\n\r\n' - ])) + self.parser.parse( + b''.join([ + b'HTTP/1.1 301 Moved Permanently\r\n', + b'Location: http://www.google.com/\r\n', + b'Content-Type: text/html; charset=UTF-8\r\n', + b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', + b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', + b'Cache-Control: public, max-age=2592000\r\n', + b'Server: gws\r\n', + b'Content-Length: 219\r\n', + b'X-XSS-Protection: 1; mode=block\r\n', + b'X-Frame-Options: SAMEORIGIN\r\n\r\n', + b'\n' + + b'301 Moved', + b'\n

301 Moved

\nThe document has moved\n' + + b'here.\r\n\r\n', + ]), + ) self.assertEqual(self.parser.code, b'301') self.assertEqual(self.parser.reason, b'Moved Permanently') self.assertEqual(self.parser.version, b'HTTP/1.1') @@ -384,63 +449,77 @@ def test_response_parse(self) -> None: self.parser.body, b'\n' + b'301 Moved\n

301 Moved

\nThe document has moved\n' + - b'here.\r\n\r\n') - self.assertEqual(self.parser.headers[b'content-length'], - (b'Content-Length', b'219')) + b'here.\r\n\r\n', + ) + self.assertEqual( + self.parser.headers[b'content-length'], + (b'Content-Length', b'219'), + ) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_response_partial_parse(self) -> None: self.parser.type = httpParserTypes.RESPONSE_PARSER - self.parser.parse(b''.join([ - b'HTTP/1.1 301 Moved Permanently\r\n', - b'Location: http://www.google.com/\r\n', - b'Content-Type: text/html; charset=UTF-8\r\n', - b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', - b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', - b'Cache-Control: public, max-age=2592000\r\n', - b'Server: gws\r\n', - b'Content-Length: 219\r\n', - b'X-XSS-Protection: 1; mode=block\r\n', - b'X-Frame-Options: SAMEORIGIN\r\n' - ])) - self.assertEqual(self.parser.headers[b'x-frame-options'], - (b'X-Frame-Options', b'SAMEORIGIN')) + self.parser.parse( + b''.join([ + b'HTTP/1.1 301 Moved Permanently\r\n', + b'Location: http://www.google.com/\r\n', + b'Content-Type: text/html; charset=UTF-8\r\n', + b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', + b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', + b'Cache-Control: public, max-age=2592000\r\n', + b'Server: gws\r\n', + b'Content-Length: 219\r\n', + b'X-XSS-Protection: 1; mode=block\r\n', + b'X-Frame-Options: SAMEORIGIN\r\n', + ]), + ) + self.assertEqual( + self.parser.headers[b'x-frame-options'], + (b'X-Frame-Options', b'SAMEORIGIN'), + ) self.assertEqual( self.parser.state, - httpParserStates.RCVING_HEADERS) + httpParserStates.RCVING_HEADERS, + ) self.parser.parse(b'\r\n') self.assertEqual( self.parser.state, - httpParserStates.HEADERS_COMPLETE) + httpParserStates.HEADERS_COMPLETE, + ) self.parser.parse( b'\n' + - b'301 Moved') + b'301 Moved', + ) self.assertEqual( self.parser.state, - httpParserStates.RCVING_BODY) + httpParserStates.RCVING_BODY, + ) self.parser.parse( b'\n

301 Moved

\nThe document has moved\n' + - b'here.\r\n\r\n') + b'here.\r\n\r\n', + ) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_chunked_response_parse(self) -> None: self.parser.type = httpParserTypes.RESPONSE_PARSER - self.parser.parse(b''.join([ - b'HTTP/1.1 200 OK\r\n', - b'Content-Type: application/json\r\n', - b'Date: Wed, 22 May 2013 15:08:15 GMT\r\n', - b'Server: gunicorn/0.16.1\r\n', - b'transfer-encoding: chunked\r\n', - b'Connection: keep-alive\r\n\r\n', - b'4\r\n', - b'Wiki\r\n', - b'5\r\n', - b'pedia\r\n', - b'E\r\n', - b' in\r\n\r\nchunks.\r\n', - b'0\r\n', - b'\r\n' - ])) + self.parser.parse( + b''.join([ + b'HTTP/1.1 200 OK\r\n', + b'Content-Type: application/json\r\n', + b'Date: Wed, 22 May 2013 15:08:15 GMT\r\n', + b'Server: gunicorn/0.16.1\r\n', + b'transfer-encoding: chunked\r\n', + b'Connection: keep-alive\r\n\r\n', + b'4\r\n', + b'Wiki\r\n', + b'5\r\n', + b'pedia\r\n', + b'E\r\n', + b' in\r\n\r\nchunks.\r\n', + b'0\r\n', + b'\r\n', + ]), + ) self.assertEqual(self.parser.body, b'Wikipedia in\r\n\r\nchunks.') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) @@ -448,7 +527,7 @@ def test_pipelined_response_parse(self) -> None: response = build_http_response( httpStatusCodes.OK, reason=b'OK', headers={ - b'Content-Length': b'15' + b'Content-Length': b'15', }, body=b'{"key":"value"}', ) @@ -461,7 +540,7 @@ def test_pipelined_chunked_response_parse(self) -> None: b'Transfer-Encoding': b'chunked', b'Content-Type': b'application/json', }, - body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n' + body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n', ) self.assert_pipeline_response(response) @@ -480,52 +559,69 @@ def assert_pipeline_response(self, response: bytes) -> None: self.assertEqual(parser.buffer, b'') def test_chunked_request_parse(self) -> None: - self.parser.parse(build_http_request( - httpMethods.POST, b'http://example.org/', - headers={ - b'Transfer-Encoding': b'chunked', - b'Content-Type': b'application/json', - }, - body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n')) + self.parser.parse( + build_http_request( + httpMethods.POST, + b'http://example.org/', + headers={ + b'Transfer-Encoding': b'chunked', + b'Content-Type': b'application/json', + }, + body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n', + ), + ) self.assertEqual(self.parser.body, b'{"key":"value"}') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) - self.assertEqual(self.parser.build(), build_http_request( - httpMethods.POST, b'/', - headers={ - b'Transfer-Encoding': b'chunked', - b'Content-Type': b'application/json', - }, - body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n')) + self.assertEqual( + self.parser.build(), build_http_request( + httpMethods.POST, + b'/', + headers={ + b'Transfer-Encoding': b'chunked', + b'Content-Type': b'application/json', + }, + body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n', + ), + ) def test_is_http_1_1_keep_alive(self) -> None: - self.parser.parse(build_http_request( - httpMethods.GET, b'/' - )) + self.parser.parse( + build_http_request( + httpMethods.GET, b'/', + ), + ) self.assertTrue(self.parser.is_http_1_1_keep_alive()) def test_is_http_1_1_keep_alive_with_non_close_connection_header( - self) -> None: - self.parser.parse(build_http_request( - httpMethods.GET, b'/', - headers={ - b'Connection': b'keep-alive', - } - )) + self, + ) -> None: + self.parser.parse( + build_http_request( + httpMethods.GET, b'/', + headers={ + b'Connection': b'keep-alive', + }, + ), + ) self.assertTrue(self.parser.is_http_1_1_keep_alive()) def test_is_not_http_1_1_keep_alive_with_close_header(self) -> None: - self.parser.parse(build_http_request( - httpMethods.GET, b'/', - headers={ - b'Connection': b'close', - } - )) + self.parser.parse( + build_http_request( + httpMethods.GET, b'/', + headers={ + b'Connection': b'close', + }, + ), + ) self.assertFalse(self.parser.is_http_1_1_keep_alive()) def test_is_not_http_1_1_keep_alive_for_http_1_0(self) -> None: - self.parser.parse(build_http_request( - httpMethods.GET, b'/', protocol_version=b'HTTP/1.0', - )) + self.parser.parse( + build_http_request( + httpMethods.GET, b'/', protocol_version=b'HTTP/1.0', + ), + ) self.assertFalse(self.parser.is_http_1_1_keep_alive()) def test_paramiko_doc(self) -> None: diff --git a/tests/http/test_http_proxy.py b/tests/http/test_http_proxy.py index 3d85ed52ae..f00737ed05 100644 --- a/tests/http/test_http_proxy.py +++ b/tests/http/test_http_proxy.py @@ -25,9 +25,11 @@ class TestHttpProxyPlugin(unittest.TestCase): @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + def setUp( + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + ) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector @@ -37,12 +39,13 @@ def setUp(self, self.plugin = mock.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], - b'HttpProxyBasePlugin': [self.plugin] + b'HttpProxyBasePlugin': [self.plugin], } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=self.flags) + flags=self.flags, + ) self.protocol_handler.initialize() def test_proxy_plugin_initialized(self) -> None: @@ -51,7 +54,8 @@ def test_proxy_plugin_initialized(self) -> None: @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_proxy_plugin_on_and_before_upstream_connection( self, - mock_server_conn: mock.Mock) -> None: + mock_server_conn: mock.Mock, + ) -> None: self.plugin.return_value.write_to_descriptors.return_value = False self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.before_upstream_connection.side_effect = lambda r: r @@ -60,14 +64,20 @@ def test_proxy_plugin_on_and_before_upstream_connection( self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ - b'Host': b'upstream.host' - }) + b'Host': b'upstream.host', + }, + ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_called_with('upstream.host', DEFAULT_HTTP_PORT) @@ -77,7 +87,8 @@ def test_proxy_plugin_on_and_before_upstream_connection( @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_proxy_plugin_before_upstream_connection_can_teardown( self, - mock_server_conn: mock.Mock) -> None: + mock_server_conn: mock.Mock, + ) -> None: self.plugin.return_value.write_to_descriptors.return_value = False self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.before_upstream_connection.side_effect = HttpProtocolException() @@ -85,14 +96,20 @@ def test_proxy_plugin_before_upstream_connection_can_teardown( self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ - b'Host': b'upstream.host' - }) + b'Host': b'upstream.host', + }, + ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_not_called() diff --git a/tests/http/test_http_proxy_tls_interception.py b/tests/http/test_http_proxy_tls_interception.py index 96564fcf3d..43989a5d61 100644 --- a/tests/http/test_http_proxy_tls_interception.py +++ b/tests/http/test_http_proxy_tls_interception.py @@ -44,7 +44,8 @@ def test_e2e( mock_gen_public_key: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock) -> None: + mock_ssl_wrap: mock.Mock, + ) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) @@ -74,7 +75,8 @@ def mock_connection() -> Any: # Do not mock the original wrap method self.mock_server_conn.return_value.wrap.side_effect = \ lambda x, y: TcpServerConnection.wrap( - self.mock_server_conn.return_value, x, y) + self.mock_server_conn.return_value, x, y, + ) type(self.mock_server_conn.return_value).connection = \ mock.PropertyMock(side_effect=mock_connection) @@ -84,7 +86,7 @@ def mock_connection() -> Any: self.flags = Proxy.initialize( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', - ca_signing_key_file='ca-signing-key.pem' + ca_signing_key_file='ca-signing-key.pem', ) self.plugin = mock.MagicMock() self.proxy_plugin = mock.MagicMock() @@ -95,7 +97,8 @@ def mock_connection() -> Any: self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=self.flags) + flags=self.flags, + ) self.protocol_handler.initialize() self.plugin.assert_called() @@ -105,13 +108,15 @@ def mock_connection() -> Any: self.assertEqual(self.proxy_plugin.call_args[0][1], self.flags) self.assertEqual( self.proxy_plugin.call_args[0][2].connection, - self._conn) + self._conn, + ) connect_request = build_http_request( httpMethods.CONNECT, bytes_(netloc), headers={ b'Host': bytes_(netloc), - }) + }, + ) self._conn.recv.return_value = connect_request # Prepare mocked HttpProtocolHandlerPlugin @@ -130,11 +135,16 @@ def mock_connection() -> Any: self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() @@ -142,52 +152,63 @@ def mock_connection() -> Any: self.plugin.return_value.get_descriptors.assert_called() self.plugin.return_value.write_to_descriptors.assert_called_with([]) self.plugin.return_value.on_client_data.assert_called_with( - connect_request) + connect_request, + ) self.plugin.return_value.on_request_complete.assert_called() self.plugin.return_value.read_from_descriptors.assert_called_with([ - self._conn]) + self._conn, + ]) self.proxy_plugin.return_value.before_upstream_connection.assert_called() self.proxy_plugin.return_value.handle_client_request.assert_called() self.mock_server_conn.assert_called_with(host, port) self.mock_server_conn.return_value.connection.setblocking.assert_called_with( - False) + False, + ) self.mock_ssl_context.assert_called_with( - ssl.Purpose.SERVER_AUTH, cafile=None) + ssl.Purpose.SERVER_AUTH, cafile=None, + ) # self.assertEqual(self.mock_ssl_context.return_value.options, # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | # ssl.OP_NO_TLSv1_1) self.assertEqual(plain_connection.setblocking.call_count, 2) self.mock_ssl_context.return_value.wrap_socket.assert_called_with( - plain_connection, server_hostname=host) + plain_connection, server_hostname=host, + ) self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) self.assertEqual(ssl_connection.setblocking.call_count, 1) self.assertEqual( self.mock_server_conn.return_value._conn, - ssl_connection) + ssl_connection, + ) self._conn.send.assert_called_with( - HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, + ) assert self.flags.ca_cert_dir is not None self.mock_ssl_wrap.assert_called_with( self._conn, server_side=True, keyfile=self.flags.ca_signing_key_file, certfile=HttpProxyPlugin.generated_cert_file_path( - self.flags.ca_cert_dir, host), - ssl_version=ssl.PROTOCOL_TLS + self.flags.ca_cert_dir, host, + ), + ssl_version=ssl.PROTOCOL_TLS, ) self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual( self.protocol_handler.client.connection, - self.mock_ssl_wrap.return_value) + self.mock_ssl_wrap.return_value, + ) # Assert connection references for all other plugins is updated self.assertEqual( self.plugin.return_value.client._conn, - self.mock_ssl_wrap.return_value) + self.mock_ssl_wrap.return_value, + ) self.assertEqual( self.proxy_plugin.return_value.client._conn, - self.mock_ssl_wrap.return_value) + self.mock_ssl_wrap.return_value, + ) diff --git a/tests/http/test_protocol_handler.py b/tests/http/test_protocol_handler.py index a79e6b4ad8..fe365e6a95 100644 --- a/tests/http/test_protocol_handler.py +++ b/tests/http/test_protocol_handler.py @@ -31,9 +31,11 @@ class TestHttpProtocolHandler(unittest.TestCase): @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + def setUp( + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + ) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value @@ -47,7 +49,8 @@ def setUp(self, self.mock_selector = mock_selector self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), flags=self.flags) + TcpClientConnection(self._conn, self._addr), flags=self.flags, + ) self.protocol_handler.initialize() @mock.patch('proxy.http.proxy.server.TcpServerConnection') @@ -56,19 +59,24 @@ def test_http_get(self, mock_server_connection: mock.Mock) -> None: server.connect.return_value = True server.buffer_size.return_value = 0 self.mock_selector_for_client_read_read_server_write( - self.mock_selector, server) + self.mock_selector, server, + ) # Send request line assert self.http_server_port is not None - self._conn.recv.return_value = (b'GET http://localhost:%d HTTP/1.1' % - self.http_server_port) + CRLF + self._conn.recv.return_value = ( + b'GET http://localhost:%d HTTP/1.1' % + self.http_server_port + ) + CRLF self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.request.state, - httpParserStates.LINE_RCVD) + httpParserStates.LINE_RCVD, + ) self.assertNotEqual( self.protocol_handler.request.state, - httpParserStates.COMPLETE) + httpParserStates.COMPLETE, + ) # Send headers and blank line, thus completing HTTP request assert self.http_server_port is not None @@ -77,20 +85,23 @@ def test_http_get(self, mock_server_connection: mock.Mock) -> None: b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Proxy-Connection: Keep-Alive', - CRLF + CRLF, ]) self.assert_data_queued(mock_server_connection, server) self.protocol_handler.run_once() server.flush.assert_called_once() def assert_tunnel_response( - self, mock_server_connection: mock.Mock, server: mock.Mock) -> None: + self, mock_server_connection: mock.Mock, server: mock.Mock, + ) -> None: self.protocol_handler.run_once() self.assertTrue( - cast(HttpProxyPlugin, self.protocol_handler.plugins['HttpProxyPlugin']).server is not None) + cast(HttpProxyPlugin, self.protocol_handler.plugins['HttpProxyPlugin']).server is not None, + ) self.assertEqual( self.protocol_handler.client.buffer[0], - HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, + ) mock_server_connection.assert_called_once() server.connect.assert_called_once() server.queue.assert_not_called() @@ -112,26 +123,50 @@ def has_buffer() -> bool: server.has_buffer.side_effect = has_buffer self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ), ], - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=0, - data=None), selectors.EVENT_WRITE), ], - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ), ], - [(selectors.SelectorKey( - fileobj=server.connection, - fd=server.connection.fileno, - events=0, - data=None), selectors.EVENT_WRITE), ], + [ + ( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ], + [ + ( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=0, + data=None, + ), + selectors.EVENT_WRITE, + ), + ], + [ + ( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ], + [ + ( + selectors.SelectorKey( + fileobj=server.connection, + fd=server.connection.fileno, + events=0, + data=None, + ), + selectors.EVENT_WRITE, + ), + ], ] assert self.http_server_port is not None @@ -140,7 +175,7 @@ def has_buffer() -> bool: b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Proxy-Connection: Keep-Alive', - CRLF + CRLF, ]) self.assert_tunnel_response(mock_server_connection, server) @@ -157,40 +192,45 @@ def test_proxy_connection_failed(self) -> None: self._conn.recv.return_value = CRLF.join([ b'GET http://unknown.domain HTTP/1.1', b'Host: unknown.domain', - CRLF + CRLF, ]) self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.client.buffer[0], - ProxyConnectionFailed.RESPONSE_PKT) + ProxyConnectionFailed.RESPONSE_PKT, + ) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_proxy_authentication_failed( self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + ) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) flags = Proxy.initialize( - auth_code=base64.b64encode(b'user:pass')) + auth_code=base64.b64encode(b'user:pass'), + ) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PROXY_AUTH), ]) self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), flags=flags) + TcpClientConnection(self._conn, self._addr), flags=flags, + ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', - CRLF + CRLF, ]) self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.client.buffer[0], - ProxyAuthenticationFailed.RESPONSE_PKT) + ProxyAuthenticationFailed.RESPONSE_PKT, + ) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @@ -198,7 +238,8 @@ def test_proxy_authentication_failed( def test_authenticated_proxy_http_get( self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + ) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) @@ -207,14 +248,16 @@ def test_authenticated_proxy_http_get( server.buffer_size.return_value = 0 flags = Proxy.initialize( - auth_code=base64.b64encode(b'user:pass')) + auth_code=base64.b64encode(b'user:pass'), + ) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), flags=flags) + TcpClientConnection(self._conn, self._addr), flags=flags, + ) self.protocol_handler.initialize() assert self.http_server_port is not None @@ -222,13 +265,15 @@ def test_authenticated_proxy_http_get( self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.request.state, - httpParserStates.INITIALIZED) + httpParserStates.INITIALIZED, + ) self._conn.recv.return_value = CRLF self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.request.state, - httpParserStates.LINE_RCVD) + httpParserStates.LINE_RCVD, + ) assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ @@ -237,7 +282,7 @@ def test_authenticated_proxy_http_get( b'Accept: */*', b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', - CRLF + CRLF, ]) self.assert_data_queued(mock_server_connection, server) @@ -247,23 +292,27 @@ def test_authenticated_proxy_http_get( def test_authenticated_proxy_http_tunnel( self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + ) -> None: server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 self._conn = mock_fromfd.return_value self.mock_selector_for_client_read_read_server_write( - mock_selector, server) + mock_selector, server, + ) flags = Proxy.initialize( - auth_code=base64.b64encode(b'user:pass')) + auth_code=base64.b64encode(b'user:pass'), + ) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), - bytes_(PLUGIN_WEB_SERVER) + bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), flags=flags) + TcpClientConnection(self._conn, self._addr), flags=flags, + ) self.protocol_handler.initialize() assert self.http_server_port is not None @@ -273,7 +322,7 @@ def test_authenticated_proxy_http_tunnel( b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', - CRLF + CRLF, ]) self.assert_tunnel_response(mock_server_connection, server) self.protocol_handler.client.flush() @@ -283,31 +332,52 @@ def test_authenticated_proxy_http_tunnel( server.flush.assert_called_once() def mock_selector_for_client_read_read_server_write( - self, mock_selector: mock.Mock, server: mock.Mock) -> None: + self, mock_selector: mock.Mock, server: mock.Mock, + ) -> None: mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ), ], - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=0, - data=None), selectors.EVENT_READ), ], - [(selectors.SelectorKey( - fileobj=server.connection, - fd=server.connection.fileno, - events=0, - data=None), selectors.EVENT_WRITE), ], + [ + ( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ], + [ + ( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=0, + data=None, + ), + selectors.EVENT_READ, + ), + ], + [ + ( + selectors.SelectorKey( + fileobj=server.connection, + fd=server.connection.fileno, + events=0, + data=None, + ), + selectors.EVENT_WRITE, + ), + ], ] def assert_data_queued( - self, mock_server_connection: mock.Mock, server: mock.Mock) -> None: + self, mock_server_connection: mock.Mock, server: mock.Mock, + ) -> None: self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.request.state, - httpParserStates.COMPLETE) + httpParserStates.COMPLETE, + ) mock_server_connection.assert_called_once() server.connect.assert_called_once() server.closed = False @@ -318,7 +388,7 @@ def assert_data_queued( b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Via: 1.1 proxy.py v%s' % bytes_(__version__), - CRLF + CRLF, ]) server.queue.assert_called_once_with(pkt) server.buffer_size.return_value = len(pkt) @@ -327,13 +397,14 @@ def assert_data_queued_to_server(self, server: mock.Mock) -> None: assert self.http_server_port is not None self.assertEqual( self._conn.send.call_args[0][0], - HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, + ) pkt = CRLF.join([ b'GET / HTTP/1.1', b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), - CRLF + CRLF, ]) self._conn.recv.return_value = pkt @@ -344,9 +415,14 @@ def assert_data_queued_to_server(self, server: mock.Mock) -> None: server.flush.assert_not_called() def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: - mock_selector.return_value.select.return_value = [( - selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ), ] + mock_selector.return_value.select.return_value = [ + ( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ] diff --git a/tests/http/test_web_server.py b/tests/http/test_web_server.py index cdb3592633..b95bd7c671 100644 --- a/tests/http/test_web_server.py +++ b/tests/http/test_web_server.py @@ -40,36 +40,43 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: ]) self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=self.flags) + flags=self.flags, + ) self.protocol_handler.initialize() @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_disk( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + ) -> None: pac_file = os.path.join( os.path.dirname(PROXY_PY_DIR), 'helper', - 'proxy.pac') + 'proxy.pac', + ) self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) self.init_and_make_pac_file_request(pac_file) self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.request.state, - httpParserStates.COMPLETE) + httpParserStates.COMPLETE, + ) with open(pac_file, 'rb') as f: - self._conn.send.called_once_with(build_http_response( - 200, reason=b'OK', headers={ - b'Content-Type': b'application/x-ns-proxy-autoconfig', - b'Connection': b'close' - }, body=f.read() - )) + self._conn.send.called_once_with( + build_http_response( + 200, reason=b'OK', headers={ + b'Content-Type': b'application/x-ns-proxy-autoconfig', + b'Connection': b'close', + }, body=f.read(), + ), + ) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_buffer( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + ) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' @@ -77,25 +84,34 @@ def test_pac_file_served_from_buffer( self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.request.state, - httpParserStates.COMPLETE) - self._conn.send.called_once_with(build_http_response( - 200, reason=b'OK', headers={ - b'Content-Type': b'application/x-ns-proxy-autoconfig', - b'Connection': b'close' - }, body=pac_file_content - )) + httpParserStates.COMPLETE, + ) + self._conn.send.called_once_with( + build_http_response( + 200, reason=b'OK', headers={ + b'Content-Type': b'application/x-ns-proxy-autoconfig', + b'Connection': b'close', + }, body=pac_file_content, + ), + ) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_default_web_server_returns_404( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + ) -> None: self._conn = mock_fromfd.return_value - mock_selector.return_value.select.return_value = [( - selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ), ] + mock_selector.return_value.select.return_value = [ + ( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ] flags = Proxy.initialize() flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), @@ -103,7 +119,8 @@ def test_default_web_server_returns_404( ]) self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=flags) + flags=flags, + ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET /hello HTTP/1.1', @@ -112,17 +129,22 @@ def test_default_web_server_returns_404( self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.request.state, - httpParserStates.COMPLETE) + httpParserStates.COMPLETE, + ) self.assertEqual( self.protocol_handler.client.buffer[0], - HttpWebServerPlugin.DEFAULT_404_RESPONSE) + HttpWebServerPlugin.DEFAULT_404_RESPONSE, + ) - @unittest.skipIf(os.environ.get('GITHUB_ACTIONS', False), - 'Disabled on GitHub actions because this test is flaky on GitHub infrastructure.') + @unittest.skipIf( + os.environ.get('GITHUB_ACTIONS', False), + 'Disabled on GitHub actions because this test is flaky on GitHub infrastructure.', + ) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + ) -> None: # Setup a static directory static_server_dir = os.path.join(tempfile.gettempdir(), 'static') index_file_path = os.path.join(static_server_dir, 'index.html') @@ -133,23 +155,34 @@ def test_static_web_server_serves( self._conn = mock_fromfd.return_value self._conn.recv.return_value = build_http_request( - b'GET', b'/index.html') + b'GET', b'/index.html', + ) mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_WRITE, - data=None), selectors.EVENT_WRITE)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_WRITE, + data=None, + ), + selectors.EVENT_WRITE, + )], + ] flags = Proxy.initialize( enable_static_server=True, - static_server_dir=static_server_dir) + static_server_dir=static_server_dir, + ) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), @@ -157,7 +190,8 @@ def test_static_web_server_serves( self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=flags) + flags=flags, + ) self.protocol_handler.initialize() self.protocol_handler.run_once() @@ -166,38 +200,51 @@ def test_static_web_server_serves( self.assertEqual(mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) encoded_html_file_content = gzip.compress(html_file_content) - self.assertEqual(self._conn.send.call_args[0][0], build_http_response( - 200, reason=b'OK', headers={ - b'Content-Type': b'text/html', - b'Cache-Control': b'max-age=86400', - b'Content-Encoding': b'gzip', - b'Connection': b'close', - b'Content-Length': bytes_(len(encoded_html_file_content)), - }, - body=encoded_html_file_content - )) + self.assertEqual( + self._conn.send.call_args[0][0], build_http_response( + 200, reason=b'OK', headers={ + b'Content-Type': b'text/html', + b'Cache-Control': b'max-age=86400', + b'Content-Encoding': b'gzip', + b'Connection': b'close', + b'Content-Length': bytes_(len(encoded_html_file_content)), + }, + body=encoded_html_file_content, + ), + ) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves_404( self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + ) -> None: self._conn = mock_fromfd.return_value self._conn.recv.return_value = build_http_request( - b'GET', b'/not-found.html') + b'GET', b'/not-found.html', + ) mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_WRITE, - data=None), selectors.EVENT_WRITE)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_WRITE, + data=None, + ), + selectors.EVENT_WRITE, + )], + ] flags = Proxy.initialize(enable_static_server=True) flags.plugins = Proxy.load_plugins([ @@ -207,7 +254,8 @@ def test_static_web_server_serves_404( self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=flags) + flags=flags, + ) self.protocol_handler.initialize() self.protocol_handler.run_once() @@ -215,19 +263,23 @@ def test_static_web_server_serves_404( self.assertEqual(mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) - self.assertEqual(self._conn.send.call_args[0][0], - HttpWebServerPlugin.DEFAULT_404_RESPONSE) + self.assertEqual( + self._conn.send.call_args[0][0], + HttpWebServerPlugin.DEFAULT_404_RESPONSE, + ) @mock.patch('socket.fromfd') def test_on_client_connection_called_on_teardown( - self, mock_fromfd: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, + ) -> None: flags = Proxy.initialize() plugin = mock.MagicMock() flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=flags) + flags=flags, + ) self.protocol_handler.initialize() plugin.assert_called() with mock.patch.object(self.protocol_handler, 'run_once') as mock_run_once: @@ -245,7 +297,8 @@ def init_and_make_pac_file_request(self, pac_file: str) -> None: ]) self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=flags) + flags=flags, + ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET / HTTP/1.1', @@ -253,9 +306,14 @@ def init_and_make_pac_file_request(self, pac_file: str) -> None: ]) def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: - mock_selector.return_value.select.return_value = [( - selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ), ] + mock_selector.return_value.select.return_value = [ + ( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ] diff --git a/tests/http/test_websocket_client.py b/tests/http/test_websocket_client.py index faf18b2ac6..ef500c5b97 100644 --- a/tests/http/test_websocket_client.py +++ b/tests/http/test_websocket_client.py @@ -21,18 +21,21 @@ class TestWebsocketClient(unittest.TestCase): @mock.patch('proxy.http.websocket.client.socket.gethostbyname') @mock.patch('base64.b64encode') @mock.patch('proxy.http.websocket.client.new_socket_connection') - def test_handshake(self, mock_connect: mock.Mock, - mock_b64encode: mock.Mock, - mock_gethostbyname: mock.Mock) -> None: + def test_handshake( + self, mock_connect: mock.Mock, + mock_b64encode: mock.Mock, + mock_gethostbyname: mock.Mock, + ) -> None: key = b'MySecretKey' mock_b64encode.return_value = key mock_gethostbyname.return_value = '127.0.0.1' mock_connect.return_value.recv.return_value = \ build_websocket_handshake_response( - WebsocketFrame.key_to_accept(key)) + WebsocketFrame.key_to_accept(key), + ) client = WebsocketClient(b'localhost', DEFAULT_PORT) mock_connect.return_value.send.assert_not_called() client.handshake() mock_connect.return_value.send.assert_called_with( - build_websocket_handshake_request(key) + build_websocket_handshake_request(key), ) diff --git a/tests/plugin/test_http_proxy_plugins.py b/tests/plugin/test_http_proxy_plugins.py index cf6c351cf2..ab62147a8d 100644 --- a/tests/plugin/test_http_proxy_plugins.py +++ b/tests/plugin/test_http_proxy_plugins.py @@ -34,16 +34,22 @@ class TestHttpProxyPluginExamples(unittest.TestCase): @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + def setUp( + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + ) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) adblock_json_path = Path( - __file__).parent.parent.parent / "proxy" / "plugin" / "adblock.json" + __file__, + ).parent.parent.parent / "proxy" / "plugin" / "adblock.json" self.flags = Proxy.initialize( - input_args=["--filtered-url-regex-config", - str(adblock_json_path)]) + input_args=[ + "--filtered-url-regex-config", + str(adblock_json_path), + ], + ) self.plugin = mock.MagicMock() self.mock_fromfd = mock_fromfd @@ -58,12 +64,14 @@ def setUp(self, self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=self.flags) + flags=self.flags, + ) self.protocol_handler.initialize() @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_modify_post_data_plugin( - self, mock_server_conn: mock.Mock) -> None: + self, mock_server_conn: mock.Mock, + ) -> None: original = b'{"key": "value"}' modified = b'{"key": "modified"}' @@ -74,14 +82,19 @@ def test_modify_post_data_plugin( b'Content-Type': b'application/x-www-form-urlencoded', b'Content-Length': bytes_(len(original)), }, - body=original + body=original, ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_called_with('httpbin.org', DEFAULT_HTTP_PORT) @@ -94,27 +107,34 @@ def test_modify_post_data_plugin( b'Content-Type': b'application/json', b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, - body=modified - ) + body=modified, + ), ) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_proposed_rest_api_plugin( - self, mock_server_conn: mock.Mock) -> None: + self, mock_server_conn: mock.Mock, + ) -> None: path = b'/v1/users/' self._conn.recv.return_value = build_http_request( b'GET', b'http://%s%s' % ( - ProposedRestApiPlugin.API_SERVER, path), + ProposedRestApiPlugin.API_SERVER, path, + ), headers={ b'Host': ProposedRestApiPlugin.API_SERVER, - } + }, ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_not_called() @@ -125,29 +145,39 @@ def test_proposed_rest_api_plugin( headers={b'Content-Type': b'application/json'}, body=bytes_( json.dumps( - ProposedRestApiPlugin.REST_API_SPEC[path])) - )) + ProposedRestApiPlugin.REST_API_SPEC[path], + ), + ), + ), + ) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_redirect_to_custom_server_plugin( - self, mock_server_conn: mock.Mock) -> None: + self, mock_server_conn: mock.Mock, + ) -> None: request = build_http_request( b'GET', b'http://example.org/get', headers={ b'Host': b'example.org', - } + }, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() upstream = urlparse.urlsplit( - RedirectToCustomServerPlugin.UPSTREAM_SERVER) + RedirectToCustomServerPlugin.UPSTREAM_SERVER, + ) mock_server_conn.assert_called_with('localhost', 8899) mock_server_conn.return_value.queue.assert_called_with( build_http_request( @@ -155,26 +185,32 @@ def test_redirect_to_custom_server_plugin( headers={ b'Host': upstream.netloc, b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, - } - ) + }, + ), ) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_filter_by_upstream_host_plugin( - self, mock_server_conn: mock.Mock) -> None: + self, mock_server_conn: mock.Mock, + ) -> None: request = build_http_request( b'GET', b'http://facebook.com/', headers={ b'Host': b'facebook.com', - } + }, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() mock_server_conn.assert_not_called() @@ -184,19 +220,20 @@ def test_filter_by_upstream_host_plugin( status_code=httpStatusCodes.I_AM_A_TEAPOT, reason=b'I\'m a tea pot', headers={ - b'Connection': b'close' + b'Connection': b'close', }, - ) + ), ) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_man_in_the_middle_plugin( - self, mock_server_conn: mock.Mock) -> None: + self, mock_server_conn: mock.Mock, + ) -> None: request = build_http_request( b'GET', b'http://super.secure/', headers={ b'Host': b'super.secure', - } + }, ) self._conn.recv.return_value = request @@ -213,21 +250,34 @@ def closed() -> bool: type(server).closed = mock.PropertyMock(side_effect=closed) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], - [(selectors.SelectorKey( - fileobj=server.connection, - fd=server.connection.fileno, - events=selectors.EVENT_WRITE, - data=None), selectors.EVENT_WRITE)], - [(selectors.SelectorKey( - fileobj=server.connection, - fd=server.connection.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + [( + selectors.SelectorKey( + fileobj=server.connection, + fd=server.connection.fileno, + events=selectors.EVENT_WRITE, + data=None, + ), + selectors.EVENT_WRITE, + )], + [( + selectors.SelectorKey( + fileobj=server.connection, + fd=server.connection.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] # Client read self.protocol_handler.run_once() @@ -238,8 +288,8 @@ def closed() -> bool: b'GET', b'/', headers={ b'Host': b'super.secure', - b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE - } + b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, + }, ) server.queue.assert_called_once_with(queued_request) @@ -251,31 +301,39 @@ def closed() -> bool: server.recv.return_value = \ build_http_response( httpStatusCodes.OK, - reason=b'OK', body=b'Original Response From Upstream') + reason=b'OK', body=b'Original Response From Upstream', + ) self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.client.buffer[0].tobytes(), build_http_response( httpStatusCodes.OK, - reason=b'OK', body=b'Hello from man in the middle') + reason=b'OK', body=b'Hello from man in the middle', + ), ) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_filter_by_url_regex_plugin( - self, mock_server_conn: mock.Mock) -> None: + self, mock_server_conn: mock.Mock, + ) -> None: request = build_http_request( b'GET', b'http://www.facebook.com/tr/', headers={ b'Host': b'www.facebook.com', - } + }, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] self.protocol_handler.run_once() self.assertEqual( @@ -284,5 +342,5 @@ def test_filter_by_url_regex_plugin( status_code=httpStatusCodes.NOT_FOUND, reason=b'Blocked', headers={b'Connection': b'close'}, - ) + ), ) diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index 39311b22aa..164edcfee9 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -38,15 +38,17 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): @mock.patch('proxy.http.proxy.server.sign_csr') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_sign_csr: mock.Mock, - mock_gen_csr: mock.Mock, - mock_gen_public_key: mock.Mock, - mock_server_conn: mock.Mock, - mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock) -> None: + def setUp( + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_sign_csr: mock.Mock, + mock_gen_csr: mock.Mock, + mock_gen_public_key: mock.Mock, + mock_server_conn: mock.Mock, + mock_ssl_context: mock.Mock, + mock_ssl_wrap: mock.Mock, + ) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_sign_csr = mock_sign_csr @@ -65,7 +67,8 @@ def setUp(self, self.flags = Proxy.initialize( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', - ca_signing_key_file='ca-signing-key.pem',) + ca_signing_key_file='ca-signing-key.pem', + ) self.plugin = mock.MagicMock() plugin = get_plugin_by_test_name(self._testMethodName) @@ -77,7 +80,8 @@ def setUp(self, self._conn = mock.MagicMock(spec=socket.socket) mock_fromfd.return_value = self._conn self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), flags=self.flags) + TcpClientConnection(self._conn, self._addr), flags=self.flags, + ) self.protocol_handler.initialize() self.server = self.mock_server_conn.return_value @@ -105,30 +109,49 @@ def mock_connection() -> Any: self.server.has_buffer.side_effect = has_buffer type(self.server).closed = mock.PropertyMock(side_effect=closed) type( - self.server).connection = mock.PropertyMock( - side_effect=mock_connection) + self.server, + ).connection = mock.PropertyMock( + side_effect=mock_connection, + ) self.mock_selector.return_value.select.side_effect = [ - [(selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], - [(selectors.SelectorKey( - fileobj=self.client_ssl_connection, - fd=self.client_ssl_connection.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], - [(selectors.SelectorKey( - fileobj=self.server_ssl_connection, - fd=self.server_ssl_connection.fileno, - events=selectors.EVENT_WRITE, - data=None), selectors.EVENT_WRITE)], - [(selectors.SelectorKey( - fileobj=self.server_ssl_connection, - fd=self.server_ssl_connection.fileno, - events=selectors.EVENT_READ, - data=None), selectors.EVENT_READ)], ] + [( + selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + [( + selectors.SelectorKey( + fileobj=self.client_ssl_connection, + fd=self.client_ssl_connection.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + [( + selectors.SelectorKey( + fileobj=self.server_ssl_connection, + fd=self.server_ssl_connection.fileno, + events=selectors.EVENT_WRITE, + data=None, + ), + selectors.EVENT_WRITE, + )], + [( + selectors.SelectorKey( + fileobj=self.server_ssl_connection, + fd=self.server_ssl_connection.fileno, + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] # Connect def send(raw: bytes) -> int: @@ -136,7 +159,7 @@ def send(raw: bytes) -> int: self._conn.send.side_effect = send self._conn.recv.return_value = build_http_request( - httpMethods.CONNECT, b'uni.corn:443' + httpMethods.CONNECT, b'uni.corn:443', ) self.protocol_handler.run_once() @@ -148,10 +171,11 @@ def send(raw: bytes) -> int: self.server.connect.assert_called() self.assertEqual( self.protocol_handler.client.connection, - self.client_ssl_connection) + self.client_ssl_connection, + ) self.assertEqual(self.server.connection, self.server_ssl_connection) self._conn.send.assert_called_with( - HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT + HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) self.assertFalse(self.protocol_handler.client.has_buffer()) @@ -165,7 +189,7 @@ def test_modify_post_data_plugin(self) -> None: b'Content-Type': b'application/x-www-form-urlencoded', b'Content-Length': bytes_(len(original)), }, - body=original + body=original, ) self.protocol_handler.run_once() self.server.queue.assert_called_with( @@ -176,8 +200,8 @@ def test_modify_post_data_plugin(self) -> None: b'Content-Length': bytes_(len(modified)), b'Content-Type': b'application/json', }, - body=modified - ) + body=modified, + ), ) def test_man_in_the_middle_plugin(self) -> None: @@ -185,7 +209,7 @@ def test_man_in_the_middle_plugin(self) -> None: b'GET', b'/', headers={ b'Host': b'uni.corn', - } + }, ) self.client_ssl_connection.recv.return_value = request @@ -201,11 +225,13 @@ def test_man_in_the_middle_plugin(self) -> None: self.server.recv.return_value = \ build_http_response( httpStatusCodes.OK, - reason=b'OK', body=b'Original Response From Upstream') + reason=b'OK', body=b'Original Response From Upstream', + ) self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.client.buffer[0].tobytes(), build_http_response( httpStatusCodes.OK, - reason=b'OK', body=b'Hello from man in the middle') + reason=b'OK', body=b'Hello from man in the middle', + ), ) diff --git a/tests/test_main.py b/tests/test_main.py index 598a9ed45e..ec99edbc58 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -78,7 +78,8 @@ def test_init_with_no_arguments( mock_acceptor_pool: mock.Mock, mock_event_manager: mock.Mock, mock_initialize: mock.Mock, - mock_sleep: mock.Mock) -> None: + mock_sleep: mock.Mock, + ) -> None: mock_sleep.side_effect = KeyboardInterrupt() input_args: List[str] = [] @@ -88,7 +89,7 @@ def test_init_with_no_arguments( mock_acceptor_pool.assert_called_with( flags=mock_initialize.return_value, work_klass=HttpProtocolHandler, - event_queue=None + event_queue=None, ) mock_acceptor_pool.return_value.setup.assert_called() mock_acceptor_pool.return_value.shutdown.assert_called() @@ -109,7 +110,8 @@ def test_pid_file_is_written_and_removed( mock_open: mock.Mock, mock_exists: mock.Mock, mock_remove: mock.Mock, - mock_sleep: mock.Mock) -> None: + mock_sleep: mock.Mock, + ) -> None: pid_file = get_temp_file('pid') mock_sleep.side_effect = KeyboardInterrupt() mock_args = mock_parse_args.return_value @@ -123,7 +125,8 @@ def test_pid_file_is_written_and_removed( mock_event_manager.assert_not_called() mock_open.assert_called_with(pid_file, 'wb') mock_open.return_value.__enter__.return_value.write.assert_called_with( - bytes_(os.getpid())) + bytes_(os.getpid()), + ) mock_exists.assert_called_with(pid_file) mock_remove.assert_called_with(pid_file) @@ -134,7 +137,8 @@ def test_basic_auth( self, mock_acceptor_pool: mock.Mock, mock_event_manager: mock.Mock, - mock_sleep: mock.Mock) -> None: + mock_sleep: mock.Mock, + ) -> None: mock_sleep.side_effect = KeyboardInterrupt() input_args = ['--basic-auth', 'user:pass'] @@ -145,7 +149,8 @@ def test_basic_auth( mock_acceptor_pool.assert_called_once() self.assertEqual( flgs.auth_code, - b'dXNlcjpwYXNz') + b'dXNlcjpwYXNz', + ) @mock.patch('time.sleep') @mock.patch('builtins.print') @@ -158,7 +163,8 @@ def test_main_py3_runs( mock_acceptor_pool: mock.Mock, mock_event_manager: mock.Mock, mock_print: mock.Mock, - mock_sleep: mock.Mock) -> None: + mock_sleep: mock.Mock, + ) -> None: mock_sleep.side_effect = KeyboardInterrupt() input_args = ['--basic-auth', 'user:pass'] @@ -178,7 +184,8 @@ def test_main_py3_runs( def test_main_py2_exit( self, mock_is_py3: mock.Mock, - mock_print: mock.Mock) -> None: + mock_print: mock.Mock, + ) -> None: mock_is_py3.return_value = False with self.assertRaises(SystemExit) as e: main(num_workers=1) @@ -189,7 +196,8 @@ def test_main_py2_exit( @mock.patch('builtins.print') def test_main_version( self, - mock_print: mock.Mock) -> None: + mock_print: mock.Mock, + ) -> None: with self.assertRaises(SystemExit) as e: main(['--version']) mock_print.assert_called_with(__version__) diff --git a/tests/test_set_open_file_limit.py b/tests/test_set_open_file_limit.py index 3bae38cfe2..1163564bed 100644 --- a/tests/test_set_open_file_limit.py +++ b/tests/test_set_open_file_limit.py @@ -20,7 +20,8 @@ @unittest.skipIf( os.name == 'nt', - 'Open file limit tests disabled for Windows') + 'Open file limit tests disabled for Windows', +) class TestSetOpenFileLimit(unittest.TestCase): @mock.patch('resource.getrlimit', return_value=(128, 1024)) @@ -28,7 +29,8 @@ class TestSetOpenFileLimit(unittest.TestCase): def test_set_open_file_limit( self, mock_set_rlimit: mock.Mock, - mock_get_rlimit: mock.Mock) -> None: + mock_get_rlimit: mock.Mock, + ) -> None: Proxy.set_open_file_limit(256) mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE) mock_set_rlimit.assert_called_with(resource.RLIMIT_NOFILE, (256, 1024)) @@ -38,7 +40,8 @@ def test_set_open_file_limit( def test_set_open_file_limit_not_called( self, mock_set_rlimit: mock.Mock, - mock_get_rlimit: mock.Mock) -> None: + mock_get_rlimit: mock.Mock, + ) -> None: Proxy.set_open_file_limit(256) mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE) mock_set_rlimit.assert_not_called() @@ -48,7 +51,8 @@ def test_set_open_file_limit_not_called( def test_set_open_file_limit_not_called_coz_upper_bound_check( self, mock_set_rlimit: mock.Mock, - mock_get_rlimit: mock.Mock) -> None: + mock_get_rlimit: mock.Mock, + ) -> None: Proxy.set_open_file_limit(1024) mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE) mock_set_rlimit.assert_not_called() diff --git a/tests/testing/test_embed.py b/tests/testing/test_embed.py index e69c5ffe56..87609c86ce 100644 --- a/tests/testing/test_embed.py +++ b/tests/testing/test_embed.py @@ -22,7 +22,8 @@ @unittest.skipIf( - os.name == 'nt', 'Disabled for Windows due to weird permission issues.') + os.name == 'nt', 'Disabled for Windows due to weird permission issues.', +) class TestProxyPyEmbedded(TestCase): """This test case is a demonstration of proxy.TestCase and also serves as integration test suite for proxy.py.""" @@ -39,7 +40,8 @@ def test_with_proxy(self) -> None: httpMethods.GET, b'http://localhost:%d/' % self.PROXY_PORT, headers={ b'Host': b'localhost:%d' % self.PROXY_PORT, - }) + }, + ), ) response = conn.recv(DEFAULT_CLIENT_RECVBUF_SIZE) self.assertEqual( @@ -48,9 +50,9 @@ def test_with_proxy(self) -> None: httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND', headers={ b'Server': PROXY_AGENT_HEADER_VALUE, - b'Connection': b'close' - } - ) + b'Connection': b'close', + }, + ), ) def test_proxy_vcr(self) -> None: @@ -77,7 +79,8 @@ def make_http_request_using_proxy(self) -> None: with self.assertRaises(urllib.error.HTTPError): r: http.client.HTTPResponse = opener.open( 'http://localhost:%d/' % - self.PROXY_PORT, timeout=10) + self.PROXY_PORT, timeout=10, + ) self.assertEqual(r.status, 404) self.assertEqual(r.headers.get('server'), PROXY_AGENT_HEADER_VALUE) self.assertEqual(r.headers.get('connection'), b'close') diff --git a/tests/testing/test_test_case.py b/tests/testing/test_test_case.py index c1dafa07bb..0d4ee869ce 100644 --- a/tests/testing/test_test_case.py +++ b/tests/testing/test_test_case.py @@ -19,4 +19,5 @@ class TestTestCase(unittest.TestCase): def test_wait_for_server(self) -> None: with self.assertRaises(TimeoutError): proxy.TestCase.wait_for_server( - get_available_port(), wait_for_seconds=1) + get_available_port(), wait_for_seconds=1, + ) diff --git a/version-check.py b/version-check.py index 27e2b30a67..a957a40e02 100644 --- a/version-check.py +++ b/version-check.py @@ -21,10 +21,14 @@ # Version is also hardcoded in README.md flags section readme_version_cmd = 'cat README.md | grep "proxy.py v" | tail -2 | head -1 | cut -d " " -f 2 | cut -c2-' readme_version_output = subprocess.check_output( - ['bash', '-c', readme_version_cmd]) + ['bash', '-c', readme_version_cmd], +) readme_version = readme_version_output.decode().strip() if readme_version != lib_version: - print('Version mismatch found. {0} (readme) vs {1} (lib).'.format( - readme_version, lib_version)) + print( + 'Version mismatch found. {0} (readme) vs {1} (lib).'.format( + readme_version, lib_version, + ), + ) sys.exit(1)