diff --git a/proxy/__init__.py b/proxy/__init__.py index 16cacc417d..a2e0fa77ad 100755 --- a/proxy/__init__.py +++ b/proxy/__init__.py @@ -9,7 +9,7 @@ :license: BSD, see LICENSE for more details. """ from .proxy import entry_point, main, Proxy -from .testing.test_case import TestCase +from .testing import TestCase __all__ = [ # PyPi package entry_point. See diff --git a/proxy/common/constants.py b/proxy/common/constants.py index 43c449be0f..738594036c 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -100,12 +100,18 @@ DEFAULT_DATA_DIRECTORY_PATH = os.path.join(str(pathlib.Path.home()), '.proxy') # Cor plugins enabled by default or via flags +DEFAULT_ABC_PLUGINS = [ + 'HttpProtocolHandlerPlugin', + 'HttpProxyBasePlugin', + 'HttpWebServerBasePlugin', + 'ProxyDashboardWebsocketPlugin', +] PLUGIN_HTTP_PROXY = 'proxy.http.proxy.HttpProxyPlugin' PLUGIN_WEB_SERVER = 'proxy.http.server.HttpWebServerPlugin' PLUGIN_PAC_FILE = 'proxy.http.server.HttpWebServerPacFilePlugin' PLUGIN_DEVTOOLS_PROTOCOL = 'proxy.http.inspector.DevtoolsProtocolPlugin' -PLUGIN_DASHBOARD = 'proxy.dashboard.dashboard.ProxyDashboard' -PLUGIN_INSPECT_TRAFFIC = 'proxy.dashboard.inspect_traffic.InspectTrafficPlugin' +PLUGIN_DASHBOARD = 'proxy.dashboard.ProxyDashboard' +PLUGIN_INSPECT_TRAFFIC = 'proxy.dashboard.InspectTrafficPlugin' PLUGIN_PROXY_AUTH = 'proxy.http.proxy.AuthPlugin' PY2_DEPRECATION_MESSAGE = '''DEPRECATION: proxy.py no longer supports Python 2.7. Kindly upgrade to Python 3+. ' diff --git a/proxy/common/flag.py b/proxy/common/flag.py index 33f4021db3..aace87c70e 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -19,14 +19,15 @@ from typing import Optional, List, Any, cast +from .plugins import Plugins from .types import IpAddress -from .utils import text_, bytes_, setup_logger, is_py2, set_open_file_limit -from .utils import import_plugin, load_plugins +from .utils import text_, bytes_, is_py2, set_open_file_limit from .constants import COMMA, DEFAULT_DATA_DIRECTORY_PATH, DEFAULT_NUM_WORKERS from .constants import DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HEADERS, PY2_DEPRECATION_MESSAGE from .constants import PLUGIN_DASHBOARD, PLUGIN_DEVTOOLS_PROTOCOL from .constants import PLUGIN_HTTP_PROXY, PLUGIN_INSPECT_TRAFFIC, PLUGIN_PAC_FILE from .constants import PLUGIN_WEB_SERVER, PLUGIN_PROXY_AUTH +from .logger import Logger from .version import __version__ @@ -94,10 +95,9 @@ def initialize( sys.exit(1) # Discover flags from requested plugin. - # This also surface external plugin flags under --help - for i, f in enumerate(input_args): - if f == '--plugin': - import_plugin(bytes_(input_args[i + 1])) + # This will also surface external plugin flags + # under --help. + Plugins.discover(input_args) # Parse flags args = flags.parse_args(input_args) @@ -108,7 +108,7 @@ def initialize( sys.exit(0) # Setup logging module - setup_logger(args.log_file, args.log_level, args.log_format) + Logger.setup_logger(args.log_file, args.log_level, args.log_format) # Setup limits set_open_file_limit(args.open_file_limit) @@ -125,7 +125,7 @@ def initialize( ] # Load default plugins along with user provided --plugins - plugins = load_plugins(default_plugins + extra_plugins) + plugins = Plugins.load(default_plugins + extra_plugins) # proxy.py currently cannot serve over HTTPS and also perform TLS interception # at the same time. Check if user is trying to enable both feature diff --git a/proxy/common/logger.py b/proxy/common/logger.py new file mode 100644 index 0000000000..e872ef18c9 --- /dev/null +++ b/proxy/common/logger.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import logging + +from typing import Optional + +from .constants import DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT, DEFAULT_LOG_LEVEL + +SINGLE_CHAR_TO_LEVEL = { + 'D': 'DEBUG', + 'I': 'INFO', + 'W': 'WARNING', + 'E': 'ERROR', + 'C': 'CRITICAL', +} + + +class Logger: + """Common logging utilities and setup.""" + + @staticmethod + def setup_logger( + log_file: Optional[str] = DEFAULT_LOG_FILE, + log_level: str = DEFAULT_LOG_LEVEL, + log_format: str = DEFAULT_LOG_FORMAT, + ) -> None: + ll = getattr(logging, SINGLE_CHAR_TO_LEVEL[log_level.upper()[0]]) + if log_file: + logging.basicConfig( + filename=log_file, + filemode='a', + level=ll, + format=log_format, + ) + else: + logging.basicConfig(level=ll, format=log_format) diff --git a/proxy/common/plugins.py b/proxy/common/plugins.py new file mode 100644 index 0000000000..2c04530b20 --- /dev/null +++ b/proxy/common/plugins.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import os +import abc +import logging +import inspect +import importlib + +from typing import Any, List, Dict, Optional, Union + +from .utils import bytes_, text_ +from .constants import DOT, DEFAULT_ABC_PLUGINS + +logger = logging.getLogger(__name__) + + +class Plugins: + """Common utilities for plugin discovery.""" + + @staticmethod + def discover(input_args: List[str]) -> None: + """Search for plugin and plugins flag in command line arguments, + then iterates over each value and discovers the plugin. + """ + for i, f in enumerate(input_args): + if f in ('--plugin', '--plugins'): + v = input_args[i + 1] + parts = v.split(',') + for part in parts: + Plugins.importer(bytes_(part)) + + @staticmethod + def load( + plugins: List[Union[bytes, type]], + abc_plugins: Optional[List[str]] = None, + ) -> Dict[bytes, List[type]]: + """Accepts a list Python modules, scans them to identify + if they are an implementation of abstract plugin classes and + returns a dictionary of matching plugins for each abstract class. + """ + p: Dict[bytes, List[type]] = {} + for abc_plugin in (abc_plugins or DEFAULT_ABC_PLUGINS): + p[bytes_(abc_plugin)] = [] + for plugin_ in plugins: + klass, module_name = Plugins.importer(plugin_) + assert klass and module_name + mro = list(inspect.getmro(klass)) + mro.reverse() + iterator = iter(mro) + while next(iterator) is not abc.ABC: + pass + base_klass = next(iterator) + if klass not in p[bytes_(base_klass.__name__)]: + p[bytes_(base_klass.__name__)].append(klass) + logger.info('Loaded plugin %s.%s', module_name, klass.__name__) + return p + + @staticmethod + def importer(plugin: Union[bytes, type]) -> Any: + """Import and returns the plugin.""" + if isinstance(plugin, type): + return (plugin, '__main__') + plugin_ = text_(plugin.strip()) + assert plugin_ != '' + module_name, klass_name = plugin_.rsplit(text_(DOT), 1) + klass = getattr( + importlib.import_module( + module_name.replace( + os.path.sep, text_(DOT), + ), + ), + klass_name, + ) + return (klass, module_name) diff --git a/proxy/common/utils.py b/proxy/common/utils.py index bcbc872a33..1abb2bdece 100644 --- a/proxy/common/utils.py +++ b/proxy/common/utils.py @@ -9,23 +9,18 @@ :license: BSD, see LICENSE for more details. """ import os -import abc import sys import ssl import socket import logging -import inspect -import importlib import functools import ipaddress import contextlib from types import TracebackType -from typing import Optional, Dict, Any, List, Tuple, Type, Callable, Union +from typing import Optional, Dict, Any, List, Tuple, Type, Callable from .constants import HTTP_1_1, COLON, WHITESPACE, CRLF, DEFAULT_TIMEOUT -from .constants import DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT, DEFAULT_LOG_LEVEL -from .constants import DOT if os.name != 'nt': import resource @@ -263,77 +258,6 @@ def get_available_port() -> int: return int(port) -def setup_logger( - log_file: Optional[str] = DEFAULT_LOG_FILE, - log_level: str = DEFAULT_LOG_LEVEL, - log_format: str = DEFAULT_LOG_FORMAT, -) -> None: - ll = getattr( - logging, - { - 'D': 'DEBUG', - 'I': 'INFO', - 'W': 'WARNING', - 'E': 'ERROR', - 'C': 'CRITICAL', - }[log_level.upper()[0]], - ) - if log_file: - logging.basicConfig( - filename=log_file, - filemode='a', - level=ll, - format=log_format, - ) - else: - logging.basicConfig(level=ll, format=log_format) - - -def load_plugins( - plugins: List[Union[bytes, type]], -) -> Dict[bytes, List[type]]: - """Accepts a comma separated list of Python modules and returns - a list of respective Python classes.""" - p: Dict[bytes, List[type]] = { - b'HttpProtocolHandlerPlugin': [], - b'HttpProxyBasePlugin': [], - b'HttpWebServerBasePlugin': [], - b'ProxyDashboardWebsocketPlugin': [], - } - for plugin_ in plugins: - klass, module_name = import_plugin(plugin_) - assert klass and module_name - mro = list(inspect.getmro(klass)) - mro.reverse() - iterator = iter(mro) - while next(iterator) is not abc.ABC: - pass - base_klass = next(iterator) - if klass not in p[bytes_(base_klass.__name__)]: - p[bytes_(base_klass.__name__)].append(klass) - logger.info('Loaded plugin %s.%s', module_name, klass.__name__) - return p - - -def import_plugin(plugin: Union[bytes, type]) -> Any: - if isinstance(plugin, type): - module_name = '__main__' - klass = plugin - else: - plugin_ = text_(plugin.strip()) - assert plugin_ != '' - module_name, klass_name = plugin_.rsplit(text_(DOT), 1) - klass = getattr( - importlib.import_module( - module_name.replace( - os.path.sep, text_(DOT), - ), - ), - klass_name, - ) - return (klass, module_name) - - def set_open_file_limit(soft_limit: int) -> None: """Configure open file description soft limit on supported OS.""" if os.name != 'nt': # resource module not available on Windows OS diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 02ae41473f..cfa8045fc3 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -27,7 +27,7 @@ from ..event import EventQueue, eventNames from ...common.constants import DEFAULT_THREADLESS from ...common.flag import flags -from ...common.utils import setup_logger +from ...common.logger import Logger logger = logging.getLogger(__name__) @@ -159,7 +159,7 @@ def run_once(self) -> None: self._start_threaded_work(conn, addr) def run(self) -> None: - setup_logger( + Logger.setup_logger( self.flags.log_file, self.flags.log_level, self.flags.log_format, ) diff --git a/proxy/core/acceptor/threadless.py b/proxy/core/acceptor/threadless.py index 3cdd323cd7..7a957f4180 100644 --- a/proxy/core/acceptor/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -26,7 +26,7 @@ from ..connection import TcpClientConnection from ..event import EventQueue, eventNames -from ...common.utils import setup_logger +from ...common.logger import Logger from ...common.types import Readables, Writables from ...common.constants import DEFAULT_TIMEOUT @@ -204,7 +204,7 @@ def run_once(self) -> None: self.cleanup_inactive() def run(self) -> None: - setup_logger( + Logger.setup_logger( self.flags.log_file, self.flags.log_level, self.flags.log_format, ) diff --git a/proxy/core/ssh/__init__.py b/proxy/core/ssh/__init__.py index 232621f0b5..f793bc839a 100644 --- a/proxy/core/ssh/__init__.py +++ b/proxy/core/ssh/__init__.py @@ -8,3 +8,10 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +from .client import SshClient +from .tunnel import Tunnel + +__all__ = [ + 'SshClient', + 'Tunnel', +] diff --git a/proxy/dashboard/__init__.py b/proxy/dashboard/__init__.py index 232621f0b5..0f5d329522 100644 --- a/proxy/dashboard/__init__.py +++ b/proxy/dashboard/__init__.py @@ -8,3 +8,12 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +from .dashboard import ProxyDashboard +from .inspect_traffic import InspectTrafficPlugin +from .plugin import ProxyDashboardWebsocketPlugin + +__all__ = [ + 'ProxyDashboard', + 'InspectTrafficPlugin', + 'ProxyDashboardWebsocketPlugin', +] diff --git a/proxy/dashboard/dashboard.py b/proxy/dashboard/dashboard.py index 49ae0f2cfd..57fb5aa81c 100644 --- a/proxy/dashboard/dashboard.py +++ b/proxy/dashboard/dashboard.py @@ -14,9 +14,7 @@ from .plugin import ProxyDashboardWebsocketPlugin -from ..common.flag import flags from ..common.utils import build_http_response, bytes_ -from ..common.constants import DEFAULT_ENABLE_DASHBOARD from ..http.server import HttpWebServerPlugin, HttpWebServerBasePlugin, httpProtocolTypes from ..http.parser import HttpParser from ..http.websocket import WebsocketFrame @@ -25,14 +23,6 @@ logger = logging.getLogger(__name__) -flags.add_argument( - '--enable-dashboard', - action='store_true', - default=DEFAULT_ENABLE_DASHBOARD, - help='Default: False. Enables proxy.py dashboard.', -) - - class ProxyDashboard(HttpWebServerBasePlugin): """Proxy Dashboard.""" diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 9154f9f59d..446dde3a34 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -286,7 +286,10 @@ def handle_readables(self, readables: Readables) -> bool: return False except socket.error as e: if e.errno == errno.ECONNRESET: - logger.warning('%r' % e) + # Most requests for mobile devices will end up + # with client closed connection. Using `debug` + # here to avoid flooding the logs. + logger.debug('%r' % e) else: logger.exception( 'Exception while receiving from %s connection %r with reason %r' % diff --git a/proxy/proxy.py b/proxy/proxy.py index 99956d45ae..7059fd2e78 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -25,6 +25,7 @@ from .common.flag import FlagParser, flags from .common.constants import DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT, DEFAULT_LOG_LEVEL from .common.constants import DEFAULT_OPEN_FILE_LIMIT, DEFAULT_PLUGINS, DEFAULT_VERSION +from .common.constants import DEFAULT_ENABLE_DASHBOARD logger = logging.getLogger(__name__) @@ -76,6 +77,20 @@ help='Comma separated plugins', ) +# TODO: Ideally all `--enable-*` flags must be at the top-level. +# --enable-dashboard is specially needed here because +# ProxyDashboard class is not imported by anyone. +# +# If we move this flag definition within dashboard.py, +# users will also have to explicitly enable dashboard plugin +# to also use flags provided by it. +flags.add_argument( + '--enable-dashboard', + action='store_true', + default=DEFAULT_ENABLE_DASHBOARD, + help='Default: False. Enables proxy.py dashboard.', +) + class Proxy: """Context manager to control core AcceptorPool server lifecycle. diff --git a/proxy/testing/__init__.py b/proxy/testing/__init__.py index 232621f0b5..e841545b30 100644 --- a/proxy/testing/__init__.py +++ b/proxy/testing/__init__.py @@ -8,3 +8,8 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +from .test_case import TestCase + +__all__ = [ + 'TestCase', +] diff --git a/tests/common/test_pki.py b/tests/common/test_pki.py index 76dd723837..1abcb1c625 100644 --- a/tests/common/test_pki.py +++ b/tests/common/test_pki.py @@ -20,6 +20,10 @@ class TestPki(unittest.TestCase): + def setUp(self) -> None: + self._tempdir = tempfile.gettempdir() + return super().setUp() + @mock.patch('subprocess.Popen') def test_run_openssl_command(self, mock_popen: mock.Mock) -> None: command = ['my', 'custom', 'command'] @@ -103,7 +107,7 @@ def test_gen_public_key(self) -> None: def test_gen_csr(self) -> None: key_path, nopass_key_path, crt_path = self._gen_public_private_key() - csr_path = os.path.join(tempfile.gettempdir(), 'test_gen_public.csr') + csr_path = os.path.join(self._tempdir, 'test_gen_public.csr') pki.gen_csr(csr_path, key_path, 'password', crt_path) self.assertTrue(os.path.exists(csr_path)) # TODO: Assert CSR is valid for provided crt and key @@ -117,14 +121,14 @@ def test_sign_csr(self) -> None: def _gen_public_private_key(self) -> Tuple[str, str, str]: key_path, nopass_key_path = self._gen_private_key() - crt_path = os.path.join(tempfile.gettempdir(), 'test_gen_public.crt') + crt_path = os.path.join(self._tempdir, 'test_gen_public.crt') pki.gen_public_key(crt_path, key_path, 'password', '/CN=example.com') return (key_path, nopass_key_path, crt_path) def _gen_private_key(self) -> Tuple[str, str]: - key_path = os.path.join(tempfile.gettempdir(), 'test_gen_private.key') + key_path = os.path.join(self._tempdir, 'test_gen_private.key') nopass_key_path = os.path.join( - tempfile.gettempdir(), + self._tempdir, 'test_gen_private_nopass.key', ) pki.gen_private_key(key_path, 'password') diff --git a/tests/http/test_protocol_handler.py b/tests/http/test_protocol_handler.py index c5af7f3cff..93d821fdd6 100644 --- a/tests/http/test_protocol_handler.py +++ b/tests/http/test_protocol_handler.py @@ -15,9 +15,10 @@ from typing import cast from unittest import mock +from proxy.common.plugins import Plugins from proxy.common.flag import FlagParser from proxy.common.version import __version__ -from proxy.common.utils import bytes_, load_plugins +from proxy.common.utils import bytes_ from proxy.common.constants import CRLF, PLUGIN_HTTP_PROXY, PLUGIN_PROXY_AUTH, PLUGIN_WEB_SERVER from proxy.core.connection import TcpClientConnection from proxy.http.parser import HttpParser @@ -42,7 +43,7 @@ def setUp( self.http_server_port = 65535 self.flags = FlagParser.initialize() - self.flags.plugins = load_plugins([ + self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) @@ -215,7 +216,7 @@ def test_proxy_authentication_failed( flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), ) - flags.plugins = load_plugins([ + flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PROXY_AUTH), @@ -253,7 +254,7 @@ def test_authenticated_proxy_http_get( flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), ) - flags.plugins = load_plugins([ + flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) @@ -308,7 +309,7 @@ def test_authenticated_proxy_http_tunnel( flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), ) - flags.plugins = load_plugins([ + flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) diff --git a/tests/http/test_web_server.py b/tests/http/test_web_server.py index 6333481f34..f7e48e5fb1 100644 --- a/tests/http/test_web_server.py +++ b/tests/http/test_web_server.py @@ -15,11 +15,12 @@ import selectors from unittest import mock +from proxy.common.plugins import Plugins from proxy.common.flag import FlagParser from proxy.core.connection import TcpClientConnection from proxy.http.handler import HttpProtocolHandler from proxy.http.parser import httpParserStates -from proxy.common.utils import build_http_response, build_http_request, bytes_, text_, load_plugins +from proxy.common.utils import build_http_response, build_http_request, bytes_, text_ from proxy.common.constants import CRLF, PLUGIN_HTTP_PROXY, PLUGIN_PAC_FILE, PLUGIN_WEB_SERVER, PROXY_PY_DIR from proxy.http.server import HttpWebServerPlugin @@ -34,7 +35,7 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector = mock_selector self.flags = FlagParser.initialize() - self.flags.plugins = load_plugins([ + self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) @@ -113,7 +114,7 @@ def test_default_web_server_returns_404( ), ] flags = FlagParser.initialize() - flags.plugins = load_plugins([ + flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) @@ -137,7 +138,7 @@ def test_default_web_server_returns_404( ) @unittest.skipIf( - os.environ.get('GITHUB_ACTIONS', True), + os.environ.get('GITHUB_ACTIONS', 'false') == 'true', 'Disabled on GitHub actions because this test is flaky on GitHub infrastructure.', ) @mock.patch('selectors.DefaultSelector') @@ -183,7 +184,7 @@ def test_static_web_server_serves( enable_static_server=True, static_server_dir=static_server_dir, ) - flags.plugins = load_plugins([ + flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) @@ -247,7 +248,7 @@ def test_static_web_server_serves_404( ] flags = FlagParser.initialize(enable_static_server=True) - flags.plugins = load_plugins([ + flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) @@ -290,7 +291,7 @@ def test_on_client_connection_called_on_teardown( def init_and_make_pac_file_request(self, pac_file: str) -> None: flags = FlagParser.initialize(pac_file=pac_file) - flags.plugins = load_plugins([ + flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PAC_FILE), diff --git a/tests/test_main.py b/tests/test_main.py index eb951fbcf6..b9115cff4e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -27,13 +27,11 @@ from proxy.common.constants import DEFAULT_PAC_FILE, DEFAULT_PLUGINS, DEFAULT_PID_FILE, DEFAULT_PORT, DEFAULT_BASIC_AUTH from proxy.common.constants import DEFAULT_NUM_WORKERS, DEFAULT_OPEN_FILE_LIMIT, DEFAULT_IPV6_HOSTNAME from proxy.common.constants import DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_CLIENT_RECVBUF_SIZE, PY2_DEPRECATION_MESSAGE +from proxy.common.constants import PLUGIN_INSPECT_TRAFFIC, PLUGIN_DASHBOARD, PLUGIN_DEVTOOLS_PROTOCOL, PLUGIN_WEB_SERVER +from proxy.common.constants import PLUGIN_HTTP_PROXY from proxy.common.version import __version__ -def get_temp_file(name: str) -> str: - return os.path.join(tempfile.gettempdir(), name) - - class TestMain(unittest.TestCase): @staticmethod @@ -143,7 +141,7 @@ def test_enable_events( mock_sleep.assert_called() @mock.patch('time.sleep') - @mock.patch('proxy.common.flag.load_plugins') + @mock.patch('proxy.common.plugins.Plugins.load') @mock.patch('proxy.common.flag.FlagParser.parse_args') @mock.patch('proxy.proxy.EventManager') @mock.patch('proxy.proxy.AcceptorPool') @@ -163,11 +161,11 @@ def test_enable_dashboard( mock_load_plugins.assert_called() self.assertEqual( mock_load_plugins.call_args_list[0][0][0], [ - b'proxy.http.server.HttpWebServerPlugin', - b'proxy.dashboard.dashboard.ProxyDashboard', - b'proxy.dashboard.inspect_traffic.InspectTrafficPlugin', - b'proxy.http.inspector.DevtoolsProtocolPlugin', - b'proxy.http.proxy.HttpProxyPlugin', + bytes_(PLUGIN_WEB_SERVER), + bytes_(PLUGIN_DASHBOARD), + bytes_(PLUGIN_INSPECT_TRAFFIC), + bytes_(PLUGIN_DEVTOOLS_PROTOCOL), + bytes_(PLUGIN_HTTP_PROXY), ], ) mock_parse_args.assert_called_once() @@ -179,7 +177,7 @@ def test_enable_dashboard( mock_event_manager.return_value.stop_event_dispatcher.assert_called_once() @mock.patch('time.sleep') - @mock.patch('proxy.common.flag.load_plugins') + @mock.patch('proxy.common.plugins.Plugins.load') @mock.patch('proxy.common.flag.FlagParser.parse_args') @mock.patch('proxy.proxy.EventManager') @mock.patch('proxy.proxy.AcceptorPool') @@ -199,9 +197,9 @@ def test_enable_devtools( mock_load_plugins.assert_called() self.assertEqual( mock_load_plugins.call_args_list[0][0][0], [ - b'proxy.http.inspector.DevtoolsProtocolPlugin', - b'proxy.http.server.HttpWebServerPlugin', - b'proxy.http.proxy.HttpProxyPlugin', + bytes_(PLUGIN_DEVTOOLS_PROTOCOL), + bytes_(PLUGIN_WEB_SERVER), + bytes_(PLUGIN_HTTP_PROXY), ], ) mock_parse_args.assert_called_once() @@ -227,7 +225,7 @@ def test_pid_file_is_written_and_removed( mock_remove: mock.Mock, mock_sleep: mock.Mock, ) -> None: - pid_file = get_temp_file('pid') + pid_file = os.path.join(tempfile.gettempdir(), 'pid') mock_sleep.side_effect = KeyboardInterrupt() mock_args = mock_parse_args.return_value self.mock_default_args(mock_args)