Skip to content

Commit

Permalink
Give this some proper package structure
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Jun 27, 2019
1 parent 597acb4 commit 1a7e237
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
/*.iml
/__pycache__/
/*.egg-info/
/build/
/dist/
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bettermap
=========

`bettermap` is a drop-in replacement for Python's map function. It parallelizes
the map function across all available processors.
1 change: 1 addition & 0 deletions bettermap/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bettermap import *
211 changes: 211 additions & 0 deletions bettermap/bettermap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
#!/usr/bin/python3
import io
from typing import *
import sys
from concurrent.futures import ThreadPoolExecutor
import collections

import itertools
import multiprocessing as mp
import multiprocessing.connection
import dill

from queue import Queue
from threading import Thread

def threaded_generator(g, maxsize:int = 16):
q = Queue(maxsize=maxsize)

sentinel = object()

def fill_queue():
try:
for value in g:
q.put(value)
finally:
q.put(sentinel)

thread = Thread(name=repr(g), target=fill_queue, daemon=True)
thread.start()

yield from iter(q.get, sentinel)

def slices(n: int, i: Iterable) -> Iterable[List]:
i = iter(i)
while True:
s = list(itertools.islice(i, n))
if len(s) > 0:
yield s
else:
break

def window(seq: Iterable, n:int = 2) -> Iterable[List]:
win = collections.deque(maxlen=n)
for e in seq:
win.append(e)
if len(win) == n:
yield list(win)

def map_per_process(
fn,
input_sequence: Iterable,
*,
serialization_items: Optional[List[Any]] = None,
parallelism: int = mp.cpu_count()
) -> Iterable:
if serialization_items is not None and len(serialization_items) > 0:
serialization_ids = [id(o) for o in serialization_items]
class MapPickler(dill.Pickler):
def persistent_id(self, obj):
try:
return serialization_ids.index(id(obj))
except ValueError:
return None
class MapUnpickler(dill.Unpickler):
def persistent_load(self, pid):
return serialization_items[pid]
else:
MapPickler = dill.Pickler
MapUnpickler = dill.Unpickler
def pickle(o: Any) -> bytes:
with io.BytesIO() as buffer:
pickler = MapPickler(buffer)
pickler.dump(o)
return buffer.getvalue()
def unpickle(b: bytes) -> Any:
with io.BytesIO(b) as buffer:
unpickler = MapUnpickler(buffer)
return unpickler.load()

pipeno_to_pipe: Dict[int, multiprocessing.connection.Connection] = {}
pipeno_to_process: Dict[int, mp.Process] = {}

def process_one_item(send_pipe: multiprocessing.connection.Connection, item):
try:
processed_item = fn(item)
except Exception as e:
import traceback
send_pipe.send((None, (e, traceback.format_exc())))
else:
send_pipe.send((pickle(processed_item), None))
send_pipe.close()

def yield_from_pipes(pipes: List[multiprocessing.connection.Connection]):
for pipe in pipes:
result, error = pipe.recv()
pipeno = pipe.fileno()
del pipeno_to_pipe[pipeno]
pipe.close()

process = pipeno_to_process[pipeno]
process.join()
del pipeno_to_process[pipeno]

if error is None:
yield unpickle(result)
else:
e, tb = error
sys.stderr.write("".join(tb))
raise e

try:
for item in input_sequence:
receive_pipe, send_pipe = mp.Pipe(duplex=False)
process = mp.Process(target=process_one_item, args=(send_pipe, item))
pipeno_to_pipe[receive_pipe.fileno()] = receive_pipe
pipeno_to_process[receive_pipe.fileno()] = process
process.start()

# read out the values
timeout = 0 if len(pipeno_to_process) < parallelism else None
# If we have fewer processes going than we have CPUs, we just pick up the values
# that are done. If we are at the process limit, we wait until one of them is done.
ready_pipes = multiprocessing.connection.wait(pipeno_to_pipe.values(), timeout=timeout)
yield from yield_from_pipes(ready_pipes)

# yield the rest of the items
while len(pipeno_to_process) > 0:
ready_pipes = multiprocessing.connection.wait(pipeno_to_pipe.values(), timeout=None)
yield from yield_from_pipes(ready_pipes)

finally:
for process in pipeno_to_process.values():
if process.is_alive():
process.terminate()

def ordered_map_per_process(
fn,
input_sequence: Iterable,
*,
serialization_items: Optional[List[Any]] = None
) -> Iterable:
def process_item(item):
index, item = item
return index, fn(item)
results_with_index = map_per_process(
process_item,
enumerate(input_sequence),
serialization_items=serialization_items)

expected_index = 0
items_in_wait = []
for item in results_with_index:
index, result = item
if index == expected_index:
yield result
expected_index = index + 1

items_in_wait.sort(reverse=True)
while len(items_in_wait) > 0 and items_in_wait[-1][0] == expected_index:
index, result = items_in_wait.pop()
yield result
expected_index = index + 1
else:
items_in_wait.append(item)

def ordered_map_per_thread(
fn,
input_sequence: Iterable,
*,
parallelism: int = mp.cpu_count()
) -> Iterable:
executor = ThreadPoolExecutor(max_workers=parallelism)
input_sequence = (executor.submit(fn, item) for item in input_sequence)
input_sequence = threaded_generator(input_sequence, maxsize=parallelism)
for future in input_sequence:
yield future.result()
executor.shutdown()

def map_in_chunks(
fn,
input_sequence: Iterable,
*,
chunk_size: int = 10,
serialization_items: Optional[List[Any]] = None
) -> Iterable:
def process_chunk(chunk: List) -> List:
return list(map(fn, chunk))

processed_chunks = map_per_process(
process_chunk,
slices(chunk_size, input_sequence),
serialization_items=serialization_items)
for processed_chunk in processed_chunks:
yield from processed_chunk

def ordered_map_in_chunks(
fn,
input_sequence: Iterable,
*,
chunk_size: int = 10,
serialization_items: Optional[List[Any]] = None
) -> Iterable:
def process_chunk(chunk: List) -> List:
return list(map(fn, chunk))

processed_chunks = ordered_map_per_process(
process_chunk,
slices(chunk_size, input_sequence),
serialization_items=serialization_items)
for processed_chunk in processed_chunks:
yield from processed_chunk
14 changes: 14 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from setuptools import setup, find_packages

setup(
name='bettermap',
version='1.0.0',
description="Drop-in replacements for Python's map function",
url='https://github.com/allenai/bettermap',
author="Dirk Groeneveld",
author_email="[email protected]",
packages=find_packages(),
py_modules=['pipette'],
install_requires=['dill'],
python_requires='>=3.6'
)

0 comments on commit 1a7e237

Please sign in to comment.