diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 0000000000..6309360fe6 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,33 @@ +name: Proxy.py + +on: [push] + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macOS-latest] + python: [3.6, 3.7] + max-parallel: 4 + fail-fast: false + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-testing.txt + - name: Lint Checker + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 proxy.py tests.py --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 proxy.py tests.py --count --exit-zero --max-line-length=127 --statistics + - name: Run Tests + run: | + pytest tests.py diff --git a/Dockerfile b/Dockerfile index c3ecff36be..b86a0e8b03 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,5 +10,4 @@ EXPOSE 8899/tcp WORKDIR /app ENTRYPOINT [ "./proxy.py" ] -CMD [ "--host=0.0.0.0", \ - "--port=8899" ] +CMD [ "--port=8899" ] diff --git a/Makefile b/Makefile index ebb8616af7..90890243ee 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ VERSION ?= v$(shell python proxy.py --version) LATEST_TAG := $(NS)/$(IMAGE_NAME):latest IMAGE_TAG := $(NS)/$(IMAGE_NAME):$(VERSION) -.PHONY: all clean test package test-release release coverage flake8 container run-container release-container +.PHONY: all clean test package test-release release coverage lint container run-container release-container all: clean test @@ -35,9 +35,11 @@ coverage: coverage3 html open htmlcov/index.html -flake8: +lint: flake8 --ignore=E501,W504 --builtins="unicode" proxy.py flake8 --ignore=E501,W504 tests.py + autopep8 --recursive --in-place --aggressive --aggressive proxy.py + autopep8 --recursive --in-place --aggressive --aggressive tests.py container: docker build -t $(LATEST_TAG) -t $(IMAGE_TAG) . diff --git a/README.md b/README.md index 138a27e059..71af63698d 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,37 @@ [![Proxy.Py](ProxyPy.png)](https://github.com/abhinavsingh/proxy.py) -[![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [![alt text](https://travis-ci.org/abhinavsingh/proxy.py.svg?branch=develop "Build Status")](https://travis-ci.org/abhinavsingh/proxy.py/) [![Coverage Status](https://coveralls.io/repos/github/abhinavsingh/proxy.py/badge.svg?branch=develop)](https://coveralls.io/github/abhinavsingh/proxy.py?branch=develop) +[![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) +[![PyPi Downloads](https://img.shields.io/pypi/dm/proxy.py.svg)](https://pypi.org/project/proxy.py/) +[![Build Status](https://travis-ci.org/abhinavsingh/proxy.py.svg?branch=develop)](https://travis-ci.org/abhinavsingh/proxy.py/) +[![Coverage](https://coveralls.io/repos/github/abhinavsingh/proxy.py/badge.svg?branch=develop)](https://coveralls.io/github/abhinavsingh/proxy.py?branch=develop) + +[![Twitter](https://img.shields.io/twitter/follow/imoracle?label=Follow%20on%20Twitter&style=social)](https://twitter.com/imoracle) Features -------- - Distributed as a single file module -- Optionally enable builtin Web Server -- Customize proxy and http routing via plugins - No external dependency other than standard Python library - Support for `http`, `https`, `http2` and `websockets` request proxy - Optimized for large file uploads and downloads - IPv4 and IPv6 support - Basic authentication support - Can serve a [PAC (Proxy Auto-configuration)](https://en.wikipedia.org/wiki/Proxy_auto-config) file +- Optionally enable builtin Web Server +- Customize proxy and http routing via [plugins](https://github.com/abhinavsingh/proxy.py/blob/develop/plugin_examples.py) Install ------- -To install proxy.py, simply: +#### Stable version $ pip install --upgrade proxy.py -Using docker: +#### Development version + + $ pip install git+https://github.com/abhinavsingh/proxy.py.git@develop + +#### Docker image $ docker run -it -p 8899:8899 --rm abhinavsingh/proxy.py diff --git a/plugin_examples.py b/plugin_examples.py index 56ae127562..f86593d1f1 100644 --- a/plugin_examples.py +++ b/plugin_examples.py @@ -13,6 +13,12 @@ def before_upstream_connection(self): # Redirect all non-https requests to inbuilt WebServer. self.request.url = urlparse.urlsplit(b'http://localhost:8899') + def on_upstream_connection(self): + pass + + def handle_upstream_response(self, raw): + return raw + class FilterByTargetDomainPlugin(proxy.HttpProxyBasePlugin): """Only accepts specific requests dropping all other requests.""" @@ -26,7 +32,13 @@ def before_upstream_connection(self): # are not consistent between CONNECT and non-CONNECT requests. if (self.request.method != b'CONNECT' and self.filtered_domain in self.request.url.hostname) or \ (self.request.method == b'CONNECT' and self.filtered_domain in self.request.url.path): - raise proxy.HttpRequestRejected(status_code=418, body='I\'m a tea pot') + raise proxy.HttpRequestRejected(status_code=418, body=b'I\'m a tea pot') + + def on_upstream_connection(self): + pass + + def handle_upstream_response(self, raw): + return raw class SaveHttpResponses(proxy.HttpProxyBasePlugin): @@ -37,3 +49,9 @@ def __init__(self, config, client, request): def handle_upstream_response(self, chunk): return chunk + + def before_upstream_connection(self): + pass + + def on_upstream_connection(self): + pass diff --git a/proxy.py b/proxy.py index 4cc236df8c..29ae3fcd57 100755 --- a/proxy.py +++ b/proxy.py @@ -14,16 +14,17 @@ import errno import importlib import inspect +import ipaddress import logging import multiprocessing import os -import queue import socket import sys import threading -import ipaddress +from abc import ABC, abstractmethod from collections import namedtuple -from typing import Dict, List, Tuple, Optional +from multiprocessing import connection +from typing import Dict, List, Tuple, Optional, Union from urllib import parse as urlparse import select @@ -117,14 +118,17 @@ def recv(self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> Optional[bytes]: try: data: bytes = self.conn.recv(buffer_size) if len(data) > 0: - logger.debug('received %d bytes from %s' % (len(data), self.what)) + logger.debug( + 'received %d bytes from %s' % + (len(data), self.what)) return data except socket.error as e: if e.errno == errno.ECONNRESET: logger.debug('%r' % e) else: logger.exception( - 'Exception while receiving from connection %s %r with reason %r' % (self.what, self.conn, e)) + 'Exception while receiving from connection %s %r with reason %r' % + (self.what, self.conn, e)) return None def close(self) -> bool: @@ -154,7 +158,7 @@ class TcpServerConnection(TcpConnection): """Establishes connection to destination server.""" def __init__(self, host: str, port: int): - super(TcpServerConnection, self).__init__(b'server') + super().__init__(b'server') self.addr: Tuple[str, int] = (host, int(port)) def __del__(self): @@ -165,9 +169,11 @@ def connect(self) -> None: try: ip = ipaddress.ip_address(text_(self.addr[0])) if ip.version == 4: - self.conn = socket.create_connection((self.addr[0], self.addr[1])) + self.conn = socket.create_connection( + (self.addr[0], self.addr[1])) else: - self.conn = socket.socket(socket.AF_INET6, socket.SOCK_STREAM, 0) + self.conn = socket.socket( + socket.AF_INET6, socket.SOCK_STREAM, 0) self.conn.connect((self.addr[0], self.addr[1], 0, 0)) except ValueError: # Not a valid IP address, most likely its a domain name. @@ -178,12 +184,12 @@ class TcpClientConnection(TcpConnection): """Accepted client connection.""" def __init__(self, conn: socket.socket, addr: Tuple[str, int]): - super(TcpClientConnection, self).__init__(b'client') + super().__init__(b'client') self.conn: socket.socket = conn self.addr: Tuple[str, int] = addr -class TcpServer: +class TcpServer(ABC): """TcpServer server implementation. Inheritor MUST implement `handle` method. It accepts an instance of `TcpClientConnection`. @@ -191,7 +197,12 @@ class TcpServer: down internal state. """ - def __init__(self, hostname=DEFAULT_IPV4_HOSTNAME, port=DEFAULT_PORT, backlog=DEFAULT_BACKLOG, ipv4=DEFAULT_IPV4): + def __init__( + self, + hostname=DEFAULT_IPV4_HOSTNAME, + port=DEFAULT_PORT, + backlog=DEFAULT_BACKLOG, + ipv4=DEFAULT_IPV4): self.port: int = port self.backlog: int = backlog self.ipv4: bool = ipv4 @@ -202,12 +213,15 @@ def __init__(self, hostname=DEFAULT_IPV4_HOSTNAME, port=DEFAULT_PORT, backlog=DE DEFAULT_IPV6_HOSTNAME] \ else DEFAULT_IPV4_HOSTNAME if self.ipv4 else DEFAULT_IPV6_HOSTNAME + @abstractmethod def setup(self) -> None: pass + @abstractmethod def handle(self, client: TcpClientConnection): raise NotImplementedError() + @abstractmethod def shutdown(self) -> None: pass @@ -245,15 +259,21 @@ class MultiCoreRequestDispatcher(TcpServer): client request. """ - def __init__(self, hostname=DEFAULT_IPV4_HOSTNAME, port=DEFAULT_PORT, backlog=DEFAULT_BACKLOG, - num_workers=DEFAULT_NUM_WORKERS, ipv4=DEFAULT_IPV4, config=None): - super(MultiCoreRequestDispatcher, self).__init__(hostname, port, backlog, ipv4) + def __init__( + self, + hostname=DEFAULT_IPV4_HOSTNAME, + port=DEFAULT_PORT, + backlog=DEFAULT_BACKLOG, + num_workers=DEFAULT_NUM_WORKERS, + ipv4=DEFAULT_IPV4, + config=None): + super().__init__(hostname, port, backlog, ipv4) self.num_workers: int = multiprocessing.cpu_count() if num_workers > 0: self.num_workers = num_workers self.workers: List[Worker] = [] - self.work_queues: List[multiprocessing.Queue] = [] + self.work_queues: List[multiprocessing.Pipe] = [] self.current_worker_id = 0 self.config: HttpProtocolConfig = config @@ -261,9 +281,9 @@ def __init__(self, hostname=DEFAULT_IPV4_HOSTNAME, port=DEFAULT_PORT, backlog=DE def setup(self): logger.info('Starting %d workers' % self.num_workers) for worker_id in range(self.num_workers): - work_queue = multiprocessing.Queue() + work_queue = multiprocessing.Pipe() - worker = Worker(work_queue, self.config) + worker = Worker(work_queue[1], self.config) worker.daemon = True worker.start() @@ -273,15 +293,18 @@ def setup(self): def handle(self, client: TcpClientConnection): # Dispatch in round robin fashion work_queue = self.work_queues[self.current_worker_id] - logging.debug('Dispatched client request to worker id %d', self.current_worker_id) + logging.debug( + 'Dispatched client request to worker id %d', + self.current_worker_id) self.current_worker_id += 1 self.current_worker_id %= self.num_workers - work_queue.put((Worker.operations.HTTP_PROTOCOL, client)) + work_queue[0].send((Worker.operations.HTTP_PROTOCOL, client)) def shutdown(self): logger.info('Shutting down %d workers' % self.num_workers) for work_queue in self.work_queues: - work_queue.put((Worker.operations.SHUTDOWN, None)) + work_queue[0].send((Worker.operations.SHUTDOWN, None)) + work_queue[0].close() for worker in self.workers: worker.join() @@ -298,24 +321,23 @@ class Worker(multiprocessing.Process): 'SHUTDOWN', ))(1, 2) - def __init__(self, work_queue, config=None): - super(Worker, self).__init__() - self.work_queue: multiprocessing.Queue = work_queue + def __init__(self, work_queue: connection.Connection, config=None): + super().__init__() + self.work_queue: connection.Connection = work_queue self.config: HttpProtocolConfig = config def run(self): while True: try: - op, payload = self.work_queue.get(True, 1) + op, payload = self.work_queue.recv() if op == Worker.operations.HTTP_PROTOCOL: proxy = HttpProtocolHandler(payload, config=self.config) proxy.setDaemon(True) proxy.start() elif op == Worker.operations.SHUTDOWN: + logging.debug('Worker shutting down....') + self.work_queue.close() break - except queue.Empty: - pass - # Safeguard against https://gist.github.com/abhinavsingh/b8d4266ff4f38b6057f9c50075e8cd75 except ConnectionRefusedError: pass except KeyboardInterrupt: @@ -335,7 +357,8 @@ def __init__(self): self.state = ChunkParser.states.WAITING_FOR_SIZE self.body: bytes = b'' # Parsed chunks self.chunk: bytes = b'' # Partial chunk received - self.size: int = None # Expected size of next following chunk + # Expected size of next following chunk + self.size: Optional[int] = None def parse(self, raw: bytes): more = True if len(raw) > 0 else False @@ -389,7 +412,9 @@ class HttpParser: ))(1, 2) def __init__(self, parser_type): - assert parser_type in (HttpParser.types.REQUEST_PARSER, HttpParser.types.RESPONSE_PARSER) + assert parser_type in ( + HttpParser.types.REQUEST_PARSER, + HttpParser.types.RESPONSE_PARSER) self.type: HttpParser.types = parser_type self.state: HttpParser.states = HttpParser.states.INITIALIZED @@ -432,7 +457,7 @@ def set_host_port(self): def is_chunked_encoded_response(self): return self.type == HttpParser.types.RESPONSE_PARSER and b'transfer-encoding' in self.headers and \ - self.headers[b'transfer-encoding'][1].lower() == b'chunked' + self.headers[b'transfer-encoding'][1].lower() == b'chunked' def parse(self, raw): self.bytes += raw @@ -448,10 +473,11 @@ def parse(self, raw): self.buffer = raw def process(self, raw): - if self.state in (HttpParser.states.HEADERS_COMPLETE, - HttpParser.states.RCVING_BODY, - HttpParser.states.COMPLETE) and \ - (self.method == b'POST' or self.type == HttpParser.types.RESPONSE_PARSER): + if self.state in ( + HttpParser.states.HEADERS_COMPLETE, + HttpParser.states.RCVING_BODY, + HttpParser.states.COMPLETE) and ( + self.method == b'POST' or self.type == HttpParser.types.RESPONSE_PARSER): if not self.body: self.body = b'' @@ -480,7 +506,9 @@ def process(self, raw): self.process_header(line) # When connect request is received without a following host header - # See `TestHttpParser.test_connect_request_without_host_header_request_parse` for details + # See + # `TestHttpParser.test_connect_request_without_host_header_request_parse` + # for details if self.state == HttpParser.states.LINE_RCVD and \ self.type == HttpParser.types.REQUEST_PARSER and \ self.method == b'CONNECT' and \ @@ -557,7 +585,8 @@ def build(self, disable_headers=None): for k in self.headers: if k.lower() not in disable_headers: - req += self.build_header(self.headers[k][0], self.headers[k][1]) + CRLF + req += self.build_header(self.headers[k] + [0], self.headers[k][1]) + CRLF req += CRLF if self.body: @@ -578,11 +607,11 @@ def split(raw): raw = raw[pos + len(CRLF):] return line, raw - ################################################################################### + ########################################################################## # HttpParser was originally written to parse the incoming raw Http requests. # Since request / response objects passed to HttpProtocolBasePlugin methods # are also HttpParser objects, methods below were added to simplify developer API. - #################################################################################### + ########################################################################## def has_upstream_server(self): """Host field SHOULD be None for incoming local WebServer requests.""" @@ -624,8 +653,11 @@ class HttpRequestRejected(HttpProtocolException): Connections can either be dropped/closed or optionally an HTTP status code can be returned.""" - def __init__(self, status_code: bytes = None, body: bytes = None): - super(HttpRequestRejected, self).__init__() + def __init__(self, + status_code: Union[bytes, + int] = None, + body: bytes = None): + super().__init__() self.status_code: bytes = status_code self.body: bytes = body @@ -650,9 +682,15 @@ class HttpProtocolConfig: This config class helps us avoid passing around bunch of key/value pairs across methods. """ - def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE, - client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE, - pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, plugins=None, disable_headers=None): + def __init__( + self, + auth_code=DEFAULT_BASIC_AUTH, + server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE, + client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, + pac_file=DEFAULT_PAC_FILE, + pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, + plugins=None, + disable_headers=None): self.auth_code = auth_code self.server_recvbuf_size = server_recvbuf_size self.client_recvbuf_size = client_recvbuf_size @@ -666,15 +704,20 @@ def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SER self.disable_headers = disable_headers -class HttpProtocolBasePlugin: +class HttpProtocolBasePlugin(ABC): """Base HttpProtocolHandler Plugin class. Implement various lifecycle event methods to customize behavior.""" - def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser): + def __init__( + self, + config: HttpProtocolConfig, + client: TcpClientConnection, + request: HttpParser): self.config: HttpProtocolConfig = config self.client: TcpClientConnection = client self.request: HttpParser = request + super().__init__() def name(self) -> str: """A unique name for your plugin. @@ -683,31 +726,39 @@ def name(self) -> str: access a specific plugin by its name.""" return self.__class__.__name__ + @abstractmethod def get_descriptors(self) -> Tuple[List, List, List]: return [], [], [] + @abstractmethod def flush_to_descriptors(self, w) -> None: pass + @abstractmethod def read_from_descriptors(self, r) -> None: pass + @abstractmethod def on_client_data(self, raw: bytes) -> bytes: return raw + @abstractmethod def on_request_complete(self) -> None: """Called right after client request parser has reached COMPLETE state.""" pass + @abstractmethod def handle_response_chunk(self, chunk: bytes) -> bytes: """Handle data chunks as received from the server. Return optionally modified chunk to return back to client.""" return chunk + @abstractmethod def access_log(self) -> None: pass + @abstractmethod def on_client_connection_close(self) -> None: pass @@ -732,7 +783,8 @@ def response(self, _request: HttpParser) -> bytes: return self.RESPONSE_PKT def __str__(self) -> str: - return '' % (self.host, self.port, self.reason) + return '' % ( + self.host, self.port, self.reason) class ProxyAuthenticationFailed(HttpProtocolException): @@ -752,12 +804,16 @@ def response(self, _request: HttpParser) -> bytes: return self.RESPONSE_PKT -class HttpProxyBasePlugin: +class HttpProxyBasePlugin(ABC): """Base HttpProxyPlugin Plugin class. Implement various lifecycle event methods to customize behavior.""" - def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser): + def __init__( + self, + config: HttpProtocolConfig, + client: TcpClientConnection, + request: HttpParser): self.config = config self.client = client self.request = request @@ -769,16 +825,19 @@ def name(self) -> str: access a specific plugin by its name.""" return self.__class__.__name__ + @abstractmethod def before_upstream_connection(self): """Handler called just before Proxy upstream connection is established. Raise HttpRequestRejected to drop the connection.""" pass + @abstractmethod def on_upstream_connection(self): """Handler called right after upstream connection has been established.""" pass + @abstractmethod def handle_upstream_response(self, raw): """Handled called right after reading response from upstream server and before queuing that response to client. @@ -795,8 +854,12 @@ class HttpProxyPlugin(HttpProtocolBasePlugin): CRLF ]) - def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser): - super(HttpProxyPlugin, self).__init__(config, client, request) + def __init__( + self, + config: HttpProtocolConfig, + client: TcpClientConnection, + request: HttpParser): + super().__init__(config, client, request) self.server = None self.response = HttpParser(HttpParser.types.RESPONSE_PARSER) @@ -826,7 +889,8 @@ def flush_to_descriptors(self, w): try: self.server.flush() except BrokenPipeError: - logging.error('BrokenPipeError when flushing buffer for server') + logging.error( + 'BrokenPipeError when flushing buffer for server') return True def read_from_descriptors(self, r): @@ -859,10 +923,14 @@ def on_client_connection_close(self): if self.server: logger.debug( - 'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size()) + 'Closed server connection with pending server buffer size %d bytes' % + self.server.buffer_size()) if not self.server.closed: self.server.close() + def handle_response_chunk(self, chunk: bytes) -> bytes: + return chunk + def on_client_data(self, raw): if not self.request.has_upstream_server(): return raw @@ -890,14 +958,19 @@ def on_request_complete(self): # queue appropriate response for client # notifying about established connection if self.request.method == b'CONNECT': - self.client.queue(HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + self.client.queue( + HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) # for general http requests, re-build request packet # and queue for the server with appropriate headers else: # remove args.disable_headers before dispatching to upstream - self.request.add_headers([(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]) - self.request.del_headers([b'proxy-authorization', b'proxy-connection', b'connection', b'keep-alive']) - self.server.queue(self.request.build(disable_headers=self.config.disable_headers)) + self.request.add_headers( + [(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]) + self.request.del_headers( + [b'proxy-authorization', b'proxy-connection', b'connection', b'keep-alive']) + self.server.queue( + self.request.build( + disable_headers=self.config.disable_headers)) def access_log(self): if not self.request.has_upstream_server(): @@ -906,14 +979,22 @@ def access_log(self): host, port = self.server.addr if self.server else (None, None) if self.request.method == b'CONNECT': logger.info( - '%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1], - text_(self.request.method), text_(host), - text_(port), self.response.total_size)) + '%s:%s - %s %s:%s - %s bytes' % + (self.client.addr[0], + self.client.addr[1], + text_( + self.request.method), + text_(host), + text_(port), + self.response.total_size)) elif self.request.method: - logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % ( - self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port, - text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason), - self.response.total_size)) + logger.info( + '%s:%s - %s %s:%s%s - %s %s - %s bytes' % + (self.client.addr[0], self.client.addr[1], text_( + self.request.method), text_(host), port, text_( + self.request.build_url()), text_( + self.response.code), text_( + self.response.reason), self.response.total_size)) def authenticate(self, headers): if self.config.auth_code: @@ -949,8 +1030,12 @@ class HttpWebServerPlugin(HttpProtocolBasePlugin): CRLF ]) - def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser): - super(HttpWebServerPlugin, self).__init__(config, client, request) + def __init__( + self, + config: HttpProtocolConfig, + client: TcpClientConnection, + request: HttpParser): + super().__init__(config, client, request) if self.config.pac_file: try: with open(self.config.pac_file, 'rb') as f: @@ -980,8 +1065,29 @@ def on_request_complete(self): def access_log(self): if self.request.has_upstream_server(): return - logger.info('%s:%s - %s %s' % (self.client.addr[0], self.client.addr[1], - text_(self.request.method), text_(self.request.build_url()))) + logger.info( + '%s:%s - %s %s' % + (self.client.addr[0], self.client.addr[1], text_( + self.request.method), text_( + self.request.build_url()))) + + def flush_to_descriptors(self, w) -> None: + pass + + def read_from_descriptors(self, r) -> None: + pass + + def on_client_data(self, raw: bytes) -> bytes: + return raw + + def handle_response_chunk(self, chunk: bytes) -> bytes: + return chunk + + def on_client_connection_close(self) -> None: + pass + + def get_descriptors(self) -> Tuple[List, List, List]: + return [], [], [] class HttpProtocolHandler(threading.Thread): @@ -990,14 +1096,15 @@ class HttpProtocolHandler(threading.Thread): Accepts `Client` connection object and manages HttpProtocolBasePlugin invocations. """ - def __init__(self, client, config=None): - super(HttpProtocolHandler, self).__init__() - self.start_time = self.now() - self.last_activity = self.start_time + def __init__(self, client: TcpClientConnection, + config: HttpProtocolConfig = None): + super().__init__() + self.start_time: datetime.datetime = self.now() + self.last_activity: datetime.datetime = self.start_time - self.client = client - self.config = config if config else HttpProtocolConfig() - self.request = HttpParser(HttpParser.types.REQUEST_PARSER) + self.client: TcpClientConnection = client + self.config: HttpProtocolConfig = config if config else HttpProtocolConfig() + self.request: HttpParser = HttpParser(HttpParser.types.REQUEST_PARSER) self.plugins: Dict[str, HttpProtocolBasePlugin] = {} if 'HttpProtocolBasePlugin' in self.config.plugins: @@ -1006,7 +1113,7 @@ def __init__(self, client, config=None): self.plugins[instance.name()] = instance @staticmethod - def now(): + def now() -> datetime.datetime: return datetime.datetime.utcnow() def connection_inactive_for(self): @@ -1029,7 +1136,8 @@ def run_once(self): write_desc += plugin_write_desc err_desc += plugin_err_desc - readable, writable, errored = select.select(read_desc, write_desc, err_desc, 1) + readable, writable, errored = select.select( + read_desc, write_desc, err_desc, 1) # Flush buffer for ready to write sockets if self.client.conn in writable: @@ -1037,7 +1145,8 @@ def run_once(self): try: self.client.flush() except BrokenPipeError: - logging.error('BrokenPipeError when flushing buffer for client') + logging.error( + 'BrokenPipeError when flushing buffer for client') return True for plugin in self.plugins.values(): @@ -1066,9 +1175,10 @@ def run_once(self): if self.request.state == HttpParser.states.COMPLETE: # HttpProtocolBasePlugin.on_request_complete for plugin in self.plugins.values(): - # TODO: Cleanup by not returning True for teardown cases + # TODO: Cleanup by not returning True for teardown + # cases plugin_response = plugin.on_request_complete() - if type(plugin_response) is bool: + if isinstance(plugin_response, bool): return True # ProxyAuthenticationFailed, ProxyConnectionFailed, HttpRequestRejected except HttpProtocolException as e: @@ -1089,8 +1199,9 @@ def run_once(self): # Teardown if client buffer is empty and connection is inactive if self.client.buffer_size() == 0: if self.is_connection_inactive(): - logger.debug('Client buffer is empty and maximum inactivity has reached ' - 'between client and server connection, tearing down...') + logger.debug( + 'Client buffer is empty and maximum inactivity has reached ' + 'between client and server connection, tearing down...') return True def run(self): @@ -1103,19 +1214,25 @@ def run(self): except KeyboardInterrupt: pass except Exception as e: - logger.exception('Exception while handling connection %r with reason %r' % (self.client.conn, e)) + logger.exception( + 'Exception while handling connection %r with reason %r' % + (self.client.conn, e)) finally: for plugin in self.plugins.values(): plugin.access_log() self.client.close() - logger.debug('Closed client connection with pending ' - 'client buffer size %d bytes' % self.client.buffer_size()) + logger.debug( + 'Closed client connection with pending ' + 'client buffer size %d bytes' % + self.client.buffer_size()) for plugin in self.plugins.values(): plugin.on_client_connection_close() - logger.debug('Closed proxy for connection %r ' - 'at address %r' % (self.client.conn, self.client.addr)) + logger.debug( + 'Closed proxy for connection %r ' + 'at address %r' % + (self.client.conn, self.client.addr)) def is_py3() -> bool: @@ -1126,10 +1243,14 @@ def is_py3() -> bool: def set_open_file_limit(soft_limit): """Configure open file description soft limit on supported OS.""" if os.name != 'nt': # resource module not available on Windows OS - curr_soft_limit, curr_hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) + curr_soft_limit, curr_hard_limit = resource.getrlimit( + resource.RLIMIT_NOFILE) if curr_soft_limit < soft_limit < curr_hard_limit: - resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, curr_hard_limit)) - logger.debug('Open file descriptor soft limit set to %d' % soft_limit) + resource.setrlimit( + resource.RLIMIT_NOFILE, (soft_limit, curr_hard_limit)) + logger.debug( + 'Open file descriptor soft limit set to %d' % + soft_limit) def load_plugins(plugins: str) -> Dict[str, List]: @@ -1146,13 +1267,16 @@ def load_plugins(plugins: str) -> Dict[str, List]: module_name, klass_name = plugin.rsplit('.', 1) module = importlib.import_module(module_name) klass = getattr(module, klass_name) - base_klass = inspect.getmro(klass)[::-1][1:][0] + base_klass = inspect.getmro(klass)[::-1][2:][0] p[base_klass.__name__].append(klass) logging.info('Loaded plugin %s', klass) return p -def setup_logger(log_file=DEFAULT_LOG_FILE, log_level=DEFAULT_LOG_LEVEL, log_format=DEFAULT_LOG_FORMAT): +def setup_logger( + log_file=DEFAULT_LOG_FILE, + log_level=DEFAULT_LOG_LEVEL, + log_format=DEFAULT_LOG_FORMAT): ll = getattr( logging, {'D': 'DEBUG', @@ -1161,7 +1285,11 @@ def setup_logger(log_file=DEFAULT_LOG_FILE, log_level=DEFAULT_LOG_LEVEL, log_for 'E': 'ERROR', 'C': 'CRITICAL'}[log_level.upper()[0]]) if log_file: - logging.basicConfig(filename=log_file, filemode='a', level=ll, format=log_format) + logging.basicConfig( + filename=log_file, + filemode='a', + level=ll, + format=log_format) else: logging.basicConfig(level=ll, format=log_format) @@ -1173,59 +1301,104 @@ def init_parser() -> argparse.ArgumentParser: epilog='Proxy.py not working? Report at: %s/issues/new' % __homepage__ ) # Argument names are ordered alphabetically. - parser.add_argument('--backlog', type=int, default=DEFAULT_BACKLOG, - help='Default: 100. Maximum number of pending connections to proxy server') - parser.add_argument('--basic-auth', type=str, default=DEFAULT_BASIC_AUTH, - help='Default: No authentication. Specify colon separated user:password ' - 'to enable basic authentication.') - parser.add_argument('--client-recvbuf-size', type=int, default=DEFAULT_CLIENT_RECVBUF_SIZE, - help='Default: 1 MB. Maximum amount of data received from the ' - 'client in a single recv() operation. Bump this ' - 'value for faster uploads at the expense of ' - 'increased RAM.') - parser.add_argument('--disable-headers', type=str, default=COMMA.join(DEFAULT_DISABLE_HEADERS), - help='Default: None. Comma separated list of headers to remove before ' - 'dispatching client request to upstream server.') - parser.add_argument('--disable-http-proxy', action='store_true', default=DEFAULT_DISABLE_HTTP_PROXY, - help='Default: False. Whether to disable proxy.HttpProxyPlugin.') + parser.add_argument( + '--backlog', + type=int, + default=DEFAULT_BACKLOG, + help='Default: 100. Maximum number of pending connections to proxy server') + parser.add_argument( + '--basic-auth', + type=str, + default=DEFAULT_BASIC_AUTH, + help='Default: No authentication. Specify colon separated user:password ' + 'to enable basic authentication.') + parser.add_argument( + '--client-recvbuf-size', + type=int, + default=DEFAULT_CLIENT_RECVBUF_SIZE, + help='Default: 1 MB. Maximum amount of data received from the ' + 'client in a single recv() operation. Bump this ' + 'value for faster uploads at the expense of ' + 'increased RAM.') + parser.add_argument( + '--disable-headers', + type=str, + default=COMMA.join(DEFAULT_DISABLE_HEADERS), + help='Default: None. Comma separated list of headers to remove before ' + 'dispatching client request to upstream server.') + parser.add_argument( + '--disable-http-proxy', + action='store_true', + default=DEFAULT_DISABLE_HTTP_PROXY, + help='Default: False. Whether to disable proxy.HttpProxyPlugin.') parser.add_argument('--hostname', type=str, default=DEFAULT_IPV4_HOSTNAME, help='Default: 127.0.0.1. Server IP address.') parser.add_argument('--ipv4', action='store_true', default=DEFAULT_IPV4, help='Whether to listen on IPv4 address. ' 'By default server only listens on IPv6.') - parser.add_argument('--enable-web-server', action='store_true', default=DEFAULT_ENABLE_WEB_SERVER, - help='Default: False. Whether to enable proxy.HttpWebServerPlugin.') - parser.add_argument('--log-level', type=str, default=DEFAULT_LOG_LEVEL, - help='Valid options: DEBUG, INFO (default), WARNING, ERROR, CRITICAL. ' - 'Both upper and lowercase values are allowed. ' - 'You may also simply use the leading character e.g. --log-level d') + parser.add_argument( + '--enable-web-server', + action='store_true', + default=DEFAULT_ENABLE_WEB_SERVER, + help='Default: False. Whether to enable proxy.HttpWebServerPlugin.') + parser.add_argument( + '--log-level', + type=str, + default=DEFAULT_LOG_LEVEL, + help='Valid options: DEBUG, INFO (default), WARNING, ERROR, CRITICAL. ' + 'Both upper and lowercase values are allowed. ' + 'You may also simply use the leading character e.g. --log-level d') parser.add_argument('--log-file', type=str, default=DEFAULT_LOG_FILE, help='Default: sys.stdout. Log file destination.') parser.add_argument('--log-format', type=str, default=DEFAULT_LOG_FORMAT, help='Log format for Python logger.') parser.add_argument('--num-workers', type=int, default=DEFAULT_NUM_WORKERS, help='Defaults to number of CPU cores.') - parser.add_argument('--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT, - help='Default: 1024. Maximum number of files (TCP connections) ' - 'that proxy.py can open concurrently.') - parser.add_argument('--pac-file', type=str, default=DEFAULT_PAC_FILE, - help='A file (Proxy Auto Configuration) or string to serve when ' - 'the server receives a direct file request. ' - 'Using this option enables proxy.HttpWebServerPlugin.') - parser.add_argument('--pac-file-url-path', type=str, default=DEFAULT_PAC_FILE_URL_PATH, - help='Default: %s. Web server path to serve the PAC file.' % text_(DEFAULT_PAC_FILE_URL_PATH)) - parser.add_argument('--pid-file', type=str, default=DEFAULT_PID_FILE, - help='Default: None. Save parent process ID to a file.') - parser.add_argument('--plugins', type=str, default=DEFAULT_PLUGINS, help='Comma separated plugins') + parser.add_argument( + '--open-file-limit', + type=int, + default=DEFAULT_OPEN_FILE_LIMIT, + help='Default: 1024. Maximum number of files (TCP connections) ' + 'that proxy.py can open concurrently.') + parser.add_argument( + '--pac-file', + type=str, + default=DEFAULT_PAC_FILE, + help='A file (Proxy Auto Configuration) or string to serve when ' + 'the server receives a direct file request. ' + 'Using this option enables proxy.HttpWebServerPlugin.') + parser.add_argument( + '--pac-file-url-path', + type=str, + default=DEFAULT_PAC_FILE_URL_PATH, + help='Default: %s. Web server path to serve the PAC file.' % + text_(DEFAULT_PAC_FILE_URL_PATH)) + parser.add_argument( + '--pid-file', + type=str, + default=DEFAULT_PID_FILE, + help='Default: None. Save parent process ID to a file.') + parser.add_argument( + '--plugins', + type=str, + default=DEFAULT_PLUGINS, + help='Comma separated plugins') parser.add_argument('--port', type=int, default=DEFAULT_PORT, help='Default: 8899. Server port.') - parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE, - help='Default: 1 MB. Maximum amount of data received from the ' - 'server in a single recv() operation. Bump this ' - 'value for faster downloads at the expense of ' - 'increased RAM.') - parser.add_argument('--version', '-v', action='store_true', default=DEFAULT_VERSION, - help='Prints proxy.py version.') + parser.add_argument( + '--server-recvbuf-size', + type=int, + default=DEFAULT_SERVER_RECVBUF_SIZE, + help='Default: 1 MB. Maximum amount of data received from the ' + 'server in a single recv() operation. Bump this ' + 'value for faster downloads at the expense of ' + 'increased RAM.') + parser.add_argument( + '--version', + '-v', + action='store_true', + default=DEFAULT_VERSION, + help='Prints proxy.py version.') return parser @@ -1255,13 +1428,14 @@ def main(args) -> None: if args.basic_auth: auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth)) - config = HttpProtocolConfig(auth_code=auth_code, - server_recvbuf_size=args.server_recvbuf_size, - client_recvbuf_size=args.client_recvbuf_size, - pac_file=args.pac_file, - pac_file_url_path=args.pac_file_url_path, - disable_headers=[header.lower() for header in args.disable_headers.split(COMMA) if - header.strip() != '']) + config = HttpProtocolConfig( + auth_code=auth_code, + server_recvbuf_size=args.server_recvbuf_size, + client_recvbuf_size=args.client_recvbuf_size, + pac_file=args.pac_file, + pac_file_url_path=args.pac_file_url_path, + disable_headers=[ + header.lower() for header in args.disable_headers.split(COMMA) if header.strip() != '']) if config.pac_file is not None: args.enable_web_server = True diff --git a/requirements-testing.txt b/requirements-testing.txt index 92d29063fd..a1812c4c76 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -2,3 +2,5 @@ python-coveralls==2.9.3 coverage==4.5.2 flake8==3.7.8 twine==1.12.1 +pytest==5.1.2 +autopep8==1.4.4 diff --git a/tests.py b/tests.py index d165480800..a109e9b328 100644 --- a/tests.py +++ b/tests.py @@ -26,8 +26,9 @@ if os.name != 'nt': import resource -logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s') +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s') def get_available_port(): @@ -47,7 +48,8 @@ def testHandlesIOError(self): with mock.patch('proxy.logger') as mock_logger: self.conn.recv() mock_logger.exception.assert_called() - logging.info(mock_logger.exception.call_args[0][0].startswith('Exception while receiving from connection')) + logging.info(mock_logger.exception.call_args[0][0].startswith( + 'Exception while receiving from connection')) def testHandlesConnReset(self): self.conn = proxy.TcpConnection(proxy.TcpConnection.types.CLIENT) @@ -83,13 +85,15 @@ def testNoOpIfAlreadyClosed(self): def testTcpServerClosesConnOnGC(self, mock_create_connection): conn = mock.MagicMock() mock_create_connection.return_value = conn - self.conn = proxy.TcpServerConnection(proxy.DEFAULT_IPV4_HOSTNAME, proxy.DEFAULT_PORT) + self.conn = proxy.TcpServerConnection( + proxy.DEFAULT_IPV4_HOSTNAME, proxy.DEFAULT_PORT) self.conn.connect() del self.conn conn.close.assert_called() -@unittest.skipIf(os.getenv('TESTING_ON_TRAVIS', 0), 'Opening sockets not allowed on Travis') +@unittest.skipIf(os.getenv('TESTING_ON_TRAVIS', 0), + 'Opening sockets not allowed on Travis') class TestTcpServer(unittest.TestCase): ipv4_port = None ipv6_port = None @@ -106,13 +110,20 @@ def handle(self, client): client.conn.sendall(b'WORLD') client.close() + def setup(self) -> None: + pass + + def shutdown(self) -> None: + pass + @classmethod def setUpClass(cls): cls.ipv4_port = get_available_port() cls.ipv6_port = get_available_port() - cls.ipv4_server = TestTcpServer._TestTcpServer(port=cls.ipv4_port, ipv4=True) - cls.ipv6_server = TestTcpServer._TestTcpServer(hostname=proxy.DEFAULT_IPV6_HOSTNAME, port=cls.ipv6_port, - ipv4=False) + cls.ipv4_server = TestTcpServer._TestTcpServer( + port=cls.ipv4_port, ipv4=True) + cls.ipv6_server = TestTcpServer._TestTcpServer( + hostname=proxy.DEFAULT_IPV6_HOSTNAME, port=cls.ipv6_port, ipv4=False) cls.ipv4_thread = Thread(target=cls.ipv4_server.run) cls.ipv6_thread = Thread(target=cls.ipv6_server.run) cls.ipv4_thread.setDaemon(True) @@ -129,9 +140,13 @@ def baseTestCase(self, ipv4=True): while True: sock = None try: - sock = socket.socket(socket.AF_INET if ipv4 else socket.AF_INET6, socket.SOCK_STREAM, 0) - sock.connect((proxy.DEFAULT_IPV4_HOSTNAME if ipv4 else proxy.DEFAULT_IPV6_HOSTNAME, - self.ipv4_port if ipv4 else self.ipv6_port)) + sock = socket.socket( + socket.AF_INET if ipv4 else socket.AF_INET6, + socket.SOCK_STREAM, + 0) + sock.connect( + (proxy.DEFAULT_IPV4_HOSTNAME if ipv4 else proxy.DEFAULT_IPV6_HOSTNAME, + self.ipv4_port if ipv4 else self.ipv6_port)) sock.sendall(b'HELLO') data = sock.recv(proxy.DEFAULT_BUFFER_SIZE) self.assertEqual(data, b'WORLD') @@ -158,7 +173,8 @@ def setDaemon(self, _val): pass def start(self): - self.client.conn.sendall(proxy.CRLF.join([b'HTTP/1.1 200 OK', proxy.CRLF])) + self.client.conn.sendall(proxy.CRLF.join( + [b'HTTP/1.1 200 OK', proxy.CRLF])) self.client.conn.close() @@ -166,39 +182,45 @@ def mock_tcp_proxy_side_effect(client, **kwargs): return MockHttpProxy(client, **kwargs) -@unittest.skipIf(os.getenv('TESTING_ON_TRAVIS', 0), 'Opening sockets not allowed on Travis') +@unittest.skipIf(os.getenv('TESTING_ON_TRAVIS', 0), + 'Opening sockets not allowed on Travis') class TestMultiCoreRequestDispatcher(unittest.TestCase): tcp_port = None tcp_server = None tcp_thread = None - @mock.patch.object(proxy, 'HttpProtocolHandler', side_effect=mock_tcp_proxy_side_effect) + @mock.patch.object( + proxy, + 'HttpProtocolHandler', + side_effect=mock_tcp_proxy_side_effect) def testHttpProxyConnection(self, mock_tcp_proxy): try: self.tcp_port = get_available_port() - self.tcp_server = proxy.MultiCoreRequestDispatcher(hostname=proxy.DEFAULT_IPV4_HOSTNAME, port=self.tcp_port, - ipv4=True, num_workers=1) + self.tcp_server = proxy.MultiCoreRequestDispatcher( + hostname=proxy.DEFAULT_IPV4_HOSTNAME, + port=self.tcp_port, + ipv4=True, + num_workers=1) self.tcp_thread = Thread(target=self.tcp_server.run) self.tcp_thread.setDaemon(True) self.tcp_thread.start() while True: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - sock.connect((proxy.DEFAULT_IPV4_HOSTNAME, self.tcp_port)) - sock.send(proxy.CRLF.join([ - b'GET http://httpbin.org/get HTTP/1.1', - b'Host: httpbin.org', - proxy.CRLF - ])) - data = sock.recv(proxy.DEFAULT_BUFFER_SIZE) - self.assertEqual(data, proxy.CRLF.join([b'HTTP/1.1 200 OK', proxy.CRLF])) - self.tcp_server.shutdown() # explicit early call worker shutdown to avoid resource leak warnings - break - except ConnectionRefusedError: - time.sleep(0.1) - finally: - sock.close() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) as sock: + try: + sock.connect( + (proxy.DEFAULT_IPV4_HOSTNAME, self.tcp_port)) + sock.send(proxy.CRLF.join([ + b'GET http://httpbin.org/get HTTP/1.1', + b'Host: httpbin.org', + proxy.CRLF + ])) + data = sock.recv(proxy.DEFAULT_BUFFER_SIZE) + self.assertEqual(data, proxy.CRLF.join( + [b'HTTP/1.1 200 OK', proxy.CRLF])) + break + except ConnectionRefusedError: + time.sleep(0.1) finally: self.tcp_server.stop() self.tcp_thread.join() @@ -231,32 +253,44 @@ def test_chunk_parse_issue_27(self): self.assertEqual(self.parser.chunk, b'3') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'') - self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_SIZE) + self.assertEqual( + self.parser.state, + proxy.ChunkParser.states.WAITING_FOR_SIZE) self.parser.parse(b'\r\n') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, 3) self.assertEqual(self.parser.body, b'') - self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_DATA) + self.assertEqual( + self.parser.state, + proxy.ChunkParser.states.WAITING_FOR_DATA) self.parser.parse(b'abc') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abc') - self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_SIZE) + self.assertEqual( + self.parser.state, + proxy.ChunkParser.states.WAITING_FOR_SIZE) self.parser.parse(b'\r\n') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abc') - self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_SIZE) + self.assertEqual( + self.parser.state, + proxy.ChunkParser.states.WAITING_FOR_SIZE) self.parser.parse(b'4\r\n') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, 4) self.assertEqual(self.parser.body, b'abc') - self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_DATA) + self.assertEqual( + self.parser.state, + proxy.ChunkParser.states.WAITING_FOR_DATA) self.parser.parse(b'defg\r\n0') self.assertEqual(self.parser.chunk, b'0') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abcdefg') - self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_SIZE) + self.assertEqual( + self.parser.state, + proxy.ChunkParser.states.WAITING_FOR_SIZE) self.parser.parse(b'\r\n\r\n') self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) @@ -270,15 +304,21 @@ def setUp(self): self.parser = proxy.HttpParser(proxy.HttpParser.types.REQUEST_PARSER) def test_build_header(self): - self.assertEqual(proxy.HttpParser.build_header(b'key', b'value'), b'key: value') + self.assertEqual( + proxy.HttpParser.build_header( + b'key', b'value'), b'key: value') def test_split(self): - self.assertEqual(proxy.HttpParser.split(b'CONNECT python.org:443 HTTP/1.0\r\n\r\n'), - (b'CONNECT python.org:443 HTTP/1.0', b'\r\n')) + self.assertEqual( + proxy.HttpParser.split(b'CONNECT python.org:443 HTTP/1.0\r\n\r\n'), + (b'CONNECT python.org:443 HTTP/1.0', + b'\r\n')) def test_split_false_line(self): - self.assertEqual(proxy.HttpParser.split(b'CONNECT python.org:443 HTTP/1.0'), - (False, b'CONNECT python.org:443 HTTP/1.0')) + self.assertEqual( + proxy.HttpParser.split(b'CONNECT python.org:443 HTTP/1.0'), + (False, + b'CONNECT python.org:443 HTTP/1.0')) def test_get_full_parse(self): raw = proxy.CRLF.join([ @@ -286,7 +326,8 @@ def test_get_full_parse(self): b'Host: %s', proxy.CRLF ]) - pkt = raw % (b'https://example.com/path/dir/?a=b&c=d#p=q', b'example.com') + pkt = raw % (b'https://example.com/path/dir/?a=b&c=d#p=q', + b'example.com') self.parser.parse(pkt) self.assertEqual(self.parser.total_size, len(pkt)) self.assertEqual(self.parser.build_url(), b'/path/dir/?a=b&c=d#p=q') @@ -295,10 +336,15 @@ def test_get_full_parse(self): self.assertEqual(self.parser.url.port, None) self.assertEqual(self.parser.version, b'HTTP/1.1') self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE) - self.assertDictContainsSubset({b'host': (b'Host', b'example.com')}, self.parser.headers) + self.assertDictContainsSubset( + {b'host': (b'Host', b'example.com')}, self.parser.headers) self.parser.del_headers([b'host']) self.parser.add_headers([(b'Host', b'example.com')]) - self.assertEqual(raw % (b'/path/dir/?a=b&c=d#p=q', b'example.com'), self.parser.build()) + self.assertEqual( + raw % + (b'/path/dir/?a=b&c=d#p=q', + b'example.com'), + self.parser.build()) def test_build_url_none(self): self.assertEqual(self.parser.build_url(), b'/None') @@ -307,9 +353,10 @@ def test_line_rcvd_to_rcving_headers_state_change(self): pkt = b'GET http://localhost HTTP/1.1' self.parser.parse(pkt) self.assertEqual(self.parser.total_size, len(pkt)) - self.assert_state_change_with_crlf(proxy.HttpParser.states.INITIALIZED, - proxy.HttpParser.states.LINE_RCVD, - proxy.HttpParser.states.RCVING_HEADERS) + self.assert_state_change_with_crlf( + proxy.HttpParser.states.INITIALIZED, + proxy.HttpParser.states.LINE_RCVD, + proxy.HttpParser.states.RCVING_HEADERS) def test_get_partial_parse1(self): pkt = proxy.CRLF.join([ @@ -320,7 +367,9 @@ def test_get_partial_parse1(self): self.assertEqual(self.parser.method, None) self.assertEqual(self.parser.url, None) self.assertEqual(self.parser.version, None) - self.assertEqual(self.parser.state, proxy.HttpParser.states.INITIALIZED) + self.assertEqual( + self.parser.state, + proxy.HttpParser.states.INITIALIZED) self.parser.parse(proxy.CRLF) self.assertEqual(self.parser.total_size, len(pkt) + len(proxy.CRLF)) @@ -332,14 +381,17 @@ def test_get_partial_parse1(self): host_hdr = b'Host: localhost:8080' self.parser.parse(host_hdr) - self.assertEqual(self.parser.total_size, len(pkt) + len(proxy.CRLF) + len(host_hdr)) + self.assertEqual(self.parser.total_size, + len(pkt) + len(proxy.CRLF) + len(host_hdr)) self.assertDictEqual(self.parser.headers, dict()) self.assertEqual(self.parser.buffer, b'Host: localhost:8080') self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD) self.parser.parse(proxy.CRLF * 2) - self.assertEqual(self.parser.total_size, len(pkt) + (3 * len(proxy.CRLF)) + len(host_hdr)) - self.assertDictContainsSubset({b'host': (b'Host', b'localhost:8080')}, self.parser.headers) + self.assertEqual(self.parser.total_size, len(pkt) + + (3 * len(proxy.CRLF)) + len(host_hdr)) + self.assertDictContainsSubset( + {b'host': (b'Host', b'localhost:8080')}, self.parser.headers) self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE) def test_get_partial_parse2(self): @@ -355,14 +407,20 @@ def test_get_partial_parse2(self): self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD) self.parser.parse(b'localhost:8080' + proxy.CRLF) - self.assertDictContainsSubset({b'host': (b'Host', b'localhost:8080')}, self.parser.headers) + self.assertDictContainsSubset( + {b'host': (b'Host', b'localhost:8080')}, self.parser.headers) self.assertEqual(self.parser.buffer, b'') - self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS) + self.assertEqual( + self.parser.state, + proxy.HttpParser.states.RCVING_HEADERS) self.parser.parse(b'Content-Type: text/plain' + proxy.CRLF) self.assertEqual(self.parser.buffer, b'') - self.assertDictContainsSubset({b'content-type': (b'Content-Type', b'text/plain')}, self.parser.headers) - self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS) + self.assertDictContainsSubset( + {b'content-type': (b'Content-Type', b'text/plain')}, self.parser.headers) + self.assertEqual( + self.parser.state, + proxy.HttpParser.states.RCVING_HEADERS) self.parser.parse(proxy.CRLF) self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE) @@ -380,16 +438,19 @@ def test_post_full_parse(self): self.assertEqual(self.parser.url.hostname, b'localhost') self.assertEqual(self.parser.url.port, None) self.assertEqual(self.parser.version, b'HTTP/1.1') - self.assertDictContainsSubset({b'content-type': (b'Content-Type', b'application/x-www-form-urlencoded')}, - self.parser.headers) - self.assertDictContainsSubset({b'content-length': (b'Content-Length', b'7')}, self.parser.headers) + self.assertDictContainsSubset( + {b'content-type': (b'Content-Type', b'application/x-www-form-urlencoded')}, self.parser.headers) + self.assertDictContainsSubset( + {b'content-length': (b'Content-Length', b'7')}, self.parser.headers) self.assertEqual(self.parser.body, b'a=b&c=d') self.assertEqual(self.parser.buffer, b'') self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE) self.assertEqual(len(self.parser.build()), len(raw % b'/')) - def assert_state_change_with_crlf(self, initial_state: proxy.HttpParser.states, - next_state: proxy.HttpParser.states, final_state: proxy.HttpParser.states): + def assert_state_change_with_crlf(self, + initial_state: proxy.HttpParser.states, + next_state: proxy.HttpParser.states, + final_state: proxy.HttpParser.states): self.assertEqual(self.parser.state, initial_state) self.parser.parse(proxy.CRLF) self.assertEqual(self.parser.state, next_state) @@ -407,12 +468,15 @@ def test_post_partial_parse(self): self.assertEqual(self.parser.url.hostname, b'localhost') self.assertEqual(self.parser.url.port, None) self.assertEqual(self.parser.version, b'HTTP/1.1') - self.assert_state_change_with_crlf(proxy.HttpParser.states.RCVING_HEADERS, - proxy.HttpParser.states.RCVING_HEADERS, - proxy.HttpParser.states.HEADERS_COMPLETE) + self.assert_state_change_with_crlf( + proxy.HttpParser.states.RCVING_HEADERS, + proxy.HttpParser.states.RCVING_HEADERS, + proxy.HttpParser.states.HEADERS_COMPLETE) self.parser.parse(b'a=b') - self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_BODY) + self.assertEqual( + self.parser.state, + proxy.HttpParser.states.RCVING_BODY) self.assertEqual(self.parser.body, b'a=b') self.assertEqual(self.parser.buffer, b'') @@ -475,7 +539,9 @@ def test_response_parse_without_content_length(self): b'Date: Thu, 13 Dec 2018 16:24:09 GMT', proxy.CRLF ])) - self.assertEqual(self.parser.state, proxy.HttpParser.states.HEADERS_COMPLETE) + self.assertEqual( + self.parser.state, + proxy.HttpParser.states.HEADERS_COMPLETE) def test_response_parse(self): self.parser.type = proxy.HttpParser.types.RESPONSE_PARSER @@ -498,11 +564,13 @@ def test_response_parse(self): self.assertEqual(self.parser.code, b'301') self.assertEqual(self.parser.reason, b'Moved Permanently') self.assertEqual(self.parser.version, b'HTTP/1.1') - self.assertEqual(self.parser.body, - b'\n' + - b'301 Moved\n

301 Moved

\nThe document has moved\n' + - b'here.\r\n\r\n') - self.assertDictContainsSubset({b'content-length': (b'Content-Length', b'219')}, self.parser.headers) + self.assertEqual( + self.parser.body, + b'\n' + + b'301 Moved\n

301 Moved

\nThe document has moved\n' + + b'here.\r\n\r\n') + self.assertDictContainsSubset( + {b'content-length': (b'Content-Length', b'219')}, self.parser.headers) self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE) def test_response_partial_parse(self): @@ -519,14 +587,21 @@ def test_response_partial_parse(self): b'X-XSS-Protection: 1; mode=block\r\n', b'X-Frame-Options: SAMEORIGIN\r\n' ])) - self.assertDictContainsSubset({b'x-frame-options': (b'X-Frame-Options', b'SAMEORIGIN')}, self.parser.headers) - self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS) + self.assertDictContainsSubset( + {b'x-frame-options': (b'X-Frame-Options', b'SAMEORIGIN')}, self.parser.headers) + self.assertEqual( + self.parser.state, + proxy.HttpParser.states.RCVING_HEADERS) self.parser.parse(b'\r\n') - self.assertEqual(self.parser.state, proxy.HttpParser.states.HEADERS_COMPLETE) + self.assertEqual( + self.parser.state, + proxy.HttpParser.states.HEADERS_COMPLETE) self.parser.parse( b'\n' + b'301 Moved') - self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_BODY) + self.assertEqual( + self.parser.state, + proxy.HttpParser.states.RCVING_BODY) self.parser.parse( b'\n

301 Moved

\nThe document has moved\n' + b'here.\r\n\r\n') @@ -585,7 +660,8 @@ class HTTPRequestHandler(BaseHTTPRequestHandler): def do_GET(self): self.send_response(200) - # TODO(abhinavsingh): Proxy should work just fine even without content-length header + # TODO(abhinavsingh): Proxy should work just fine even without + # content-length header self.send_header('content-length', 2) self.end_headers() self.wfile.write(b'OK') @@ -600,12 +676,14 @@ class TestHttpProtocolHandler(unittest.TestCase): @classmethod def setUpClass(cls): cls.http_server_port = get_available_port() - cls.http_server = HTTPServer(('127.0.0.1', cls.http_server_port), HTTPRequestHandler) + cls.http_server = HTTPServer( + ('127.0.0.1', cls.http_server_port), HTTPRequestHandler) cls.http_server_thread = Thread(target=cls.http_server.serve_forever) cls.http_server_thread.setDaemon(True) cls.http_server_thread.start() cls.config = proxy.HttpProtocolConfig() - cls.config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + cls.config.plugins = proxy.load_plugins( + 'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') @classmethod def tearDownClass(cls): @@ -616,20 +694,29 @@ def tearDownClass(cls): def setUp(self): self._conn = MockTcpConnection() self._addr = ('127.0.0.1', 54382) - self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), config=self.config) + self.proxy = proxy.HttpProtocolHandler( + proxy.TcpClientConnection( + self._conn, self._addr), config=self.config) @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') def test_http_get(self, mock_server_connection, mock_select): server = mock_server_connection.return_value server.connect.return_value = True - mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])] + mock_select.side_effect = [ + ([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])] # Send request line - self.proxy.client.conn.queue((b'GET http://localhost:%d HTTP/1.1' % self.http_server_port) + proxy.CRLF) + self.proxy.client.conn.queue( + (b'GET http://localhost:%d HTTP/1.1' % + self.http_server_port) + proxy.CRLF) self.proxy.run_once() - self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.LINE_RCVD) - self.assertNotEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE) + self.assertEqual( + self.proxy.request.state, + proxy.HttpParser.states.LINE_RCVD) + self.assertNotEqual( + self.proxy.request.state, + proxy.HttpParser.states.COMPLETE) # Send headers and blank line, thus completing HTTP request self.proxy.client.conn.queue(proxy.CRLF.join([ @@ -646,7 +733,9 @@ def test_http_get(self, mock_server_connection, mock_select): def assert_tunnel_response(self, mock_server_connection, server): self.proxy.run_once() self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None) - self.assertEqual(self.proxy.client.buffer, proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + self.assertEqual( + self.proxy.client.buffer, + proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) mock_server_connection.assert_called_once() server.connect.assert_called_once() server.queue.assert_not_called() @@ -697,10 +786,14 @@ def test_proxy_connection_failed(self, mock_select): @mock.patch('select.select') def test_proxy_authentication_failed(self, mock_select): mock_select.return_value = ([self._conn], [], []) - config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) - config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') - self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), - config=config) + config = proxy.HttpProtocolConfig( + auth_code=b'Basic %s' % + base64.b64encode(b'user:pass')) + config.plugins = proxy.load_plugins( + 'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + self.proxy = proxy.HttpProtocolHandler( + proxy.TcpClientConnection( + self._conn, self._addr), config=config) self.proxy.client.conn.queue(proxy.CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', @@ -711,23 +804,33 @@ def test_proxy_authentication_failed(self, mock_select): @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') - def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select): + def test_authenticated_proxy_http_get( + self, mock_server_connection, mock_select): mock_select.return_value = ([self._conn], [], []) server = mock_server_connection.return_value server.connect.return_value = True client = proxy.TcpClientConnection(self._conn, self._addr) - config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) - config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + config = proxy.HttpProtocolConfig( + auth_code=b'Basic %s' % + base64.b64encode(b'user:pass')) + config.plugins = proxy.load_plugins( + 'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.HttpProtocolHandler(client, config=config) - self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port) + self.proxy.client.conn.queue( + b'GET http://localhost:%d HTTP/1.1' % + self.http_server_port) self.proxy.run_once() - self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED) + self.assertEqual( + self.proxy.request.state, + proxy.HttpParser.states.INITIALIZED) self.proxy.client.conn.queue(proxy.CRLF) self.proxy.run_once() - self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.LINE_RCVD) + self.assertEqual( + self.proxy.request.state, + proxy.HttpParser.states.LINE_RCVD) self.proxy.client.conn.queue(proxy.CRLF.join([ b'User-Agent: proxy.py/%s' % proxy.version, @@ -741,15 +844,21 @@ def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select) @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') - def test_authenticated_proxy_http_tunnel(self, mock_server_connection, mock_select): + def test_authenticated_proxy_http_tunnel( + self, mock_server_connection, mock_select): server = mock_server_connection.return_value server.connect.return_value = True - mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])] - - config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) - config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') - self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), - config=config) + mock_select.side_effect = [ + ([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])] + + config = proxy.HttpProtocolConfig( + auth_code=b'Basic %s' % + base64.b64encode(b'user:pass')) + config.plugins = proxy.load_plugins( + 'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + self.proxy = proxy.HttpProtocolHandler( + proxy.TcpClientConnection( + self._conn, self._addr), config=config) self.proxy.client.conn.queue(proxy.CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, @@ -771,9 +880,14 @@ def test_pac_file_served_from_disk(self, mock_select): config = proxy.HttpProtocolConfig(pac_file='proxy.pac') self.init_and_make_pac_file_request(config) self.proxy.run_once() - self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE) + self.assertEqual( + self.proxy.request.state, + proxy.HttpParser.states.COMPLETE) with open('proxy.pac', 'rb') as pac_file: - self.assertEqual(self._conn.received, proxy.HttpWebServerPlugin.PAC_FILE_RESPONSE_PREFIX + pac_file.read()) + self.assertEqual( + self._conn.received, + proxy.HttpWebServerPlugin.PAC_FILE_RESPONSE_PREFIX + + pac_file.read()) @mock.patch('select.select') def test_pac_file_served_from_buffer(self, mock_select): @@ -782,31 +896,43 @@ def test_pac_file_served_from_buffer(self, mock_select): config = proxy.HttpProtocolConfig(pac_file=pac_file_content) self.init_and_make_pac_file_request(config) self.proxy.run_once() - self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE) - self.assertEqual(self._conn.received, proxy.HttpWebServerPlugin.PAC_FILE_RESPONSE_PREFIX + pac_file_content) + self.assertEqual( + self.proxy.request.state, + proxy.HttpParser.states.COMPLETE) + self.assertEqual( + self._conn.received, + proxy.HttpWebServerPlugin.PAC_FILE_RESPONSE_PREFIX + + pac_file_content) @mock.patch('select.select') def test_default_web_server_returns_404(self, mock_select): mock_select.return_value = [self._conn], [], [] config = proxy.HttpProtocolConfig() - config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') - self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), - config=config) + config.plugins = proxy.load_plugins( + 'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + self.proxy = proxy.HttpProtocolHandler( + proxy.TcpClientConnection( + self._conn, self._addr), config=config) self.proxy.client.conn.queue(proxy.CRLF.join([ b'GET /hello HTTP/1.1', proxy.CRLF, proxy.CRLF ])) self.proxy.run_once() - self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE) - self.assertEqual(self._conn.received, proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) + self.assertEqual( + self.proxy.request.state, + proxy.HttpParser.states.COMPLETE) + self.assertEqual( + self._conn.received, + proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) def test_on_client_connection_called_on_teardown(self): config = proxy.HttpProtocolConfig() plugin = mock.MagicMock() config.plugins = {'HttpProtocolBasePlugin': [plugin]} - self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), - config=config) + self.proxy = proxy.HttpProtocolHandler( + proxy.TcpClientConnection( + self._conn, self._addr), config=config) plugin.assert_called() with mock.patch.object(self.proxy, 'run_once') as mock_run_once: mock_run_once.return_value = True @@ -816,9 +942,11 @@ def test_on_client_connection_called_on_teardown(self): plugin.return_value.on_client_connection_close.assert_called() def init_and_make_pac_file_request(self, config): - config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') - self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), - config=config) + config.plugins = proxy.load_plugins( + 'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + self.proxy = proxy.HttpProtocolHandler( + proxy.TcpClientConnection( + self._conn, self._addr), config=config) self.proxy.client.conn.queue(proxy.CRLF.join([ b'GET / HTTP/1.1', proxy.CRLF, @@ -827,7 +955,9 @@ def init_and_make_pac_file_request(self, config): def assert_data_queued(self, mock_server_connection, server): self.proxy.run_once() - self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE) + self.assertEqual( + self.proxy.request.state, + proxy.HttpParser.states.COMPLETE) mock_server_connection.assert_called_once() server.connect.assert_called_once() server.closed = False @@ -863,19 +993,19 @@ def assert_data_queued_to_server(self, server): class TestWorker(unittest.TestCase): def setUp(self): - self.queue = multiprocessing.Queue() - self.worker = proxy.Worker(self.queue) + self.pipe = multiprocessing.Pipe() + self.worker = proxy.Worker(self.pipe[1]) @mock.patch('proxy.HttpProtocolHandler') def test_shutdown_op(self, mock_http_proxy): - self.queue.put((proxy.Worker.operations.SHUTDOWN, None)) + self.pipe[0].send((proxy.Worker.operations.SHUTDOWN, None)) self.worker.run() # Worker should consume the prior shutdown operation self.assertFalse(mock_http_proxy.called) @mock.patch('proxy.HttpProtocolHandler') def test_spawns_http_proxy_threads(self, mock_http_proxy): - self.queue.put((proxy.Worker.operations.HTTP_PROTOCOL, None)) - self.queue.put((proxy.Worker.operations.SHUTDOWN, None)) + self.pipe[0].send((proxy.Worker.operations.HTTP_PROTOCOL, None)) + self.pipe[0].send((proxy.Worker.operations.SHUTDOWN, None)) self.worker.run() self.assertTrue(mock_http_proxy.called) @@ -898,7 +1028,9 @@ def test_status_code_response(self): ])) def test_body_response(self): - e = proxy.HttpRequestRejected(status_code=b'404 NOT FOUND', body=b'Nothing here') + e = proxy.HttpRequestRejected( + status_code=b'404 NOT FOUND', + body=b'Nothing here') self.assertEqual(e.response(self.request), proxy.CRLF.join([ b'HTTP/1.1 404 NOT FOUND', proxy.PROXY_AGENT_HEADER, @@ -913,7 +1045,11 @@ class TestMain(unittest.TestCase): @mock.patch('proxy.set_open_file_limit') @mock.patch('proxy.MultiCoreRequestDispatcher') @mock.patch('proxy.logging.basicConfig') - def test_log_file_setup(self, mock_config, mock_multicore_dispatcher, mock_set_open_file_limit): + def test_log_file_setup( + self, + mock_config, + mock_multicore_dispatcher, + mock_set_open_file_limit): log_file = '/tmp/proxy.log' proxy.main(['--log-file', log_file]) mock_set_open_file_limit.assert_called() @@ -931,23 +1067,35 @@ def test_log_file_setup(self, mock_config, mock_multicore_dispatcher, mock_set_o @mock.patch('builtins.open') @mock.patch('proxy.set_open_file_limit') @mock.patch('proxy.MultiCoreRequestDispatcher') - @unittest.skipIf(True, 'This test passes while development on Intellij but fails via CLI :(') - def test_pid_file_is_written_and_removed(self, mock_multicore_dispatcher, mock_set_open_file_limit, - mock_open, mock_exists, mock_remove): + @unittest.skipIf( + True, + 'This test passes while development on Intellij but fails via CLI :(') + def test_pid_file_is_written_and_removed( + self, + mock_multicore_dispatcher, + mock_set_open_file_limit, + mock_open, + mock_exists, + mock_remove): pid_file = '/tmp/proxy.pid' proxy.main(['--pid-file', pid_file]) mock_set_open_file_limit.assert_called() mock_multicore_dispatcher.assert_called() mock_multicore_dispatcher.return_value.run.assert_called() mock_open.assert_called_with(pid_file, 'wb') - mock_open.return_value.__enter__.return_value.write.assert_called_with(proxy.bytes_(str(os.getpid()))) + mock_open.return_value.__enter__.return_value.write.assert_called_with( + proxy.bytes_(str(os.getpid()))) mock_exists.assert_called_with(pid_file) mock_remove.assert_called_with(pid_file) @mock.patch('proxy.HttpProtocolConfig') @mock.patch('proxy.set_open_file_limit') @mock.patch('proxy.MultiCoreRequestDispatcher') - def test_main(self, mock_multicore_dispatcher, mock_set_open_file_limit, mock_config): + def test_main( + self, + mock_multicore_dispatcher, + mock_set_open_file_limit, + mock_config): proxy.main(['--basic-auth', 'user:pass']) self.assertTrue(mock_set_open_file_limit.called) mock_multicore_dispatcher.assert_called_with( @@ -970,7 +1118,12 @@ def test_main(self, mock_multicore_dispatcher, mock_set_open_file_limit, mock_co @mock.patch('proxy.HttpProtocolConfig') @mock.patch('proxy.set_open_file_limit') @mock.patch('proxy.MultiCoreRequestDispatcher') - def test_main_version(self, mock_multicore_dispatcher, mock_set_open_file_limit, mock_config, mock_print): + def test_main_version( + self, + mock_multicore_dispatcher, + mock_set_open_file_limit, + mock_config, + mock_print): with self.assertRaises(SystemExit): proxy.main(['--version']) mock_print.assert_called_with(proxy.text_(proxy.version)) @@ -983,8 +1136,13 @@ def test_main_version(self, mock_multicore_dispatcher, mock_set_open_file_limit, @mock.patch('proxy.set_open_file_limit') @mock.patch('proxy.MultiCoreRequestDispatcher') @mock.patch('proxy.is_py3') - def test_main_py3_runs(self, mock_is_py3, mock_multicore_dispatcher, mock_set_open_file_limit, - mock_config, mock_print): + def test_main_py3_runs( + self, + mock_is_py3, + mock_multicore_dispatcher, + mock_set_open_file_limit, + mock_config, + mock_print): mock_is_py3.return_value = True proxy.main([]) mock_is_py3.assert_called() @@ -998,9 +1156,16 @@ def test_main_py3_runs(self, mock_is_py3, mock_multicore_dispatcher, mock_set_op @mock.patch('proxy.set_open_file_limit') @mock.patch('proxy.MultiCoreRequestDispatcher') @mock.patch('proxy.is_py3') - @unittest.skipIf(True, 'This test passes while development on Intellij but fails via CLI :(') - def test_main_py2_exit(self, mock_is_py3, mock_multicore_dispatcher, mock_set_open_file_limit, - mock_config, mock_print): + @unittest.skipIf( + True, + 'This test passes while development on Intellij but fails via CLI :(') + def test_main_py2_exit( + self, + mock_is_py3, + mock_multicore_dispatcher, + mock_set_open_file_limit, + mock_config, + mock_print): mock_is_py3.return_value = False with self.assertRaises(SystemExit): proxy.main([]) @@ -1022,7 +1187,9 @@ def test_bytes(self): def test_bytes_nochange(self): self.assertEqual(proxy.bytes_(b'hello'), b'hello') - @unittest.skipIf(os.name == 'nt', 'Open file limit tests disabled for Windows') + @unittest.skipIf( + os.name == 'nt', + 'Open file limit tests disabled for Windows') @mock.patch('resource.getrlimit', return_value=(128, 1024)) @mock.patch('resource.setrlimit', return_value=None) def test_set_open_file_limit(self, mock_set_rlimit, mock_get_rlimit): @@ -1030,18 +1197,24 @@ def test_set_open_file_limit(self, mock_set_rlimit, mock_get_rlimit): mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE) mock_set_rlimit.assert_called_with(resource.RLIMIT_NOFILE, (256, 1024)) - @unittest.skipIf(os.name == 'nt', 'Open file limit tests disabled for Windows') + @unittest.skipIf( + os.name == 'nt', + 'Open file limit tests disabled for Windows') @mock.patch('resource.getrlimit', return_value=(256, 1024)) @mock.patch('resource.setrlimit', return_value=None) - def test_set_open_file_limit_not_called(self, mock_set_rlimit, mock_get_rlimit): + def test_set_open_file_limit_not_called( + self, mock_set_rlimit, mock_get_rlimit): proxy.set_open_file_limit(256) mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE) mock_set_rlimit.assert_not_called() - @unittest.skipIf(os.name == 'nt', 'Open file limit tests disabled for Windows') + @unittest.skipIf( + os.name == 'nt', + 'Open file limit tests disabled for Windows') @mock.patch('resource.getrlimit', return_value=(256, 1024)) @mock.patch('resource.setrlimit', return_value=None) - def test_set_open_file_limit_not_called1(self, mock_set_rlimit, mock_get_rlimit): + def test_set_open_file_limit_not_called1( + self, mock_set_rlimit, mock_get_rlimit): proxy.set_open_file_limit(1024) mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE) mock_set_rlimit.assert_not_called()