|
12 | 12 | import sys
|
13 | 13 | import os
|
14 | 14 | import gc
|
| 15 | +import importlib |
15 | 16 | import errno
|
16 | 17 | import functools
|
17 | 18 | import signal
|
|
20 | 21 | import socket
|
21 | 22 | import random
|
22 | 23 | import logging
|
| 24 | +import shutil |
23 | 25 | import subprocess
|
24 | 26 | import struct
|
| 27 | +import tempfile |
25 | 28 | import operator
|
26 | 29 | import pickle
|
27 | 30 | import weakref
|
@@ -6343,6 +6346,81 @@ def test_atexit(self):
|
6343 | 6346 | self.assertEqual(f.read(), 'deadbeef')
|
6344 | 6347 |
|
6345 | 6348 |
|
| 6349 | +class _TestSpawnedSysPath(BaseTestCase): |
| 6350 | + """Test that sys.path is setup in forkserver and spawn processes.""" |
| 6351 | + |
| 6352 | + ALLOWED_TYPES = ('processes',) |
| 6353 | + |
| 6354 | + def setUp(self): |
| 6355 | + self._orig_sys_path = list(sys.path) |
| 6356 | + self._temp_dir = tempfile.mkdtemp(prefix="test_sys_path-") |
| 6357 | + self._mod_name = "unique_test_mod" |
| 6358 | + module_path = os.path.join(self._temp_dir, f"{self._mod_name}.py") |
| 6359 | + with open(module_path, "w", encoding="utf-8") as mod: |
| 6360 | + mod.write("# A simple test module\n") |
| 6361 | + sys.path[:] = [p for p in sys.path if p] # remove any existing ""s |
| 6362 | + sys.path.insert(0, self._temp_dir) |
| 6363 | + sys.path.insert(0, "") # Replaced with an abspath in child. |
| 6364 | + try: |
| 6365 | + self._ctx_forkserver = multiprocessing.get_context("forkserver") |
| 6366 | + except ValueError: |
| 6367 | + self._ctx_forkserver = None |
| 6368 | + self._ctx_spawn = multiprocessing.get_context("spawn") |
| 6369 | + |
| 6370 | + def tearDown(self): |
| 6371 | + sys.path[:] = self._orig_sys_path |
| 6372 | + shutil.rmtree(self._temp_dir, ignore_errors=True) |
| 6373 | + |
| 6374 | + @staticmethod |
| 6375 | + def enq_imported_module_names(queue): |
| 6376 | + queue.put(tuple(sys.modules)) |
| 6377 | + |
| 6378 | + def test_forkserver_preload_imports_sys_path(self): |
| 6379 | + ctx = self._ctx_forkserver |
| 6380 | + if not ctx: |
| 6381 | + self.skipTest("requires forkserver start method.") |
| 6382 | + self.assertNotIn(self._mod_name, sys.modules) |
| 6383 | + multiprocessing.forkserver._forkserver._stop() # Must be fresh. |
| 6384 | + ctx.set_forkserver_preload( |
| 6385 | + ["test.test_multiprocessing_forkserver", self._mod_name]) |
| 6386 | + q = ctx.Queue() |
| 6387 | + proc = ctx.Process(target=self.enq_imported_module_names, args=(q,)) |
| 6388 | + proc.start() |
| 6389 | + proc.join() |
| 6390 | + child_imported_modules = q.get() |
| 6391 | + q.close() |
| 6392 | + self.assertIn(self._mod_name, child_imported_modules) |
| 6393 | + |
| 6394 | + @staticmethod |
| 6395 | + def enq_sys_path_and_import(queue, mod_name): |
| 6396 | + queue.put(sys.path) |
| 6397 | + try: |
| 6398 | + importlib.import_module(mod_name) |
| 6399 | + except ImportError as exc: |
| 6400 | + queue.put(exc) |
| 6401 | + else: |
| 6402 | + queue.put(None) |
| 6403 | + |
| 6404 | + def test_child_sys_path(self): |
| 6405 | + for ctx in (self._ctx_spawn, self._ctx_forkserver): |
| 6406 | + if not ctx: |
| 6407 | + continue |
| 6408 | + with self.subTest(f"{ctx.get_start_method()} start method"): |
| 6409 | + q = ctx.Queue() |
| 6410 | + proc = ctx.Process(target=self.enq_sys_path_and_import, |
| 6411 | + args=(q, self._mod_name)) |
| 6412 | + proc.start() |
| 6413 | + proc.join() |
| 6414 | + child_sys_path = q.get() |
| 6415 | + import_error = q.get() |
| 6416 | + q.close() |
| 6417 | + self.assertNotIn("", child_sys_path) # replaced by an abspath |
| 6418 | + self.assertIn(self._temp_dir, child_sys_path) # our addition |
| 6419 | + # ignore the first element, it is the absolute "" replacement |
| 6420 | + self.assertEqual(child_sys_path[1:], sys.path[1:]) |
| 6421 | + self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}") |
| 6422 | + |
| 6423 | + |
6346 | 6424 | class MiscTestCase(unittest.TestCase):
|
6347 | 6425 | def test__all__(self):
|
6348 | 6426 | # Just make sure names in not_exported are excluded
|
|
0 commit comments