Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major rework of shutdown and reconnect logs to avoid zombie tasks and race conditions #117

Merged
1 change: 1 addition & 0 deletions lavalink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@
"AbortingNodeConnection",
"NodeNotReady",
"PlayerNotFound",
"wait_until_ready",
]
27 changes: 24 additions & 3 deletions lavalink/lavalink.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import discord
from discord.ext.commands import Bot

from . import enums, log, node, player_manager, utils, errors

from . import enums, log, node, player_manager, errors

__all__ = [
"initialize",
Expand All @@ -21,6 +20,7 @@
"all_players",
"all_connected_players",
"active_players",
"wait_until_ready",
]

_event_listeners = []
Expand Down Expand Up @@ -344,7 +344,7 @@ def dispatch(op: enums.LavalinkIncomingOp, data, raw_data: dict):
return

for coro in listeners:
_loop.create_task(coro(*args)).add_done_callback(utils.task_callback_trace)
_loop.create_task(coro(*args))


async def close(bot):
Expand Down Expand Up @@ -378,3 +378,24 @@ def all_connected_players() -> Tuple[player_manager.Player]:
def active_players() -> Tuple[player_manager.Player]:
ps = all_connected_players()
return tuple(p for p in ps if p.is_playing)


async def wait_until_ready(timeout: Optional[float] = None, wait_if_no_node: Optional[int] = None):
if wait_if_no_node:
for iteration in range(0, abs(wait_if_no_node), 1):
if not node._nodes:
await asyncio.sleep(1)
else:
break
if not node._nodes:
raise asyncio.TimeoutError
for result in await asyncio.gather(
*(
node_.wait_until_ready(timeout)
for node_ in node._nodes
if (not node_.ready) and not node_.state == node.NodeState.DISCONNECTING
),
return_exceptions=True,
):
if result is not None:
raise result
130 changes: 90 additions & 40 deletions lavalink/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
from discord.ext.commands import Bot

from . import __version__, ws_discord_log, ws_ll_log
from .enums import *
from .enums import (
LavalinkIncomingOp,
NodeState,
LavalinkEvents,
LavalinkOutgoingOp,
DiscordVoiceSocketResponses,
)
from .player_manager import PlayerManager
from .rest_api import Track
from .utils import task_callback_exception, task_callback_debug, task_callback_trace
from .errors import NodeNotReady, NodeNotFound
from .errors import AbortingNodeConnection, NodeNotReady, NodeNotFound

