Update the Lavalink version parsing and add tests for it (#6093)

This commit is contained in:
Jakub Kuczys
2023-06-21 15:52:00 +02:00
committed by GitHub
parent 41204ccf77
commit 49bf103891
4 changed files with 139 additions and 27 deletions

View File

@@ -10,6 +10,7 @@ import shlex
import shutil
import tempfile
from typing import ClassVar, Final, List, Optional, Pattern, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self
import aiohttp
import lavalink
@@ -47,9 +48,6 @@ if TYPE_CHECKING:
_ = Translator("Audio", pathlib.Path(__file__))
log = getLogger("red.Audio.manager")
LAVALINK_DOWNLOAD_DIR: Final[pathlib.Path] = data_manager.cog_data_path(raw_name="Audio")
LAVALINK_JAR_FILE: Final[pathlib.Path] = LAVALINK_DOWNLOAD_DIR / "Lavalink.jar"
LAVALINK_APP_YML: Final[pathlib.Path] = LAVALINK_DOWNLOAD_DIR / "application.yml"
_FAILED_TO_START: Final[Pattern] = re.compile(rb"Web server failed to start\. (.*)")
@@ -109,6 +107,9 @@ LAVALINK_VERSION_LINE_PRE35: Final[Pattern] = re.compile(
rb"^Version:\s+(?P<version>\S+)$", re.MULTILINE | re.VERBOSE
)
# used for LL 3.5-rc4 and newer
# This regex is limited to the realistic usage in the LL version number,
# not everything that could be a part of it according to the spec.
# We can easily release an update to this regex in the future if it ever becomes necessary.
LAVALINK_VERSION_LINE: Final[Pattern] = re.compile(
rb"""
^
@@ -117,9 +118,11 @@ LAVALINK_VERSION_LINE: Final[Pattern] = re.compile(
(?P<major>0|[1-9]\d*)\.(?P<minor>0|[1-9]\d*)
# Before LL 3.6, when patch version == 0, it was stripped from the version string
(?:\.(?P<patch>0|[1-9]\d*))?
(?:-rc(?P<rc>0|[1-9]\d*))?
# only used by our downstream Lavalink if we need to make a release before upstream
(?:_red(?P<red>[1-9]\d*))?
# Before LL 3.6, the dot in rc.N was optional
(?:-rc\.?(?P<rc>0|[1-9]\d*))?
# additional build metadata, can be used by our downstream Lavalink
# if we need to alter an upstream release
(?:\+red\.(?P<red>[1-9]\d*))?
)
$
""",
@@ -135,6 +138,19 @@ class LavalinkOldVersion:
def __str__(self) -> None:
return f"{self.raw_version}_{self.build_number}"
@classmethod
def from_version_output(cls, output: bytes) -> Self:
build_match = LAVALINK_BUILD_LINE.search(output)
if build_match is None:
raise ValueError("Could not find Build line in the given `--version` output.")
version_match = LAVALINK_VERSION_LINE_PRE35.search(output)
if version_match is None:
raise ValueError("Could not find Version line in the given `--version` output.")
return cls(
raw_version=version_match["version"].decode(),
build_number=int(build_match["build"]),
)
def __eq__(self, other: object) -> bool:
if isinstance(other, LavalinkOldVersion):
return self.build_number == other.build_number
@@ -195,6 +211,19 @@ class LavalinkVersion:
version += f"_red{self.red}"
return version
@classmethod
def from_version_output(cls, output: bytes) -> Self:
match = LAVALINK_VERSION_LINE.search(output)
if match is None:
raise ValueError("Could not find Version line in the given `--version` output.")
return LavalinkVersion(
major=int(match["major"]),
minor=int(match["minor"]),
patch=int(match["patch"] or 0),
rc=int(match["rc"]) if match["rc"] is not None else None,
red=int(match["red"] or 0),
)
def _get_comparison_tuple(self) -> Tuple[int, int, int, bool, int, int]:
return self.major, self.minor, self.patch, self.rc is None, self.rc or 0, self.red
@@ -265,6 +294,18 @@ class ServerManager:
self._args = []
self._pipe_task = None
@property
def lavalink_download_dir(self) -> pathlib.Path:
return data_manager.cog_data_path(raw_name="Audio")
@property
def lavalink_jar_file(self) -> pathlib.Path:
return self.lavalink_download_dir / "Lavalink.jar"
@property
def lavalink_app_yml(self) -> pathlib.Path:
return self.lavalink_download_dir / "application.yml"
@property
def path(self) -> Optional[str]:
return self._java_exc
@@ -330,7 +371,7 @@ class ServerManager:
self._proc = (
await asyncio.subprocess.create_subprocess_exec( # pylint:disable=no-member
*args,
cwd=str(LAVALINK_DOWNLOAD_DIR),
cwd=str(self.lavalink_download_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
@@ -351,7 +392,7 @@ class ServerManager:
async def process_settings(self):
data = change_dict_naming_convention(await self._config.yaml.all())
with open(LAVALINK_APP_YML, "w") as f:
with open(self.lavalink_app_yml, "w") as f:
yaml.safe_dump(data, f)
async def _get_jar_args(self) -> Tuple[List[str], Optional[str]]:
@@ -392,7 +433,7 @@ class ServerManager:
"please fix this by setting the correct value with '[p]llset heapsize'.",
)
command_args.extend(["-jar", str(LAVALINK_JAR_FILE)])
command_args.extend(["-jar", str(self.lavalink_jar_file)])
self._args = command_args
return command_args, invalid
@@ -522,7 +563,7 @@ class ServerManager:
finally:
file.close()
shutil.move(path, str(LAVALINK_JAR_FILE), copy_function=shutil.copyfile)
shutil.move(path, str(self.lavalink_jar_file), copy_function=shutil.copyfile)
log.info("Successfully downloaded Lavalink.jar (%s bytes written)", format(nbytes, ","))
await self._is_up_to_date()
@@ -535,7 +576,7 @@ class ServerManager:
args.append("--version")
_proc = await asyncio.subprocess.create_subprocess_exec( # pylint:disable=no-member
*args,
cwd=str(LAVALINK_DOWNLOAD_DIR),
cwd=str(self.lavalink_download_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
@@ -554,24 +595,17 @@ class ServerManager:
return False
if (build := LAVALINK_BUILD_LINE.search(stdout)) is not None:
if (version := LAVALINK_VERSION_LINE_PRE35.search(stdout)) is None:
try:
self._lavalink_version = LavalinkOldVersion.from_version_output(stdout)
except ValueError:
# Output is unexpected, suspect corrupted jarfile
return False
self._lavalink_version = LavalinkOldVersion(
raw_version=version["version"].decode(),
build_number=int(build["build"]),
)
elif (version := LAVALINK_VERSION_LINE.search(stdout)) is not None:
self._lavalink_version = LavalinkVersion(
major=int(version["major"]),
minor=int(version["minor"]),
patch=int(version["patch"] or 0),
rc=int(version["rc"]) if version["rc"] is not None else None,
red=int(version["red"] or 0),
)
else:
# Output is unexpected, suspect corrupted jarfile
return False
try:
self._lavalink_version = LavalinkVersion.from_version_output(stdout)
except ValueError:
# Output is unexpected, suspect corrupted jarfile
return False
date = buildtime["build_time"].decode()
date = date.replace(".", "/")
self._lavalink_branch = branch["branch"].decode()
@@ -582,7 +616,7 @@ class ServerManager:
return self._up_to_date
async def maybe_download_jar(self):
if not (LAVALINK_JAR_FILE.exists() and await self._is_up_to_date()):
if not (self.lavalink_jar_file.exists() and await self._is_up_to_date()):
await self._download_jar()
async def wait_until_ready(self, timeout: Optional[float] = None):