-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Give this some proper package structure
- Loading branch information
Showing
5 changed files
with
236 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
/*.iml | ||
/__pycache__/ | ||
/*.egg-info/ | ||
/build/ | ||
/dist/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
bettermap | ||
========= | ||
|
||
`bettermap` is a drop-in replacement for Python's map function. It parallelizes | ||
the map function across all available processors. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .bettermap import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
#!/usr/bin/python3 | ||
import io | ||
from typing import * | ||
import sys | ||
from concurrent.futures import ThreadPoolExecutor | ||
import collections | ||
|
||
import itertools | ||
import multiprocessing as mp | ||
import multiprocessing.connection | ||
import dill | ||
|
||
from queue import Queue | ||
from threading import Thread | ||
|
||
def threaded_generator(g, maxsize:int = 16): | ||
q = Queue(maxsize=maxsize) | ||
|
||
sentinel = object() | ||
|
||
def fill_queue(): | ||
try: | ||
for value in g: | ||
q.put(value) | ||
finally: | ||
q.put(sentinel) | ||
|
||
thread = Thread(name=repr(g), target=fill_queue, daemon=True) | ||
thread.start() | ||
|
||
yield from iter(q.get, sentinel) | ||
|
||
def slices(n: int, i: Iterable) -> Iterable[List]: | ||
i = iter(i) | ||
while True: | ||
s = list(itertools.islice(i, n)) | ||
if len(s) > 0: | ||
yield s | ||
else: | ||
break | ||
|
||
def window(seq: Iterable, n:int = 2) -> Iterable[List]: | ||
win = collections.deque(maxlen=n) | ||
for e in seq: | ||
win.append(e) | ||
if len(win) == n: | ||
yield list(win) | ||
|
||
def map_per_process( | ||
fn, | ||
input_sequence: Iterable, | ||
*, | ||
serialization_items: Optional[List[Any]] = None, | ||
parallelism: int = mp.cpu_count() | ||
) -> Iterable: | ||
if serialization_items is not None and len(serialization_items) > 0: | ||
serialization_ids = [id(o) for o in serialization_items] | ||
class MapPickler(dill.Pickler): | ||
def persistent_id(self, obj): | ||
try: | ||
return serialization_ids.index(id(obj)) | ||
except ValueError: | ||
return None | ||
class MapUnpickler(dill.Unpickler): | ||
def persistent_load(self, pid): | ||
return serialization_items[pid] | ||
else: | ||
MapPickler = dill.Pickler | ||
MapUnpickler = dill.Unpickler | ||
def pickle(o: Any) -> bytes: | ||
with io.BytesIO() as buffer: | ||
pickler = MapPickler(buffer) | ||
pickler.dump(o) | ||
return buffer.getvalue() | ||
def unpickle(b: bytes) -> Any: | ||
with io.BytesIO(b) as buffer: | ||
unpickler = MapUnpickler(buffer) | ||
return unpickler.load() | ||
|
||
pipeno_to_pipe: Dict[int, multiprocessing.connection.Connection] = {} | ||
pipeno_to_process: Dict[int, mp.Process] = {} | ||
|
||
def process_one_item(send_pipe: multiprocessing.connection.Connection, item): | ||
try: | ||
processed_item = fn(item) | ||
except Exception as e: | ||
import traceback | ||
send_pipe.send((None, (e, traceback.format_exc()))) | ||
else: | ||
send_pipe.send((pickle(processed_item), None)) | ||
send_pipe.close() | ||
|
||
def yield_from_pipes(pipes: List[multiprocessing.connection.Connection]): | ||
for pipe in pipes: | ||
result, error = pipe.recv() | ||
pipeno = pipe.fileno() | ||
del pipeno_to_pipe[pipeno] | ||
pipe.close() | ||
|
||
process = pipeno_to_process[pipeno] | ||
process.join() | ||
del pipeno_to_process[pipeno] | ||
|
||
if error is None: | ||
yield unpickle(result) | ||
else: | ||
e, tb = error | ||
sys.stderr.write("".join(tb)) | ||
raise e | ||
|
||
try: | ||
for item in input_sequence: | ||
receive_pipe, send_pipe = mp.Pipe(duplex=False) | ||
process = mp.Process(target=process_one_item, args=(send_pipe, item)) | ||
pipeno_to_pipe[receive_pipe.fileno()] = receive_pipe | ||
pipeno_to_process[receive_pipe.fileno()] = process | ||
process.start() | ||
|
||
# read out the values | ||
timeout = 0 if len(pipeno_to_process) < parallelism else None | ||
# If we have fewer processes going than we have CPUs, we just pick up the values | ||
# that are done. If we are at the process limit, we wait until one of them is done. | ||
ready_pipes = multiprocessing.connection.wait(pipeno_to_pipe.values(), timeout=timeout) | ||
yield from yield_from_pipes(ready_pipes) | ||
|
||
# yield the rest of the items | ||
while len(pipeno_to_process) > 0: | ||
ready_pipes = multiprocessing.connection.wait(pipeno_to_pipe.values(), timeout=None) | ||
yield from yield_from_pipes(ready_pipes) | ||
|
||
finally: | ||
for process in pipeno_to_process.values(): | ||
if process.is_alive(): | ||
process.terminate() | ||
|
||
def ordered_map_per_process( | ||
fn, | ||
input_sequence: Iterable, | ||
*, | ||
serialization_items: Optional[List[Any]] = None | ||
) -> Iterable: | ||
def process_item(item): | ||
index, item = item | ||
return index, fn(item) | ||
results_with_index = map_per_process( | ||
process_item, | ||
enumerate(input_sequence), | ||
serialization_items=serialization_items) | ||
|
||
expected_index = 0 | ||
items_in_wait = [] | ||
for item in results_with_index: | ||
index, result = item | ||
if index == expected_index: | ||
yield result | ||
expected_index = index + 1 | ||
|
||
items_in_wait.sort(reverse=True) | ||
while len(items_in_wait) > 0 and items_in_wait[-1][0] == expected_index: | ||
index, result = items_in_wait.pop() | ||
yield result | ||
expected_index = index + 1 | ||
else: | ||
items_in_wait.append(item) | ||
|
||
def ordered_map_per_thread( | ||
fn, | ||
input_sequence: Iterable, | ||
*, | ||
parallelism: int = mp.cpu_count() | ||
) -> Iterable: | ||
executor = ThreadPoolExecutor(max_workers=parallelism) | ||
input_sequence = (executor.submit(fn, item) for item in input_sequence) | ||
input_sequence = threaded_generator(input_sequence, maxsize=parallelism) | ||
for future in input_sequence: | ||
yield future.result() | ||
executor.shutdown() | ||
|
||
def map_in_chunks( | ||
fn, | ||
input_sequence: Iterable, | ||
*, | ||
chunk_size: int = 10, | ||
serialization_items: Optional[List[Any]] = None | ||
) -> Iterable: | ||
def process_chunk(chunk: List) -> List: | ||
return list(map(fn, chunk)) | ||
|
||
processed_chunks = map_per_process( | ||
process_chunk, | ||
slices(chunk_size, input_sequence), | ||
serialization_items=serialization_items) | ||
for processed_chunk in processed_chunks: | ||
yield from processed_chunk | ||
|
||
def ordered_map_in_chunks( | ||
fn, | ||
input_sequence: Iterable, | ||
*, | ||
chunk_size: int = 10, | ||
serialization_items: Optional[List[Any]] = None | ||
) -> Iterable: | ||
def process_chunk(chunk: List) -> List: | ||
return list(map(fn, chunk)) | ||
|
||
processed_chunks = ordered_map_per_process( | ||
process_chunk, | ||
slices(chunk_size, input_sequence), | ||
serialization_items=serialization_items) | ||
for processed_chunk in processed_chunks: | ||
yield from processed_chunk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name='bettermap', | ||
version='1.0.0', | ||
description="Drop-in replacements for Python's map function", | ||
url='https://github.com/allenai/bettermap', | ||
author="Dirk Groeneveld", | ||
author_email="[email protected]", | ||
packages=find_packages(), | ||
py_modules=['pipette'], | ||
install_requires=['dill'], | ||
python_requires='>=3.6' | ||
) |