-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[batch] maybe reduce average JVMJob "connecting to jvm" time #13870
Changes from 6 commits
92bf566
44a2d4e
996fdab
922a501
9fcf0c4
b3a52d5
0b2e4b7
3464c29
1d9776f
4d8fdef
3dfbd86
5894a99
9e9e504
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,6 @@ | |
import orjson | ||
from aiodocker.exceptions import DockerError # type: ignore | ||
from aiohttp import web | ||
from sortedcontainers import SortedSet | ||
|
||
from gear import json_request, json_response | ||
from hailtop import aiotools, httpx | ||
|
@@ -2373,6 +2372,7 @@ async def cleanup(self): | |
except asyncio.CancelledError: | ||
raise | ||
except Exception as e: | ||
self.worker.return_jvm(await self.worker.recreate_jvm(self.jvm)) | ||
raise IncompleteJVMCleanupError( | ||
f'while unmounting fuse blob storage {bucket} from {mount_path} for {self.jvm_name} for job {self.id}' | ||
) from e | ||
|
@@ -2910,6 +2910,17 @@ async def get_job_resource_usage(self) -> bytes: | |
return await self.container.get_job_resource_usage() | ||
|
||
|
||
class JVMQueue: | ||
def __init__(self, n_cores): | ||
self.queue: asyncio.Queue[JVM] = asyncio.Queue() | ||
self.total = 0 | ||
self.target = CORES // n_cores | ||
self.n_cores = n_cores | ||
|
||
def __repr__(self): | ||
return f'JVMQueue({repr(self.queue)}, {self.total}, {self.target}, {self.n_cores})' | ||
|
||
|
||
class Worker: | ||
def __init__(self, client_session: httpx.ClientSession): | ||
self.active = False | ||
|
@@ -2942,39 +2953,54 @@ def __init__(self, client_session: httpx.ClientSession): | |
|
||
self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager() | ||
|
||
self._jvms_by_cores: Dict[int, JVMQueue] = { | ||
n_cores: JVMQueue(n_cores) for n_cores in (1, 2, 4, 8) | ||
} | ||
self._jvm_waiters: asyncio.Queue[int] = asyncio.Queue() | ||
self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms()) | ||
self._jvms = SortedSet([], key=lambda jvm: jvm.n_cores) | ||
|
||
async def _initialize_jvms(self): | ||
assert instance_config | ||
if instance_config.worker_type() in ('standard', 'D', 'highmem', 'E'): | ||
jvms: List[Awaitable[JVM]] = [] | ||
for jvm_cores in (1, 2, 4, 8): | ||
for _ in range(CORES // jvm_cores): | ||
jvms.append(JVM.create(len(jvms), jvm_cores, self)) | ||
assert len(jvms) == N_JVM_CONTAINERS | ||
self._jvms.update(await asyncio.gather(*jvms)) | ||
log.info(f'JVMs initialized {self._jvms}') | ||
global_jvm_index = 0 | ||
while True: | ||
try: | ||
n_cores = self._jvm_waiters.get_nowait() | ||
jvmqueue = self._jvms_by_cores[n_cores] | ||
jvmqueue.queue.put_nowait(await JVM.create(global_jvm_index, n_cores, self)) | ||
jvmqueue.total += 1 | ||
global_jvm_index += 1 | ||
except asyncio.QueueEmpty: | ||
for n_cores, jvmqueue in self._jvms_by_cores.items(): | ||
while jvmqueue.target != jvmqueue.total: | ||
jvmqueue.queue.put_nowait(await JVM.create(global_jvm_index, n_cores, self)) | ||
jvmqueue.total += 1 | ||
global_jvm_index += 1 | ||
break | ||
assert self._jvm_waiters.empty() | ||
log.info(f'JVMs initialized {self._jvms_by_cores}') | ||
|
||
async def borrow_jvm(self, n_cores: int) -> JVM: | ||
assert instance_config | ||
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'): | ||
raise ValueError(f'no JVMs available on {instance_config.worker_type()}') | ||
await self._jvm_initializer_task | ||
assert self._jvms | ||
index = self._jvms.bisect_key_left(n_cores) | ||
assert index < len(self._jvms), index | ||
return self._jvms.pop(index) | ||
|
||
jvmqueue = self._jvms_by_cores[n_cores] | ||
try: | ||
return jvmqueue.queue.get_nowait() | ||
except asyncio.QueueEmpty: | ||
assert not self._jvm_initializer_task.done(), (CORES, n_cores, self._jvms_by_cores) | ||
self._jvm_waiters.put_nowait(n_cores) | ||
return await jvmqueue.queue.get() | ||
|
||
def return_jvm(self, jvm: JVM): | ||
jvm.reset() | ||
self._jvms.add(jvm) | ||
self._jvms_by_cores[jvm.n_cores].queue.put_nowait(jvm) | ||
|
||
async def recreate_jvm(self, jvm: JVM): | ||
self._jvms.remove(jvm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a latent bug; the JVM is still owned by the job when recreate_jvm is called, it won't be in this array. If this every happened it would fail. |
||
log.info(f'quarantined {jvm} and recreated a new jvm') | ||
new_jvm = await JVM.create(jvm.index, jvm.n_cores, self) | ||
self._jvms.add(new_jvm) | ||
await jvm.kill() # is this OK to do? Seems like we ought to, no? | ||
log.info(f'killed {jvm} and recreated a new jvm') | ||
return await JVM.create(jvm.index, jvm.n_cores, self) | ||
|
||
@property | ||
def headers(self) -> Dict[str, str]: | ||
|
@@ -2984,8 +3010,9 @@ async def shutdown(self): | |
log.info('Worker.shutdown') | ||
self._jvm_initializer_task.cancel() | ||
async with AsyncExitStack() as cleanup: | ||
for jvm in self._jvms: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems a bit odd; the JVMs might still be held by jobs, right? |
||
cleanup.push_async_callback(jvm.kill) | ||
for jvmqueue in self._jvms_by_cores.values(): | ||
while not jvmqueue.queue.empty(): | ||
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill) | ||
cleanup.push_async_callback(self.task_manager.shutdown_and_wait) | ||
if self.file_store: | ||
cleanup.push_async_callback(self.file_store.close) | ||
|
@@ -3000,11 +3027,6 @@ async def run_job(self, job): | |
raise | ||
except JVMCreationError: | ||
self.stop_event.set() | ||
except IncompleteJVMCleanupError: | ||
assert isinstance(job, JVMJob) | ||
assert job.jvm is not None | ||
await self.recreate_jvm(job.jvm) | ||
log.exception(f'while running {job}, ignoring') | ||
except Exception as e: | ||
if not user_error(e): | ||
log.exception(f'while running {job}, ignoring') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm having a hard time following this code and why the outer queue of
jvm_waiters
is necessary.