Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[batch] maybe reduce average JVMJob "connecting to jvm" time #13870

Merged
merged 13 commits into from
Oct 25, 2023
Merged
108 changes: 80 additions & 28 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import orjson
from aiodocker.exceptions import DockerError # type: ignore
from aiohttp import web
from sortedcontainers import SortedSet

from gear import json_request, json_response
from hailtop import aiotools, httpx
Expand Down Expand Up @@ -2373,6 +2372,7 @@ async def cleanup(self):
except asyncio.CancelledError:
raise
except Exception as e:
await self.worker.return_broken_jvm(self.jvm)
raise IncompleteJVMCleanupError(
f'while unmounting fuse blob storage {bucket} from {mount_path} for {self.jvm_name} for job {self.id}'
) from e
Expand Down Expand Up @@ -2910,6 +2910,47 @@ async def get_job_resource_usage(self) -> bytes:
return await self.container.get_job_resource_usage()


class JVMPool:
global_jvm_index = 0

def __init__(self, n_cores: int, worker: Worker):
self.queue: asyncio.Queue[JVM] = asyncio.Queue()
self.total_jvms_including_borrowed = 0
self.max_jvms = CORES // n_cores
self.n_cores = n_cores
self.worker = worker

def borrow_jvm_nowait(self) -> JVM:
return self.jvms.get_nowait()

async def borrow_jvm(self) -> JVM:
return await self.jvms.get()

def return_jvm(self, jvm: JVM):
assert self.n_cores == jvm.n_cores
assert self.queue.qsize() < self.max_jvms
self.queue.put_nowait(jvm)

async def return_broken_jvm(self, jvm: JVM):
jvm.kill()
self.total_jvms_including_borrowed -= 1
await self.create_jvm()
log.info(f'killed {jvm} and recreated a new jvm')

async def create_jvm(self):
assert self.queue.qsize() < self.max_jvms
assert self.total_jvms_including_borrowed < self.max_jvms
self.queue.put_nowait(await JVM.create(JVMPool.global_jvm_index, self.n_cores, self.worker))
self.total_jvms_including_borrowed += 1
JVMPool.global_jvm_index += 1

def full(self) -> bool:
return self.total_jvms_including_borrowed == self.max_jvms

def __repr__(self):
return f'JVMPool({self.jvms!r}, {self.total_jvms_including_borrowed!r}, {self.max_jvms!r}, {self.n_cores!r})'


class Worker:
def __init__(self, client_session: httpx.ClientSession):
self.active = False
Expand Down Expand Up @@ -2942,39 +2983,54 @@ def __init__(self, client_session: httpx.ClientSession):

self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager()

self._jvmpools_by_cores: Dict[int, JVMQueue] = {
n_cores: JVMPool(n_cores) for n_cores in (1, 2, 4, 8)
}
self._waiting_for_jvm_with_n_cores: asyncio.Queue[int] = asyncio.Queue()
self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms())
self._jvms = SortedSet([], key=lambda jvm: jvm.n_cores)

async def _initialize_jvms(self):
assert instance_config
if instance_config.worker_type() in ('standard', 'D', 'highmem', 'E'):
jvms: List[Awaitable[JVM]] = []
for jvm_cores in (1, 2, 4, 8):
for _ in range(CORES // jvm_cores):
jvms.append(JVM.create(len(jvms), jvm_cores, self))
assert len(jvms) == N_JVM_CONTAINERS
self._jvms.update(await asyncio.gather(*jvms))
log.info(f'JVMs initialized {self._jvms}')
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'):
log.info(f'no JVMs initialized')

while True:
try:
n_cores = self._waiting_for_jvm_with_n_cores.get_nowait()
await self._jvmpools_by_cores[n_cores].create_jvm()
except asyncio.QueueEmpty:
next_unfull_jvmpool = None
for jvmpool in self._jvmpools_by_cores.values():
if not jvmpool.full():
next_unfull_jvmpool = jvmpool
break

if next_unfull_jvmpool is None:
break
await next_unfull_jvmpool.create_jvm()

assert self._waiting_for_jvm_with_n_cores.empty()
assert all(jvmpool.full() for jvmpool in self._jvmpools_by_cores.values())
log.info(f'JVMs initialized {self._jvmpools_by_cores}')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jigold Thanks for pushing back! What do you think of it now?

To directly answer your question: one queue (jvmpool.queue) is a place for a consumer to borrow a JVM, the other queue (waiting_for_jvm_with_n_cores) is a place for a producer to learn that a consumer is waiting.

Without waiting_for_jvm_with_n_cores, _initialize_jvms has no way to be told that someone is waiting for a JVM. asyncio.Queue doesn't expose a method like has_waiters().


async def borrow_jvm(self, n_cores: int) -> JVM:
assert instance_config
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'):
raise ValueError(f'no JVMs available on {instance_config.worker_type()}')
await self._jvm_initializer_task
assert self._jvms
index = self._jvms.bisect_key_left(n_cores)
assert index < len(self._jvms), index
return self._jvms.pop(index)

jvmpool = self._jvmpools_by_cores[n_cores]
try:
return jvmpool.borrow_jvm_nowait()
except asyncio.QueueEmpty:
self._waiting_for_jvm_with_n_cores.put_nowait(n_cores)
return await jvmpool.borrow_jvm()

def return_jvm(self, jvm: JVM):
jvm.reset()
self._jvms.add(jvm)
self._jvmpools_by_cores[jvm.n_cores].return_jvm(jvm)

async def recreate_jvm(self, jvm: JVM):
self._jvms.remove(jvm)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a latent bug; the JVM is still owned by the job when recreate_jvm is called, it won't be in this array. If this every happened it would fail.

log.info(f'quarantined {jvm} and recreated a new jvm')
new_jvm = await JVM.create(jvm.index, jvm.n_cores, self)
self._jvms.add(new_jvm)
async def return_broken_jvm(self, jvm: JVM):
return await self._jvmpools_by_cores[jvm.n_cores].return_broken_jvm(jvm)

@property
def headers(self) -> Dict[str, str]:
Expand All @@ -2984,8 +3040,9 @@ async def shutdown(self):
log.info('Worker.shutdown')
self._jvm_initializer_task.cancel()
async with AsyncExitStack() as cleanup:
for jvm in self._jvms:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit odd; the JVMs might still be held by jobs, right?

cleanup.push_async_callback(jvm.kill)
for jvmqueue in self._jvms_by_cores.values():
while not jvmqueue.queue.empty():
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill)
cleanup.push_async_callback(self.task_manager.shutdown_and_wait)
if self.file_store:
cleanup.push_async_callback(self.file_store.close)
Expand All @@ -3000,11 +3057,6 @@ async def run_job(self, job):
raise
except JVMCreationError:
self.stop_event.set()
except IncompleteJVMCleanupError:
assert isinstance(job, JVMJob)
assert job.jvm is not None
await self.recreate_jvm(job.jvm)
log.exception(f'while running {job}, ignoring')
except Exception as e:
if not user_error(e):
log.exception(f'while running {job}, ignoring')
Expand Down