diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0953ec4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/*.iml +/__pycache__/ +/*.egg-info/ +/build/ +/dist/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..5678429 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +bettermap +========= + +`bettermap` is a drop-in replacement for Python's map function. It parallelizes +the map function across all available processors. diff --git a/bettermap/__init__.py b/bettermap/__init__.py new file mode 100644 index 0000000..415a5d1 --- /dev/null +++ b/bettermap/__init__.py @@ -0,0 +1 @@ +from .bettermap import * diff --git a/bettermap/bettermap.py b/bettermap/bettermap.py new file mode 100644 index 0000000..d1f9346 --- /dev/null +++ b/bettermap/bettermap.py @@ -0,0 +1,211 @@ +#!/usr/bin/python3 +import io +from typing import * +import sys +from concurrent.futures import ThreadPoolExecutor +import collections + +import itertools +import multiprocessing as mp +import multiprocessing.connection +import dill + +from queue import Queue +from threading import Thread + +def threaded_generator(g, maxsize:int = 16): + q = Queue(maxsize=maxsize) + + sentinel = object() + + def fill_queue(): + try: + for value in g: + q.put(value) + finally: + q.put(sentinel) + + thread = Thread(name=repr(g), target=fill_queue, daemon=True) + thread.start() + + yield from iter(q.get, sentinel) + +def slices(n: int, i: Iterable) -> Iterable[List]: + i = iter(i) + while True: + s = list(itertools.islice(i, n)) + if len(s) > 0: + yield s + else: + break + +def window(seq: Iterable, n:int = 2) -> Iterable[List]: + win = collections.deque(maxlen=n) + for e in seq: + win.append(e) + if len(win) == n: + yield list(win) + +def map_per_process( + fn, + input_sequence: Iterable, + *, + serialization_items: Optional[List[Any]] = None, + parallelism: int = mp.cpu_count() +) -> Iterable: + if serialization_items is not None and len(serialization_items) > 0: + serialization_ids = [id(o) for o in serialization_items] + class MapPickler(dill.Pickler): + def persistent_id(self, obj): + try: + return serialization_ids.index(id(obj)) + except ValueError: + return None + class MapUnpickler(dill.Unpickler): + def persistent_load(self, pid): + return serialization_items[pid] + else: + MapPickler = dill.Pickler + MapUnpickler = dill.Unpickler + def pickle(o: Any) -> bytes: + with io.BytesIO() as buffer: + pickler = MapPickler(buffer) + pickler.dump(o) + return buffer.getvalue() + def unpickle(b: bytes) -> Any: + with io.BytesIO(b) as buffer: + unpickler = MapUnpickler(buffer) + return unpickler.load() + + pipeno_to_pipe: Dict[int, multiprocessing.connection.Connection] = {} + pipeno_to_process: Dict[int, mp.Process] = {} + + def process_one_item(send_pipe: multiprocessing.connection.Connection, item): + try: + processed_item = fn(item) + except Exception as e: + import traceback + send_pipe.send((None, (e, traceback.format_exc()))) + else: + send_pipe.send((pickle(processed_item), None)) + send_pipe.close() + + def yield_from_pipes(pipes: List[multiprocessing.connection.Connection]): + for pipe in pipes: + result, error = pipe.recv() + pipeno = pipe.fileno() + del pipeno_to_pipe[pipeno] + pipe.close() + + process = pipeno_to_process[pipeno] + process.join() + del pipeno_to_process[pipeno] + + if error is None: + yield unpickle(result) + else: + e, tb = error + sys.stderr.write("".join(tb)) + raise e + + try: + for item in input_sequence: + receive_pipe, send_pipe = mp.Pipe(duplex=False) + process = mp.Process(target=process_one_item, args=(send_pipe, item)) + pipeno_to_pipe[receive_pipe.fileno()] = receive_pipe + pipeno_to_process[receive_pipe.fileno()] = process + process.start() + + # read out the values + timeout = 0 if len(pipeno_to_process) < parallelism else None + # If we have fewer processes going than we have CPUs, we just pick up the values + # that are done. If we are at the process limit, we wait until one of them is done. + ready_pipes = multiprocessing.connection.wait(pipeno_to_pipe.values(), timeout=timeout) + yield from yield_from_pipes(ready_pipes) + + # yield the rest of the items + while len(pipeno_to_process) > 0: + ready_pipes = multiprocessing.connection.wait(pipeno_to_pipe.values(), timeout=None) + yield from yield_from_pipes(ready_pipes) + + finally: + for process in pipeno_to_process.values(): + if process.is_alive(): + process.terminate() + +def ordered_map_per_process( + fn, + input_sequence: Iterable, + *, + serialization_items: Optional[List[Any]] = None +) -> Iterable: + def process_item(item): + index, item = item + return index, fn(item) + results_with_index = map_per_process( + process_item, + enumerate(input_sequence), + serialization_items=serialization_items) + + expected_index = 0 + items_in_wait = [] + for item in results_with_index: + index, result = item + if index == expected_index: + yield result + expected_index = index + 1 + + items_in_wait.sort(reverse=True) + while len(items_in_wait) > 0 and items_in_wait[-1][0] == expected_index: + index, result = items_in_wait.pop() + yield result + expected_index = index + 1 + else: + items_in_wait.append(item) + +def ordered_map_per_thread( + fn, + input_sequence: Iterable, + *, + parallelism: int = mp.cpu_count() +) -> Iterable: + executor = ThreadPoolExecutor(max_workers=parallelism) + input_sequence = (executor.submit(fn, item) for item in input_sequence) + input_sequence = threaded_generator(input_sequence, maxsize=parallelism) + for future in input_sequence: + yield future.result() + executor.shutdown() + +def map_in_chunks( + fn, + input_sequence: Iterable, + *, + chunk_size: int = 10, + serialization_items: Optional[List[Any]] = None +) -> Iterable: + def process_chunk(chunk: List) -> List: + return list(map(fn, chunk)) + + processed_chunks = map_per_process( + process_chunk, + slices(chunk_size, input_sequence), + serialization_items=serialization_items) + for processed_chunk in processed_chunks: + yield from processed_chunk + +def ordered_map_in_chunks( + fn, + input_sequence: Iterable, + *, + chunk_size: int = 10, + serialization_items: Optional[List[Any]] = None +) -> Iterable: + def process_chunk(chunk: List) -> List: + return list(map(fn, chunk)) + + processed_chunks = ordered_map_per_process( + process_chunk, + slices(chunk_size, input_sequence), + serialization_items=serialization_items) + for processed_chunk in processed_chunks: + yield from processed_chunk diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..a0040bc --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup, find_packages + +setup( + name='bettermap', + version='1.0.0', + description="Drop-in replacements for Python's map function", + url='https://github.com/allenai/bettermap', + author="Dirk Groeneveld", + author_email="dirkg@allenai.org", + packages=find_packages(), + py_modules=['pipette'], + install_requires=['dill'], + python_requires='>=3.6' +)