[Config] Asynchronous getters (#907)

* Make config get async

* Asyncify alias

* Asyncify bank

* Asyncify cog manager

* IT BOOTS

* Asyncify core commands

* Asyncify repo manager

* Asyncify downloader

* Asyncify economy

* Asyncify alias TESTS

* Asyncify economy TESTS

* Asyncify downloader TESTS

* Asyncify config TESTS

* A bank thing

* Asyncify Bank cog

* Warning message in docs

* Update docs with await syntax

* Update docs with await syntax
This commit is contained in:
Will
2017-08-11 21:43:21 -04:00
committed by GitHub
parent cf8e11238c
commit de912a3cfb
18 changed files with 371 additions and 296 deletions

View File

@@ -1,6 +1,6 @@
import datetime
from collections import namedtuple
from typing import Tuple, Generator, Union
from typing import Tuple, Generator, Union, List
import discord
from copy import deepcopy
@@ -78,17 +78,17 @@ def _decode_time(time: int) -> datetime.datetime:
return datetime.datetime.utcfromtimestamp(time)
def get_balance(member: discord.Member) -> int:
async def get_balance(member: discord.Member) -> int:
"""
Gets the current balance of a member.
:param member:
:return:
"""
acc = get_account(member)
acc = await get_account(member)
return acc.balance
def can_spend(member: discord.Member, amount: int) -> bool:
async def can_spend(member: discord.Member, amount: int) -> bool:
"""
Determines if a member can spend the given amount.
:param member:
@@ -97,7 +97,7 @@ def can_spend(member: discord.Member, amount: int) -> bool:
"""
if _invalid_amount(amount):
return False
return get_balance(member) > amount
return await get_balance(member) > amount
async def set_balance(member: discord.Member, amount: int) -> int:
@@ -111,17 +111,17 @@ async def set_balance(member: discord.Member, amount: int) -> int:
"""
if amount < 0:
raise ValueError("Not allowed to have negative balance.")
if is_global():
if await is_global():
group = _conf.user(member)
else:
group = _conf.member(member)
await group.balance.set(amount)
if group.created_at() == 0:
if await group.created_at() == 0:
time = _encoded_current_time()
await group.created_at.set(time)
if group.name() == "":
if await group.name() == "":
await group.name.set(member.display_name)
return amount
@@ -144,7 +144,7 @@ async def withdraw_credits(member: discord.Member, amount: int) -> int:
if _invalid_amount(amount):
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
bal = get_balance(member)
bal = await get_balance(member)
if amount > bal:
raise ValueError("Insufficient funds {} > {}".format(amount, bal))
@@ -163,7 +163,7 @@ async def deposit_credits(member: discord.Member, amount: int) -> int:
if _invalid_amount(amount):
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
bal = get_balance(member)
bal = await get_balance(member)
return await set_balance(member, amount + bal)
@@ -190,13 +190,13 @@ async def wipe_bank(user: Union[discord.User, discord.Member]):
Deletes all accounts from the bank.
:return:
"""
if is_global():
if await is_global():
await _conf.user(user).clear()
else:
await _conf.member(user).clear()
def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]:
async def get_guild_accounts(guild: discord.Guild) -> List[Account]:
"""
Gets all account data for the given guild.
@@ -207,14 +207,16 @@ def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]:
if is_global():
raise RuntimeError("The bank is currently global.")
accs = _conf.member(guild.owner).all_from_kind()
ret = []
accs = await _conf.member(guild.owner).all_from_kind()
for user_id, acc in accs.items():
acc_data = acc.copy() # There ya go kowlin
acc_data['created_at'] = _decode_time(acc_data['created_at'])
yield Account(**acc_data)
ret.append(Account(**acc_data))
return ret
def get_global_accounts(user: discord.User) -> Generator[Account, None, None]:
async def get_global_accounts(user: discord.User) -> List[Account]:
"""
Gets all global account data.
@@ -225,44 +227,47 @@ def get_global_accounts(user: discord.User) -> Generator[Account, None, None]:
if not is_global():
raise RuntimeError("The bank is not currently global.")
accs = _conf.user(user).all_from_kind() # this is a dict of user -> acc
ret = []
accs = await _conf.user(user).all_from_kind() # this is a dict of user -> acc
for user_id, acc in accs.items():
acc_data = acc.copy()
acc_data['created_at'] = _decode_time(acc_data['created_at'])
yield Account(**acc_data)
ret.append(Account(**acc_data))
return ret
def get_account(member: Union[discord.Member, discord.User]) -> Account:
async def get_account(member: Union[discord.Member, discord.User]) -> Account:
"""
Gets the appropriate account for the given member.
:param member:
:return:
"""
if is_global():
acc_data = _conf.user(member)().copy()
if await is_global():
acc_data = (await _conf.user(member)()).copy()
default = _DEFAULT_USER.copy()
else:
acc_data = _conf.member(member)().copy()
acc_data = (await _conf.member(member)()).copy()
default = _DEFAULT_MEMBER.copy()
if acc_data == {}:
acc_data = default
acc_data['name'] = member.display_name
try:
acc_data['balance'] = get_default_balance(member.guild)
acc_data['balance'] = await get_default_balance(member.guild)
except AttributeError:
acc_data['balance'] = get_default_balance()
acc_data['balance'] = await get_default_balance()
acc_data['created_at'] = _decode_time(acc_data['created_at'])
return Account(**acc_data)
def is_global() -> bool:
async def is_global() -> bool:
"""
Determines if the bank is currently global.
:return:
"""
return _conf.is_global()
return await _conf.is_global()
async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -> bool:
@@ -272,7 +277,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -
:param user: Must be a Member object if changing TO global mode.
:return: New bank mode, True is global.
"""
if is_global() is global_:
if (await is_global()) is global_:
return global_
if is_global():
@@ -287,7 +292,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -
return global_
def get_bank_name(guild: discord.Guild=None) -> str:
async def get_bank_name(guild: discord.Guild=None) -> str:
"""
Gets the current bank name. If the bank is guild-specific the
guild parameter is required.
@@ -296,10 +301,10 @@ def get_bank_name(guild: discord.Guild=None) -> str:
:param guild:
:return:
"""
if is_global():
return _conf.bank_name()
if await is_global():
return await _conf.bank_name()
elif guild is not None:
return _conf.guild(guild).bank_name()
return await _conf.guild(guild).bank_name()
else:
raise RuntimeError("Guild parameter is required and missing.")
@@ -314,7 +319,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str:
:param guild:
:return:
"""
if is_global():
if await is_global():
await _conf.bank_name.set(name)
elif guild is not None:
await _conf.guild(guild).bank_name.set(name)
@@ -324,7 +329,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str:
return name
def get_currency_name(guild: discord.Guild=None) -> str:
async def get_currency_name(guild: discord.Guild=None) -> str:
"""
Gets the currency name of the bank. The guild parameter is required if
the bank is guild-specific.
@@ -333,10 +338,10 @@ def get_currency_name(guild: discord.Guild=None) -> str:
:param guild:
:return:
"""
if is_global():
return _conf.currency()
if await is_global():
return await _conf.currency()
elif guild is not None:
return _conf.guild(guild).currency()
return await _conf.guild(guild).currency()
else:
raise RuntimeError("Guild must be provided.")
@@ -351,7 +356,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str:
:param guild:
:return:
"""
if is_global():
if await is_global():
await _conf.currency.set(name)
elif guild is not None:
await _conf.guild(guild).currency.set(name)
@@ -361,7 +366,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str:
return name
def get_default_balance(guild: discord.Guild=None) -> int:
async def get_default_balance(guild: discord.Guild=None) -> int:
"""
Gets the current default balance amount. If the bank is guild-specific
you must pass guild.
@@ -370,10 +375,10 @@ def get_default_balance(guild: discord.Guild=None) -> int:
:param guild:
:return:
"""
if is_global():
return _conf.default_balance()
if await is_global():
return await _conf.default_balance()
elif guild is not None:
return _conf.guild(guild).default_balance()
return await _conf.guild(guild).default_balance()
else:
raise RuntimeError("Guild is missing and required!")
@@ -393,9 +398,11 @@ async def set_default_balance(amount: int, guild: discord.Guild=None) -> int:
if amount < 0:
raise ValueError("Amount must be greater than zero.")
if is_global():
if await is_global():
await _conf.default_balance.set(amount)
elif guild is not None:
await _conf.guild(guild).default_balance.set(amount)
else:
raise RuntimeError("Guild is missing and required.")
return amount

View File

@@ -1,3 +1,4 @@
import asyncio
import importlib.util
from importlib.machinery import ModuleSpec
@@ -39,14 +40,14 @@ class Red(commands.Bot):
mod_role=None
)
def prefix_manager(bot, message):
async def prefix_manager(bot, message):
if not cli_flags.prefix:
global_prefix = self.db.prefix()
global_prefix = await bot.db.prefix()
else:
global_prefix = cli_flags.prefix
if message.guild is None:
return global_prefix
server_prefix = self.db.guild(message.guild).prefix()
server_prefix = await bot.db.guild(message.guild).prefix()
return server_prefix if server_prefix else global_prefix
if "command_prefix" not in kwargs:
@@ -56,7 +57,8 @@ class Red(commands.Bot):
kwargs["owner_id"] = cli_flags.owner
if "owner_id" not in kwargs:
kwargs["owner_id"] = self.db.owner()
loop = asyncio.get_event_loop()
loop.run_until_complete(self._dict_abuse(kwargs))
self.counter = Counter()
self.uptime = None
@@ -68,6 +70,15 @@ class Red(commands.Bot):
super().__init__(**kwargs)
async def _dict_abuse(self, indict):
"""
Please blame <@269933075037814786> for this.
:param indict:
:return:
"""
indict['owner_id'] = await self.db.owner()
async def is_owner(self, user):
if user.id in self._co_owners:
return True
@@ -103,13 +114,13 @@ class Red(commands.Bot):
await self.db.packages.set(packages)
async def add_loaded_package(self, pkg_name: str):
curr_pkgs = self.db.packages()
curr_pkgs = await self.db.packages()
if pkg_name not in curr_pkgs:
curr_pkgs.append(pkg_name)
await self.save_packages_status(curr_pkgs)
async def remove_loaded_package(self, pkg_name: str):
curr_pkgs = self.db.packages()
curr_pkgs = await self.db.packages()
if pkg_name in curr_pkgs:
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])

View File

@@ -31,26 +31,29 @@ class CogManager:
install_path=str(bot_dir.resolve() / "cogs")
)
self._paths = set(list(self.conf.paths()) + list(paths))
self._paths = list(paths)
@property
def paths(self) -> Tuple[Path, ...]:
async def paths(self) -> Tuple[Path, ...]:
"""
This will return all currently valid path directories.
:return:
"""
paths = [Path(p) for p in self._paths]
conf_paths = await self.conf.paths()
other_paths = self._paths
all_paths = set(list(conf_paths) + list(other_paths))
paths = [Path(p) for p in all_paths]
if self.install_path not in paths:
paths.insert(0, self.install_path)
paths.insert(0, await self.install_path())
return tuple(p.resolve() for p in paths if p.is_dir())
@property
def install_path(self) -> Path:
async def install_path(self) -> Path:
"""
Returns the install path for 3rd party cogs.
:return:
"""
p = Path(self.conf.install_path())
p = Path(await self.conf.install_path())
return p.resolve()
async def set_install_path(self, path: Path) -> Path:
@@ -99,10 +102,10 @@ class CogManager:
if not path.is_dir():
raise InvalidPath("'{}' is not a valid directory.".format(path))
if path == self.install_path:
if path == await self.install_path():
raise ValueError("Cannot add the install path as an additional path.")
all_paths = set(self.paths + (path, ))
all_paths = set(await self.paths() + (path, ))
# noinspection PyTypeChecker
await self.set_paths(all_paths)
@@ -113,7 +116,7 @@ class CogManager:
:return:
"""
path = self._ensure_path_obj(path)
all_paths = list(self.paths)
all_paths = list(await self.paths())
if path in all_paths:
all_paths.remove(path) # Modifies in place
await self.set_paths(all_paths)
@@ -125,11 +128,10 @@ class CogManager:
:param paths_:
:return:
"""
self._paths = paths_
str_paths = [str(p) for p in paths_]
await self.conf.paths.set(str_paths)
def find_cog(self, name: str) -> ModuleSpec:
async def find_cog(self, name: str) -> ModuleSpec:
"""
Finds a cog in the list of available path.
@@ -137,7 +139,7 @@ class CogManager:
:param name:
:return:
"""
resolved_paths = [str(p.resolve()) for p in self.paths]
resolved_paths = [str(p.resolve()) for p in await self.paths()]
for finder, module_name, _ in pkgutil.iter_modules(resolved_paths):
if name == module_name:
spec = finder.find_spec(name)
@@ -166,7 +168,7 @@ class CogManagerUI:
"""
Lists current cog paths in order of priority.
"""
install_path = ctx.bot.cog_mgr.install_path
install_path = await ctx.bot.cog_mgr.install_path()
cog_paths = ctx.bot.cog_mgr.paths
cog_paths = [p for p in cog_paths if p != install_path]
@@ -204,7 +206,7 @@ class CogManagerUI:
Removes a path from the available cog paths given the path_number
from !paths
"""
cog_paths = ctx.bot.cog_mgr.paths
cog_paths = await ctx.bot.cog_mgr.paths()
try:
to_remove = cog_paths[path_number]
except IndexError:
@@ -224,7 +226,7 @@ class CogManagerUI:
from_ -= 1
to -= 1
all_paths = list(ctx.bot.cog_mgr.paths)
all_paths = list(await ctx.bot.cog_mgr.paths())
try:
to_move = all_paths.pop(from_)
except IndexError:
@@ -257,6 +259,6 @@ class CogManagerUI:
await ctx.send("That path does not exist.")
return
install_path = ctx.bot.cog_mgr.install_path
install_path = await ctx.bot.cog_mgr.install_path()
await ctx.send("The bot will install new cogs to the `{}`"
" directory.".format(install_path))

