Skip to content

Commit

Permalink
[batch] maybe reduce average JVMJob "connecting to jvm" time
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan King committed Oct 20, 2023
1 parent 35994fb commit f2145c4
Showing 1 changed file with 42 additions and 25 deletions.
67 changes: 42 additions & 25 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,7 @@ async def cleanup(self):
except asyncio.CancelledError:
raise
except Exception as e:
self.jvm = await self.worker.recreate_jvm(self.jvm)
raise IncompleteJVMCleanupError(
f'while unmounting fuse blob storage {bucket} from {mount_path} for {self.jvm_name} for job {self.id}'
) from e
Expand Down Expand Up @@ -2910,6 +2911,14 @@ async def get_job_resource_usage(self) -> bytes:
return await self.container.get_job_resource_usage()


class JVMQueue:
def __init__(self, n_cores):
self.queue: asyncio.Queue[JVM] = asyncio.Queue()
self.total = 0
self.target = CORES // n_cores
self.n_cores = n_cores


class Worker:
def __init__(self, client_session: httpx.ClientSession):
self.active = False
Expand Down Expand Up @@ -2942,39 +2951,50 @@ def __init__(self, client_session: httpx.ClientSession):

self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager()

self._jvms_by_cores: Dict[int, JVMQueue] = {
n_cores: JVMQueue(n_cores) for n_cores in (1, 2, 4, 8)
}
self._jvm_waiters: asyncio.Queue[int] = asyncio.Queue()
self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms())
self._jvms = SortedSet([], key=lambda jvm: jvm.n_cores)

async def _initialize_jvms(self):
assert instance_config
if instance_config.worker_type() in ('standard', 'D', 'highmem', 'E'):
jvms: List[Awaitable[JVM]] = []
for jvm_cores in (1, 2, 4, 8):
for _ in range(CORES // jvm_cores):
jvms.append(JVM.create(len(jvms), jvm_cores, self))
assert len(jvms) == N_JVM_CONTAINERS
self._jvms.update(await asyncio.gather(*jvms))
log.info(f'JVMs initialized {self._jvms}')
global_jvm_index = 0
while True:
try:
self._jvm_waiters.get_nowait()
except asyncio.QueueEmpty:
for n_cores, jvmqueue in self._jvms_by_cores.items():
if jvmqueue.target != jvmqueue.total:
jvmqueue.queue.put_nowait(await JVM.create(global_jvm_index, n_cores, self))
jvmqueue.total += 1
global_jvm_index += 1
continue
break
log.info(f'JVMs initialized {self._jvms_by_cores}')

async def borrow_jvm(self, n_cores: int) -> JVM:
assert instance_config
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'):
raise ValueError(f'no JVMs available on {instance_config.worker_type()}')
await self._jvm_initializer_task
assert self._jvms
index = self._jvms.bisect_key_left(n_cores)
assert index < len(self._jvms), index
return self._jvms.pop(index)

jvmqueue = self._jvms_by_cores[n_cores]
try:
return jvmqueue.queue.get_nowait()
except asyncio.QueueEmpty:
assert not self._jvm_initializer_task.done()
self._jvm_waiters.put_nowait(n_cores)
return await jvmqueue.queue.get()

def return_jvm(self, jvm: JVM):
jvm.reset()
self._jvms.add(jvm)
self._jvms_by_cores[jvm.n_cores].queue.put_nowait(jvm)

async def recreate_jvm(self, jvm: JVM):
self._jvms.remove(jvm)
log.info(f'quarantined {jvm} and recreated a new jvm')
new_jvm = await JVM.create(jvm.index, jvm.n_cores, self)
self._jvms.add(new_jvm)
await jvm.kill() # is this OK to do? Seems like we ought to, no?
log.info(f'killed {jvm} and recreated a new jvm')
return await JVM.create(jvm.index, jvm.n_cores, self)

@property
def headers(self) -> Dict[str, str]:
Expand All @@ -2984,8 +3004,10 @@ async def shutdown(self):
log.info('Worker.shutdown')
self._jvm_initializer_task.cancel()
async with AsyncExitStack() as cleanup:
for jvm in self._jvms:
cleanup.push_async_callback(jvm.kill)
for _, jvmqueue in self._jvms_by_cores.items():
assert jvmqueue.queue.qsize() == jvmqueue.target
while not jvmqueue.queue.empty():
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill)
cleanup.push_async_callback(self.task_manager.shutdown_and_wait)
if self.file_store:
cleanup.push_async_callback(self.file_store.close)
Expand All @@ -3000,11 +3022,6 @@ async def run_job(self, job):
raise
except JVMCreationError:
self.stop_event.set()
except IncompleteJVMCleanupError:
assert isinstance(job, JVMJob)
assert job.jvm is not None
await self.recreate_jvm(job.jvm)
log.exception(f'while running {job}, ignoring')
except Exception as e:
if not user_error(e):
log.exception(f'while running {job}, ignoring')
Expand Down

0 comments on commit f2145c4

Please sign in to comment.