diff --git a/Lib/logging/config.py b/Lib/logging/config.py index 3bc63b78621aba..752c8cbf025ca8 100644 --- a/Lib/logging/config.py +++ b/Lib/logging/config.py @@ -29,6 +29,7 @@ import logging import logging.handlers import re +import socket import struct import sys import threading @@ -885,7 +886,11 @@ class ConfigSocketReceiver(ThreadingTCPServer): def __init__(self, host='localhost', port=DEFAULT_LOGGING_CONFIG_PORT, handler=None, ready=None, verify=None): - ThreadingTCPServer.__init__(self, (host, port), handler) + try: + ThreadingTCPServer.__init__(self, (host, port), handler) + except OSError as err: + self.address_family = socket.AF_INET6 + ThreadingTCPServer.__init__(self, (host, port), handler) logging._acquireLock() self.abort = 0 logging._releaseLock() diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py index 510e4b5aba44a6..85f371bb47c024 100644 --- a/Lib/multiprocessing/connection.py +++ b/Lib/multiprocessing/connection.py @@ -51,6 +51,9 @@ default_family = 'AF_UNIX' families += ['AF_UNIX'] +if hasattr(socket, 'AF_INET6') and socket.has_ipv6: + families.append('AF_INET6') + if sys.platform == 'win32': default_family = 'AF_PIPE' families += ['AF_PIPE'] @@ -70,7 +73,7 @@ def arbitrary_address(family): ''' Return an arbitrary free address for the given family ''' - if family == 'AF_INET': + if family in {'AF_INET', 'AF_INET6'}: return ('localhost', 0) elif family == 'AF_UNIX': # Prefer abstract sockets if possible to avoid problems with the address @@ -101,9 +104,16 @@ def address_type(address): ''' Return the types of the address - This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE' + This can be 'AF_INET', 'AF_INET6', 'AF_UNIX', or 'AF_PIPE' ''' if type(address) == tuple: + if '.' in address[0]: + return 'AF_INET' + if ':' in address[0]: + return 'AF_INET6' + addr_info = socket.getaddrinfo(*address[:2]) + if addr_info: + return addr_info[0][0].name return 'AF_INET' elif type(address) is str and address.startswith('\\\\'): return 'AF_PIPE' diff --git a/Lib/test/_test_eintr.py b/Lib/test/_test_eintr.py index e43b59d064f55a..dd73a8f7c66bdb 100644 --- a/Lib/test/_test_eintr.py +++ b/Lib/test/_test_eintr.py @@ -285,28 +285,28 @@ def test_sendmsg(self): self._test_send(lambda sock, data: sock.sendmsg([data])) def test_accept(self): - sock = socket.create_server((socket_helper.HOST, 0)) - self.addCleanup(sock.close) - port = sock.getsockname()[1] - - code = '\n'.join(( - 'import socket, time', - '', - 'host = %r' % socket_helper.HOST, - 'port = %s' % port, - 'sleep_time = %r' % self.sleep_time, - '', - '# let parent block on accept()', - 'time.sleep(sleep_time)', - 'with socket.create_connection((host, port)):', - ' time.sleep(sleep_time)', - )) - - proc = self.subprocess(code) - with kill_on_error(proc): - client_sock, _ = sock.accept() - client_sock.close() - self.assertEqual(proc.wait(), 0) + with socket_helper.bind_ip_socket_and_port() as sock_port: + sock, port = sock_port + sock.listen() + + code = '\n'.join(( + 'import socket, time', + '', + 'host = %r' % socket_helper.HOST, + 'port = %s' % port, + 'sleep_time = %r' % self.sleep_time, + '', + '# let parent block on accept()', + 'time.sleep(sleep_time)', + 'with socket.create_connection((host, port)):', + ' time.sleep(sleep_time)', + )) + + proc = self.subprocess(code) + with kill_on_error(proc): + client_sock, _ = sock.accept() + client_sock.close() + self.assertEqual(proc.wait(), 0) # Issue #25122: There is a race condition in the FreeBSD kernel on # handling signals in the FIFO device. Skip the test until the bug is diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index ead92cfa2abfea..02b24f41a428c3 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -184,6 +184,14 @@ class BaseTestCase(object): ALLOWED_TYPES = ('processes', 'manager', 'threads') + def get_families(self): + fams = set(self.connection.families) + if not socket_helper.IPV6_ENABLED: + fams -= {'AF_INET6'} + if not socket_helper.IPV4_ENABLED: + fams -= {'AF_INET'} + return fams + def assertTimingAlmostEqual(self, a, b): if CHECK_TIMINGS: self.assertAlmostEqual(a, b, 1) @@ -3284,7 +3292,7 @@ class _TestListener(BaseTestCase): ALLOWED_TYPES = ('processes',) def test_multiple_bind(self): - for family in self.connection.families: + for family in self.get_families(): l = self.connection.Listener(family=family) self.addCleanup(l.close) self.assertRaises(OSError, self.connection.Listener, @@ -3324,7 +3332,7 @@ def _test(cls, address): conn.close() def test_listener_client(self): - for family in self.connection.families: + for family in self.get_families(): l = self.connection.Listener(family=family) p = self.Process(target=self._test, args=(l.address,)) p.daemon = True @@ -3351,7 +3359,7 @@ def test_issue14725(self): l.close() def test_issue16955(self): - for fam in self.connection.families: + for fam in self.get_families(): l = self.connection.Listener(family=fam) c = self.connection.Client(l.address) a = l.accept() @@ -3464,7 +3472,8 @@ def _listener(cls, conn, families): new_conn.close() l.close() - l = socket.create_server((socket_helper.HOST, 0)) + l = socket.create_server((socket_helper.HOST, 0), + family=socket_helper.get_family()) conn.send(l.getsockname()) new_conn, addr = l.accept() conn.send(new_conn) @@ -3481,7 +3490,7 @@ def _remote(cls, conn): client.close() address, msg = conn.recv() - client = socket.socket() + client = socket_helper.tcp_socket() client.connect(address) client.sendall(msg.upper()) client.close() @@ -3489,7 +3498,7 @@ def _remote(cls, conn): conn.close() def test_pickling(self): - families = self.connection.families + families = self.get_families() lconn, lconn0 = self.Pipe() lp = self.Process(target=self._listener, args=(lconn0, families)) @@ -4638,7 +4647,7 @@ def test_wait(self, slow=False): @classmethod def _child_test_wait_socket(cls, address, slow): - s = socket.socket() + s = socket_helper.tcp_socket() s.connect(address) for i in range(10): if slow: @@ -4648,7 +4657,8 @@ def _child_test_wait_socket(cls, address, slow): def test_wait_socket(self, slow=False): from multiprocessing.connection import wait - l = socket.create_server((socket_helper.HOST, 0)) + l = socket.create_server((socket_helper.HOST, 0), + family=socket_helper.get_family()) addr = l.getsockname() readers = [] procs = [] @@ -4836,7 +4846,8 @@ def test_timeout(self): try: socket.setdefaulttimeout(0.1) parent, child = multiprocessing.Pipe(duplex=True) - l = multiprocessing.connection.Listener(family='AF_INET') + l = multiprocessing.connection.Listener( + family=socket_helper.get_family().name) p = multiprocessing.Process(target=self._test_timeout, args=(child, l.address)) p.start() @@ -4910,11 +4921,11 @@ def get_high_socket_fd(self): # The child process will not have any socket handles, so # calling socket.fromfd() should produce WSAENOTSOCK even # if there is a handle of the same number. - return socket.socket().detach() + return socket_helper.tcp_socket().detach() else: # We want to produce a socket with an fd high enough that a # freshly created child process will not have any fds as high. - fd = socket.socket().detach() + fd = socket_helper.tcp_socket().detach() to_close = [] while fd < 50: to_close.append(fd) @@ -4925,7 +4936,7 @@ def get_high_socket_fd(self): def close(self, fd): if WIN32: - socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=fd).close() + socket.socket(socket_helper.get_family(), socket.SOCK_STREAM, fileno=fd).close() else: os.close(fd) diff --git a/Lib/test/ssl_servers.py b/Lib/test/ssl_servers.py index a4bd7455d47e76..d3b0f4315573dd 100644 --- a/Lib/test/ssl_servers.py +++ b/Lib/test/ssl_servers.py @@ -20,6 +20,8 @@ class HTTPSServer(_HTTPServer): + address_family = socket_helper.get_family() + def __init__(self, server_address, handler_class, context): _HTTPServer.__init__(self, server_address, handler_class) self.context = context diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index e78712b74b1377..c3e430f8ec1fed 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -12,7 +12,7 @@ HOSTv6 = "::1" -def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): +def find_unused_port(family=None, socktype=socket.SOCK_STREAM): """Returns an unused port that should be suitable for binding. This is achieved by creating a temporary socket with the same family and type as the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to @@ -20,17 +20,20 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): eliciting an unused ephemeral port from the OS. The temporary socket is then closed and deleted, and the ephemeral port is returned. + When family is None, we use to the result of get_family() instead. + Either this method or bind_port() should be used for any tests where a server socket needs to be bound to a particular port for the duration of the test. Which one to use depends on whether the calling code is creating a python socket, or if an unused port needs to be provided in a constructor or passed to an external program (i.e. the -accept argument to openssl's - s_server mode). Always prefer bind_port() over find_unused_port() where - possible. Hard coded ports should *NEVER* be used. As soon as a server - socket is bound to a hard coded port, the ability to run multiple instances - of the test simultaneously on the same host is compromised, which makes the - test a ticking time bomb in a buildbot environment. On Unix buildbots, this - may simply manifest as a failed test, which can be recovered from without + s_server mode). Always prefer bind_port(), bind_ip_socket_and_port(), + and get_bound_ip_socket_and_port() over find_unused_port() where possible. + Hard coded ports should *NEVER* be used. As soon as a server socket is + bound to a hard coded port, the ability to run multiple instances of the + test simultaneously on the same host is compromised, which makes the test a + ticking time bomb in a buildbot environment. On Unix buildbots, this may + simply manifest as a failed test, which can be recovered from without intervention in most cases, but on Windows, the entire python process can completely and utterly wedge, requiring someone to log in to the buildbot and manually kill the affected process. @@ -66,19 +69,26 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): other process when we close and delete our temporary socket but before our calling code has a chance to bind the returned port. We can deal with this issue if/when we come across it. + + TODO(gpshead): We should support a https://pypi.org/project/portpicker/ + portserver or equivalent running on our buildbot workers and use that + that for more reliability at avoiding conflicts between parallel tests. """ + if family is None: + family = get_family() with socket.socket(family, socktype) as tempsock: port = bind_port(tempsock) del tempsock return port + def bind_port(sock, host=HOST): """Bind the socket to a free port and return the port number. Relies on ephemeral ports in order to ensure we are using an unbound port. This is important as many tests may be running simultaneously, especially in a buildbot environment. This method raises an exception if the sock.family - is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + is AF_INET* and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR or SO_REUSEPORT set on it. Tests should *never* set these socket options for TCP/IP sockets. The only case for setting these options is testing multicasting via multiple UDP sockets. @@ -88,7 +98,8 @@ def bind_port(sock, host=HOST): from bind()'ing to our host/port for the duration of the test. """ - if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if (sock.family in {socket.AF_INET, socket.AF_INET6} and + sock.type == socket.SOCK_STREAM): if hasattr(socket, 'SO_REUSEADDR'): if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: raise support.TestFailed("tests should never set the " @@ -112,6 +123,60 @@ def bind_port(sock, host=HOST): port = sock.getsockname()[1] return port + +def get_family(): + """Get a host appropriate socket AF_INET or AF_INET6 family.""" + if IPV4_ENABLED: + return socket.AF_INET + if IPV6_ENABLED: + return socket.AF_INET6 + raise unittest.SkipTest('Neither IPv4 or IPv6 is enabled.') + + +def tcp_socket(): + """Get a new host appropriate IPv4 or IPv6 TCP STREAM socket.socket().""" + return socket.socket(get_family(), socket.SOCK_STREAM) + + +def udp_socket(proto=-1): + """Get a new host appropriate IPv4 or IPv6 UDP DGRAM socket.socket().""" + return socket.socket(get_family(), socket.SOCK_DGRAM, proto) + + +def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): + """Get an IP socket bound to a port as a sock, port tuple. + + Creates a socket of socktype bound to hostname using whichever of IPv6 or + IPv4 is available. Context is a (socket, port) tuple. Exiting the context + closes the socket. + + Prefer the bind_ip_socket_and_port context manager within a test method. + """ + family = get_family() + sock = socket.socket(family, socktype) + try: + port = bind_port(sock) + except support.TestFailed: + sock.close() + raise + return sock, port + + +@contextlib.contextmanager +def bind_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): + """A context manager that creates a socket of socktype. + + It uses whichever of IPv6 or IPv4 is available based on get_family(). + Context is a (socket, port) tuple. The socket is closed on context exit. + """ + sock, port = get_bound_ip_socket_and_port( + hostname=hostname, socktype=socktype) + try: + yield sock, port + finally: + sock.close() + + def bind_unix_socket(sock, addr): """Bind a unix socket, raising SkipTest if PermissionError is raised.""" assert sock.family == socket.AF_UNIX @@ -139,6 +204,22 @@ def _is_ipv6_enabled(): IPV6_ENABLED = _is_ipv6_enabled() +def _is_ipv4_enabled(): + """Check whether IPv4 is enabled on this host.""" + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind((HOSTv4, 0)) + return True + except OSError: + return False + finally: + if sock: + sock.close() + +IPV4_ENABLED = _is_ipv4_enabled() + + _bind_nix_socket_error = None def skip_unless_bind_unix_socket(test): """Decorator for tests requiring a functional bind() for unix sockets.""" diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py index b32edddc7d5505..bafab802a5593e 100644 --- a/Lib/test/test_asynchat.py +++ b/Lib/test/test_asynchat.py @@ -26,8 +26,7 @@ class echo_server(threading.Thread): def __init__(self, event): threading.Thread.__init__(self) self.event = event - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = socket_helper.bind_port(self.sock) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() # This will be set if the client wants us to wait before echoing # data back. self.start_resend_event = None @@ -69,7 +68,7 @@ class echo_client(asynchat.async_chat): def __init__(self, terminator, server_port): asynchat.async_chat.__init__(self) self.contents = [] - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.create_socket(socket_helper.get_family(), socket.SOCK_STREAM) self.connect((HOST, server_port)) self.set_terminator(terminator) self.buffer = b"" diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 6eaa2899442184..b8a49699b4b956 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -56,7 +56,7 @@ def _basetest_open_connection(self, open_connection_fut): def test_open_connection(self): with test_utils.run_test_server() as httpd: - conn_fut = asyncio.open_connection(*httpd.address) + conn_fut = asyncio.open_connection(*httpd.address[:2]) self._basetest_open_connection(conn_fut) @socket_helper.skip_unless_bind_unix_socket @@ -84,8 +84,8 @@ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: conn_fut = asyncio.open_connection( - *httpd.address, - ssl=test_utils.dummy_ssl_context()) + *httpd.address[:2], + ssl=test_utils.dummy_ssl_context()) self._basetest_open_connection_no_loop_ssl(conn_fut) @@ -115,7 +115,7 @@ def _basetest_open_connection_error(self, open_connection_fut): def test_open_connection_error(self): with test_utils.run_test_server() as httpd: - conn_fut = asyncio.open_connection(*httpd.address) + conn_fut = asyncio.open_connection(*httpd.address[:2]) self._basetest_open_connection_error(conn_fut) @socket_helper.skip_unless_bind_unix_socket @@ -582,19 +582,23 @@ async def handle_client(self, client_reader, client_writer): await client_writer.wait_closed() def start(self): - sock = socket.create_server(('127.0.0.1', 0)) + sock = socket.create_server( + (socket_helper.HOST, 0), + family=socket_helper.get_family()) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, sock=sock)) - return sock.getsockname() + return sock.getsockname()[:2] def handle_client_callback(self, client_reader, client_writer): self.loop.create_task(self.handle_client(client_reader, client_writer)) def start_callback(self): - sock = socket.create_server(('127.0.0.1', 0)) - addr = sock.getsockname() + sock = socket.create_server( + (socket_helper.HOST, 0), + family=socket_helper.get_family()) + addr = sock.getsockname()[:2] sock.close() self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client_callback, @@ -815,8 +819,10 @@ def test_drain_raises(self): def server(): # Runs in a separate thread. - with socket.create_server(('localhost', 0)) as sock: - addr = sock.getsockname() + with socket_helper.bind_ip_socket_and_port() as sock_port: + sock = sock_port[0] + sock.listen() + addr = sock.getsockname()[:2] q.put(addr) clt, _ = sock.accept() clt.close() @@ -907,7 +913,7 @@ def test_LimitOverrunError_pickleable(self): def test_wait_closed_on_close(self): with test_utils.run_test_server() as httpd: rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address)) + asyncio.open_connection(*httpd.address[:2])) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -924,7 +930,7 @@ def test_wait_closed_on_close(self): def test_wait_closed_on_close_with_unread_data(self): with test_utils.run_test_server() as httpd: rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address)) + asyncio.open_connection(*httpd.address[:2])) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -935,7 +941,7 @@ def test_wait_closed_on_close_with_unread_data(self): def test_async_writer_api(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + rd, wr = await asyncio.open_connection(*httpd.address[:2]) wr.write(b'GET / HTTP/1.0\r\n\r\n') data = await rd.readline() @@ -955,7 +961,7 @@ async def inner(httpd): def test_async_writer_api_exception_after_close(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + rd, wr = await asyncio.open_connection(*httpd.address[:2]) wr.write(b'GET / HTTP/1.0\r\n\r\n') data = await rd.readline() @@ -982,7 +988,7 @@ def test_eof_feed_when_closing_writer(self): with test_utils.run_test_server() as httpd: rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address)) + asyncio.open_connection(*httpd.address[:2])) wr.close() f = wr.wait_closed() diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py index 3765194cd0dd27..9b5e4ceb0e05a3 100644 --- a/Lib/test/test_asyncio/utils.py +++ b/Lib/test/test_asyncio/utils.py @@ -18,6 +18,7 @@ import weakref from unittest import mock +from test.support import socket_helper from http.server import HTTPServer from wsgiref.simple_server import WSGIRequestHandler, WSGIServer @@ -140,6 +141,7 @@ def log_message(self, format, *args): class SilentWSGIServer(WSGIServer): + address_family = socket_helper.get_family() request_timeout = support.LOOPBACK_TIMEOUT def get_request(self): @@ -215,7 +217,7 @@ class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): def server_bind(self): socketserver.UnixStreamServer.server_bind(self) - self.server_name = '127.0.0.1' + self.server_name = socket_helper.HOST self.server_port = 80 @@ -236,7 +238,7 @@ def get_request(self): # as the second return value will be a path; # hence we return some fake data sufficient # to get the tests going - return request, ('127.0.0.1', '') + return request, (socket_helper.HOST, '') class SilentUnixWSGIServer(UnixWSGIServer): @@ -275,7 +277,7 @@ def run_test_unix_server(*, use_ssl=False): @contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): +def run_test_server(*, host=socket_helper.HOST, port=0, use_ssl=False): yield from _run_test_server(address=(host, port), use_ssl=use_ssl, server_cls=SilentWSGIServer, server_ssl_cls=SSLWSGIServer) diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py index 3bd904d1774bc3..47a8f8da057c6e 100644 --- a/Lib/test/test_asyncore.py +++ b/Lib/test/test_asyncore.py @@ -329,9 +329,8 @@ def tearDown(self): @threading_helper.reap_threads def test_send(self): evt = threading.Event() - sock = socket.socket() + sock, port = socket_helper.get_bound_ip_socket_and_port() sock.settimeout(3) - port = socket_helper.bind_port(sock) cap = BytesIO() args = (evt, cap, sock) @@ -344,7 +343,7 @@ def test_send(self): data = b"Suppose there isn't a 16-ton weight?" d = dispatcherwithsend_noread() - d.create_socket() + d.create_socket(family=sock.family) d.connect((socket_helper.HOST, port)) # give time for socket to connect @@ -793,6 +792,7 @@ def test_quick_connect(self): finally: threading_helper.join_thread(t) +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 support required') class TestAPI_UseIPv4Sockets(BaseTestAPI): family = socket.AF_INET addr = (socket_helper.HOST, 0) diff --git a/Lib/test/test_docxmlrpc.py b/Lib/test/test_docxmlrpc.py index 7d3e30cbee964a..f5ff914701b590 100644 --- a/Lib/test/test_docxmlrpc.py +++ b/Lib/test/test_docxmlrpc.py @@ -1,7 +1,9 @@ from xmlrpc.server import DocXMLRPCServer import http.client import re +import socket import sys +from test.support import socket_helper import threading import unittest @@ -20,7 +22,14 @@ def make_request_and_skip(self): def make_server(): - serv = DocXMLRPCServer(("localhost", 0), logRequests=False) + try: + serv = DocXMLRPCServer((socket_helper.HOST, 0), logRequests=False) + except OSError: + if not socket_helper.IPV6_ENABLED: + raise + class IPv6DocXMLRPCServer(DocXMLRPCServer): + address_family = socket.AF_INET6 + serv = IPv6DocXMLRPCServer((socket_helper.HOST, 0), logRequests=False) try: # Add some documentation @@ -74,7 +83,7 @@ def setUp(self): self.thread.start() PORT = self.serv.server_address[1] - self.client = http.client.HTTPConnection("localhost:%d" % PORT) + self.client = http.client.HTTPConnection(f"{socket_helper.HOST}:{PORT}") def tearDown(self): self.client.close() diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py index b623852f9eb4ee..f035e6e754e727 100644 --- a/Lib/test/test_epoll.py +++ b/Lib/test/test_epoll.py @@ -25,6 +25,7 @@ import os import select import socket +from test.support import socket_helper import time import unittest @@ -41,7 +42,8 @@ class TestEPoll(unittest.TestCase): def setUp(self): - self.serverSocket = socket.create_server(('127.0.0.1', 0)) + self.serverSocket, _ = socket_helper.get_bound_ip_socket_and_port() + self.serverSocket.listen() self.connections = [self.serverSocket] def tearDown(self): @@ -49,10 +51,10 @@ def tearDown(self): skt.close() def _connected_pair(self): - client = socket.socket() + client = socket.socket(self.serverSocket.family) client.setblocking(False) try: - client.connect(('127.0.0.1', self.serverSocket.getsockname()[1])) + client.connect((socket_helper.HOST, self.serverSocket.getsockname()[1])) except OSError as e: self.assertEqual(e.args[0], errno.EINPROGRESS) else: diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index a48b429ca38027..93e64b1ec27841 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -265,9 +265,15 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): handler = DummyFTPHandler - def __init__(self, address, af=socket.AF_INET, encoding=DEFAULT_ENCODING): + def __init__(self, address, af=None, encoding=DEFAULT_ENCODING): threading.Thread.__init__(self) asyncore.dispatcher.__init__(self) + if af is None and address[0] == socket_helper.HOST: + if socket_helper.IPV4_ENABLED: + af = socket.AF_INET + else: + assert socket_helper.IPV6_ENABLED, 'no IPv4 or IPv6?' + af = socket.AF_INET6 self.daemon = True self.create_socket(af, socket.SOCK_STREAM) self.bind(address) @@ -699,19 +705,22 @@ def test_entry(line, type=None, perm=None, unique=None, name=None): for x in self.client.mlsd(): self.fail("unexpected data %s" % x) - def test_makeport(self): + @skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") + def test_makeport_ipv4(self): with self.client.makeport(): # IPv4 is in use, just make sure send_eprt has not been used self.assertEqual(self.server.handler_instance.last_received_cmd, - 'port') + 'port') - def test_makepasv(self): + @skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") + def test_makepasv_ipv4(self): host, port = self.client.makepasv() conn = socket.create_connection((host, port), timeout=TIMEOUT) conn.close() # IPv4 is in use, just make sure send_epsv has not been used self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv') + @skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") def test_makepasv_issue43285_security_disabled(self): """Test the opt-in to the old vulnerable behavior.""" self.client.trust_server_pasv_ipv4_address = True @@ -1033,9 +1042,8 @@ class TestTimeouts(TestCase): def setUp(self): self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() self.sock.settimeout(20) - self.port = socket_helper.bind_port(self.sock) self.server_thread = threading.Thread(target=self.server) self.server_thread.daemon = True self.server_thread.start() diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index e9272569ecc531..790d8feb842a87 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -1331,7 +1331,8 @@ def test_read1_bound_content_length(self): def test_response_fileno(self): # Make sure fd returned by fileno is valid. - serv = socket.create_server((HOST, 0)) + serv = socket_helper.get_bound_ip_socket_and_port()[0] + serv.listen() self.addCleanup(serv.close) result = None @@ -1350,7 +1351,7 @@ def run_server(): thread = threading.Thread(target=run_server) thread.start() self.addCleanup(thread.join, float(1)) - conn = client.HTTPConnection(*serv.getsockname()) + conn = client.HTTPConnection(*serv.getsockname()[:2]) conn.request("CONNECT", "dummy:1234") response = conn.getresponse() try: @@ -1673,9 +1674,8 @@ def test_client_constants(self): class SourceAddressTest(TestCase): def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = socket_helper.bind_port(self.serv) - self.source_port = socket_helper.find_unused_port() + self.serv, self.port = socket_helper.get_bound_ip_socket_and_port() + self.source_port = socket_helper.find_unused_port(family=self.serv.family) self.serv.listen() self.conn = None @@ -1706,8 +1706,8 @@ class TimeoutTest(TestCase): PORT = None def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - TimeoutTest.PORT = socket_helper.bind_port(self.serv) + self.serv = socket_helper.tcp_socket() + self.PORT = socket_helper.bind_port(self.serv) self.serv.listen() def tearDown(self): @@ -1722,7 +1722,7 @@ def testTimeoutAttribute(self): self.assertIsNone(socket.getdefaulttimeout()) socket.setdefaulttimeout(30) try: - httpConn = client.HTTPConnection(HOST, TimeoutTest.PORT) + httpConn = client.HTTPConnection(HOST, self.PORT) httpConn.connect() finally: socket.setdefaulttimeout(None) @@ -1733,7 +1733,7 @@ def testTimeoutAttribute(self): self.assertIsNone(socket.getdefaulttimeout()) socket.setdefaulttimeout(30) try: - httpConn = client.HTTPConnection(HOST, TimeoutTest.PORT, + httpConn = client.HTTPConnection(HOST, self.PORT, timeout=None) httpConn.connect() finally: @@ -1742,7 +1742,7 @@ def testTimeoutAttribute(self): httpConn.close() # a value - httpConn = client.HTTPConnection(HOST, TimeoutTest.PORT, timeout=30) + httpConn = client.HTTPConnection(HOST, self.PORT, timeout=30) httpConn.connect() self.assertEqual(httpConn.sock.gettimeout(), 30) httpConn.close() diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index cb0a3aa9e40451..eceb364a607d82 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -31,6 +31,7 @@ import unittest from test import support from test.support import os_helper +from test.support import socket_helper from test.support import threading_helper @@ -42,6 +43,9 @@ def log_message(self, *args): def read(self, n=None): return '' +class IPvWhateverHTTPServer(HTTPServer): + address_family = socket_helper.get_family() + class TestServerThread(threading.Thread): def __init__(self, test_object, request_handler): @@ -50,8 +54,8 @@ def __init__(self, test_object, request_handler): self.test_object = test_object def run(self): - self.server = HTTPServer(('localhost', 0), self.request_handler) - self.test_object.HOST, self.test_object.PORT = self.server.socket.getsockname() + self.server = IPvWhateverHTTPServer(('localhost', 0), self.request_handler) + self.test_object.HOST, self.test_object.PORT = self.server.socket.getsockname()[:2] self.test_object.server_started.set() self.test_object = None try: diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py index c2b935f58164e5..8fb3403f0b41d3 100644 --- a/Lib/test/test_imaplib.py +++ b/Lib/test/test_imaplib.py @@ -13,6 +13,7 @@ from test.support import (verbose, run_with_tz, run_with_locale, cpython_only) from test.support import hashlib_helper +from test.support import socket_helper from test.support import threading_helper from test.support import warnings_helper import unittest @@ -27,6 +28,15 @@ CAFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "pycacert.pem") +if socket_helper.IPV4_ENABLED: + TCPServer = socketserver.TCPServer +elif socket_helper.IPV6_ENABLED: + class TCPServer(socketserver.TCPServer): + address_family = socket.AF_INET6 +else: + raise unittest.SkipTest('IPv4 or IPv6 required.') + + class TestImaplib(unittest.TestCase): def test_Internaldate2tuple(self): @@ -92,7 +102,7 @@ def test_imap4_host_default_value(self): if ssl: - class SecureTCPServer(socketserver.TCPServer): + class SecureTCPServer(TCPServer): def get_request(self): newsocket, fromaddr = self.socket.accept() @@ -238,7 +248,7 @@ def handle_error(self, request, client_address): self.thread.start() if connect: - self.client = self.imap_class(*self.server.server_address) + self.client = self.imap_class(*self.server.server_address[:2]) return self.client, self.server @@ -265,7 +275,7 @@ def handle(self): self.wfile.write(b'* OK') _, server = self._setup(EOFHandler, connect=False) self.assertRaises(imaplib.IMAP4.abort, self.imap_class, - *server.server_address) + *server.server_address[:2]) def test_line_termination(self): class BadNewlineHandler(SimpleIMAPHandler): @@ -274,7 +284,7 @@ def cmd_CAPABILITY(self, tag, args): self._send_tagged(tag, 'OK', 'CAPABILITY completed') _, server = self._setup(BadNewlineHandler, connect=False) self.assertRaises(imaplib.IMAP4.abort, self.imap_class, - *server.server_address) + *server.server_address[:2]) def test_enable_raises_error_if_not_AUTH(self): class EnableHandler(SimpleIMAPHandler): @@ -449,11 +459,11 @@ def handle(self): _, server = self._setup(TooLongHandler, connect=False) with self.assertRaisesRegex(imaplib.IMAP4.error, 'got more than 10 bytes'): - self.imap_class(*server.server_address) + self.imap_class(*server.server_address[:2]) def test_simple_with_statement(self): _, server = self._setup(SimpleIMAPHandler, connect=False) - with self.imap_class(*server.server_address): + with self.imap_class(*server.server_address[:2]): pass def test_imaplib_timeout_test(self): @@ -481,7 +491,7 @@ def handle(self): def test_with_statement(self): _, server = self._setup(SimpleIMAPHandler, connect=False) - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: imap.login('user', 'pass') self.assertEqual(server.logged, 'user') self.assertIsNone(server.logged) @@ -489,7 +499,7 @@ def test_with_statement(self): def test_with_statement_logout(self): # It is legal to log out explicitly inside the with block _, server = self._setup(SimpleIMAPHandler, connect=False) - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: imap.login('user', 'pass') self.assertEqual(server.logged, 'user') imap.logout() @@ -541,7 +551,7 @@ def test_unselect(self): class NewIMAPTests(NewIMAPTestsMixin, unittest.TestCase): imap_class = imaplib.IMAP4 - server_class = socketserver.TCPServer + server_class = TCPServer @unittest.skipUnless(ssl, "SSL not available") @@ -557,9 +567,9 @@ def test_ssl_raises(self): with self.assertRaisesRegex(ssl.CertificateError, "IP address mismatch, certificate is not valid for " - "'127.0.0.1'"): + f"'({socket_helper.HOSTv4}|{socket_helper.HOSTv6})'"): _, server = self._setup(SimpleIMAPHandler) - client = self.imap_class(*server.server_address, + client = self.imap_class(*server.server_address[:2], ssl_context=ssl_context) client.shutdown() @@ -582,7 +592,7 @@ def test_certfile_arg_warn(self): self.imap_class('localhost', 143, certfile=CERTFILE) class ThreadedNetworkedTests(unittest.TestCase): - server_class = socketserver.TCPServer + server_class = TCPServer imap_class = imaplib.IMAP4 def make_server(self, addr, hdlr): @@ -637,7 +647,7 @@ def reaped_server(self, hdlr): @contextmanager def reaped_pair(self, hdlr): with self.reaped_server(hdlr) as server: - client = self.imap_class(*server.server_address) + client = self.imap_class(*server.server_address[:2]) try: yield server, client finally: @@ -646,7 +656,7 @@ def reaped_pair(self, hdlr): @threading_helper.reap_threads def test_connect(self): with self.reaped_server(SimpleIMAPHandler) as server: - client = self.imap_class(*server.server_address) + client = self.imap_class(*server.server_address[:2]) client.shutdown() @threading_helper.reap_threads @@ -708,7 +718,7 @@ def handle(self): with self.reaped_server(EOFHandler) as server: self.assertRaises(imaplib.IMAP4.abort, - self.imap_class, *server.server_address) + self.imap_class, *server.server_address[:2]) @threading_helper.reap_threads def test_line_termination(self): @@ -721,7 +731,7 @@ def cmd_CAPABILITY(self, tag, args): with self.reaped_server(BadNewlineHandler) as server: self.assertRaises(imaplib.IMAP4.abort, - self.imap_class, *server.server_address) + self.imap_class, *server.server_address[:2]) class UTF8Server(SimpleIMAPHandler): capabilities = 'AUTH ENABLE UTF8=ACCEPT' @@ -906,19 +916,19 @@ def handle(self): with self.reaped_server(TooLongHandler) as server: self.assertRaises(imaplib.IMAP4.error, - self.imap_class, *server.server_address) + self.imap_class, *server.server_address[:2]) @threading_helper.reap_threads def test_simple_with_statement(self): # simplest call with self.reaped_server(SimpleIMAPHandler) as server: - with self.imap_class(*server.server_address): + with self.imap_class(*server.server_address[:2]): pass @threading_helper.reap_threads def test_with_statement(self): with self.reaped_server(SimpleIMAPHandler) as server: - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: imap.login('user', 'pass') self.assertEqual(server.logged, 'user') self.assertIsNone(server.logged) @@ -927,7 +937,7 @@ def test_with_statement(self): def test_with_statement_logout(self): # what happens if already logout in the block? with self.reaped_server(SimpleIMAPHandler) as server: - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: imap.login('user', 'pass') self.assertEqual(server.logged, 'user') imap.logout() @@ -941,7 +951,7 @@ def test_dump_ur(self): untagged_resp_dict = {'READ-WRITE': [b'']} with self.reaped_server(SimpleIMAPHandler) as server: - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: with mock.patch.object(imap, '_mesg') as mock_mesg: imap._dump_ur(untagged_resp_dict) mock_mesg.assert_called_with( @@ -962,9 +972,9 @@ def test_ssl_verified(self): with self.assertRaisesRegex( ssl.CertificateError, "IP address mismatch, certificate is not valid for " - "'127.0.0.1'"): + f"'({socket_helper.HOSTv4}|{socket_helper.HOSTv6})'"): with self.reaped_server(SimpleIMAPHandler) as server: - client = self.imap_class(*server.server_address, + client = self.imap_class(*server.server_address[:2], ssl_context=ssl_context) client.shutdown() diff --git a/Lib/test/test_largefile.py b/Lib/test/test_largefile.py index 8f6bec16200534..ce2920dc6702f3 100644 --- a/Lib/test/test_largefile.py +++ b/Lib/test/test_largefile.py @@ -221,10 +221,11 @@ def run(sock): # bit more tolerance. @skip_no_disk_space(TESTFN, size * 2.5) def test_it(self): - port = socket_helper.find_unused_port() - with socket.create_server(("", port)) as sock: + with socket_helper.bind_ip_socket_and_port() as sock_port: + sock, port = sock_port + sock.listen() self.tcp_server(sock) - with socket.create_connection(("127.0.0.1", port)) as client: + with socket.create_connection((socket_helper.HOST, port)) as client: with open(TESTFN, 'rb') as f: client.sendfile(f) self.tearDown() diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index ee00a32026f65e..21323c722db925 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -823,6 +823,8 @@ class TestSMTPServer(smtpd.SMTPServer): :mod:`asyncore` module's global state. """ + address_family = socket_helper.get_family() + def __init__(self, addr, handler, poll_interval, sockmap): smtpd.SMTPServer.__init__(self, addr, None, map=sockmap, decode_data=True) @@ -937,6 +939,9 @@ class TestHTTPServer(ControlMixin, HTTPServer): :param poll_interval: The polling interval in seconds. :param log: Pass ``True`` to enable log messages. """ + + address_family = socket_helper.get_family() + def __init__(self, addr, handler, poll_interval=0.5, log=False, sslctx=None): class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler): @@ -3231,9 +3236,9 @@ def setup_via_listener(self, text, verify=None): port = t.port t.ready.clear() try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket_helper.tcp_socket() sock.settimeout(2.0) - sock.connect(('localhost', port)) + sock.connect((socket_helper.HOST, port)) slen = struct.pack('>L', len(text)) s = slen + text diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py index 4f0592188f8443..f565cc9ba27f2f 100644 --- a/Lib/test/test_nntplib.py +++ b/Lib/test/test_nntplib.py @@ -1585,8 +1585,8 @@ def nntp_class(*pos, **kw): class LocalServerTests(unittest.TestCase): def setUp(self): - sock = socket.socket() - port = socket_helper.bind_port(sock) + sock, port = socket_helper.get_bound_ip_socket_and_port() + self.addCleanup(sock.close) sock.listen() self.background = threading.Thread( target=self.run_server, args=(sock,)) diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 8b3d1feb78fe36..74e486efef8fa6 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -3234,7 +3234,11 @@ def handle_error(self): def __init__(self, address): threading.Thread.__init__(self) asyncore.dispatcher.__init__(self) - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + if socket_helper.IPV4_ENABLED: + family = socket.AF_INET + elif socket_helper.IPV6_ENABLED: + family = socket.AF_INET6 + self.create_socket(family, socket.SOCK_STREAM) self.bind(address) self.listen(5) self.host, self.port = self.socket.getsockname()[:2] @@ -3316,7 +3320,7 @@ def tearDownClass(cls): def setUp(self): self.server = SendfileTestServer((socket_helper.HOST, 0)) self.server.start() - self.client = socket.socket() + self.client = socket.socket(self.server.socket.family) self.client.connect((self.server.host, self.server.port)) self.client.settimeout(1) # synchronize by waiting for "220 ready" response diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py index c5ae9f77e4f006..b538c9ec6651d4 100644 --- a/Lib/test/test_poplib.py +++ b/Lib/test/test_poplib.py @@ -204,7 +204,7 @@ class DummyPOP3Server(asyncore.dispatcher, threading.Thread): handler = DummyPOP3Handler - def __init__(self, address, af=socket.AF_INET): + def __init__(self, address, af=socket_helper.get_family()): threading.Thread.__init__(self) asyncore.dispatcher.__init__(self) self.daemon = True @@ -481,9 +481,8 @@ class TestTimeouts(TestCase): def setUp(self): self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() self.sock.settimeout(60) # Safety net. Look issue 11812 - self.port = socket_helper.bind_port(self.sock) self.thread = threading.Thread(target=self.server, args=(self.evt, self.sock)) self.thread.daemon = True self.thread.start() diff --git a/Lib/test/test_robotparser.py b/Lib/test/test_robotparser.py index b0bed431d4b059..b877092ae15442 100644 --- a/Lib/test/test_robotparser.py +++ b/Lib/test/test_robotparser.py @@ -1,5 +1,6 @@ import io import os +import socket import threading import unittest import urllib.robotparser @@ -314,7 +315,14 @@ def setUp(self): # clear _opener global variable self.addCleanup(urllib.request.urlcleanup) - self.server = HTTPServer((socket_helper.HOST, 0), RobotHandler) + try: + self.server = HTTPServer((socket_helper.HOST, 0), RobotHandler) + except OSError: + if not socket_helper.IPV6_ENABLED: + raise + class IPv6HTTPServer(HTTPServer): + address_family = socket.AF_INET6 + self.server = IPv6HTTPServer((socket_helper.HOST, 0), RobotHandler) self.t = threading.Thread( name='HTTPServer serving', diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index f3d33ab0772dd3..37ace956c9a152 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -283,7 +283,7 @@ def testBasic(self): def testSourceAddress(self): # connect - src_port = socket_helper.find_unused_port() + src_port = socket_helper.find_unused_port(family=self.serv.socket.family) try: smtp = smtplib.SMTP(self.host, self.port, local_hostname='localhost', timeout=support.LOOPBACK_TIMEOUT, @@ -721,9 +721,8 @@ def setUp(self): sys.stdout = self.output self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() self.sock.settimeout(15) - self.port = socket_helper.bind_port(self.sock) servargs = (self.evt, self.respdata, self.sock) self.thread = threading.Thread(target=server, args=servargs) self.thread.start() @@ -739,8 +738,8 @@ def tearDown(self): threading_helper.threading_cleanup(*self.thread_key) def testLineTooLong(self): - self.assertRaises(smtplib.SMTPResponseException, smtplib.SMTP, - HOST, self.port, 'localhost', 3) + with self.assertRaises(smtplib.SMTPResponseException): + smtplib.SMTP(HOST, self.port, 'localhost', 3) sim_users = {'Mr.A@somewhere.com':'John A', diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 828d1f3dcc6701..949105783ac5f8 100755 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -175,8 +175,8 @@ def socket_setdefaulttimeout(timeout): class SocketTCPTest(unittest.TestCase): def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = socket_helper.bind_port(self.serv) + self.serv, self.port = socket_helper.get_bound_ip_socket_and_port( + socktype=socket.SOCK_STREAM) self.serv.listen() def tearDown(self): @@ -186,8 +186,8 @@ def tearDown(self): class SocketUDPTest(unittest.TestCase): def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.port = socket_helper.bind_port(self.serv) + self.serv, self.port = socket_helper.get_bound_ip_socket_and_port( + socktype=socket.SOCK_DGRAM) def tearDown(self): self.serv.close() @@ -196,7 +196,7 @@ def tearDown(self): class SocketUDPLITETest(SocketUDPTest): def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) + self.serv = socket_helper.udp_socket(socket.IPPROTO_UDPLITE) self.port = socket_helper.bind_port(self.serv) class ThreadSafeCleanupTestCase(unittest.TestCase): @@ -409,7 +409,7 @@ def __init__(self, methodName='runTest'): ThreadableTest.__init__(self) def clientSetUp(self): - self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.cli = socket.socket(socket_helper.get_family(), socket.SOCK_STREAM) def clientTearDown(self): self.cli.close() @@ -423,7 +423,7 @@ def __init__(self, methodName='runTest'): ThreadableTest.__init__(self) def clientSetUp(self): - self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.cli = socket.socket(socket_helper.get_family(), socket.SOCK_DGRAM) def clientTearDown(self): self.cli.close() @@ -720,23 +720,39 @@ def bindSock(self, sock): socket_helper.bind_port(sock, host=self.host) class TCPTestBase(InetTestBase): + """Base class for TCP tests.""" + + def newSocket(self): + return socket_helper.tcp_socket() + +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') +class TCP4TestBase(InetTestBase): """Base class for TCP-over-IPv4 tests.""" def newSocket(self): return socket.socket(socket.AF_INET, socket.SOCK_STREAM) class UDPTestBase(InetTestBase): + """Base class for UDP tests.""" + + def newSocket(self): + return socket_helper.udp_socket() + +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') +class UDP4TestBase(InetTestBase): """Base class for UDP-over-IPv4 tests.""" def newSocket(self): return socket.socket(socket.AF_INET, socket.SOCK_DGRAM) -class UDPLITETestBase(InetTestBase): +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') +class UDPLITE4TestBase(InetTestBase): """Base class for UDPLITE-over-IPv4 tests.""" def newSocket(self): return socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') class SCTPStreamBase(InetTestBase): """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode.""" @@ -839,14 +855,15 @@ def test_SocketType_is_socketobject(self): s.close() def test_repr(self): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + family = socket_helper.get_family() + s = socket.socket(family, socket.SOCK_STREAM) with s: self.assertIn('fd=%i' % s.fileno(), repr(s)) - self.assertIn('family=%s' % socket.AF_INET, repr(s)) + self.assertIn('family=%s' % family, repr(s)) self.assertIn('type=%s' % socket.SOCK_STREAM, repr(s)) self.assertIn('proto=0', repr(s)) self.assertNotIn('raddr', repr(s)) - s.bind(('127.0.0.1', 0)) + s.bind((socket_helper.HOST, 0)) self.assertIn('laddr', repr(s)) self.assertIn(str(s.getsockname()), repr(s)) self.assertIn('[closed]', repr(s)) @@ -854,7 +871,8 @@ def test_repr(self): @unittest.skipUnless(_socket is not None, 'need _socket module') def test_csocket_repr(self): - s = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) + family = socket_helper.get_family() + s = _socket.socket(family, _socket.SOCK_STREAM) try: expected = ('' % (s.fileno(), s.family, s.type, s.proto)) @@ -866,7 +884,7 @@ def test_csocket_repr(self): self.assertEqual(repr(s), expected) def test_weakref(self): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + with socket_helper.tcp_socket() as s: p = proxy(s) self.assertEqual(p.fileno(), s.fileno()) s = None @@ -889,10 +907,10 @@ def testSocketError(self): def testSendtoErrors(self): # Testing that sendto doesn't mask failures. See #10169. - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s = socket_helper.udp_socket() self.addCleanup(s.close) s.bind(('', 0)) - sockname = s.getsockname() + sockname = s.getsockname()[:2] # 2 args with self.assertRaises(TypeError) as cm: s.sendto('\u2620', sockname) @@ -1009,6 +1027,7 @@ def testHostnameRes(self): if not fqhn in all_host_names: self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names))) + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def test_host_resolution(self): for addr in [socket_helper.HOSTv4, '10.0.0.1', '255.255.255.255']: self.assertEqual(socket.gethostbyname(addr), addr) @@ -1375,9 +1394,10 @@ def testStringToIPv6(self): # XXX The following don't test module-level functionality... + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def testSockName(self): # Testing getsockname() - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.addCleanup(sock.close) sock.bind(("0.0.0.0", port)) @@ -1416,8 +1436,8 @@ def testSendAfterClose(self): self.assertRaises(OSError, sock.send, b"spam") def testCloseException(self): - sock = socket.socket() - sock.bind((socket._LOCALHOST, 0)) + sock = socket_helper.tcp_socket() + sock.bind((socket_helper.HOST, 0)) socket.socket(fileno=sock.fileno()).close() try: sock.close() @@ -1430,8 +1450,9 @@ def testCloseException(self): def testNewAttributes(self): # testing .family, .type and .protocol - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - self.assertEqual(sock.family, socket.AF_INET) + family = socket_helper.get_family() + with socket.socket(family, socket.SOCK_STREAM) as sock: + self.assertEqual(sock.family, family) if hasattr(socket, 'SOCK_CLOEXEC'): self.assertIn(sock.type, (socket.SOCK_STREAM | socket.SOCK_CLOEXEC, @@ -1441,17 +1462,19 @@ def testNewAttributes(self): self.assertEqual(sock.proto, 0) def test_getsockaddrarg(self): - sock = socket.socket() + sock = socket_helper.tcp_socket() self.addCleanup(sock.close) - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=sock.family) big_port = port + 65536 neg_port = port - 65536 - self.assertRaises(OverflowError, sock.bind, (HOST, big_port)) - self.assertRaises(OverflowError, sock.bind, (HOST, neg_port)) + with self.assertRaises(OverflowError): + sock.bind((HOST, big_port)) + with self.assertRaises(OverflowError): + sock.bind((HOST, neg_port)) # Since find_unused_port() is inherently subject to race conditions, we # call it a couple times if necessary. for i in itertools.count(): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=sock.family) try: sock.bind((HOST, port)) except OSError as e: @@ -1488,6 +1511,7 @@ def test_sio_loopback_fast_path(self): raise self.assertRaises(TypeError, s.ioctl, socket.SIO_LOOPBACK_FAST_PATH, None) + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def testGetaddrinfo(self): try: socket.getaddrinfo('localhost', 80) @@ -1625,14 +1649,14 @@ def test_sendall_interrupted_with_timeout(self): self.check_sendall_interrupted(True) def test_dealloc_warn(self): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket_helper.tcp_socket() r = repr(sock) with self.assertWarns(ResourceWarning) as cm: sock = None support.gc_collect() self.assertIn(r, str(cm.warning.args[0])) # An open socket file object gets dereferenced after the socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket_helper.tcp_socket() f = sock.makefile('rb') r = repr(sock) sock = None @@ -1642,13 +1666,13 @@ def test_dealloc_warn(self): support.gc_collect() def test_name_closed_socketio(self): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + with socket_helper.tcp_socket() as sock: fp = sock.makefile("rb") fp.close() self.assertEqual(repr(fp), "<_io.BufferedReader name=-1>") def test_unusable_closed_socketio(self): - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: fp = sock.makefile("rb", buffering=0) self.assertTrue(fp.readable()) self.assertFalse(fp.writable()) @@ -1659,7 +1683,7 @@ def test_unusable_closed_socketio(self): self.assertRaises(ValueError, fp.seekable) def test_socket_close(self): - sock = socket.socket() + sock = socket_helper.tcp_socket() try: sock.bind((HOST, 0)) socket.close(sock.fileno()) @@ -1702,11 +1726,11 @@ def test_pickle(self): def test_listen_backlog(self): for backlog in 0, -1: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + with socket_helper.tcp_socket() as srv: srv.bind((HOST, 0)) srv.listen(backlog) - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + with socket_helper.tcp_socket() as srv: srv.bind((HOST, 0)) srv.listen() @@ -1714,7 +1738,7 @@ def test_listen_backlog(self): def test_listen_backlog_overflow(self): # Issue 15989 import _testcapi - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + with socket_helper.tcp_socket() as srv: srv.bind((HOST, 0)) self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1) @@ -1861,7 +1885,7 @@ def _test_socket_fileno(self, s, family, stype): self.assertEqual(s.type, stype) fd = s.fileno() - s2 = socket.socket(fileno=fd) + s2 = socket.socket(family, fileno=fd) self.addCleanup(s2.close) # detach old fd to avoid double close s.detach() @@ -1869,36 +1893,32 @@ def _test_socket_fileno(self, s, family, stype): self.assertEqual(s2.type, stype) self.assertEqual(s2.fileno(), fd) - def test_socket_fileno(self): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + def test_socket_fileno_tcp(self): + s = socket_helper.tcp_socket() self.addCleanup(s.close) s.bind((socket_helper.HOST, 0)) - self._test_socket_fileno(s, socket.AF_INET, socket.SOCK_STREAM) + self._test_socket_fileno(s, s.family, socket.SOCK_STREAM) - if hasattr(socket, "SOCK_DGRAM"): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.addCleanup(s.close) - s.bind((socket_helper.HOST, 0)) - self._test_socket_fileno(s, socket.AF_INET, socket.SOCK_DGRAM) + @unittest.skipUnless(hasattr(socket, "SOCK_DGRAM"), "SOCK_DGRAM required") + def test_socket_fileno_udp(self): + s = socket_helper.udp_socket() + self.addCleanup(s.close) + s.bind((socket_helper.HOST, 0)) + self._test_socket_fileno(s, s.family, socket.SOCK_DGRAM) - if socket_helper.IPV6_ENABLED: - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - self.addCleanup(s.close) - s.bind((socket_helper.HOSTv6, 0, 0, 0)) - self._test_socket_fileno(s, socket.AF_INET6, socket.SOCK_STREAM) - - if hasattr(socket, "AF_UNIX"): - tmpdir = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, tmpdir) - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.addCleanup(s.close) - try: - s.bind(os.path.join(tmpdir, 'socket')) - except PermissionError: - pass - else: - self._test_socket_fileno(s, socket.AF_UNIX, - socket.SOCK_STREAM) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "AF_UNIX required") + def test_socket_fileno_unix(self): + tmpdir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, tmpdir) + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(s.close) + try: + s.bind(os.path.join(tmpdir, 'socket')) + except PermissionError: + pass + else: + self._test_socket_fileno(s, socket.AF_UNIX, + socket.SOCK_STREAM) def test_socket_fileno_rejects_float(self): with self.assertRaises(TypeError): @@ -2514,7 +2534,7 @@ def _testSendAll(self): def testFromFd(self): # Testing fromfd() fd = self.cli_conn.fileno() - sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) + sock = socket.fromfd(fd, socket_helper.get_family(), socket.SOCK_STREAM) self.addCleanup(sock.close) self.assertIsInstance(sock, socket.socket) msg = sock.recv(1024) @@ -2570,7 +2590,7 @@ def testDetach(self): self.cli_conn.close() # ...but we can create another socket using the (still open) # file descriptor - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=f) + sock = socket.socket(socket_helper.get_family(), socket.SOCK_STREAM, fileno=f) self.addCleanup(sock.close) msg = sock.recv(1024) self.assertEqual(msg, MSG) @@ -4208,21 +4228,21 @@ def _testSecondCmsgTruncInData(self): # Derive concrete test classes for different socket types. -class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase, +class SendrecvmsgUDP4TestBase(SendrecvmsgDgramFlagsBase, SendrecvmsgConnectionlessBase, - ThreadedSocketTestMixin, UDPTestBase): + ThreadedSocketTestMixin, UDP4TestBase): pass @requireAttrs(socket.socket, "sendmsg") -class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase): +class SendmsgUDP4Test(SendmsgConnectionlessTests, SendrecvmsgUDP4TestBase): pass @requireAttrs(socket.socket, "recvmsg") -class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase): +class RecvmsgUDP4Test(RecvmsgTests, SendrecvmsgUDP4TestBase): pass @requireAttrs(socket.socket, "recvmsg_into") -class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase): +class RecvmsgIntoUDP4Test(RecvmsgIntoTests, SendrecvmsgUDP4TestBase): pass @@ -4273,27 +4293,27 @@ class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin, @unittest.skipUnless(HAVE_SOCKET_UDPLITE, 'UDPLITE sockets required for this test.') -class SendrecvmsgUDPLITETestBase(SendrecvmsgDgramFlagsBase, +class SendrecvmsgUDPLITE4TestBase(SendrecvmsgDgramFlagsBase, SendrecvmsgConnectionlessBase, - ThreadedSocketTestMixin, UDPLITETestBase): + ThreadedSocketTestMixin, UDPLITE4TestBase): pass @unittest.skipUnless(HAVE_SOCKET_UDPLITE, 'UDPLITE sockets required for this test.') @requireAttrs(socket.socket, "sendmsg") -class SendmsgUDPLITETest(SendmsgConnectionlessTests, SendrecvmsgUDPLITETestBase): +class SendmsgUDPLITE4Test(SendmsgConnectionlessTests, SendrecvmsgUDPLITE4TestBase): pass @unittest.skipUnless(HAVE_SOCKET_UDPLITE, 'UDPLITE sockets required for this test.') @requireAttrs(socket.socket, "recvmsg") -class RecvmsgUDPLITETest(RecvmsgTests, SendrecvmsgUDPLITETestBase): +class RecvmsgUDPLITE4Test(RecvmsgTests, SendrecvmsgUDPLITE4TestBase): pass @unittest.skipUnless(HAVE_SOCKET_UDPLITE, 'UDPLITE sockets required for this test.') @requireAttrs(socket.socket, "recvmsg_into") -class RecvmsgIntoUDPLITETest(RecvmsgIntoTests, SendrecvmsgUDPLITETestBase): +class RecvmsgIntoUDPLITE4Test(RecvmsgIntoTests, SendrecvmsgUDPLITE4TestBase): pass @@ -4380,13 +4400,13 @@ class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase, @requireAttrs(socket.socket, "sendmsg") @unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX") -@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@requireSocket(socket_helper.get_family(), "SOCK_STREAM", "IPPROTO_SCTP") class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase): pass @requireAttrs(socket.socket, "recvmsg") @unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX") -@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@requireSocket(socket_helper.get_family(), "SOCK_STREAM", "IPPROTO_SCTP") class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, SendrecvmsgSCTPStreamTestBase): @@ -4400,7 +4420,7 @@ def testRecvmsgEOF(self): @requireAttrs(socket.socket, "recvmsg_into") @unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX") -@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@requireSocket(socket_helper.get_family(), "SOCK_STREAM", "IPPROTO_SCTP") class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, SendrecvmsgSCTPStreamTestBase): @@ -5142,7 +5162,7 @@ def mocked_socket_module(self): def test_connect(self): port = socket_helper.find_unused_port() - cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + cli = socket_helper.tcp_socket() self.addCleanup(cli.close) with self.assertRaises(OSError) as cm: cli.connect((HOST, port)) @@ -5210,7 +5230,7 @@ def _testFamily(self): self.cli = socket.create_connection((HOST, self.port), timeout=support.LOOPBACK_TIMEOUT) self.addCleanup(self.cli.close) - self.assertEqual(self.cli.family, 2) + self.assertEqual(self.cli.family, socket_helper.get_family()) testSourceAddress = _justAccept def _testSourceAddress(self): @@ -5724,7 +5744,7 @@ def testCreateConnectionBase(self): conn.sendall(data) def _testCreateConnectionBase(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] with socket.create_connection(address) as sock: self.assertFalse(sock._closed) sock.sendall(b'foo') @@ -5738,7 +5758,7 @@ def testCreateConnectionClose(self): conn.sendall(data) def _testCreateConnectionClose(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] with socket.create_connection(address) as sock: sock.close() self.assertTrue(sock._closed) @@ -6034,7 +6054,7 @@ def meth_from_sock(self, sock): # regular file def _testRegularFile(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') with socket.create_connection(address) as sock, file as file: meth = self.meth_from_sock(sock) @@ -6051,7 +6071,7 @@ def testRegularFile(self): # non regular file def _testNonRegularFile(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = io.BytesIO(self.FILEDATA) with socket.create_connection(address) as sock, file as file: sent = sock.sendfile(file) @@ -6069,7 +6089,7 @@ def testNonRegularFile(self): # empty file def _testEmptyFileSend(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] filename = os_helper.TESTFN + "2" with open(filename, 'wb'): self.addCleanup(os_helper.unlink, filename) @@ -6088,7 +6108,7 @@ def testEmptyFileSend(self): # offset def _testOffset(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') with socket.create_connection(address) as sock, file as file: meth = self.meth_from_sock(sock) @@ -6105,7 +6125,7 @@ def testOffset(self): # count def _testCount(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') sock = socket.create_connection(address, timeout=support.LOOPBACK_TIMEOUT) @@ -6126,7 +6146,7 @@ def testCount(self): # count small def _testCountSmall(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') sock = socket.create_connection(address, timeout=support.LOOPBACK_TIMEOUT) @@ -6147,7 +6167,7 @@ def testCountSmall(self): # count + offset def _testCountWithOffset(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') with socket.create_connection(address, timeout=2) as sock, file as file: count = 100007 @@ -6166,7 +6186,7 @@ def testCountWithOffset(self): # non blocking sockets are not supposed to work def _testNonBlocking(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') with socket.create_connection(address) as sock, file as file: sock.setblocking(False) @@ -6182,7 +6202,7 @@ def testNonBlocking(self): # timeout (non-triggered) def _testWithTimeout(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') sock = socket.create_connection(address, timeout=support.LOOPBACK_TIMEOUT) @@ -6200,7 +6220,7 @@ def testWithTimeout(self): # timeout (triggered) def _testWithTimeoutTriggeredSend(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] with open(os_helper.TESTFN, 'rb') as file: with socket.create_connection(address) as sock: sock.settimeout(0.01) @@ -6471,35 +6491,45 @@ def test_new_tcp_flags(self): class CreateServerTest(unittest.TestCase): - def test_address(self): - port = socket_helper.find_unused_port() + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') + def test_address_ipv4(self): + port = socket_helper.find_unused_port(family=socket.AF_INET) with socket.create_server(("127.0.0.1", port)) as sock: self.assertEqual(sock.getsockname()[0], "127.0.0.1") self.assertEqual(sock.getsockname()[1], port) - if socket_helper.IPV6_ENABLED: - with socket.create_server(("::1", port), - family=socket.AF_INET6) as sock: - self.assertEqual(sock.getsockname()[0], "::1") - self.assertEqual(sock.getsockname()[1], port) - def test_family_and_type(self): + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required') + def test_address_ipv6(self): + port = socket_helper.find_unused_port(family=socket.AF_INET6) + with socket.create_server(("::1", port), + family=socket.AF_INET6) as sock: + self.assertEqual(sock.getsockname()[0], "::1") + self.assertEqual(sock.getsockname()[1], port) + + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') + def test_family_and_type_ipv4(self): with socket.create_server(("127.0.0.1", 0)) as sock: self.assertEqual(sock.family, socket.AF_INET) self.assertEqual(sock.type, socket.SOCK_STREAM) - if socket_helper.IPV6_ENABLED: - with socket.create_server(("::1", 0), family=socket.AF_INET6) as s: - self.assertEqual(s.family, socket.AF_INET6) - self.assertEqual(sock.type, socket.SOCK_STREAM) + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required') + def test_family_and_type_ipv6(self): + with socket.create_server(("::1", 0), family=socket.AF_INET6) as sock: + self.assertEqual(sock.family, socket.AF_INET6) + self.assertEqual(sock.type, socket.SOCK_STREAM) def test_reuse_port(self): + fam = socket_helper.get_family() if not hasattr(socket, "SO_REUSEPORT"): with self.assertRaises(ValueError): - socket.create_server(("localhost", 0), reuse_port=True) + socket.create_server( + ("localhost", 0), family=fam,reuse_port=True) else: - with socket.create_server(("localhost", 0)) as sock: + with socket.create_server(("localhost", 0), family=fam) as sock: opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) self.assertEqual(opt, 0) - with socket.create_server(("localhost", 0), reuse_port=True) as sock: + with socket.create_server( + ("localhost", 0), family=fam, reuse_port=True) as sock: opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) self.assertNotEqual(opt, 0) @@ -6554,15 +6584,16 @@ def echo_client(self, addr, family): sock.sendall(b'foo') self.assertEqual(sock.recv(1024), b'foo') + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def test_tcp4(self): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET) with socket.create_server(("", port)) as sock: self.echo_server(sock) self.echo_client(("127.0.0.1", port), socket.AF_INET) @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') def test_tcp6(self): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET6) with socket.create_server(("", port), family=socket.AF_INET6) as sock: self.echo_server(sock) @@ -6573,10 +6604,11 @@ def test_tcp6(self): @unittest.skipIf(not socket.has_dualstack_ipv6(), "dualstack_ipv6 not supported") @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required.') def test_dual_stack_client_v4(self): - port = socket_helper.find_unused_port() - with socket.create_server(("", port), family=socket.AF_INET6, + with socket.create_server(("", 0), family=socket.AF_INET6, dualstack_ipv6=True) as sock: + port = sock.getsockname()[1] self.echo_server(sock) self.echo_client(("127.0.0.1", port), socket.AF_INET) @@ -6584,9 +6616,9 @@ def test_dual_stack_client_v4(self): "dualstack_ipv6 not supported") @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') def test_dual_stack_client_v6(self): - port = socket_helper.find_unused_port() - with socket.create_server(("", port), family=socket.AF_INET6, + with socket.create_server(("", 0), family=socket.AF_INET6, dualstack_ipv6=True) as sock: + port = sock.getsockname()[1] self.echo_server(sock) self.echo_client(("::1", port), socket.AF_INET6) @@ -6669,17 +6701,17 @@ def test_main(): tests.append(BasicBluetoothTest) tests.extend([ CmsgMacroTests, - SendmsgUDPTest, - RecvmsgUDPTest, - RecvmsgIntoUDPTest, + SendmsgUDP4Test, + RecvmsgUDP4Test, + RecvmsgIntoUDP4Test, SendmsgUDP6Test, RecvmsgUDP6Test, RecvmsgRFC3542AncillaryUDP6Test, RecvmsgIntoRFC3542AncillaryUDP6Test, RecvmsgIntoUDP6Test, - SendmsgUDPLITETest, - RecvmsgUDPLITETest, - RecvmsgIntoUDPLITETest, + SendmsgUDPLITE4Test, + RecvmsgUDPLITE4Test, + RecvmsgIntoUDPLITE4Test, SendmsgUDPLITE6Test, RecvmsgUDPLITE6Test, RecvmsgRFC3542AncillaryUDPLITE6Test, diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 00d5eff81537d1..78a2bdfca8eb9d 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -366,7 +366,7 @@ def test_ssl_types(self): def test_private_init(self): with self.assertRaisesRegex(TypeError, "public constructor"): - with socket.socket() as s: + with socket_helper.tcp_socket() as s: ssl.SSLSocket(s) def test_str_for_enums(self): @@ -550,7 +550,7 @@ def test_openssl_version(self): def test_refcycle(self): # Issue #7943: an SSL object doesn't create reference cycles with # itself. - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) ss = test_wrap_socket(s) wr = weakref.ref(ss) with warnings_helper.check_warnings(("", ResourceWarning)): @@ -560,7 +560,7 @@ def test_refcycle(self): def test_wrapped_unconnected(self): # Methods on an unconnected SSLSocket propagate the original # OSError raise by the underlying socket object. - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) with test_wrap_socket(s) as ss: self.assertRaises(OSError, ss.recv, 1) self.assertRaises(OSError, ss.recv_into, bytearray(b'x')) @@ -579,14 +579,14 @@ def test_timeout(self): # Issue #8524: when creating an SSL socket, the timeout of the # original socket should be retained. for timeout in (None, 0.0, 5.0): - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) s.settimeout(timeout) with test_wrap_socket(s) as ss: self.assertEqual(timeout, ss.gettimeout()) @ignore_deprecation def test_errors_sslwrap(self): - sock = socket.socket() + sock = socket_helper.tcp_socket() self.assertRaisesRegex(ValueError, "certfile must be specified", ssl.wrap_socket, sock, keyfile=CERTFILE) @@ -600,16 +600,16 @@ def test_errors_sslwrap(self): self.assertRaisesRegex(ValueError, "can't connect in server-side mode", s.connect, (HOST, 8080)) with self.assertRaises(OSError) as cm: - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: ssl.wrap_socket(sock, certfile=NONEXISTINGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaises(OSError) as cm: - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: ssl.wrap_socket(sock, certfile=CERTFILE, keyfile=NONEXISTINGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaises(OSError) as cm: - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: ssl.wrap_socket(sock, certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) @@ -618,7 +618,7 @@ def bad_cert_test(self, certfile): """Check that trying to use the given client certificate fails""" certfile = os.path.join(os.path.dirname(__file__) or os.curdir, certfile) - sock = socket.socket() + sock = socket_helper.tcp_socket() self.addCleanup(sock.close) with self.assertRaises(ssl.SSLError): test_wrap_socket(sock, @@ -838,34 +838,35 @@ def fail(cert, hostname): def test_server_side(self): # server_hostname doesn't work for server sockets ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: self.assertRaises(ValueError, ctx.wrap_socket, sock, True, server_hostname="some.hostname") def test_unknown_channel_binding(self): # should raise ValueError for unknown type - s = socket.create_server(('127.0.0.1', 0)) - c = socket.socket(socket.AF_INET) - c.connect(s.getsockname()) - with test_wrap_socket(c, do_handshake_on_connect=False) as ss: - with self.assertRaises(ValueError): - ss.get_channel_binding("unknown-type") - s.close() + with socket_helper.bind_ip_socket_and_port() as sock_port: + s = sock_port[0] + s.listen() + c = socket.socket(s.family) + c.connect(s.getsockname()) + with test_wrap_socket(c, do_handshake_on_connect=False) as ss: + with self.assertRaises(ValueError): + ss.get_channel_binding("unknown-type") @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, "'tls-unique' channel binding not available") def test_tls_unique_channel_binding(self): # unconnected should return None for known type - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) with test_wrap_socket(s) as ss: self.assertIsNone(ss.get_channel_binding("tls-unique")) # the same for server-side - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss: self.assertIsNone(ss.get_channel_binding("tls-unique")) def test_dealloc_warn(self): - ss = test_wrap_socket(socket.socket(socket.AF_INET)) + ss = test_wrap_socket(socket.socket(socket_helper.get_family())) r = repr(ss) with self.assertWarns(ResourceWarning) as cm: ss = None @@ -981,7 +982,7 @@ def test_purpose_enum(self): '1.3.6.1.5.5.7.3.2') def test_unsupported_dtls(self): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s = socket.socket(socket_helper.get_family(), socket.SOCK_DGRAM) self.addCleanup(s.close) with self.assertRaises(NotImplementedError) as cx: test_wrap_socket(s, cert_reqs=ssl.CERT_NONE) @@ -1057,10 +1058,10 @@ def local_february_name(): self.cert_time_fail(local_february_name() + " 9 00:00:00 2007 GMT") def test_connect_ex_error(self): - server = socket.socket(socket.AF_INET) + server = socket.socket(socket_helper.get_family()) self.addCleanup(server.close) port = socket_helper.bind_port(server) # Reserve port but don't listen - s = test_wrap_socket(socket.socket(socket.AF_INET), + s = test_wrap_socket(socket.socket(server.family), cert_reqs=ssl.CERT_REQUIRED) self.addCleanup(s.close) rc = s.connect_ex((HOST, port)) @@ -1077,7 +1078,7 @@ def test_read_write_zero(self): client_context, server_context, hostname = testing_context() server = ThreadedEchoServer(context=server_context) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertEqual(s.recv(0), b"") @@ -1752,7 +1753,7 @@ class MySSLObject(ssl.SSLObject): ctx.sslsocket_class = MySSLSocket ctx.sslobject_class = MySSLObject - with ctx.wrap_socket(socket.socket(), server_side=True) as sock: + with ctx.wrap_socket(socket_helper.tcp_socket(), server_side=True) as sock: self.assertIsInstance(sock, MySSLSocket) obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO()) self.assertIsInstance(obj, MySSLObject) @@ -1804,8 +1805,10 @@ def test_subclass(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE - with socket.create_server(("127.0.0.1", 0)) as s: - c = socket.create_connection(s.getsockname()) + with socket_helper.bind_ip_socket_and_port() as sock_port: + s = sock_port[0] + s.listen() + c = socket.create_connection(s.getsockname()[:2]) c.setblocking(False) with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c: with self.assertRaises(ssl.SSLWantReadError) as cm: @@ -1947,19 +1950,20 @@ def setUp(self): self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) self.server_context.load_cert_chain(SIGNED_CERTFILE) server = ThreadedEchoServer(context=self.server_context) + self.family = server.sock.family self.server_addr = (HOST, server.port) server.__enter__() self.addCleanup(server.__exit__, None, None, None) def test_connect(self): - with test_wrap_socket(socket.socket(socket.AF_INET), + with test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_NONE) as s: s.connect(self.server_addr) self.assertEqual({}, s.getpeercert()) self.assertFalse(s.server_side) # this should succeed because we specify the root cert - with test_wrap_socket(socket.socket(socket.AF_INET), + with test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_REQUIRED, ca_certs=SIGNING_CA) as s: s.connect(self.server_addr) @@ -1970,7 +1974,7 @@ def test_connect_fail(self): # This should fail because we have no verification certs. Connection # failure crashes ThreadedEchoServer, so run this in an independent # test method. - s = test_wrap_socket(socket.socket(socket.AF_INET), + s = test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_REQUIRED) self.addCleanup(s.close) self.assertRaisesRegex(ssl.SSLError, "certificate verify failed", @@ -1978,7 +1982,7 @@ def test_connect_fail(self): def test_connect_ex(self): # Issue #11326: check connect_ex() implementation - s = test_wrap_socket(socket.socket(socket.AF_INET), + s = test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_REQUIRED, ca_certs=SIGNING_CA) self.addCleanup(s.close) @@ -1988,7 +1992,7 @@ def test_connect_ex(self): def test_non_blocking_connect_ex(self): # Issue #11326: non-blocking connect_ex() should allow handshake # to proceed after the socket gets ready. - s = test_wrap_socket(socket.socket(socket.AF_INET), + s = test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_REQUIRED, ca_certs=SIGNING_CA, do_handshake_on_connect=False) @@ -2016,17 +2020,17 @@ def test_connect_with_context(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE - with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + with ctx.wrap_socket(socket.socket(self.family)) as s: s.connect(self.server_addr) self.assertEqual({}, s.getpeercert()) # Same with a server hostname - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname="dummy") as s: s.connect(self.server_addr) ctx.verify_mode = ssl.CERT_REQUIRED # This should succeed because we specify the root cert ctx.load_verify_locations(SIGNING_CA) - with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + with ctx.wrap_socket(socket.socket(self.family)) as s: s.connect(self.server_addr) cert = s.getpeercert() self.assertTrue(cert) @@ -2037,7 +2041,7 @@ def test_connect_with_context_fail(self): # test method. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) s = ctx.wrap_socket( - socket.socket(socket.AF_INET), + socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME ) self.addCleanup(s.close) @@ -2052,7 +2056,7 @@ def test_connect_capath(self): # filename) for this test to be portable across OpenSSL releases. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(capath=CAPATH) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2061,7 +2065,7 @@ def test_connect_capath(self): # Same with a bytes `capath` argument ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(capath=BYTES_CAPATH) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2073,7 +2077,7 @@ def test_connect_cadata(self): der = ssl.PEM_cert_to_DER_cert(pem) ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(cadata=pem) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2082,7 +2086,7 @@ def test_connect_cadata(self): # same with DER ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(cadata=der) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2093,7 +2097,7 @@ def test_makefile_close(self): # Issue #5238: creating a file-like object with makefile() shouldn't # delay closing the underlying "real socket" (here tested with its # file descriptor, hence skipping the test under Windows). - ss = test_wrap_socket(socket.socket(socket.AF_INET)) + ss = test_wrap_socket(socket.socket(self.family)) ss.connect(self.server_addr) fd = ss.fileno() f = ss.makefile() @@ -2108,7 +2112,7 @@ def test_makefile_close(self): self.assertEqual(e.exception.errno, errno.EBADF) def test_non_blocking_handshake(self): - s = socket.socket(socket.AF_INET) + s = socket.socket(self.family) s.connect(self.server_addr) s.setblocking(False) s = test_wrap_socket(s, @@ -2167,15 +2171,15 @@ def servername_cb(ssl_sock, server_name, initial_context): timeout=0.1) def test_ciphers(self): - with test_wrap_socket(socket.socket(socket.AF_INET), + with test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s: s.connect(self.server_addr) - with test_wrap_socket(socket.socket(socket.AF_INET), + with test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s: s.connect(self.server_addr) # Error checking can happen at instantiation or when connecting with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"): - with socket.socket(socket.AF_INET) as sock: + with socket.socket(self.family) as sock: s = test_wrap_socket(sock, cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx") s.connect(self.server_addr) @@ -2185,7 +2189,7 @@ def test_get_ca_certs_capath(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(capath=CAPATH) self.assertEqual(ctx.get_ca_certs(), []) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname='localhost') as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2198,7 +2202,7 @@ def test_context_setget(self): ctx1.load_verify_locations(capath=CAPATH) ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx2.load_verify_locations(capath=CAPATH) - s = socket.socket(socket.AF_INET) + s = socket.socket(self.family) with ctx1.wrap_socket(s, server_hostname='localhost') as ss: ss.connect(self.server_addr) self.assertIs(ss.context, ctx1) @@ -2245,7 +2249,7 @@ def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs): return ret def test_bio_handshake(self): - sock = socket.socket(socket.AF_INET) + sock = socket.socket(self.family) self.addCleanup(sock.close) sock.connect(self.server_addr) incoming = ssl.MemoryBIO() @@ -2279,7 +2283,7 @@ def test_bio_handshake(self): self.assertRaises(ssl.SSLError, sslobj.write, b'foo') def test_bio_read_write_data(self): - sock = socket.socket(socket.AF_INET) + sock = socket.socket(self.family) self.addCleanup(sock.close) sock.connect(self.server_addr) incoming = ssl.MemoryBIO() @@ -2298,6 +2302,9 @@ def test_bio_read_write_data(self): class NetworkedTests(unittest.TestCase): + @unittest.skipUnless( + socket_helper.IPV4_ENABLED, + f"{REMOTE_HOST} was IPv4 only at the time of this writing.") def test_timeout_connect_ex(self): # Issue #12065: on a timeout, connect_ex() should return the original # errno (mimicking the behaviour of non-SSL sockets). @@ -2563,8 +2570,7 @@ def __init__(self, certificate=None, ssl_version=None, self.chatty = chatty self.connectionchatty = connectionchatty self.starttls_server = starttls_server - self.sock = socket.socket() - self.port = socket_helper.bind_port(self.sock) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() self.flag = None self.active = False self.selected_alpn_protocols = [] @@ -2681,17 +2687,18 @@ def handle_error(self): def __init__(self, certfile): self.certfile = certfile - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = socket_helper.bind_port(sock, '') + sock, self.port = socket_helper.get_bound_ip_socket_and_port() asyncore.dispatcher.__init__(self, sock) self.listen(5) def handle_accepted(self, sock_obj, addr): if support.verbose: - sys.stdout.write(" server: new connection from %s:%s\n" %addr) + sys.stdout.write(" server: new connection from %s:%s\n" % addr[:2]) self.ConnectionHandler(sock_obj, self.certfile) def handle_error(self): + if support.verbose: + sys.stdout.write(" server: error:\n%s\n" % traceback.format_exc()) raise def __init__(self, certfile): @@ -2752,7 +2759,7 @@ def server_params_test(client_context, server_context, indata=b"FOO\n", chatty=chatty, connectionchatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=sni_name, session=session) as s: s.connect((HOST, server.port)) for arg in [indata, bytearray(indata), memoryview(indata)]: @@ -2914,7 +2921,7 @@ def test_getpeercert(self): client_context, server_context, hostname = testing_context() server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), do_handshake_on_connect=False, server_hostname=hostname) as s: s.connect((HOST, server.port)) @@ -2955,7 +2962,7 @@ def test_crl_check(self): # VERIFY_DEFAULT should pass server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -2966,7 +2973,7 @@ def test_crl_check(self): server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: with self.assertRaisesRegex(ssl.SSLError, "certificate verify failed"): @@ -2977,7 +2984,7 @@ def test_crl_check(self): server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -2992,7 +2999,7 @@ def test_check_hostname(self): # correct hostname should verify server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -3001,7 +3008,7 @@ def test_check_hostname(self): # incorrect hostname should raise an exception server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname="invalid") as s: with self.assertRaisesRegex( ssl.CertificateError, @@ -3011,7 +3018,7 @@ def test_check_hostname(self): # missing server_hostname arg should cause an exception, too server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with socket.socket() as s: + with socket_helper.tcp_socket() as s: with self.assertRaisesRegex(ValueError, "check_hostname requires server_hostname"): client_context.wrap_socket(s) @@ -3027,7 +3034,7 @@ def test_hostname_checks_common_name(self): # default cert has a SAN server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) @@ -3035,7 +3042,7 @@ def test_hostname_checks_common_name(self): client_context.hostname_checks_common_name = False server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: with self.assertRaises(ssl.SSLCertVerificationError): s.connect((HOST, server.port)) @@ -3053,7 +3060,7 @@ def test_ecc_cert(self): # correct hostname should verify server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -3079,7 +3086,7 @@ def test_dual_rsa_ecc(self): # correct hostname should verify server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -3127,7 +3134,7 @@ def test_check_hostname_idn(self): for server_hostname, expected_hostname in idn_hostnames: server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket(), + with context.wrap_socket(socket_helper.tcp_socket(), server_hostname=server_hostname) as s: self.assertEqual(s.server_hostname, expected_hostname) s.connect((HOST, server.port)) @@ -3138,7 +3145,7 @@ def test_check_hostname_idn(self): # incorrect hostname should raise an exception server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket(), + with context.wrap_socket(socket_helper.tcp_socket(), server_hostname="python.example.org") as s: with self.assertRaises(ssl.CertificateError): s.connect((HOST, server.port)) @@ -3162,7 +3169,7 @@ def test_wrong_cert_tls12(self): ) with server, \ - client_context.wrap_socket(socket.socket(), + client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: try: # Expect either an SSL error about the server rejecting @@ -3193,7 +3200,7 @@ def test_wrong_cert_tls13(self): context=server_context, chatty=True, connectionchatty=True, ) with server, \ - client_context.wrap_socket(socket.socket(), + client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: # TLS 1.3 perform client cert exchange after handshake s.connect((HOST, server.port)) @@ -3226,7 +3233,7 @@ def test_rude_shutdown(self): listener_ready = threading.Event() listener_gone = threading.Event() - s = socket.socket() + s = socket_helper.tcp_socket() port = socket_helper.bind_port(s, HOST) # `listener` runs in a thread. It sits in an accept() until @@ -3243,7 +3250,7 @@ def listener(): def connector(): listener_ready.wait() - with socket.socket() as c: + with socket_helper.tcp_socket() as c: c.connect((HOST, port)) listener_gone.wait() try: @@ -3271,7 +3278,7 @@ def test_ssl_cert_verify_error(self): server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket(), + with context.wrap_socket(socket_helper.tcp_socket(), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: try: s.connect((HOST, server.port)) @@ -3422,7 +3429,7 @@ def test_starttls(self): connectionchatty=True) wrapped = False with server: - s = socket.socket() + s = socket_helper.tcp_socket() s.setblocking(True) s.connect((HOST, server.port)) if support.verbose: @@ -3503,8 +3510,8 @@ def test_asyncore_server(self): indata = b"FOO\n" server = AsyncoreEchoServer(CERTFILE) with server: - s = test_wrap_socket(socket.socket()) - s.connect(('127.0.0.1', server.port)) + s = test_wrap_socket(socket_helper.tcp_socket()) + s.connect((socket_helper.HOST, server.port)) if support.verbose: sys.stdout.write( " client: sending %r...\n" % indata) @@ -3536,7 +3543,7 @@ def test_recv_send(self): chatty=True, connectionchatty=False) with server: - s = test_wrap_socket(socket.socket(), + s = test_wrap_socket(socket_helper.tcp_socket(), server_side=False, certfile=CERTFILE, ca_certs=CERTFILE, @@ -3688,7 +3695,7 @@ def test_nonblocking_send(self): chatty=True, connectionchatty=False) with server: - s = test_wrap_socket(socket.socket(), + s = test_wrap_socket(socket_helper.tcp_socket(), server_side=False, certfile=CERTFILE, ca_certs=CERTFILE, @@ -3711,9 +3718,8 @@ def fill_buffer(): def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout - server = socket.socket(socket.AF_INET) - host = "127.0.0.1" - port = socket_helper.bind_port(server) + host = socket_helper.HOST + server, port = socket_helper.get_bound_ip_socket_and_port(hostname=host) started = threading.Event() finish = False @@ -3736,7 +3742,7 @@ def serve(): try: try: - c = socket.socket(socket.AF_INET) + c = socket.socket(server.family) c.settimeout(0.2) c.connect((host, port)) # Will attempt handshake and time out @@ -3745,7 +3751,7 @@ def serve(): finally: c.close() try: - c = socket.socket(socket.AF_INET) + c = socket.socket(server.family) c = test_wrap_socket(c) c.settimeout(0.2) # Will attempt handshake and time out @@ -3762,9 +3768,9 @@ def test_server_accept(self): # Issue #16357: accept() on a SSLSocket created through # SSLContext.wrap_socket(). client_ctx, server_ctx, hostname = testing_context() - server = socket.socket(socket.AF_INET) - host = "127.0.0.1" - port = socket_helper.bind_port(server) + host = socket_helper.HOST + server, port = socket_helper.get_bound_ip_socket_and_port( + hostname=host) server = server_ctx.wrap_socket(server, server_side=True) self.assertTrue(server.server_side) @@ -3784,7 +3790,7 @@ def serve(): # Client wait until server setup and perform a connect. evt.wait() client = client_ctx.wrap_socket( - socket.socket(), server_hostname=hostname + socket_helper.tcp_socket(), server_hostname=hostname ) client.connect((hostname, port)) client.send(b'data') @@ -3801,7 +3807,7 @@ def serve(): def test_getpeercert_enotconn(self): context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.check_hostname = False - with context.wrap_socket(socket.socket()) as sock: + with context.wrap_socket(socket_helper.tcp_socket()) as sock: with self.assertRaises(OSError) as cm: sock.getpeercert() self.assertEqual(cm.exception.errno, errno.ENOTCONN) @@ -3809,7 +3815,7 @@ def test_getpeercert_enotconn(self): def test_do_handshake_enotconn(self): context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.check_hostname = False - with context.wrap_socket(socket.socket()) as sock: + with context.wrap_socket(socket_helper.tcp_socket()) as sock: with self.assertRaises(OSError) as cm: sock.do_handshake() self.assertEqual(cm.exception.errno, errno.ENOTCONN) @@ -3822,7 +3828,7 @@ def test_no_shared_ciphers(self): client_context.set_ciphers("AES128") server_context.set_ciphers("AES256") with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: with self.assertRaises(OSError): s.connect((HOST, server.port)) @@ -3839,7 +3845,7 @@ def test_version_basic(self): with ThreadedEchoServer(CERTFILE, ssl_version=ssl.PROTOCOL_TLS_SERVER, chatty=False) as server: - with context.wrap_socket(socket.socket()) as s: + with context.wrap_socket(socket_helper.tcp_socket()) as s: self.assertIs(s.version(), None) self.assertIs(s._sslobj, None) s.connect((HOST, server.port)) @@ -3852,7 +3858,7 @@ def test_tls1_3(self): client_context, server_context, hostname = testing_context() client_context.minimum_version = ssl.TLSVersion.TLSv1_3 with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertIn(s.cipher()[0], { @@ -3875,7 +3881,7 @@ def test_min_max_version_tlsv1_2(self): server_context.maximum_version = ssl.TLSVersion.TLSv1_2 with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertEqual(s.version(), 'TLSv1.2') @@ -3892,7 +3898,7 @@ def test_min_max_version_tlsv1_1(self): seclevel_workaround(client_context, server_context) with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertEqual(s.version(), 'TLSv1.1') @@ -3910,7 +3916,7 @@ def test_min_max_version_mismatch(self): seclevel_workaround(client_context, server_context) with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: with self.assertRaises(ssl.SSLError) as e: s.connect((HOST, server.port)) @@ -3925,7 +3931,7 @@ def test_min_max_version_sslv3(self): seclevel_workaround(client_context, server_context) with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertEqual(s.version(), 'SSLv3') @@ -3942,7 +3948,7 @@ def test_default_ecdh_curve(self): # our default cipher list should prefer ECDH-based ciphers # automatically. with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertIn("ECDH", s.cipher()[0]) @@ -3962,7 +3968,7 @@ def test_tls_unique_channel_binding(self): with server: with client_context.wrap_socket( - socket.socket(), + socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) # get the data @@ -3986,7 +3992,7 @@ def test_tls_unique_channel_binding(self): # now, again with client_context.wrap_socket( - socket.socket(), + socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) new_cb_data = s.get_channel_binding("tls-unique") @@ -4255,7 +4261,7 @@ def test_read_write_after_close_raises_valuerror(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - s = client_context.wrap_socket(socket.socket(), + s = client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) s.connect((HOST, server.port)) s.close() @@ -4271,7 +4277,7 @@ def test_sendfile(self): client_context, server_context, hostname = testing_context() server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) with open(os_helper.TESTFN, 'rb') as file: @@ -4345,7 +4351,7 @@ def test_session_handling(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: # session is None before handshake self.assertEqual(s.session, None) @@ -4357,7 +4363,7 @@ def test_session_handling(self): s.session = object self.assertEqual(str(e.exception), 'Value is not a SSLSession.') - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) # cannot set session after handshake @@ -4366,7 +4372,7 @@ def test_session_handling(self): self.assertEqual(str(e.exception), 'Cannot set session after handshake.') - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: # can set session before handshake and before the # connection was established @@ -4376,7 +4382,7 @@ def test_session_handling(self): self.assertEqual(s.session, session) self.assertEqual(s.session_reused, True) - with client_context2.wrap_socket(socket.socket(), + with client_context2.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: # cannot re-use session with a different SSLContext with self.assertRaises(ValueError) as e: @@ -4421,7 +4427,7 @@ def test_pha_required(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4453,7 +4459,7 @@ def msg_cb(conn, direction, version, content_type, msg_type, data): server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'PHA') @@ -4484,7 +4490,7 @@ def test_pha_optional(self): server_context.verify_mode = ssl.CERT_OPTIONAL server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4505,7 +4511,7 @@ def test_pha_optional_nocert(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4524,7 +4530,7 @@ def test_pha_no_pha_client(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) with self.assertRaisesRegex(ssl.SSLError, 'not server'): @@ -4541,7 +4547,7 @@ def test_pha_no_pha_server(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4562,7 +4568,7 @@ def test_pha_not_tls13(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) # PHA fails for TLS != 1.3 @@ -4588,7 +4594,7 @@ def test_bpo37428_pha_cert_none(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4607,7 +4613,7 @@ def test_internal_chain_client(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: with client_context.wrap_socket( - socket.socket(), + socket_helper.tcp_socket(), server_hostname=hostname ) as s: s.connect((HOST, server.port)) @@ -4646,7 +4652,7 @@ def test_internal_chain_server(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: with client_context.wrap_socket( - socket.socket(), + socket_helper.tcp_socket(), server_hostname=hostname ) as s: s.connect((HOST, server.port)) @@ -4701,7 +4707,7 @@ def test_keylog_filename(self): client_context.keylog_filename = os_helper.TESTFN server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) # header, 5 lines for TLS 1.3 @@ -4711,7 +4717,7 @@ def test_keylog_filename(self): server_context.keylog_filename = os_helper.TESTFN server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertGreaterEqual(self.keylog_lines(), 11) @@ -4720,7 +4726,7 @@ def test_keylog_filename(self): server_context.keylog_filename = os_helper.TESTFN server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertGreaterEqual(self.keylog_lines(), 21) @@ -4775,7 +4781,7 @@ def msg_cb(conn, direction, version, content_type, msg_type, data): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) @@ -4805,10 +4811,10 @@ def sni_cb(sock, servername, ctx): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 55d78b733353d2..d0231e783e2834 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -94,21 +94,45 @@ def test_forget(self): os_helper.unlink(mod_filename) os_helper.rmtree('__pycache__') - def test_HOST(self): - s = socket.create_server((socket_helper.HOST, 0)) - s.close() + def test_bind_ip_socket_and_port_HOST(self): + """This also tests get_bound_ip_socket_and_port() indirectly.""" + with socket_helper.bind_ip_socket_and_port( + hostname=socket_helper.HOST): + pass - def test_find_unused_port(self): - port = socket_helper.find_unused_port() + @unittest.skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") + def test_find_unused_port_ipv4(self): + port = socket_helper.find_unused_port(family=socket.AF_INET) s = socket.create_server((socket_helper.HOST, port)) s.close() - def test_bind_port(self): - s = socket.socket() - socket_helper.bind_port(s) - s.listen() + @unittest.skipUnless(socket_helper.IPV6_ENABLED, "IPv6 required") + def test_find_unused_port_ipv6(self): + port = socket_helper.find_unused_port(family=socket.AF_INET6) + s = socket.create_server( + (socket_helper.HOST, port), + family=socket.AF_INET6) s.close() + def test_find_unused_port_noargs(self): + port = socket_helper.find_unused_port() + s = socket.create_server( + (socket_helper.HOST, port), + family=socket_helper.get_family()) + s.close() + + @unittest.skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") + def test_bind_port_ipv4(self): + with socket.socket(socket.AF_INET) as s: + socket_helper.bind_port(s) + s.listen() + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, "IPv6 required") + def test_bind_port_ipv6(self): + with socket.socket(socket.AF_INET6) as s: + socket_helper.bind_port(s) + s.listen() + # Tests for temp_dir() def test_temp_dir(self): diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py index 41c4fcd4195e3a..7a7212be7c1252 100644 --- a/Lib/test/test_telnetlib.py +++ b/Lib/test/test_telnetlib.py @@ -25,7 +25,8 @@ class GeneralTests(unittest.TestCase): def setUp(self): self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + family = socket_helper.get_family() + self.sock = socket.socket(family, socket.SOCK_STREAM) self.sock.settimeout(60) # Safety net. Look issue 11812 self.port = socket_helper.bind_port(self.sock) self.thread = threading.Thread(target=server, args=(self.evt,self.sock)) diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py index ebb43c30b4d505..bb4da9e813528e 100644 --- a/Lib/test/test_urllib2_localnet.py +++ b/Lib/test/test_urllib2_localnet.py @@ -9,6 +9,7 @@ import hashlib from test.support import hashlib_helper +from test.support import socket_helper from test.support import threading_helper from test.support import warnings_helper @@ -30,6 +31,7 @@ class LoopbackHttpServer(http.server.HTTPServer): """HTTP server w/ a few modifications that make it useful for loopback testing purposes. """ + address_family = socket_helper.get_family() def __init__(self, server_address, RequestHandlerClass): http.server.HTTPServer.__init__(self, @@ -60,7 +62,7 @@ def __init__(self, request_handler): self._stop_server = False self.ready = threading.Event() request_handler.protocol_version = "HTTP/1.0" - self.httpd = LoopbackHttpServer(("127.0.0.1", 0), + self.httpd = LoopbackHttpServer((socket_helper.HOST, 0), request_handler) self.port = self.httpd.server_port @@ -290,7 +292,7 @@ def http_server_with_basic_auth_handler(*args, **kwargs): return BasicAuthHandler(*args, **kwargs) self.server = LoopbackHttpServerThread(http_server_with_basic_auth_handler) self.addCleanup(self.stop_server) - self.server_url = 'http://127.0.0.1:%s' % self.server.port + self.server_url = f'http://{socket_helper.HOST}:{self.server.port}' self.server.start() self.server.ready.wait() @@ -346,7 +348,7 @@ def create_fake_proxy_handler(*args, **kwargs): self.addCleanup(self.stop_server) self.server.start() self.server.ready.wait() - proxy_url = "http://127.0.0.1:%d" % self.server.port + proxy_url = f"http://{socket_helper.HOST}:{self.server.port}" handler = urllib.request.ProxyHandler({"http" : proxy_url}) self.proxy_digest_handler = urllib.request.ProxyDigestAuthHandler() self.opener = urllib.request.build_opener( diff --git a/Lib/test/test_wsgiref.py b/Lib/test/test_wsgiref.py index 93ca6b99a92c9c..a92e2250b5a2c1 100644 --- a/Lib/test/test_wsgiref.py +++ b/Lib/test/test_wsgiref.py @@ -265,7 +265,12 @@ def app(environ, start_response): class WsgiHandler(NoLogRequestHandler, WSGIRequestHandler): pass - server = make_server(socket_helper.HOST, 0, app, handler_class=WsgiHandler) + class IPStackWSGIServer(WSGIServer): + address_family = socket_helper.get_family() + + server = make_server( + socket_helper.HOST, 0, app, + server_class=IPStackWSGIServer, handler_class=WsgiHandler) self.addCleanup(server.server_close) interrupted = threading.Event() @@ -278,7 +283,7 @@ def signal_handler(signum, frame): main_thread = threading.get_ident() def run_client(): - http = HTTPConnection(*server.server_address) + http = HTTPConnection(*server.server_address[:2]) http.request("GET", "/") with http.getresponse() as response: response.read(100) diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py index a9f67466071bc6..34b8a154af3f9a 100644 --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -337,7 +337,10 @@ def run_server(): server.handle_request() # First request and attempt at second server.handle_request() # Retried second request - server = http.server.HTTPServer((socket_helper.HOST, 0), RequestHandler) + class IPvWhateverHTTPServer(http.server.HTTPServer): + address_family = socket_helper.get_family() + + server = IPvWhateverHTTPServer((socket_helper.HOST, 0), RequestHandler) self.addCleanup(server.server_close) thread = threading.Thread(target=run_server) thread.start() @@ -606,6 +609,9 @@ def getData(): return '42' class MyXMLRPCServer(xmlrpc.server.SimpleXMLRPCServer): + + address_family = socket_helper.get_family() + def get_request(self): # Ensure the socket is always non-blocking. On Linux, socket # attributes are not inherited like they are on *BSD and Windows. @@ -615,13 +621,13 @@ def get_request(self): if not requestHandler: requestHandler = xmlrpc.server.SimpleXMLRPCRequestHandler - serv = MyXMLRPCServer(("localhost", 0), requestHandler, + serv = MyXMLRPCServer((socket_helper.HOST, 0), requestHandler, encoding=encoding, logRequests=False, bind_and_activate=False) try: serv.server_bind() global ADDR, PORT, URL - ADDR, PORT = serv.socket.getsockname() + ADDR, PORT = serv.socket.getsockname()[:2] #connect to IP address directly. This avoids socket.create_connection() #trying to connect to "localhost" using all address families, which #causes slowdown e.g. on vista which supports AF_INET6. The server listens @@ -669,6 +675,9 @@ def my_function(): return True class MyXMLRPCServer(xmlrpc.server.MultiPathXMLRPCServer): + + address_family = socket_helper.get_family() + def get_request(self): # Ensure the socket is always non-blocking. On Linux, socket # attributes are not inherited like they are on *BSD and Windows. @@ -685,13 +694,13 @@ class BrokenDispatcher: def _marshaled_dispatch(self, data, dispatch_method=None, path=None): raise RuntimeError("broken dispatcher") - serv = MyXMLRPCServer(("localhost", 0), MyRequestHandler, + serv = MyXMLRPCServer((socket_helper.HOST, 0), MyRequestHandler, logRequests=False, bind_and_activate=False) serv.socket.settimeout(3) serv.server_bind() try: global ADDR, PORT, URL - ADDR, PORT = serv.socket.getsockname() + ADDR, PORT = serv.socket.getsockname()[:2] #connect to IP address directly. This avoids socket.create_connection() #trying to connect to "localhost" using all address families, which #causes slowdown e.g. on vista which supports AF_INET6. The server listens @@ -1498,7 +1507,11 @@ def test_cgihandler_has_use_builtin_types_flag(self): self.assertTrue(handler.use_builtin_types) def test_xmlrpcserver_has_use_builtin_types_flag(self): - server = xmlrpc.server.SimpleXMLRPCServer(("localhost", 0), + + class IPvWhateverSimpleXMLRPCServer(xmlrpc.server.SimpleXMLRPCServer): + address_family = socket_helper.get_family() + + server = IPvWhateverSimpleXMLRPCServer((socket_helper.HOST, 0), use_builtin_types=True) server.server_close() self.assertTrue(server.use_builtin_types)