diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 328cd5112cab7a..bcb024d8386fd1 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -258,6 +258,9 @@ def __call__(self, *args, **kwds): class BaseTestCase(object): ALLOWED_TYPES = ('processes', 'manager', 'threads') + # If not empty, limit which start method suites run this class. + START_METHODS: set[str] = set() + start_method = None # set by install_tests_in_module_dict() def assertTimingAlmostEqual(self, a, b): if CHECK_TIMINGS: @@ -6403,7 +6406,9 @@ def test_atexit(self): class _TestSpawnedSysPath(BaseTestCase): """Test that sys.path is setup in forkserver and spawn processes.""" - ALLOWED_TYPES = ('processes',) + ALLOWED_TYPES = {'processes'} + # Not applicable to fork which inherits everything from the process as is. + START_METHODS = {"forkserver", "spawn"} def setUp(self): self._orig_sys_path = list(sys.path) @@ -6415,11 +6420,8 @@ def setUp(self): sys.path[:] = [p for p in sys.path if p] # remove any existing ""s sys.path.insert(0, self._temp_dir) sys.path.insert(0, "") # Replaced with an abspath in child. - try: - self._ctx_forkserver = multiprocessing.get_context("forkserver") - except ValueError: - self._ctx_forkserver = None - self._ctx_spawn = multiprocessing.get_context("spawn") + self.assertIn(self.start_method, self.START_METHODS) + self._ctx = multiprocessing.get_context(self.start_method) def tearDown(self): sys.path[:] = self._orig_sys_path @@ -6430,15 +6432,15 @@ def enq_imported_module_names(queue): queue.put(tuple(sys.modules)) def test_forkserver_preload_imports_sys_path(self): - ctx = self._ctx_forkserver - if not ctx: - self.skipTest("requires forkserver start method.") + if self._ctx.get_start_method() != "forkserver": + self.skipTest("forkserver specific test.") self.assertNotIn(self._mod_name, sys.modules) multiprocessing.forkserver._forkserver._stop() # Must be fresh. - ctx.set_forkserver_preload( + self._ctx.set_forkserver_preload( ["test.test_multiprocessing_forkserver", self._mod_name]) - q = ctx.Queue() - proc = ctx.Process(target=self.enq_imported_module_names, args=(q,)) + q = self._ctx.Queue() + proc = self._ctx.Process( + target=self.enq_imported_module_names, args=(q,)) proc.start() proc.join() child_imported_modules = q.get() @@ -6456,23 +6458,19 @@ def enq_sys_path_and_import(queue, mod_name): queue.put(None) def test_child_sys_path(self): - for ctx in (self._ctx_spawn, self._ctx_forkserver): - if not ctx: - continue - with self.subTest(f"{ctx.get_start_method()} start method"): - q = ctx.Queue() - proc = ctx.Process(target=self.enq_sys_path_and_import, - args=(q, self._mod_name)) - proc.start() - proc.join() - child_sys_path = q.get() - import_error = q.get() - q.close() - self.assertNotIn("", child_sys_path) # replaced by an abspath - self.assertIn(self._temp_dir, child_sys_path) # our addition - # ignore the first element, it is the absolute "" replacement - self.assertEqual(child_sys_path[1:], sys.path[1:]) - self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}") + q = self._ctx.Queue() + proc = self._ctx.Process( + target=self.enq_sys_path_and_import, args=(q, self._mod_name)) + proc.start() + proc.join() + child_sys_path = q.get() + import_error = q.get() + q.close() + self.assertNotIn("", child_sys_path) # replaced by an abspath + self.assertIn(self._temp_dir, child_sys_path) # our addition + # ignore the first element, it is the absolute "" replacement + self.assertEqual(child_sys_path[1:], sys.path[1:]) + self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}") class MiscTestCase(unittest.TestCase): @@ -6669,6 +6667,8 @@ def install_tests_in_module_dict(remote_globs, start_method, if base is BaseTestCase: continue assert set(base.ALLOWED_TYPES) <= ALL_TYPES, base.ALLOWED_TYPES + if base.START_METHODS and start_method not in base.START_METHODS: + continue # class not intended for this start method. for type_ in base.ALLOWED_TYPES: if only_type and type_ != only_type: continue @@ -6682,6 +6682,7 @@ class Temp(base, Mixin, unittest.TestCase): Temp = hashlib_helper.requires_hashdigest('sha256')(Temp) Temp.__name__ = Temp.__qualname__ = newname Temp.__module__ = __module__ + Temp.start_method = start_method remote_globs[newname] = Temp elif issubclass(base, unittest.TestCase): if only_type: