Skip to content

Commit

Permalink
add port conflict resolution
Browse files Browse the repository at this point in the history
Signed-off-by: Draper <[email protected]>
  • Loading branch information
Drapersniper committed Apr 2, 2022
1 parent 5d3bf55 commit 3ca042f
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 8 deletions.
1 change: 1 addition & 0 deletions redbot/cogs/audio/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, bot: Red):
self.skip_votes = {}
self.play_lock = {}
self.antispam: Dict[int, Dict[str, AntiSpam]] = defaultdict(lambda: defaultdict(AntiSpam))
self._runtime_external_node = False

self.lavalink_connect_task = None
self._restore_task = None
Expand Down
2 changes: 2 additions & 0 deletions redbot/cogs/audio/core/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class MixinMeta(ABC):
_disconnected_players: MutableMapping[int, bool]
global_api_user: MutableMapping[str, Any]

_runtime_external_node: bool

cog_cleaned_up: bool
lavalink_connection_aborted: bool

Expand Down
2 changes: 2 additions & 0 deletions redbot/cogs/audio/core/commands/audioset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,8 @@ async def command_audioset_settings(self, ctx: commands.Context):
lavalink_version=lavalink.__version__,
use_external_lavalink=_("Enabled")
if global_data["use_external_lavalink"]
else _("Enabled (Temporary)")
if self._runtime_external_node
else _("Disabled"),
)
if (
Expand Down
58 changes: 51 additions & 7 deletions redbot/cogs/audio/core/tasks/lavalink.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from pathlib import Path
import pathlib

import lavalink
import yaml
from red_commons.logging import getLogger

from redbot.core import data_manager
Expand All @@ -11,7 +12,7 @@
from ..cog_utils import CompositeMetaClass

log = getLogger("red.cogs.Audio.cog.Tasks.lavalink")
_ = Translator("Audio", Path(__file__))
_ = Translator("Audio", pathlib.Path(__file__))


class LavalinkTasks(MixinMeta, metaclass=CompositeMetaClass):
Expand Down Expand Up @@ -41,9 +42,8 @@ async def lavalink_attempt_connect(self, timeout: int = 50, manual: bool = False
if self._restore_task:
self._restore_task.cancel()
if self.managed_node_controller is not None:
if not self.managed_node_controller._shutdown:
await self.managed_node_controller.shutdown()
await asyncio.sleep(5)
await self.managed_node_controller.shutdown()
await asyncio.sleep(5)
await lavalink.close(self.bot)
while retry_count < max_retries:
configs = await self.config.all()
Expand All @@ -65,6 +65,44 @@ async def lavalink_attempt_connect(self, timeout: int = 50, manual: bool = False
except asyncio.TimeoutError:
if self.managed_node_controller is not None:
await self.managed_node_controller.shutdown()
if self._runtime_external_node is True:
log.warning("Attempting to connect to existing Lavalink Node.")
self.lavalink_connection_aborted = False
matching_processes = (
await self.managed_node_controller.get_lavalink_process(
lazy_match=True
)
)
log.debug(
"Found %s processes with lavalink in the cmdline.",
len(matching_processes),
)
valid_working_dirs = [
cwd
for d in matching_processes
if d.get("name") == "java" and (cwd := d.get("cwd"))
]
log.debug(
"Found %s java processed with a cwd set.", len(valid_working_dirs)
)
for cwd in valid_working_dirs:
config = pathlib.Path(cwd) / "application.yml"
if config.exists() and config.is_file():
log.debug(
"The following config file exists for an unmanaged Lavalink node %s",
config,
)
try:
with config.open(mode="r") as config_data:
data = yaml.safe_load(config_data)
host = data["server"]["address"]
port = data["server"]["port"]
password = data["lavalink"]["server"]["password"]
break
except Exception:
log.verbose("Failed to read contents of %s", config)
continue
break
if self.lavalink_connection_aborted is not True:
log.critical(
"Managed node startup timeout, aborting managed node startup."
Expand Down Expand Up @@ -117,9 +155,15 @@ async def lavalink_attempt_connect(self, timeout: int = 50, manual: bool = False
return
except asyncio.TimeoutError:
await lavalink.close(self.bot)
log.warning("Connecting to Lavalink node timed out, retrying...")
retry_count += 1
await asyncio.sleep(1) # prevent busylooping
if self._runtime_external_node is True:
log.warning(
"Attempt to connect to existing Lavalink node failed, aborting future reconnects."
)
self.lavalink_connection_aborted = True
return
log.warning("Connecting to Lavalink node timed out, retrying...")
await asyncio.sleep(1)
except Exception as exc:
log.exception(
"Unhandled exception whilst connecting to Lavalink node, aborting...",
Expand Down
4 changes: 4 additions & 0 deletions redbot/cogs/audio/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class ManagedLavalinkStartFailure(ManagedLavalinkNodeException):
"""Exception thrown when a managed Lavalink node fails to start"""


class PortAlreadyInUse(ManagedLavalinkStartFailure):
"""Exception thrown when a managed Lavalink node fails to start due to a port conflict"""


class ManagedLavalinkPreviouslyShutdownException(ManagedLavalinkNodeException):
"""Exception thrown when a managed Lavalink node already has been shutdown"""

Expand Down
57 changes: 56 additions & 1 deletion redbot/cogs/audio/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
IncorrectProcessFound,
NoProcessFound,
NodeUnhealthy,
PortAlreadyInUse,
)
from .utils import (
change_dict_naming_convention,
Expand Down Expand Up @@ -127,6 +128,7 @@ class ServerManager:

def __init__(self, config: Config, cog: "Audio", timeout: Optional[int] = None) -> None:
self.ready: asyncio.Event = asyncio.Event()
self.abort_for_unmanaged: asyncio.Event = asyncio.Event()
self._config = config
self._proc: Optional[asyncio.subprocess.Process] = None # pylint:disable=no-member
self._node_pid: Optional[int] = None
Expand All @@ -135,6 +137,7 @@ def __init__(self, config: Config, cog: "Audio", timeout: Optional[int] = None)
self.timeout = timeout
self.cog = cog
self._args = []
self._current_config = {}

@property
def path(self) -> Optional[str]:
Expand Down Expand Up @@ -218,6 +221,7 @@ async def _start(self, java_path: str) -> None:

async def process_settings(self):
data = change_dict_naming_convention(await self._config.yaml.all())
self._current_config = data
with open(LAVALINK_APP_YML, "w") as f:
yaml.safe_dump(data, f)

Expand Down Expand Up @@ -320,6 +324,14 @@ async def _wait_for_launcher(self) -> None:
log.info("Managed Lavalink node is ready to receive requests.")
break
if _FAILED_TO_START.search(line):
if (
f"Port {self._current_config['server']['port']} was already in use".encode()
in line
):
raise PortAlreadyInUse(
f"Port {self._current_config['server']['port']} already in use. "
f"Managed Lavalink startup aborted."
)
raise ManagedLavalinkStartFailure(
f"Lavalink failed to start: {line.decode().strip()}"
)
Expand All @@ -333,6 +345,7 @@ async def _wait_for_launcher(self) -> None:
async def shutdown(self) -> None:
if self.start_monitor_task is not None:
self.start_monitor_task.cancel()
self.abort_for_unmanaged.clear()
await self._partial_shutdown()

async def _partial_shutdown(self) -> None:
Expand Down Expand Up @@ -441,8 +454,44 @@ async def maybe_download_jar(self):
if not (LAVALINK_JAR_FILE.exists() and await self._is_up_to_date()):
await self._download_jar()

@staticmethod
async def get_lavalink_process(
*matches: str, cwd: Optional[str] = None, lazy_match: bool = False
):
process_list = []
filter = [cwd] if cwd else []
async for proc in AsyncIter(psutil.process_iter()):
try:
if cwd:
if not (proc.cwd() in filter):
continue
cmdline = proc.cmdline()
if (matches and all(a in cmdline for a in matches)) or (
lazy_match and any("lavalink" in arg.lower() for arg in cmdline)
):
proc_as_dict = proc.as_dict(
attrs=["pid", "name", "create_time", "status", "cmdline", "cwd"]
)
process_list.append(proc_as_dict)
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
return process_list

async def wait_until_ready(self, timeout: Optional[float] = None):
await asyncio.wait_for(self.ready.wait(), timeout=timeout or self.timeout)
tasks = [
asyncio.create_task(c) for c in [self.ready.wait(), self.abort_for_unmanaged.wait()]
]
done, pending = await asyncio.wait(
tasks, timeout=timeout or self.timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
if done:
done.pop().result()
if self.abort_for_unmanaged.is_set():
raise asyncio.TimeoutError
if not self.ready.is_set():
raise asyncio.TimeoutError

async def start_monitor(self, java_path: str):
retry_count = 0
Expand Down Expand Up @@ -528,6 +577,12 @@ async def start_monitor(self, java_path: str):
log.critical(exc)
self.cog.lavalink_connection_aborted = True
return await self.shutdown()
except PortAlreadyInUse as exc:
log.critical(exc)
self.cog.lavalink_connection_aborted = False
self.cog._runtime_external_node = True
self.abort_for_unmanaged.set()
return await self.shutdown()
except ManagedLavalinkNodeException as exc:
delay = backoff.delay()
log.critical(
Expand Down

0 comments on commit 3ca042f

Please sign in to comment.