mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2026-01-30 20:13:00 -05:00
Another tiny PR :awesome:
This commit is contained in:
1
redbot/core/apis/audio/nodes/__init__.py
Normal file
1
redbot/core/apis/audio/nodes/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .events import AudioAPIEvents
|
||||
45
redbot/core/apis/audio/nodes/connection.py
Normal file
45
redbot/core/apis/audio/nodes/connection.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import typing
|
||||
|
||||
from .managed import managed_lavalink_connect_task_event
|
||||
from .. import constants
|
||||
from .. import config
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from redbot.core.bot import Red
|
||||
|
||||
__all__ = ["start_nodes"]
|
||||
|
||||
|
||||
async def start_nodes(bot: Red, identifier: typing.Optional[str] = None) -> None:
|
||||
"""Connect and initiate nodes."""
|
||||
await bot.wait_until_ready()
|
||||
if identifier:
|
||||
if bot.wavelink.nodes:
|
||||
previous = bot.wavelink.nodes.copy()
|
||||
if node := previous.get(identifier):
|
||||
await node.destroy()
|
||||
async with config._config.nodes.all() as node_data:
|
||||
if identifier in node_data:
|
||||
node_copy = copy.copy(node_data[identifier])
|
||||
elif identifier in constants.DEFAULT_COG_LAVALINK_SETTINGS:
|
||||
node_copy = copy.copy(constants.DEFAULT_COG_LAVALINK_SETTINGS[identifier])
|
||||
else:
|
||||
return
|
||||
node_copy["region"] = bot.wavelink.get_valid_region(node_copy["region"])
|
||||
await bot.wavelink.initiate_node(**node_copy)
|
||||
else:
|
||||
if bot.wavelink.nodes:
|
||||
previous = bot.wavelink.nodes.copy()
|
||||
for node in previous.values():
|
||||
await node.destroy()
|
||||
use_managed_lavalink = await config.config_cache.managed_lavalink_server.get_global()
|
||||
if use_managed_lavalink:
|
||||
await managed_lavalink_attempt_connect(timeout=120)
|
||||
await managed_lavalink_connect_task_event.wait()
|
||||
nodes = await config._config.nodes()
|
||||
for n in nodes.values():
|
||||
n["region"] = bot.wavelink.get_valid_region(n["region"])
|
||||
await bot.wavelink.initiate_node(**n)
|
||||
120
redbot/core/apis/audio/nodes/events.py
Normal file
120
redbot/core/apis/audio/nodes/events.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import wavelink
|
||||
|
||||
from redbot.core import commands
|
||||
from redbot.core.bot import Red
|
||||
from ..wavelink.events import QueueEnd
|
||||
from ..wavelink.overwrites import RedNode
|
||||
|
||||
log = logging.getLogger("red.core.apis.audio.nodes")
|
||||
|
||||
__all__ = ["AudioAPIEvents"]
|
||||
|
||||
|
||||
class AudioAPIEvents(commands.Cog, wavelink.WavelinkMixin):
|
||||
def __init__(self, bot: Red):
|
||||
self.bot = bot
|
||||
|
||||
@wavelink.WavelinkMixin.listener()
|
||||
async def on_wavelink_error(self, listener, error: Exception):
|
||||
"""Event dispatched when an error is raised during mixin listener dispatch.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
listener:
|
||||
The listener where an exception was raised.
|
||||
error: Exception
|
||||
The exception raised when dispatching a mixin listener.
|
||||
"""
|
||||
log.warning(f"Ignoring exception in listener {listener}")
|
||||
traceback.print_exception(type(error), error, error.__traceback__, file=sys.stderr)
|
||||
|
||||
@wavelink.WavelinkMixin.listener()
|
||||
async def on_node_ready(self, node: RedNode):
|
||||
"""Listener dispatched when a :class:`wavelink.node.Node` is connected and ready.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
node: Node
|
||||
The node associated with the listener event.
|
||||
"""
|
||||
|
||||
@wavelink.WavelinkMixin.listener()
|
||||
async def on_track_start(self, node: RedNode, payload: wavelink.TrackStart):
|
||||
"""Listener dispatched when a track starts.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
node: Node
|
||||
The node associated with the listener event.
|
||||
payload: TrackStart
|
||||
The :class:`wavelink.events.TrackStart` payload.
|
||||
"""
|
||||
|
||||
@wavelink.WavelinkMixin.listener()
|
||||
async def on_track_end(self, node: RedNode, payload: wavelink.TrackEnd):
|
||||
"""Listener dispatched when a track ends.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
node: Node
|
||||
The node associated with the listener event.
|
||||
payload: TrackEnd
|
||||
The :class:`wavelink.events.TrackEnd` payload.
|
||||
"""
|
||||
await payload.player.do_next()
|
||||
|
||||
@wavelink.WavelinkMixin.listener()
|
||||
async def on_track_stuck(self, node: RedNode, payload: wavelink.TrackStuck):
|
||||
"""Listener dispatched when a track is stuck.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
node: Node
|
||||
The node associated with the listener event.
|
||||
payload: TrackStuck
|
||||
The :class:`wavelink.events.TrackStuck` payload.
|
||||
"""
|
||||
await payload.player.do_next()
|
||||
|
||||
@wavelink.WavelinkMixin.listener()
|
||||
async def on_track_exception(self, node: RedNode, payload: wavelink.TrackException):
|
||||
"""Listener dispatched when a track errors.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
node: Node
|
||||
The node associated with the listener event.
|
||||
payload: TrackException
|
||||
The :class:`wavelink.events.TrackException` payload.
|
||||
"""
|
||||
await payload.player.do_next()
|
||||
|
||||
@wavelink.WavelinkMixin.listener()
|
||||
async def on_websocket_closed(self, node: RedNode, payload: wavelink.WebsocketClosed):
|
||||
"""Listener dispatched when a node websocket is closed by lavalink.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
node: Node
|
||||
The node associated with the listener event.
|
||||
payload: WebsocketClosed
|
||||
The :class:`wavelink.events.WebsocketClosed` payload.
|
||||
"""
|
||||
|
||||
@wavelink.WavelinkMixin.listener()
|
||||
async def on_queue_end(self, node: RedNode, payload: QueueEnd):
|
||||
"""Listener dispatched when a player queue ends.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
node: Node
|
||||
The node associated with the listener event.
|
||||
payload: QueueEnd
|
||||
The :class:`QueueEnd` payload.
|
||||
"""
|
||||
368
redbot/core/apis/audio/nodes/managed.py
Normal file
368
redbot/core/apis/audio/nodes/managed.py
Normal file
@@ -0,0 +1,368 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import itertools
|
||||
import logging
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from typing import ClassVar, Optional, Tuple, List
|
||||
|
||||
import aiohttp
|
||||
from tqdm import tqdm
|
||||
|
||||
from redbot.core import Config
|
||||
from .. import constants, regex, errors
|
||||
|
||||
__all__ = [
|
||||
"managed_lavalink_connect_task_event",
|
||||
"get_latest_lavalink_release",
|
||||
"LavalinkServerManager",
|
||||
]
|
||||
|
||||
log = logging.getLogger("red.core.apis.audio.nodes.managed")
|
||||
|
||||
|
||||
managed_lavalink_connect_task_event = asyncio.Event()
|
||||
|
||||
|
||||
async def get_latest_lavalink_release():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(constants.LAVALINK_JAR_ENDPOINT) as result:
|
||||
data = await result.json()
|
||||
return (
|
||||
data.get("name"),
|
||||
data.get("tag_name"),
|
||||
next(
|
||||
(
|
||||
i.get("browser_download_url")
|
||||
for i in data.get("assets", [])
|
||||
if i.get("name") == "Lavalink.jar"
|
||||
),
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LavalinkServerManager:
|
||||
|
||||
_java_available: ClassVar[Optional[bool]] = None
|
||||
_java_version: ClassVar[Optional[Tuple[int, int]]] = None
|
||||
_up_to_date: ClassVar[Optional[bool]] = None
|
||||
_blacklisted_archs: List[str] = []
|
||||
_jar_build: ClassVar[int] = constants.JAR_BUILD
|
||||
_jar_version: ClassVar[str] = constants.JAR_VERSION
|
||||
_jar_name: ClassVar[str] = f"{constants.JAR_VERSION}_{constants.JAR_BUILD}"
|
||||
_jar_download_url: ClassVar[str] = constants.LAVALINK_DOWNLOAD_URL
|
||||
|
||||
_lavaplayer: ClassVar[Optional[str]] = None
|
||||
_lavalink_build: ClassVar[Optional[int]] = None
|
||||
_jvm: ClassVar[Optional[str]] = None
|
||||
_lavalink_branch: ClassVar[Optional[str]] = None
|
||||
_buildtime: ClassVar[Optional[str]] = None
|
||||
|
||||
_java_exc: ClassVar[str] = "java"
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.ready: asyncio.Event = asyncio.Event()
|
||||
self._proc: Optional[asyncio.subprocess.Process] = None # pylint:disable=no-member
|
||||
self._monitor_task: Optional[asyncio.Task] = None
|
||||
self._shutdown: bool = False
|
||||
self._config: Config = config
|
||||
|
||||
@property
|
||||
def jvm(self) -> Optional[str]:
|
||||
return self._jvm
|
||||
|
||||
@property
|
||||
def lavaplayer(self) -> Optional[str]:
|
||||
return self._lavaplayer
|
||||
|
||||
@property
|
||||
def ll_build(self) -> Optional[int]:
|
||||
return self._lavalink_build
|
||||
|
||||
@property
|
||||
def ll_branch(self) -> Optional[str]:
|
||||
return self._lavalink_branch
|
||||
|
||||
@property
|
||||
def build_time(self) -> Optional[str]:
|
||||
return self._buildtime
|
||||
|
||||
async def start(self, java_path: str) -> None:
|
||||
arch_name = platform.machine()
|
||||
self._java_exc = java_path
|
||||
if arch_name in self._blacklisted_archs:
|
||||
raise asyncio.CancelledError(
|
||||
"You are attempting to run Lavalink audio on an unsupported machine architecture."
|
||||
)
|
||||
|
||||
if (jar_url := await self._config.lavalink.jar_url()) is not None:
|
||||
self._jar_name = jar_url
|
||||
self._jar_download_url = jar_url
|
||||
self._jar_build = await self._config.lavalink.jar_build() or self._jar_build
|
||||
else:
|
||||
if await self._config.lavalink.autoupdate():
|
||||
with contextlib.suppress(Exception):
|
||||
name, tag, url = await get_latest_lavalink_release()
|
||||
if name and "_" in name:
|
||||
tag = name
|
||||
version, build = name.split("_")
|
||||
build = int(build)
|
||||
elif tag and "_" in tag:
|
||||
name = tag
|
||||
version, build = name.split("_")
|
||||
build = int(build)
|
||||
else:
|
||||
name = tag = version = build = None
|
||||
self._jar_name = name or tag or self._jar_name
|
||||
self._jar_download_url = url or self._jar_download_url
|
||||
self._jar_build = build or self._jar_build
|
||||
self._jar_version = version or self._jar_version
|
||||
|
||||
if self._proc is not None:
|
||||
if self._proc.returncode is None:
|
||||
raise RuntimeError("Internal Lavalink server is already running")
|
||||
elif self._shutdown:
|
||||
raise RuntimeError("Server manager has already been used - create another one")
|
||||
|
||||
await self.maybe_download_jar()
|
||||
|
||||
# Copy the application.yml across.
|
||||
# For people to customise their Lavalink server configuration they need to run it
|
||||
# externally
|
||||
shutil.copyfile(str(constants.BUNDLED_APP_YML), str(constants.LAVALINK_APP_YML))
|
||||
|
||||
args = await self._get_jar_args()
|
||||
self._proc = await asyncio.subprocess.create_subprocess_exec( # pylint:disable=no-member
|
||||
*args,
|
||||
cwd=str(constants.LAVALINK_DOWNLOAD_DIR),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
|
||||
log.info("Internal Lavalink server started. PID: %s", self._proc.pid)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._wait_for_launcher(), timeout=120)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Timeout occurred whilst waiting for internal Lavalink server to be ready")
|
||||
|
||||
self._monitor_task = asyncio.create_task(self._monitor())
|
||||
|
||||
async def _get_jar_args(self) -> List[str]:
|
||||
(java_available, java_version) = await self._has_java()
|
||||
if java_version is None:
|
||||
raise RuntimeError(
|
||||
f"`{self._java_exc}` is not a valid java executable in your machine."
|
||||
)
|
||||
if not java_available:
|
||||
raise RuntimeError("You must install Java 11+ for Lavalink to run.")
|
||||
|
||||
if java_version >= (14, 0):
|
||||
raise errors.UnsupportedJavaVersion(version=java_version)
|
||||
elif java_version >= (13, 0):
|
||||
extra_flags = []
|
||||
elif java_version >= (12, 0):
|
||||
raise errors.UnsupportedJavaVersion(version=java_version)
|
||||
elif java_version >= (11, 0):
|
||||
extra_flags = ["-Djdk.tls.client.protocols=TLSv1.2"]
|
||||
else:
|
||||
raise errors.UnsupportedJavaVersion(version=java_version)
|
||||
|
||||
return [self._java_exc, *extra_flags, "-jar", str(constants.LAVALINK_JAR_FILE)]
|
||||
|
||||
async def _has_java(self) -> Tuple[bool, Optional[Tuple[int, int]]]:
|
||||
if self._java_available is not None:
|
||||
# Return cached value if we've checked this before
|
||||
return self._java_available, self._java_version
|
||||
java_available = shutil.which(self._java_exc) is not None
|
||||
if not java_available:
|
||||
self.java_available = False
|
||||
self.java_version = None
|
||||
else:
|
||||
self._java_version = version = await self._get_java_version()
|
||||
self._java_available = (11, 0) <= version < (12, 0) or (13, 0) <= version < (14, 0)
|
||||
return self._java_available, self._java_version
|
||||
|
||||
async def _get_java_version(self) -> Tuple[int, int]:
|
||||
"""This assumes we've already checked that java exists."""
|
||||
_proc: asyncio.subprocess.Process = await asyncio.create_subprocess_exec( # pylint:disable=no-member
|
||||
self._java_exc,
|
||||
"-version",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
# java -version outputs to stderr
|
||||
_, err = await _proc.communicate()
|
||||
|
||||
version_info: str = err.decode("utf-8")
|
||||
# We expect the output to look something like:
|
||||
# $ java -version
|
||||
# ...
|
||||
# ... version "MAJOR.MINOR.PATCH[_BUILD]" ...
|
||||
# ...
|
||||
# We only care about the major and minor parts though.
|
||||
|
||||
lines = version_info.splitlines()
|
||||
for line in lines:
|
||||
match = regex.JAVA_VERSION_LINE.search(line)
|
||||
short_match = regex.JAVA_SHORT_VERSION.search(line)
|
||||
if match:
|
||||
return int(match["major"]), int(match["minor"])
|
||||
elif short_match:
|
||||
return int(short_match["major"]), 0
|
||||
|
||||
raise RuntimeError(f"The output of `{self._java_exc} -version` was unexpected.")
|
||||
|
||||
async def _wait_for_launcher(self) -> None:
|
||||
log.debug("Waiting for Lavalink server to be ready")
|
||||
lastmessage = 0
|
||||
for i in itertools.cycle(range(50)):
|
||||
line = await self._proc.stdout.readline()
|
||||
if regex.LAVALINK_READY_LINE.search(line):
|
||||
self.ready.set()
|
||||
break
|
||||
if regex.LAVALINK_READY_LINE.search(line):
|
||||
raise RuntimeError(f"Lavalink failed to start: {line.decode().strip()}")
|
||||
if self._proc.returncode is not None and lastmessage + 5 < time.perf_counter():
|
||||
# Avoid Console spam only print once every 2 seconds
|
||||
lastmessage = time.perf_counter()
|
||||
log.critical("Internal lavalink server exited early")
|
||||
if i == 49:
|
||||
# Sleep after 50 lines to prevent busylooping
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _monitor(self) -> None:
|
||||
while self._proc.returncode is None:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# This task hasn't been cancelled - Lavalink was shut down by something else
|
||||
log.warning("Internal Lavalink jar shutdown unexpectedly")
|
||||
if not self._has_java_error():
|
||||
log.info("Restarting internal Lavalink server")
|
||||
await self.start(self._java_exc)
|
||||
else:
|
||||
log.critical(
|
||||
"Your Java is borked. Please find the hs_err_pid%d.log file"
|
||||
" in the Audio data folder and report this issue.",
|
||||
self._proc.pid,
|
||||
)
|
||||
|
||||
def _has_java_error(self) -> bool:
|
||||
poss_error_file = constants.LAVALINK_DOWNLOAD_DIR / "hs_err_pid{}.log".format(
|
||||
self._proc.pid
|
||||
)
|
||||
return poss_error_file.exists()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._shutdown is True or self._proc is None:
|
||||
# For convenience, calling this method more than once or calling it before starting it
|
||||
# does nothing.
|
||||
return
|
||||
log.info("Shutting down internal Lavalink server")
|
||||
if self._monitor_task is not None:
|
||||
self._monitor_task.cancel()
|
||||
self._proc.terminate()
|
||||
await self._proc.wait()
|
||||
self._shutdown = True
|
||||
|
||||
async def _download_jar(self) -> None:
|
||||
log.info("Downloading Lavalink.jar...")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self._jar_download_url) as response:
|
||||
if response.status == 404:
|
||||
# A 404 means our LAVALINK_DOWNLOAD_URL is invalid, so likely the jar version
|
||||
# hasn't been published yet
|
||||
raise errors.LavalinkDownloadFailed(
|
||||
f"Lavalink jar version {self._jar_name} hasn't been published " f"yet",
|
||||
response=response,
|
||||
should_retry=False,
|
||||
)
|
||||
elif 400 <= response.status < 600:
|
||||
# Other bad responses should be raised but we should retry just incase
|
||||
raise errors.LavalinkDownloadFailed(response=response, should_retry=True)
|
||||
fd, path = tempfile.mkstemp()
|
||||
file = open(fd, "wb")
|
||||
nbytes = 0
|
||||
with tqdm(
|
||||
desc="Lavalink.jar",
|
||||
total=response.content_length,
|
||||
file=sys.stdout,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
miniters=1,
|
||||
dynamic_ncols=True,
|
||||
leave=False,
|
||||
) as progress_bar:
|
||||
try:
|
||||
chunk = await response.content.read(1024)
|
||||
while chunk:
|
||||
chunk_size = file.write(chunk)
|
||||
nbytes += chunk_size
|
||||
progress_bar.update(chunk_size)
|
||||
chunk = await response.content.read(1024)
|
||||
file.flush()
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
shutil.move(path, str(constants.LAVALINK_JAR_FILE), copy_function=shutil.copyfile)
|
||||
|
||||
log.info("Successfully downloaded Lavalink.jar (%s bytes written)", format(nbytes, ","))
|
||||
await self._is_up_to_date()
|
||||
|
||||
async def _is_up_to_date(self):
|
||||
if self._up_to_date is True:
|
||||
# Return cached value if we've checked this before
|
||||
return True
|
||||
args = await self._get_jar_args()
|
||||
args.append("--version")
|
||||
_proc = await asyncio.subprocess.create_subprocess_exec( # pylint:disable=no-member
|
||||
*args,
|
||||
cwd=str(constants.LAVALINK_DOWNLOAD_DIR),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout = (await _proc.communicate())[0]
|
||||
if (build := regex.LAVALINK_BUILD_LINE.search(stdout)) is None:
|
||||
# Output is unexpected, suspect corrupted jarfile
|
||||
return False
|
||||
if (branch := regex.LAVALINK_BRANCH_LINE.search(stdout)) is None:
|
||||
# Output is unexpected, suspect corrupted jarfile
|
||||
return False
|
||||
if (java := regex.LAVALINK_JAVA_LINE.search(stdout)) is None:
|
||||
# Output is unexpected, suspect corrupted jarfile
|
||||
return False
|
||||
if (lavaplayer := regex.LAVALINK_LAVAPLAYER_LINE.search(stdout)) is None:
|
||||
# Output is unexpected, suspect corrupted jarfile
|
||||
return False
|
||||
if (buildtime := regex.LAVALINK_BUILD_TIME_LINE.search(stdout)) is None:
|
||||
# Output is unexpected, suspect corrupted jarfile
|
||||
return False
|
||||
|
||||
build = int(build["build"])
|
||||
date = buildtime["build_time"].decode()
|
||||
date = date.replace(".", "/")
|
||||
self._lavalink_build = build
|
||||
self._lavalink_branch = branch["branch"].decode()
|
||||
self._jvm = java["jvm"].decode()
|
||||
self._lavaplayer = lavaplayer["lavaplayer"].decode()
|
||||
self._buildtime = date
|
||||
|
||||
self._up_to_date = build >= self._jar_build
|
||||
return self._up_to_date
|
||||
|
||||
async def maybe_download_jar(self):
|
||||
if not (constants.LAVALINK_JAR_FILE.exists() and await self._is_up_to_date()):
|
||||
await self._download_jar()
|
||||
if not await self._is_up_to_date():
|
||||
raise errors.LavalinkDownloadFailed(
|
||||
f"Download of Lavalink build {self.ll_build} from {self.ll_branch} "
|
||||
f"({self._jar_download_url}) failed, Excepted build {self._jar_build} "
|
||||
f"But downloaded {self._lavalink_build}",
|
||||
should_retry=False,
|
||||
)
|
||||
Reference in New Issue
Block a user