diff --git a/Lib/http/server.py b/Lib/http/server.py index f42e9a375e479a..a641550ce1d5d1 100644 --- a/Lib/http/server.py +++ b/Lib/http/server.py @@ -997,7 +997,7 @@ def test(HandlerClass=BaseHTTPRequestHandler, sys.exit(0) -if __name__ == '__main__': +def _main(args=None): import argparse import contextlib @@ -1021,7 +1021,7 @@ def test(HandlerClass=BaseHTTPRequestHandler, parser.add_argument('port', default=8000, type=int, nargs='?', help='bind to this port ' '(default: %(default)s)') - args = parser.parse_args() + args = parser.parse_args(args) if not args.tls_cert and args.tls_key: parser.error("--tls-key requires --tls-cert to be set") @@ -1061,3 +1061,7 @@ def finish_request(self, request, client_address): tls_key=args.tls_key, tls_password=tls_key_password, ) + + +if __name__ == '__main__': + _main() diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index 11c74a02bf2903..7437b1a4e2c205 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -8,8 +8,10 @@ SimpleHTTPRequestHandler from http import server, HTTPStatus +import contextlib import os import socket +import subprocess import sys import re import ntpath @@ -20,6 +22,7 @@ import html import http, http.client import urllib.parse +import urllib.request import tempfile import time import datetime @@ -32,6 +35,8 @@ from test.support import ( is_apple, import_helper, os_helper, threading_helper ) +from test.support.script_helper import kill_python, spawn_python +from test.support.socket_helper import find_unused_port try: import ssl @@ -1280,6 +1285,254 @@ def test_server_test_ipv4(self, _): self.assertEqual(mock_server.address_family, socket.AF_INET) +class CommandLineTestCase(unittest.TestCase): + default_port = 8000 + default_bind = None + default_protocol = 'HTTP/1.0' + default_handler = SimpleHTTPRequestHandler + default_server = unittest.mock.ANY + tls_cert = certdata_file('ssl_cert.pem') + tls_key = certdata_file('ssl_key.pem') + tls_password = 'somepass' + tls_cert_options = ['--tls-cert'] + tls_key_options = ['--tls-key'] + tls_password_options = ['--tls-password-file'] + args = { + 'HandlerClass': default_handler, + 'ServerClass': default_server, + 'protocol': default_protocol, + 'port': default_port, + 'bind': default_bind, + 'tls_cert': None, + 'tls_key': None, + 'tls_password': None, + } + + def setUp(self): + super().setUp() + self.tls_password_file = tempfile.mktemp() + with open(self.tls_password_file, 'wb') as f: + f.write(self.tls_password.encode()) + self.addCleanup(os_helper.unlink, self.tls_password_file) + + def invoke_httpd(self, *args, stdout=None, stderr=None): + stdout = StringIO() if stdout is None else stdout + stderr = StringIO() if stderr is None else stderr + with contextlib.redirect_stdout(stdout), \ + contextlib.redirect_stderr(stderr): + server._main(args) + return stdout.getvalue(), stderr.getvalue() + + @mock.patch('http.server.test') + def test_port_flag(self, mock_func): + ports = [8000, 65535] + for port in ports: + with self.subTest(port=port): + self.invoke_httpd(str(port)) + call_args = self.args | dict(port=port) + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_directory_flag(self, mock_func): + options = ['-d', '--directory'] + directories = ['.', '/foo', '\\bar', '/', + 'C:\\', 'C:\\foo', 'C:\\bar', + '/home/user', './foo/foo2', 'D:\\foo\\bar'] + for flag in options: + for directory in directories: + with self.subTest(flag=flag, directory=directory): + self.invoke_httpd(flag, directory) + mock_func.assert_called_once_with(**self.args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_bind_flag(self, mock_func): + options = ['-b', '--bind'] + bind_addresses = ['localhost', '127.0.0.1', '::1', + '0.0.0.0', '8.8.8.8',] + for flag in options: + for bind_address in bind_addresses: + with self.subTest(flag=flag, bind_address=bind_address): + self.invoke_httpd(flag, bind_address) + call_args = self.args | dict(bind=bind_address) + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_protocol_flag(self, mock_func): + options = ['-p', '--protocol'] + protocols = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2.0', 'HTTP/3.0'] + for flag in options: + for protocol in protocols: + with self.subTest(flag=flag, protocol=protocol): + self.invoke_httpd(flag, protocol) + call_args = self.args | dict(protocol=protocol) + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_tls_cert_and_key_flags(self, mock_func): + for tls_cert_option in self.tls_cert_options: + for tls_key_option in self.tls_key_options: + self.invoke_httpd(tls_cert_option, self.tls_cert, + tls_key_option, self.tls_key) + call_args = { + 'tls_cert': self.tls_cert, + 'tls_key': self.tls_key, + } + call_args = self.args | call_args + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_tls_cert_and_key_and_password_flags(self, mock_func): + for tls_cert_option in self.tls_cert_options: + for tls_key_option in self.tls_key_options: + for tls_password_option in self.tls_password_options: + self.invoke_httpd(tls_cert_option, + self.tls_cert, + tls_key_option, + self.tls_key, + tls_password_option, + self.tls_password_file) + call_args = { + 'tls_cert': self.tls_cert, + 'tls_key': self.tls_key, + 'tls_password': self.tls_password, + } + call_args = self.args | call_args + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_missing_tls_cert_flag(self, mock_func): + for tls_key_option in self.tls_key_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_key_option, self.tls_key) + mock_func.reset_mock() + + for tls_password_option in self.tls_password_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_password_option, self.tls_password) + mock_func.reset_mock() + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_invalid_password_file(self, mock_func): + non_existent_file = 'non_existent_file' + for tls_password_option in self.tls_password_options: + for tls_cert_option in self.tls_cert_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_cert_option, + self.tls_cert, + tls_password_option, + non_existent_file) + + @mock.patch('http.server.test') + def test_no_arguments(self, mock_func): + self.invoke_httpd() + mock_func.assert_called_once_with(**self.args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_help_flag(self, _): + options = ['-h', '--help'] + stdout, stderr = StringIO(), StringIO() + for option in options: + with self.assertRaises(SystemExit): + self.invoke_httpd(option, stdout=stdout, stderr=stderr) + self.assertIn('usage', stdout.getvalue()) + self.assertEqual('', stderr.getvalue()) + + @mock.patch('http.server.test') + def test_unknown_flag(self, _): + stdout, stderr = StringIO(), StringIO() + with self.assertRaises(SystemExit): + self.invoke_httpd('--unknown-flag', stdout=stdout, stderr=stderr) + self.assertEqual('', stdout.getvalue()) + self.assertIn('error', stderr.getvalue()) + +class CommandLineRunTimeTestCase(unittest.TestCase): + random_data = os.urandom(32) + random_file_name = 'served_filename' + tls_cert = certdata_file('ssl_cert.pem') + tls_key = certdata_file('ssl_key.pem') + tls_password = 'somepass' + + def setUp(self): + super().setUp() + with open(self.random_file_name, 'wb') as f: + f.write(self.random_data) + self.addCleanup(os_helper.unlink, self.random_file_name) + self.tls_password_file = tempfile.mktemp() + with open(self.tls_password_file, 'wb') as f: + f.write(self.tls_password.encode()) + self.addCleanup(os_helper.unlink, self.tls_password_file) + + def fetch_file(self, path, allow_self_signed_cert=True): + context = ssl.create_default_context() + if allow_self_signed_cert: + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + req = urllib.request.Request(path, method='GET') + with urllib.request.urlopen(req, context=context) as res: + return res.read() + + def parse_cli_output(self, output): + matches = re.search(r'\((https?)://([^/:]+):(\d+)/?\)', output) + if matches is None: + return None, None, None + return matches.group(1), matches.group(2), int(matches.group(3)) + + def wait_for_server(self, proc, protocol, port, bind, timeout=50): + """Parses the output of the server process by lines and returns True if + the server is listening on the given port and bind address.""" + while timeout > 0: + line = proc.stdout.readline() + if not line: + time.sleep(0.1) + timeout -= 1 + continue + protocol_, host_, port_ = self.parse_cli_output(line) + if not protocol_ or not host_ or not port_: + time.sleep(0.1) + timeout -= 1 + continue + if protocol_ == protocol and host_ == bind and port_ == port: + return True + else: + break + return False + + def test_http_client(self): + port = find_unused_port() + bind = '127.0.0.1' + proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind, + bufsize=1, text=True) + self.addCleanup(kill_python, proc) + self.addCleanup(proc.terminate) + self.assertTrue(self.wait_for_server(proc, 'http', port, bind)) + res = self.fetch_file(f'http://{bind}:{port}/{self.random_file_name}') + self.assertEqual(res, self.random_data) + + def test_https_client(self): + port = find_unused_port() + bind = '127.0.0.1' + proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind, + '--tls-cert', self.tls_cert, + '--tls-key', self.tls_key, + '--tls-password-file', self.tls_password_file, + bufsize=1, text=True) + self.addCleanup(kill_python, proc) + self.addCleanup(proc.terminate) + self.assertTrue(self.wait_for_server(proc, 'https', port, bind)) + res = self.fetch_file(f'https://{bind}:{port}/{self.random_file_name}') + self.assertEqual(res, self.random_data) + def setUpModule(): unittest.addModuleCleanup(os.chdir, os.getcwd())