__all__ = [
"Stats",
Expand Down Expand Up @@ -172,6 +177,8 @@ def __init__(
self._ws = None
self._listener_task = None
self.session = aiohttp.ClientSession()
self.reconnect_task = None
self.try_connect_task = None

self._queue: List = []

Expand Down Expand Up @@ -216,7 +223,7 @@ def _gen_key(self):
self._resume_key.__repr__()
return self._resume_key

async def connect(self, timeout=None):
async def connect(self, timeout=None, shutdown=False):
"""
Connects to the Lavalink player event websocket.

Expand All @@ -226,34 +233,29 @@ async def connect(self, timeout=None):
Time after which to timeout on attempting to connect to the Lavalink websocket,
``None`` is considered never, but the underlying code may stop trying past a
certain point.

shutdown: bool
Whether the node was told to shut down
Raises
------
asyncio.TimeoutError
If the websocket failed to connect after the given time.
AbortingConnectionException:
Drapersniper marked this conversation as resolved.
Show resolved Hide resolved
If the connection attempt must be aborted during a reconnect attempt
"""
self._is_shutdown = False
self._is_shutdown = shutdown
if self.secured:
uri = f"wss://{self.host}:{self.port}"
else:
uri = f"ws://{self.host}:{self.port}"

ws_ll_log.info("Lavalink WS connecting to %s with headers %s", uri, self.headers)

await asyncio.wait_for(self._multi_try_connect(uri), timeout)

ws_ll_log.debug("Creating Lavalink WS listener.")
if self._listener_task is not None:
self._listener_task.cancel()
self._listener_task = self.loop.create_task(self.listener())
self._listener_task.add_done_callback(task_callback_exception)
self.loop.create_task(self._configure_resume()).add_done_callback(task_callback_debug)
if self._queue:
for data in self._queue:
await self.send(data)
self._queue.clear()
self._ready_event.set()
self.update_state(NodeState.READY)
if self.try_connect_task is not None:
self.try_connect_task.cancel()
self.try_connect_task = asyncio.create_task(self._multi_try_connect(uri))
try:
await asyncio.wait_for(self.try_connect_task, timeout=timeout)
except asyncio.CancelledError:
raise AbortingNodeConnection

async def _configure_resume(self):
if self._resuming_configured:
Expand Down Expand Up @@ -301,11 +303,16 @@ def ready(self) -> bool:
async def _multi_try_connect(self, uri):
backoff = ExponentialBackoff()
attempt = 1
if self._listener_task is not None:
self._listener_task.cancel()
if self._ws is not None:
await self._ws.close(code=4006, message=b"Reconnecting")

while self._is_shutdown is False and (self._ws is None or self._ws.closed):
self._retries += 1
if self._is_shutdown is True:
ws_ll_log.error("Lavalink node was shutdown during a connect attempt.")
raise asyncio.CancelledError
try:
ws = await self.session.ws_connect(url=uri, headers=self.headers, heartbeat=60)
except (OSError, aiohttp.ClientConnectionError):
Expand All @@ -319,11 +326,28 @@ async def _multi_try_connect(self, uri):
ws_ll_log.error("Failed connect WSServerHandshakeError")
raise asyncio.TimeoutError
else:
if self._is_shutdown is True:
ws_ll_log.error("Lavalink node was shutdown during a connect attempt.")
raise asyncio.CancelledError
self.session_resumed = ws._response.headers.get("Session-Resumed", False)
if self._ws is not None and self.session_resumed:
ws_ll_log.info("WEBSOCKET Resumed Session with key: %s", self._resume_key)
self._ws = ws
break
if self._is_shutdown is True:
raise asyncio.CancelledError
ws_ll_log.info("Lavalink WS connected to %s", uri)
ws_ll_log.debug("Creating Lavalink WS listener.")
if self._is_shutdown is False:
self._listener_task = self.loop.create_task(self.listener())
self.loop.create_task(self._configure_resume())
if self._queue:
temp = self._queue.copy()
self._queue.clear()
for data in temp:
await self.send(data)
self._ready_event.set()
self.update_state(NodeState.READY)

async def listener(self):
"""
Expand All @@ -334,10 +358,12 @@ async def listener(self):
if msg.type in self._closers:
if self._resuming_configured:
if self.state != NodeState.RECONNECTING:
if self.reconnect_task is not None:
self.reconnect_task.cancel()
ws_ll_log.info("[NODE] | NODE Resuming: %s", msg.extra)
self.update_state(NodeState.RECONNECTING)
self.loop.create_task(self._reconnect()).add_done_callback(
task_callback_debug
self.reconnect_task = self.loop.create_task(
self._reconnect(self._is_shutdown)
)
return
else:
Expand All @@ -351,9 +377,7 @@ async def listener(self):
ws_ll_log.verbose("[NODE] | Received unknown op: %s", data)
else:
ws_ll_log.trace("[NODE] | Received known op: %s", data)
self.loop.create_task(self._handle_op(op, data)).add_done_callback(
task_callback_trace
)
self.loop.create_task(self._handle_op(op, data))
elif msg.type == aiohttp.WSMsgType.ERROR:
exc = self._ws.exception()
ws_ll_log.warning(
Expand All @@ -367,12 +391,14 @@ async def listener(self):
msg.type,
msg.data,
)
if self.state != NodeState.RECONNECTING:
if self.state != NodeState.RECONNECTING and not self._is_shutdown:
ws_ll_log.warning(
"[NODE] | %s - WS %s SHUTDOWN %s.", self, not self._ws.closed, self._is_shutdown
)
if self.reconnect_task is not None:
self.reconnect_task.cancel()
self.update_state(NodeState.RECONNECTING)
self.loop.create_task(self._reconnect()).add_done_callback(task_callback_debug)
self.reconnect_task = self.loop.create_task(self._reconnect(self._is_shutdown))

async def _handle_op(self, op: LavalinkIncomingOp, data):
if op == LavalinkIncomingOp.EVENT:
Expand Down Expand Up @@ -403,10 +429,10 @@ async def _handle_op(self, op: LavalinkIncomingOp, data):
else:
ws_ll_log.verbose("Unknown op type: %r", data)

async def _reconnect(self):
async def _reconnect(self, shutdown: bool = False):
self._ready_event.clear()

if self._is_shutdown is True:
if self._is_shutdown is True or shutdown:
ws_ll_log.info("[NODE] | Shutting down Lavalink WS.")
return
if self.state != NodeState.CONNECTING:
Expand All @@ -417,16 +443,22 @@ async def _reconnect(self):
attempt = 1
while self.state == NodeState.RECONNECTING:
attempt += 1
if attempt > 10:
ws_ll_log.info("[NODE] | Failed reconnection attempt too many times, aborting ...")
asyncio.create_task(self.disconnect())
return
try:
await self.connect()
await self.connect(shutdown=shutdown)
except AbortingNodeConnection:
return
except asyncio.TimeoutError:
delay = backoff.delay()
ws_ll_log.warning(
"[NODE] | Lavalink WS reconnect connect attempt %s, retrying in %s",
"[NODE] | Lavalink WS reconnect attempt %s, retrying in %s",
attempt,
delay,
)

await asyncio.sleep(delay)
else:
ws_ll_log.info("[NODE] | Reconnect successful.")
self.dispatch_reconnect()
Expand Down Expand Up @@ -457,9 +489,7 @@ def update_state(self, next_state: NodeState):
ws_ll_log.debug("Event loop closed, not notifying state handlers.")
return
for handler in self._state_handlers:
self.loop.create_task(handler(next_state, old_state)).add_done_callback(
task_callback_trace
)
self.loop.create_task(handler(next_state, old_state))

def register_state_handler(self, func):
if not asyncio.iscoroutinefunction(func):
Expand All @@ -482,27 +512,47 @@ async def disconnect(self):
"""
Shuts down and disconnects the websocket.
"""
global _nodes
self._is_shutdown = True
self._ready_event.clear()
self._queue.clear()
if (
self.try_connect_task is not None
and not self.try_connect_task.cancelled()
and not self.loop.is_closed()
):
self.try_connect_task.cancel()
if (
self.reconnect_task is not None
and not self.reconnect_task.cancelled()
and not self.loop.is_closed()
):
self.reconnect_task.cancel()

self.update_state(NodeState.DISCONNECTING)

if self._resuming_configured:
if self._resuming_configured and not (self._ws is None or self._ws.closed):
await self.send(dict(op="configureResuming", key=None))
self._resuming_configured = False
await self.player_manager.disconnect()

if self._ws is not None and not self._ws.closed:
await self._ws.close()

if self._listener_task is not None and not self.loop.is_closed():
if (
self._listener_task is not None
and not self._listener_task.cancelled()
and not self.loop.is_closed()
):
self._listener_task.cancel()

await self.session.close()

self._state_handlers = []

_nodes.remove(self)
if len(_nodes) == 1:
_nodes = []
elif len(_nodes) > 1:
_nodes.remove(self)
ws_ll_log.info("Shutdown Lavalink WS.")

async def send(self, data):
Expand Down
3 changes: 2 additions & 1 deletion lavalink/player_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ async def disconnect(self, requested=True):
Disconnects this player from it's voice channel.
"""
self._is_autoplaying = False
self._is_playing = False
self._auto_play_sent = False
self._connected = False
if self.state == PlayerState.DISCONNECTING:
Expand Down Expand Up @@ -261,7 +262,6 @@ async def handle_event(self, event: "node.LavalinkEvents", extra):
log.trace("Received player event for player: %r - %r - %r.", self, event, extra)

if event == LavalinkEvents.TRACK_END:
self._is_playing = False
if extra == TrackEndReason.FINISHED:
await self.play()
elif event == LavalinkEvents.WEBSOCKET_CLOSED:
Expand Down Expand Up @@ -352,6 +352,7 @@ async def resume(
self._is_playing = False
self._paused = True
await self.node.play(self.guild.id, track, start=start, replace=replace, pause=True)
await self.set_volume(self.volume)
await self.pause(True)
await self.pause(pause, timed=1)

Expand Down
32 changes: 1 addition & 31 deletions lavalink/utils.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,6 @@
import asyncio
import contextlib

from .log import log


def format_time(time):
"""Formats the given time into HH:MM:SS"""
h, r = divmod(time / 1000, 3600)
m, s = divmod(r, 60)

return "%02d:%02d:%02d" % (h, m, s)


def task_callback_exception(task: asyncio.Task) -> None:
with contextlib.suppress(asyncio.CancelledError, asyncio.InvalidStateError):
if exc := task.exception():
log.exception("%s raised an Exception", task.get_name(), exc_info=exc)


def task_callback_debug(task: asyncio.Task) -> None:
with contextlib.suppress(asyncio.CancelledError, asyncio.InvalidStateError):
if exc := task.exception():
log.debug("%s raised an Exception", task.get_name(), exc_info=exc)


def task_callback_verbose(task: asyncio.Task) -> None:
with contextlib.suppress(asyncio.CancelledError, asyncio.InvalidStateError):
if exc := task.exception():
log.verbose("%s raised an Exception", task.get_name(), exc_info=exc)


def task_callback_trace(task: asyncio.Task) -> None:
with contextlib.suppress(asyncio.CancelledError, asyncio.InvalidStateError):
if exc := task.exception():
log.trace("%s raised an Exception", task.get_name(), exc_info=exc)
return f"{h:02d}:{m:02d}:{s:02d}"