diff --git a/unidist/core/backends/mpi/core/controller/api.py b/unidist/core/backends/mpi/core/controller/api.py index ab67a30a..54a68f7b 100644 --- a/unidist/core/backends/mpi/core/controller/api.py +++ b/unidist/core/backends/mpi/core/controller/api.py @@ -172,10 +172,10 @@ def init(): mpi_state = communication.MPIState.get_instance( comm, comm.Get_rank(), comm.Get_size() ) - if rank == 0 and not threads and parent_comm == MPI.COMM_NULL: - thread = Poller(1, "Thread_Poll_Tasks", comm) - thread.start() - threads.append(thread) + # if rank == 0 and not threads and parent_comm == MPI.COMM_NULL: + # thread = Poller(1, "Thread_Poll_Tasks", comm) + # thread.start() + # threads.append(thread) global topology if not topology: diff --git a/unidist/core/backends/mpi/core/controller/common.py b/unidist/core/backends/mpi/core/controller/common.py index e0fdb22e..932422ba 100644 --- a/unidist/core/backends/mpi/core/controller/common.py +++ b/unidist/core/backends/mpi/core/controller/common.py @@ -140,6 +140,12 @@ def decrement_tasks_on_worker(self, rank): """ self.task_per_worker[rank] -= 1 + def decrement_done_tasks(self, tasks_done): + self.task_per_worker = { + key: self.task_per_worker[key] - tasks_done.get(key, 0) + for key in self.task_per_worker + } + def request_worker_data(data_id): """ diff --git a/unidist/core/backends/mpi/core/controller/garbage_collector.py b/unidist/core/backends/mpi/core/controller/garbage_collector.py index 8bc74cad..86b210e3 100644 --- a/unidist/core/backends/mpi/core/controller/garbage_collector.py +++ b/unidist/core/backends/mpi/core/controller/garbage_collector.py @@ -11,7 +11,10 @@ from unidist.core.backends.mpi.core.async_operations import AsyncOperations from unidist.core.backends.mpi.core.serialization import SimpleDataSerializer from unidist.core.backends.mpi.core.controller.object_store import object_store -from unidist.core.backends.mpi.core.controller.common import initial_worker_number +from unidist.core.backends.mpi.core.controller.common import ( + initial_worker_number, + Scheduler, +) logger = common.get_logger("utils", "utils.log") @@ -131,6 +134,11 @@ def regular_cleanup(self): mpi_state.comm, communication.MPIRank.MONITOR, ) + tasks_completed = communication.recv_simple_operation( + mpi_state.comm, + communication.MPIRank.MONITOR, + ) + Scheduler.get_instance().decrement_done_tasks(tasks_completed) logger.debug( "Submitted task count {} vs executed task count {}".format( diff --git a/unidist/core/backends/mpi/core/monitor.py b/unidist/core/backends/mpi/core/monitor.py index eaf7fa1f..accf0488 100755 --- a/unidist/core/backends/mpi/core/monitor.py +++ b/unidist/core/backends/mpi/core/monitor.py @@ -19,12 +19,20 @@ mpi4py.rc(recv_mprobe=False, initialize=False) from mpi4py import MPI # noqa: E402 +initial_worker_number = 2 + class TaskCounter: __instance = None def __init__(self): self.task_counter = 0 + self.task_done_per_worker_unsend = { + k: 0 + for k in range( + initial_worker_number, communication.MPIState.get_instance().world_size + ) + } @classmethod def get_instance(cls): @@ -39,9 +47,10 @@ def get_instance(cls): cls.__instance = TaskCounter() return cls.__instance - def increment(self): + def increment(self, rank): """Increment task counter by one.""" self.task_counter += 1 + self.task_done_per_worker_unsend[rank] += 1 def monitor_loop(): @@ -65,10 +74,8 @@ def monitor_loop(): # Proceed the request if operation_type == common.Operation.TASK_DONE: - task_counter.increment() - communication.mpi_isend_object( - mpi_state.comm, source_rank, communication.MPIRank.ROOT, 1 - ) + task_counter.increment(source_rank) + elif operation_type == common.Operation.GET_TASK_COUNT: # We use a blocking send here because the receiver is waiting for the result. communication.mpi_send_object( @@ -76,6 +83,14 @@ def monitor_loop(): task_counter.task_counter, source_rank, ) + communication.mpi_send_object( + mpi_state.comm, + task_counter.task_done_per_worker_unsend, + source_rank, + ) + task_counter.task_done_per_worker_unsend = dict.fromkeys( + task_counter.task_done_per_worker_unsend, 0 + ) elif operation_type == common.Operation.CANCEL: async_operations.finish() if not MPI.Is_finalized():