-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[batch] maybe reduce average JVMJob "connecting to jvm" time #13870
Changes from 8 commits
92bf566
44a2d4e
996fdab
922a501
9fcf0c4
b3a52d5
0b2e4b7
3464c29
1d9776f
4d8fdef
3dfbd86
5894a99
9e9e504
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,6 @@ | |
import orjson | ||
from aiodocker.exceptions import DockerError # type: ignore | ||
from aiohttp import web | ||
from sortedcontainers import SortedSet | ||
|
||
from gear import json_request, json_response | ||
from hailtop import aiotools, httpx | ||
|
@@ -2373,6 +2372,7 @@ async def cleanup(self): | |
except asyncio.CancelledError: | ||
raise | ||
except Exception as e: | ||
await self.worker.return_broken_jvm(self.jvm) | ||
raise IncompleteJVMCleanupError( | ||
f'while unmounting fuse blob storage {bucket} from {mount_path} for {self.jvm_name} for job {self.id}' | ||
) from e | ||
|
@@ -2910,6 +2910,47 @@ async def get_job_resource_usage(self) -> bytes: | |
return await self.container.get_job_resource_usage() | ||
|
||
|
||
class JVMPool: | ||
global_jvm_index = 0 | ||
|
||
def __init__(self, n_cores: int, worker: Worker): | ||
self.queue: asyncio.Queue[JVM] = asyncio.Queue() | ||
self.total_jvms_including_borrowed = 0 | ||
self.max_jvms = CORES // n_cores | ||
self.n_cores = n_cores | ||
self.worker = worker | ||
|
||
def borrow_jvm_nowait(self) -> JVM: | ||
return self.jvms.get_nowait() | ||
|
||
async def borrow_jvm(self) -> JVM: | ||
return await self.jvms.get() | ||
|
||
def return_jvm(self, jvm: JVM): | ||
assert self.n_cores == jvm.n_cores | ||
assert self.queue.qsize() < self.max_jvms | ||
self.queue.put_nowait(jvm) | ||
|
||
async def return_broken_jvm(self, jvm: JVM): | ||
jvm.kill() | ||
self.total_jvms_including_borrowed -= 1 | ||
await self.create_jvm() | ||
log.info(f'killed {jvm} and recreated a new jvm') | ||
|
||
async def create_jvm(self): | ||
assert self.queue.qsize() < self.max_jvms | ||
assert self.total_jvms_including_borrowed < self.max_jvms | ||
self.queue.put_nowait(await JVM.create(JVMPool.global_jvm_index, self.n_cores, self.worker)) | ||
self.total_jvms_including_borrowed += 1 | ||
JVMPool.global_jvm_index += 1 | ||
|
||
def full(self) -> bool: | ||
return self.total_jvms_including_borrowed == self.max_jvms | ||
|
||
def __repr__(self): | ||
return f'JVMPool({self.jvms!r}, {self.total_jvms_including_borrowed!r}, {self.max_jvms!r}, {self.n_cores!r})' | ||
|
||
|
||
class Worker: | ||
def __init__(self, client_session: httpx.ClientSession): | ||
self.active = False | ||
|
@@ -2942,39 +2983,54 @@ def __init__(self, client_session: httpx.ClientSession): | |
|
||
self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager() | ||
|
||
self._jvmpools_by_cores: Dict[int, JVMQueue] = { | ||
n_cores: JVMPool(n_cores) for n_cores in (1, 2, 4, 8) | ||
} | ||
self._waiting_for_jvm_with_n_cores: asyncio.Queue[int] = asyncio.Queue() | ||
self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms()) | ||
self._jvms = SortedSet([], key=lambda jvm: jvm.n_cores) | ||
|
||
async def _initialize_jvms(self): | ||
assert instance_config | ||
if instance_config.worker_type() in ('standard', 'D', 'highmem', 'E'): | ||
jvms: List[Awaitable[JVM]] = [] | ||
for jvm_cores in (1, 2, 4, 8): | ||
for _ in range(CORES // jvm_cores): | ||
jvms.append(JVM.create(len(jvms), jvm_cores, self)) | ||
assert len(jvms) == N_JVM_CONTAINERS | ||
self._jvms.update(await asyncio.gather(*jvms)) | ||
log.info(f'JVMs initialized {self._jvms}') | ||
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'): | ||
log.info(f'no JVMs initialized') | ||
|
||
while True: | ||
try: | ||
n_cores = self._waiting_for_jvm_with_n_cores.get_nowait() | ||
await self._jvmpools_by_cores[n_cores].create_jvm() | ||
except asyncio.QueueEmpty: | ||
next_unfull_jvmpool = None | ||
for jvmpool in self._jvmpools_by_cores.values(): | ||
if not jvmpool.full(): | ||
next_unfull_jvmpool = jvmpool | ||
break | ||
|
||
if next_unfull_jvmpool is None: | ||
break | ||
await next_unfull_jvmpool.create_jvm() | ||
|
||
assert self._waiting_for_jvm_with_n_cores.empty() | ||
assert all(jvmpool.full() for jvmpool in self._jvmpools_by_cores.values()) | ||
log.info(f'JVMs initialized {self._jvmpools_by_cores}') | ||
|
||
async def borrow_jvm(self, n_cores: int) -> JVM: | ||
assert instance_config | ||
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'): | ||
raise ValueError(f'no JVMs available on {instance_config.worker_type()}') | ||
await self._jvm_initializer_task | ||
assert self._jvms | ||
index = self._jvms.bisect_key_left(n_cores) | ||
assert index < len(self._jvms), index | ||
return self._jvms.pop(index) | ||
|
||
jvmpool = self._jvmpools_by_cores[n_cores] | ||
try: | ||
return jvmpool.borrow_jvm_nowait() | ||
except asyncio.QueueEmpty: | ||
self._waiting_for_jvm_with_n_cores.put_nowait(n_cores) | ||
return await jvmpool.borrow_jvm() | ||
|
||
def return_jvm(self, jvm: JVM): | ||
jvm.reset() | ||
self._jvms.add(jvm) | ||
self._jvmpools_by_cores[jvm.n_cores].return_jvm(jvm) | ||
|
||
async def recreate_jvm(self, jvm: JVM): | ||
self._jvms.remove(jvm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a latent bug; the JVM is still owned by the job when recreate_jvm is called, it won't be in this array. If this every happened it would fail. |
||
log.info(f'quarantined {jvm} and recreated a new jvm') | ||
new_jvm = await JVM.create(jvm.index, jvm.n_cores, self) | ||
self._jvms.add(new_jvm) | ||
async def return_broken_jvm(self, jvm: JVM): | ||
return await self._jvmpools_by_cores[jvm.n_cores].return_broken_jvm(jvm) | ||
|
||
@property | ||
def headers(self) -> Dict[str, str]: | ||
|
@@ -2984,8 +3040,9 @@ async def shutdown(self): | |
log.info('Worker.shutdown') | ||
self._jvm_initializer_task.cancel() | ||
async with AsyncExitStack() as cleanup: | ||
for jvm in self._jvms: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems a bit odd; the JVMs might still be held by jobs, right? |
||
cleanup.push_async_callback(jvm.kill) | ||
for jvmqueue in self._jvms_by_cores.values(): | ||
while not jvmqueue.queue.empty(): | ||
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill) | ||
cleanup.push_async_callback(self.task_manager.shutdown_and_wait) | ||
if self.file_store: | ||
cleanup.push_async_callback(self.file_store.close) | ||
|
@@ -3000,11 +3057,6 @@ async def run_job(self, job): | |
raise | ||
except JVMCreationError: | ||
self.stop_event.set() | ||
except IncompleteJVMCleanupError: | ||
assert isinstance(job, JVMJob) | ||
assert job.jvm is not None | ||
await self.recreate_jvm(job.jvm) | ||
log.exception(f'while running {job}, ignoring') | ||
except Exception as e: | ||
if not user_error(e): | ||
log.exception(f'while running {job}, ignoring') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jigold Thanks for pushing back! What do you think of it now?
To directly answer your question: one queue (
jvmpool.queue
) is a place for a consumer to borrow a JVM, the other queue (waiting_for_jvm_with_n_cores
) is a place for a producer to learn that a consumer is waiting.Without
waiting_for_jvm_with_n_cores
,_initialize_jvms
has no way to be told that someone is waiting for a JVM.asyncio.Queue
doesn't expose a method likehas_waiters()
.