Skip to content

Commit

Permalink
Merge pull request #296 from eltoder/feature/forkserver
Browse files Browse the repository at this point in the history
Use forkserver start method for multiprocessing
  • Loading branch information
sodul authored Apr 24, 2024
2 parents 4de285b + 7899692 commit 8cabab8
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
29 changes: 13 additions & 16 deletions green/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations


import atexit
import os
import shutil
import sys
import tempfile
from typing import Sequence
Expand Down Expand Up @@ -87,22 +88,18 @@ def _main(argv: Sequence[str] | None, testing: bool) -> int:
def main(argv: Sequence[str] | None = None, testing: bool = False) -> int:
# create the temp dir only once (i.e., not while in the recursed call)
if not os.environ.get("TMPDIR"): # pragma: nocover
# Use `atexit` to cleanup `temp_dir_for_tests` so that multiprocessing can run its
# own cleanup before its temp directory is deleted.
temp_dir_for_tests = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(temp_dir_for_tests, ignore_errors=True))
os.environ["TMPDIR"] = temp_dir_for_tests
prev_tempdir = tempfile.tempdir
tempfile.tempdir = temp_dir_for_tests
try:
with tempfile.TemporaryDirectory() as temp_dir_for_tests:
try:
os.environ["TMPDIR"] = temp_dir_for_tests
tempfile.tempdir = temp_dir_for_tests
return _main(argv, testing)
finally:
del os.environ["TMPDIR"]
tempfile.tempdir = None
except OSError as os_error:
if os_error.errno == 39:
# "Directory not empty" when trying to delete the temp dir can just be a warning
print(f"warning: {os_error.strerror}")
return 0
else:
raise os_error
return _main(argv, testing)
finally:
del os.environ["TMPDIR"]
tempfile.tempdir = prev_tempdir
else:
return _main(argv, testing)

Expand Down
3 changes: 1 addition & 2 deletions green/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import copy # pragma: no cover
import functools # pragma: no cover
import logging # pragma: no cover
import multiprocessing # pragma: no cover
import os # pragma: no cover
import pathlib # pragma: no cover
import sys # pragma: no cover
Expand All @@ -36,7 +35,7 @@ def get_default_args() -> argparse.Namespace:
"""
return argparse.Namespace( # pragma: no cover
targets=["."], # Not in configs
processes=multiprocessing.cpu_count(),
processes=os.cpu_count(),
initializer="",
finalizer="",
maxtasksperchild=None,
Expand Down
13 changes: 11 additions & 2 deletions green/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,21 @@ def run(
# The call to toParallelTargets needs to happen before pool stuff so we can crash if there
# are, for example, syntax errors in the code to be loaded.
parallel_targets = toParallelTargets(suite, args.targets)
# Use "forkserver" method when available to avoid problems with "fork". See, for example,
# https://github.com/python/cpython/issues/84559
if "forkserver" in multiprocessing.get_all_start_methods():
mp_method = "forkserver"
else:
mp_method = None
mp_context = multiprocessing.get_context(mp_method)
pool = LoggingDaemonlessPool(
processes=args.processes or None,
initializer=InitializerOrFinalizer(args.initializer),
finalizer=InitializerOrFinalizer(args.finalizer),
maxtasksperchild=args.maxtasksperchild,
context=mp_context,
)
manager: SyncManager = multiprocessing.Manager()
manager: SyncManager = mp_context.Manager()
targets: list[tuple[str, Queue]] = [
(target, manager.Queue()) for target in parallel_targets
]
Expand Down Expand Up @@ -165,10 +173,11 @@ def run(

pool.close()
pool.join()
manager.shutdown()

result.stopTestRun()

# Ignore the type mismatch untile we make GreenTestResult a subclass of unittest.TestResult.
# Ignore the type mismatch until we make GreenTestResult a subclass of unittest.TestResult.
removeResult(result) # type: ignore

return result
11 changes: 7 additions & 4 deletions green/test/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from textwrap import dedent
import unittest
from unittest import mock
import warnings
import weakref

from green.config import get_default_args
Expand Down Expand Up @@ -114,7 +115,7 @@ def setUp(self):
self.loader = GreenTestLoader()

def tearDown(self):
del self.tmpdir
shutil.rmtree(self.tmpdir, ignore_errors=True)
del self.stream

def test_stdout(self):
Expand Down Expand Up @@ -162,7 +163,7 @@ def test01(self):

def test_warnings(self):
"""
setting warnings='always' doesn't crash
test runner does not generate warnings
"""
self.args.warnings = "always"
sub_tmpdir = pathlib.Path(tempfile.mkdtemp(dir=self.tmpdir))
Expand All @@ -177,10 +178,12 @@ def test01(self):
(sub_tmpdir / "test_warnings.py").write_text(content, encoding="utf-8")
os.chdir(sub_tmpdir)
try:
tests = self.loader.loadTargets("test_warnings")
result = run(tests, self.stream, self.args)
with warnings.catch_warnings(record=True) as recorded:
tests = self.loader.loadTargets("test_warnings")
result = run(tests, self.stream, self.args)
finally:
os.chdir(self.startdir)
self.assertEqual(recorded, [])
self.assertEqual(result.testsRun, 1)
self.assertIn("OK", self.stream.getvalue())

Expand Down

0 comments on commit 8cabab8

Please sign in to comment.