Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unique Worker ID #2529

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions docs/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ You can also manage child processes by sending specific signals to the main proc
- `SIGTTIN`: Increase the number of worker processes by one.
- `SIGTTOU`: Decrease the number of worker processes by one.

Additionally, if the built-in process manager is used uvicorn will provide you with an unique worker ID for each worker. This worker ID will be injected into the [state](https://asgi.readthedocs.io/en/latest/specs/lifespan.html#lifespan-state) of your application as the entry `'uvicorn_worker_id': int`. This ID is consistent across restarts and enables you to define idempotent startup- and shutdown-routines for each worker process.

### Gunicorn

!!! warning
Expand Down
4 changes: 2 additions & 2 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,8 +1147,8 @@ async def open_connection(url: str):

async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
expected_states: list[dict[str, typing.Any]] = [
{"a": 123, "b": [1]},
{"a": 123, "b": [1, 2]},
{"a": 123, "b": [1], "uvicorn_worker_id": 1},
{"a": 123, "b": [1, 2], "uvicorn_worker_id": 1},
]

actual_states: list[dict[str, typing.Any]] = []
Expand Down
6 changes: 3 additions & 3 deletions tests/supervisors/test_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
pass # pragma: no cover


def run(sockets: list[socket.socket] | None) -> None:
def run(sockets: list[socket.socket] | None, process_num: int) -> None:
while True: # pragma: no cover
time.sleep(1)


def test_process_ping_pong() -> None:
process = Process(Config(app=app), target=lambda x: None, sockets=[])
process = Process(Config(app=app), target=lambda x, y: None, sockets=[], process_num=0)
threading.Thread(target=process.always_pong, daemon=True).start()
assert process.ping()


def test_process_ping_pong_timeout() -> None:
process = Process(Config(app=app), target=lambda x: None, sockets=[])
process = Process(Config(app=app), target=lambda x, y: None, sockets=[], process_num=0)
assert not process.ping(0.1)


Expand Down
8 changes: 7 additions & 1 deletion uvicorn/lifespan/on.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
from asyncio import Queue
from typing import Any, Union
from typing import Any, Optional, Union, cast

from uvicorn import Config
from uvicorn._types import (
Expand Down Expand Up @@ -78,6 +78,12 @@ async def shutdown(self) -> None:
async def main(self) -> None:
try:
app = self.config.loaded_app

# inject worker id into app state
uvicorn_worker_id = cast(Optional[int], self.state.get("uvicorn_worker_id"))
if uvicorn_worker_id is not None and hasattr(app.app, "state"):
app.app.state.uvicorn_worker_id = uvicorn_worker_id

scope: LifespanScope = {
"type": "lifespan",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.0"},
Expand Down
18 changes: 11 additions & 7 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ def __init__(self, config: Config) -> None:

self._captured_signals: list[int] = []

def run(self, sockets: list[socket.socket] | None = None) -> None:
def run(self, sockets: list[socket.socket] | None = None, process_num: int = 0) -> None:
self.config.setup_event_loop()
return asyncio.run(self.serve(sockets=sockets))
return asyncio.run(self.serve(sockets=sockets, process_num=process_num))

async def serve(self, sockets: list[socket.socket] | None = None) -> None:
async def serve(self, sockets: list[socket.socket] | None = None, process_num: int = 0) -> None:
with self.capture_signals():
await self._serve(sockets)
await self._serve(sockets, process_num)

async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
async def _serve(self, sockets: list[socket.socket] | None = None, process_num: int = 0) -> None:
process_id = os.getpid()

config = self.config
Expand All @@ -81,7 +81,7 @@ async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
logger.info(message, process_id, extra={"color_message": color_message})

await self.startup(sockets=sockets)
await self.startup(sockets=sockets, process_num=process_num)
if self.should_exit:
return
await self.main_loop()
Expand All @@ -91,7 +91,11 @@ async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
logger.info(message, process_id, extra={"color_message": color_message})

async def startup(self, sockets: list[socket.socket] | None = None) -> None:
async def startup(self, sockets: list[socket.socket] | None = None, process_num: int = 0) -> None:
# inject process_num as worker id into lifespan state
worker_id = process_num + 1
self.lifespan.state["uvicorn_worker_id"] = worker_id

await self.lifespan.startup()
if self.lifespan.should_exit:
self.should_exit = True
Expand Down
26 changes: 15 additions & 11 deletions uvicorn/supervisors/multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ class Process:
def __init__(
self,
config: Config,
target: Callable[[list[socket] | None], None],
target: Callable[[list[socket] | None, int], None],
sockets: list[socket],
process_num: int,
) -> None:
self.process_num = process_num
self.real_target = target

self.parent_conn, self.child_conn = Pipe()
Expand Down Expand Up @@ -60,7 +62,7 @@ def target(self, sockets: list[socket] | None = None) -> Any: # pragma: no cove
)

threading.Thread(target=self.always_pong, daemon=True).start()
return self.real_target(sockets)
return self.real_target(sockets, self.process_num)

def is_alive(self, timeout: float = 5) -> bool:
if not self.process.is_alive():
Expand Down Expand Up @@ -103,14 +105,13 @@ class Multiprocess:
def __init__(
self,
config: Config,
target: Callable[[list[socket] | None], None],
target: Callable[[list[socket] | None, int], None],
sockets: list[socket],
) -> None:
self.config = config
self.target = target
self.sockets = sockets

self.processes_num = config.workers
self.processes: list[Process] = []

self.should_exit = threading.Event()
Expand All @@ -119,9 +120,13 @@ def __init__(
for sig in SIGNALS:
signal.signal(sig, lambda sig, frame: self.signal_queue.append(sig))

@property
def processes_num(self) -> int:
return len(self.processes)

def init_processes(self) -> None:
for _ in range(self.processes_num):
process = Process(self.config, self.target, self.sockets)
for process_num in range(self.config.workers):
process = Process(self.config, self.target, self.sockets, process_num)
process.start()
self.processes.append(process)

Expand All @@ -137,7 +142,7 @@ def restart_all(self) -> None:
for idx, process in enumerate(self.processes):
process.terminate()
process.join()
new_process = Process(self.config, self.target, self.sockets)
new_process = Process(self.config, self.target, self.sockets, process.process_num)
new_process.start()
self.processes[idx] = new_process

Expand Down Expand Up @@ -174,7 +179,7 @@ def keep_subprocess_alive(self) -> None:
return # pragma: full coverage

logger.info(f"Child process [{process.pid}] died")
process = Process(self.config, self.target, self.sockets)
process = Process(self.config, self.target, self.sockets, process.process_num)
process.start()
self.processes[idx] = process

Expand Down Expand Up @@ -206,8 +211,8 @@ def handle_hup(self) -> None: # pragma: py-win32

def handle_ttin(self) -> None: # pragma: py-win32
logger.info("Received SIGTTIN, increasing the number of processes.")
self.processes_num += 1
process = Process(self.config, self.target, self.sockets)
process_num = self.processes_num
process = Process(self.config, self.target, self.sockets, process_num)
process.start()
self.processes.append(process)

Expand All @@ -216,7 +221,6 @@ def handle_ttou(self) -> None: # pragma: py-win32
if self.processes_num <= 1:
logger.info("Already reached one process, cannot decrease the number of processes anymore.")
return
self.processes_num -= 1
process = self.processes.pop()
process.terminate()
process.join()
Loading