View File

@@ -38,6 +38,14 @@ class Value:
def identifiers(self):
return tuple(str(i) for i in self._identifiers)
async def _get(self, default):
driver = self.spawner.get_driver()
try:
ret = await driver.get(self.identifiers)
except KeyError:
return default or self.default
return ret
def __call__(self, default=None):
"""
Each :py:class:`Value` object is created by the :py:meth:`Group.__getattr__` method.
@@ -46,25 +54,26 @@ class Value:
For example::
foo = conf.guild(some_guild).foo()
foo = await conf.guild(some_guild).foo()
# Is equivalent to this
group_obj = conf.guild(some_guild)
value_obj = conf.foo
foo = value_obj()
foo = await value_obj()
.. important::
This is now, for all intents and purposes, a coroutine.
:param default:
This argument acts as an override for the registered default provided by :py:attr:`default`. This argument
is ignored if its value is :python:`None`.
:type default: Optional[object]
:return:
A coroutine object that must be awaited.
"""
driver = self.spawner.get_driver()
try:
ret = driver.get(self.identifiers)
except KeyError:
return default or self.default
return ret
return self._get(default)
async def set(self, value):
"""
@@ -182,7 +191,7 @@ class Group(Value):
return not isinstance(default, dict)
def get_attr(self, item: str, default=None, resolve=True):
async def get_attr(self, item: str, default=None, resolve=True):
"""
This is available to use as an alternative to using normal Python attribute access. It is required if you find
a need for dynamic attribute access.
@@ -198,7 +207,7 @@ class Group(Value):
user = ctx.author
# Where the value of item is the name of the data field in Config
await ctx.send(self.conf.user(user).get_attr(item))
await ctx.send(await self.conf.user(user).get_attr(item))
:param str item:
The name of the data field in :py:class:`.Config`.
@@ -211,20 +220,20 @@ class Group(Value):
"""
value = getattr(self, item)
if resolve:
return value(default=default)
return await value(default=default)
else:
return value
def all(self) -> dict:
async def all(self) -> dict:
"""
This method allows you to get "all" of a particular group of data. It will return the dictionary of all data
for a particular Guild/Channel/Role/User/Member etc.
:rtype: dict
"""
return self()
return await self()
def all_from_kind(self) -> dict:
async def all_from_kind(self) -> dict:
"""
This method allows you to get all data from all entries in a given Kind. It will return a dictionary of Kind
ID's -> data.
@@ -232,7 +241,7 @@ class Group(Value):
:rtype: dict
"""
# noinspection PyTypeChecker
return self._super_group()
return await self._super_group()
async def set(self, value):
if not isinstance(value, dict):
@@ -292,18 +301,18 @@ class MemberGroup(Group):
)
return group_obj
def all_guilds(self) -> dict:
async def all_guilds(self) -> dict:
"""
Returns a dict of :code:`GUILD_ID -> MEMBER_ID -> data`.
:rtype: dict
"""
# noinspection PyTypeChecker
return self._super_group()
return await self._super_group()
def all(self) -> dict:
async def all(self) -> dict:
# noinspection PyTypeChecker
return self._guild_group()
return await self._guild_group()
class Config:
@@ -315,7 +324,7 @@ class Config:
however the process for accessing global data is a bit different. There is no :python:`global` method
because global data is accessed by normal attribute access::
conf.foo()
await conf.foo()
.. py:attribute:: cog_name

View File

@@ -29,7 +29,7 @@ class Core:
async def load(self, ctx, *, cog_name: str):
"""Loads a package"""
try:
spec = ctx.bot.cog_mgr.find_cog(cog_name)
spec = await ctx.bot.cog_mgr.find_cog(cog_name)
except NoModuleFound:
await ctx.send("No module by that name was found in any"
" cog path.")
@@ -63,7 +63,7 @@ class Core:
ctx.bot.unload_extension(cog_name)
self.cleanup_and_refresh_modules(cog_name)
try:
spec = ctx.bot.cog_mgr.find_cog(cog_name)
spec = await ctx.bot.cog_mgr.find_cog(cog_name)
ctx.bot.load_extension(spec)
except Exception as e:
log.exception("Package reloading failed", exc_info=e)

View File

@@ -5,7 +5,7 @@ class BaseDriver:
def get_driver(self):
raise NotImplementedError
def get(self, identifiers: Tuple[str]):
async def get(self, identifiers: Tuple[str]):
raise NotImplementedError
async def set(self, identifiers: Tuple[str], value):

View File

@@ -32,7 +32,7 @@ class JSON(BaseDriver):
def get_driver(self):
return self
def get(self, identifiers: Tuple[str]):
async def get(self, identifiers: Tuple[str]):
partial = self.data
for i in identifiers:
partial = partial[i]

View File

@@ -34,11 +34,11 @@ def init_events(bot, cli_flags):
if cli_flags.no_cogs is False:
print("Loading packages...")
failed = []
packages = bot.db.packages()
packages = await bot.db.packages()
for package in packages:
try:
spec = bot.cog_mgr.find_cog(package)
spec = await bot.cog_mgr.find_cog(package)
bot.load_extension(spec)
except Exception as e:
log.exception("Failed to load package {}".format(package),