diff --git a/bettermap/bettermap.py b/bettermap/bettermap.py index 1f546d2..1947508 100644 --- a/bettermap/bettermap.py +++ b/bettermap/bettermap.py @@ -7,6 +7,7 @@ import itertools import multiprocessing as mp import multiprocessing.connection +from multiprocessing.context import ForkProcess from typing import Iterable, List, Optional, Any, Dict import dill @@ -14,6 +15,10 @@ from queue import Queue from threading import Thread + +mpctx = mp.get_context("fork") + + def threaded_generator(g, maxsize: int = 16): q = Queue(maxsize=maxsize) @@ -55,7 +60,7 @@ def map_per_process( input_sequence: Iterable, *, serialization_items: Optional[List[Any]] = None, - parallelism: int = mp.cpu_count() + parallelism: int = mpctx.cpu_count() ) -> Iterable: if serialization_items is not None and len(serialization_items) > 0: serialization_ids = [id(o) for o in serialization_items] @@ -82,7 +87,7 @@ def unpickle(b: bytes) -> Any: return unpickler.load() pipeno_to_pipe: Dict[int, multiprocessing.connection.Connection] = {} - pipeno_to_process: Dict[int, mp.Process] = {} + pipeno_to_process: Dict[int, ForkProcess] = {} def process_one_item(send_pipe: multiprocessing.connection.Connection, item): try: @@ -114,8 +119,8 @@ def yield_from_pipes(pipes: List[multiprocessing.connection.Connection]): try: for item in input_sequence: - receive_pipe, send_pipe = mp.Pipe(duplex=False) - process = mp.Process(target=process_one_item, args=(send_pipe, item)) + receive_pipe, send_pipe = mpctx.Pipe(duplex=False) + process = mpctx.Process(target=process_one_item, args=(send_pipe, item)) pipeno_to_pipe[receive_pipe.fileno()] = receive_pipe pipeno_to_process[receive_pipe.fileno()] = process process.start() @@ -173,7 +178,7 @@ def ordered_map_per_thread( fn, input_sequence: Iterable, *, - parallelism: int = mp.cpu_count() + parallelism: int = mpctx.cpu_count() ) -> Iterable: executor = ThreadPoolExecutor(max_workers=parallelism) input_sequence = (executor.submit(fn, item) for item in input_sequence)