Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[batch] maybe reduce average JVMJob "connecting to jvm" time #13870

Merged
merged 13 commits into from
Oct 25, 2023
Merged
74 changes: 48 additions & 26 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import orjson
from aiodocker.exceptions import DockerError # type: ignore
from aiohttp import web
from sortedcontainers import SortedSet

from gear import json_request, json_response
from hailtop import aiotools, httpx
Expand Down Expand Up @@ -2373,6 +2372,7 @@ async def cleanup(self):
except asyncio.CancelledError:
raise
except Exception as e:
self.worker.return_jvm(await self.worker.recreate_jvm(self.jvm))
raise IncompleteJVMCleanupError(
f'while unmounting fuse blob storage {bucket} from {mount_path} for {self.jvm_name} for job {self.id}'
) from e
Expand Down Expand Up @@ -2910,6 +2910,17 @@ async def get_job_resource_usage(self) -> bytes:
return await self.container.get_job_resource_usage()


class JVMQueue:
def __init__(self, n_cores):
self.queue: asyncio.Queue[JVM] = asyncio.Queue()
self.total = 0
self.target = CORES // n_cores
self.n_cores = n_cores

def __repr__(self):
return f'JVMQueue({repr(self.queue)}, {self.total}, {self.target}, {self.n_cores})'


class Worker:
def __init__(self, client_session: httpx.ClientSession):
self.active = False
Expand Down Expand Up @@ -2942,39 +2953,54 @@ def __init__(self, client_session: httpx.ClientSession):

self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager()

self._jvms_by_cores: Dict[int, JVMQueue] = {
n_cores: JVMQueue(n_cores) for n_cores in (1, 2, 4, 8)
}
self._jvm_waiters: asyncio.Queue[int] = asyncio.Queue()
self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms())
self._jvms = SortedSet([], key=lambda jvm: jvm.n_cores)

async def _initialize_jvms(self):
assert instance_config
if instance_config.worker_type() in ('standard', 'D', 'highmem', 'E'):
jvms: List[Awaitable[JVM]] = []
for jvm_cores in (1, 2, 4, 8):
for _ in range(CORES // jvm_cores):
jvms.append(JVM.create(len(jvms), jvm_cores, self))
assert len(jvms) == N_JVM_CONTAINERS
self._jvms.update(await asyncio.gather(*jvms))
log.info(f'JVMs initialized {self._jvms}')
global_jvm_index = 0
while True:
try:
n_cores = self._jvm_waiters.get_nowait()
jvmqueue = self._jvms_by_cores[n_cores]
jvmqueue.queue.put_nowait(await JVM.create(global_jvm_index, n_cores, self))
jvmqueue.total += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a hard time following this code and why the outer queue of jvm_waiters is necessary.

global_jvm_index += 1
except asyncio.QueueEmpty:
for n_cores, jvmqueue in self._jvms_by_cores.items():
while jvmqueue.target != jvmqueue.total:
jvmqueue.queue.put_nowait(await JVM.create(global_jvm_index, n_cores, self))
jvmqueue.total += 1
global_jvm_index += 1
break
assert self._jvm_waiters.empty()
log.info(f'JVMs initialized {self._jvms_by_cores}')

async def borrow_jvm(self, n_cores: int) -> JVM:
assert instance_config
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'):
raise ValueError(f'no JVMs available on {instance_config.worker_type()}')
await self._jvm_initializer_task
assert self._jvms
index = self._jvms.bisect_key_left(n_cores)
assert index < len(self._jvms), index
return self._jvms.pop(index)

jvmqueue = self._jvms_by_cores[n_cores]
try:
return jvmqueue.queue.get_nowait()
except asyncio.QueueEmpty:
assert not self._jvm_initializer_task.done(), (CORES, n_cores, self._jvms_by_cores)
self._jvm_waiters.put_nowait(n_cores)
return await jvmqueue.queue.get()

def return_jvm(self, jvm: JVM):
jvm.reset()
self._jvms.add(jvm)
self._jvms_by_cores[jvm.n_cores].queue.put_nowait(jvm)

async def recreate_jvm(self, jvm: JVM):
self._jvms.remove(jvm)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a latent bug; the JVM is still owned by the job when recreate_jvm is called, it won't be in this array. If this every happened it would fail.

log.info(f'quarantined {jvm} and recreated a new jvm')
new_jvm = await JVM.create(jvm.index, jvm.n_cores, self)
self._jvms.add(new_jvm)
await jvm.kill() # is this OK to do? Seems like we ought to, no?
log.info(f'killed {jvm} and recreated a new jvm')
return await JVM.create(jvm.index, jvm.n_cores, self)

@property
def headers(self) -> Dict[str, str]:
Expand All @@ -2984,8 +3010,9 @@ async def shutdown(self):
log.info('Worker.shutdown')
self._jvm_initializer_task.cancel()
async with AsyncExitStack() as cleanup:
for jvm in self._jvms:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit odd; the JVMs might still be held by jobs, right?

cleanup.push_async_callback(jvm.kill)
for jvmqueue in self._jvms_by_cores.values():
while not jvmqueue.queue.empty():
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill)
cleanup.push_async_callback(self.task_manager.shutdown_and_wait)
if self.file_store:
cleanup.push_async_callback(self.file_store.close)
Expand All @@ -3000,11 +3027,6 @@ async def run_job(self, job):
raise
except JVMCreationError:
self.stop_event.set()
except IncompleteJVMCleanupError:
assert isinstance(job, JVMJob)
assert job.jvm is not None
await self.recreate_jvm(job.jvm)
log.exception(f'while running {job}, ignoring')
except Exception as e:
if not user_error(e):
log.exception(f'while running {job}, ignoring')
Expand Down