[Config] Rewrite (#869)

This commit is contained in:
Will
2017-07-30 19:40:31 -04:00
committed by Twentysix
parent 5c2be25dfc
commit 99bfb2fc7a
14 changed files with 636 additions and 724 deletions

View File

@@ -47,7 +47,7 @@ class Red(commands.Bot):
kwargs["owner_id"] = cli_flags.owner
if "owner_id" not in kwargs:
kwargs["owner_id"] = self.db.get("owner")
kwargs["owner_id"] = self.db.owner()
self.counter = Counter()
self.uptime = None
@@ -89,7 +89,7 @@ class Red(commands.Bot):
for package in self.extensions:
if package.startswith("cogs."):
loaded.append(package)
await self.db.set("packages", loaded)
await self.db.packages.set(loaded)
class ExitCodes(Enum):

View File

@@ -22,7 +22,7 @@ def interactive_config(red, token_set, prefix_set):
print("That doesn't look like a valid token.")
token = ""
if token:
loop.run_until_complete(red.db.set("token", token))
loop.run_until_complete(red.db.token.set(token))
if not prefix_set:
prefix = ""
@@ -39,7 +39,7 @@ def interactive_config(red, token_set, prefix_set):
if not confirm("> "):
prefix = ""
if prefix:
loop.run_until_complete(red.db.set("prefix", [prefix]))
loop.run_until_complete(red.db.prefix.set([prefix]))
ask_sentry(red)
@@ -55,9 +55,9 @@ def ask_sentry(red: Red):
" found issues in a timely manner. If you wish to opt in\n"
" the process please type \"yes\":\n")
if not confirm("> "):
loop.run_until_complete(red.db.set("enable_sentry", False))
loop.run_until_complete(red.db.enable_sentry.set(False))
else:
loop.run_until_complete(red.db.set("enable_sentry", True))
loop.run_until_complete(red.db.enable_sentry.set(True))
print("\nThank you for helping us with the development process!")

View File

@@ -1,521 +1,374 @@
from pathlib import Path
from core.drivers.red_json import JSON as JSONDriver
from core.drivers.red_mongo import Mongo
import logging
from typing import Callable
from typing import Callable, Union, Tuple
import discord
from copy import deepcopy
from pathlib import Path
from .drivers.red_json import JSON as JSONDriver
log = logging.getLogger("red.config")
class BaseConfig:
def __init__(self, cog_name, unique_identifier, driver_spawn, force_registration=False,
hash_uuid=True, collection="GLOBAL", collection_uuid=None,
defaults={}):
self.cog_name = cog_name
if hash_uuid:
self.uuid = str(hash(unique_identifier))
else:
self.uuid = unique_identifier
self.driver_spawn = driver_spawn
self._driver = None
self.collection = collection
self.collection_uuid = collection_uuid
self.force_registration = force_registration
class Value:
def __init__(self, identifiers: Tuple[str], default_value, spawner):
self._identifiers = identifiers
self.default = default_value
self.spawner = spawner
@property
def identifiers(self):
return tuple(str(i) for i in self._identifiers)
def __call__(self, default=None):
driver = self.spawner.get_driver()
try:
self.driver.maybe_add_ident(self.uuid)
except AttributeError:
pass
ret = driver.get(self.identifiers)
except KeyError:
return default or self.default
return ret
self.driver_getmap = {
"GLOBAL": self.driver.get_global,
"GUILD": self.driver.get_guild,
"CHANNEL": self.driver.get_channel,
"ROLE": self.driver.get_role,
"USER": self.driver.get_user
}
async def set(self, value):
driver = self.spawner.get_driver()
await driver.set(self.identifiers, value)
self.driver_setmap = {
"GLOBAL": self.driver.set_global,
"GUILD": self.driver.set_guild,
"CHANNEL": self.driver.set_channel,
"ROLE": self.driver.set_role,
"USER": self.driver.set_user
}
self.curr_key = None
class Group(Value):
def __init__(self, identifiers: Tuple[str],
defaults: dict,
spawner,
force_registration: bool=False):
self.defaults = defaults
self.force_registration = force_registration
self.spawner = spawner
self.unsettable_keys = ("cog_name", "cog_identifier", "_id",
"guild_id", "channel_id", "role_id",
"user_id", "uuid")
self.invalid_keys = (
"driver_spawn",
"_driver", "collection",
"collection_uuid", "force_registration"
super().__init__(identifiers, {}, self.spawner)
# noinspection PyTypeChecker
def __getattr__(self, item: str) -> Union["Group", Value]:
"""
Takes in the next accessible item. If it's found to be a Group
we return another Group object. If it's found to be a Value
we return a Value object. If it is not found and
force_registration is True then we raise AttributeException,
otherwise return a Value object.
:param item:
:return:
"""
is_group = self.is_group(item)
is_value = not is_group and self.is_value(item)
new_identifiers = self.identifiers + (item, )
if is_group:
return Group(
identifiers=new_identifiers,
defaults=self.defaults[item],
spawner=self.spawner,
force_registration=self.force_registration
)
elif is_value:
return Value(
identifiers=new_identifiers,
default_value=self.defaults[item],
spawner=self.spawner
)
elif self.force_registration:
raise AttributeError(
"'{}' is not a valid registered Group"
"or value.".format(item)
)
else:
return Value(
identifiers=new_identifiers,
default_value=None,
spawner=self.spawner
)
@property
def _super_group(self) -> 'Group':
super_group = Group(
self.identifiers[:-1],
defaults={},
spawner=self.spawner,
force_registration=self.force_registration
)
return super_group
self.defaults = defaults if defaults else {
"GLOBAL": {}, "GUILD": {}, "CHANNEL": {}, "ROLE": {},
"MEMBER": {}, "USER": {}}
def is_group(self, item: str) -> bool:
"""
Determines if an attribute access is pointing at a registered group.
:param item:
:return:
"""
default = self.defaults.get(item)
return isinstance(default, dict)
def is_value(self, item: str) -> bool:
"""
Determines if an attribute access is pointing at a registered value.
:param item:
:return:
"""
try:
default = self.defaults[item]
except KeyError:
return False
return not isinstance(default, dict)
def get_attr(self, item: str, default=None):
"""
You should avoid this function whenever possible.
:param item:
:param default:
:return:
"""
value = getattr(self, item)
return value(default=default)
def all(self) -> dict:
"""
Gets all entries of the given kind. If this kind is member
then this method returns all members from the same
server.
:return:
"""
# noinspection PyTypeChecker
return self._super_group()
async def set(self, value):
if not isinstance(value, dict):
raise ValueError(
"You may only set the value of a group to be a dict."
)
await super().set(value)
async def set_attr(self, item: str, value):
"""
You should avoid this function whenever possible.
:param item:
:param value:
:return:
"""
value_obj = getattr(self, item)
await value_obj.set(value)
async def clear(self):
"""
Wipes out data for the given entry in this category
e.g. Guild/Role/User
:return:
"""
await self.set({})
async def clear_all(self):
"""
Removes all data from all entries.
:return:
"""
await self._super_group.set({})
class MemberGroup(Group):
@property
def _super_group(self) -> Group:
new_identifiers = self.identifiers[:2]
group_obj = Group(
identifiers=new_identifiers,
defaults={},
spawner=self.spawner
)
return group_obj
@property
def _guild_group(self) -> Group:
new_identifiers = self.identifiers[:3]
group_obj = Group(
identifiers=new_identifiers,
defaults={},
spawner=self.spawner
)
return group_obj
def all_guilds(self) -> dict:
"""
Gets a dict of all guilds and members.
REMEMBER: ID's are stored in these dicts as STRINGS.
:return:
"""
# noinspection PyTypeChecker
return self._super_group()
def all(self) -> dict:
"""
Returns the dict of all members in the same guild.
:return:
"""
# noinspection PyTypeChecker
return self._guild_group()
class Config:
GLOBAL = "GLOBAL"
GUILD = "GUILD"
CHANNEL = "TEXTCHANNEL"
ROLE = "ROLE"
USER = "USER"
MEMBER = "MEMBER"
def __init__(self, cog_name: str, unique_identifier: str,
driver_spawn: Callable,
force_registration: bool=False,
defaults: dict=None):
self.cog_name = cog_name
self.unique_identifier = unique_identifier
self.spawner = driver_spawn
self.force_registration = force_registration
self.defaults = defaults or {}
@classmethod
def get_conf(cls, cog_instance: object, unique_identifier: int=0,
force_registration: bool=False):
def get_conf(cls, cog_instance, identifier: int,
force_registration=False):
"""
Gets a config object that cog's can use to safely store data. The
backend to this is totally modular and can easily switch between
JSON and a DB. However, when changed, all data will likely be lost
unless cogs write some converters for their data.
Positional Arguments:
cog_instance - The cog `self` object, can be passed in from your
cog's __init__ method.
Keyword Arguments:
unique_identifier - a random integer or string that is used to
differentiate your cog from any other named the same. This way we
can safely store data for multiple cogs that are named the same.
YOU SHOULD USE THIS.
force_registration - A flag which will cause the Config object to
throw exceptions if you try to get/set data keys that you have
not pre-registered. I highly recommend you ENABLE this as it
will help reduce dumb typo errors.
Returns a Config instance based on a simplified set of initial
variables.
:param cog_instance:
:param identifier: Any random integer, used to keep your data
distinct from any other cog with the same name.
:param force_registration: Should config require registration
of data keys before allowing you to get/set values?
:return:
"""
url = None # TODO: get mongo url
port = None # TODO: get mongo port
def spawn_mongo_driver():
return Mongo(url, port)
# TODO: Determine which backend users want, default to JSON
cog_name = cog_instance.__class__.__name__
uuid = str(hash(identifier))
driver_spawn = JSONDriver(cog_name)
return cls(cog_name=cog_name, unique_identifier=unique_identifier,
driver_spawn=driver_spawn, force_registration=force_registration)
spawner = JSONDriver(cog_name)
return cls(cog_name=cog_name, unique_identifier=uuid,
force_registration=force_registration,
driver_spawn=spawner)
@classmethod
def get_core_conf(cls, force_registration: bool=False):
core_data_path = Path.cwd() / 'core' / '.data'
driver_spawn = JSONDriver("Core", data_path_override=core_data_path)
return cls(cog_name="Core", driver_spawn=driver_spawn,
unique_identifier=0,
unique_identifier='0',
force_registration=force_registration)
@property
def driver(self):
if self._driver is None:
try:
self._driver = self.driver_spawn()
except TypeError:
return self.driver_spawn
def __getattr__(self, item: str) -> Union[Group, Value]:
"""
This is used to generate Value or Group objects for global
values.
:param item:
:return:
"""
global_group = self._get_base_group(self.GLOBAL)
return getattr(global_group, item)
return self._driver
def __getattr__(self, key):
"""This should be used to return config key data as determined by
`self.collection` and `self.collection_uuid`."""
raise NotImplemented
def __setattr__(self, key, value):
if 'defaults' in self.__dict__: # Necessary to let the cog load
restricted = list(self.defaults[self.collection].keys()) + \
list(self.unsettable_keys)
if key in restricted:
raise ValueError("Not allowed to dynamically set attributes of"
" unsettable_keys: {}".format(restricted))
@staticmethod
def _get_defaults_dict(key: str, value) -> dict:
"""
Since we're allowing nested config stuff now, not storing the
defaults as a flat dict sounds like a good idea. May turn
out to be an awful one but we'll see.
:param key:
:param value:
:return:
"""
ret = {}
partial = ret
splitted = key.split('__')
for i, k in enumerate(splitted, start=1):
if not k.isidentifier():
raise RuntimeError("'{}' is an invalid config key.".format(k))
if i == len(splitted):
partial[k] = value
else:
self.__dict__[key] = value
else:
self.__dict__[key] = value
def clear(self):
"""Clears all values in the current context ONLY."""
raise NotImplemented
def set(self, key, value):
"""This should set config key with value `value` in the
corresponding collection as defined by `self.collection` and
`self.collection_uuid`."""
raise NotImplemented
def guild(self, guild):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def channel(self, channel):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def role(self, role):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def member(self, member):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def user(self, user):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def register_global(self, **global_defaults):
"""
Registers a new dict of global defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param global_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in global_defaults.items():
try:
self._register_global(k, v)
except KeyError:
log.exception("Bad default global key.")
def _register_global(self, key, default=None):
"""Registers a global config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["GLOBAL"][key] = default
def register_guild(self, **guild_defaults):
"""
Registers a new dict of guild defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param guild_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in guild_defaults.items():
try:
self._register_guild(k, v)
except KeyError:
log.exception("Bad default guild key.")
def _register_guild(self, key, default=None):
"""Registers a guild config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["GUILD"][key] = default
def register_channel(self, **channel_defaults):
"""
Registers a new dict of channel defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param channel_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in channel_defaults.items():
try:
self._register_channel(k, v)
except KeyError:
log.exception("Bad default channel key.")
def _register_channel(self, key, default=None):
"""Registers a channel config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["CHANNEL"][key] = default
def register_role(self, **role_defaults):
"""
Registers a new dict of role defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param role_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in role_defaults.items():
try:
self._register_role(k, v)
except KeyError:
log.exception("Bad default role key.")
def _register_role(self, key, default=None):
"""Registers a role config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["ROLE"][key] = default
def register_member(self, **member_defaults):
"""
Registers a new dict of member defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param member_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in member_defaults.items():
try:
self._register_member(k, v)
except KeyError:
log.exception("Bad default member key.")
def _register_member(self, key, default=None):
"""Registers a member config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["MEMBER"][key] = default
def register_user(self, **user_defaults):
"""
Registers a new dict of user defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param user_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in user_defaults.items():
try:
self._register_user(k, v)
except KeyError:
log.exception("Bad default user key.")
def _register_user(self, key, default=None):
"""Registers a user config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["USER"][key] = default
class Config(BaseConfig):
"""
Config object created by `Config.get_conf()`
This configuration object is designed to make backend data
storage mechanisms pluggable. It also is designed to
help a cog developer make fewer mistakes (such as
typos) when dealing with cog data and to make those mistakes
apparent much faster in the design process.
It also has the capability to safely store data between cogs
that share the same name.
There are two main components to this config object. First,
you have the ability to get data on a level specific basis.
The seven levels available are: global, guild, channel, role,
member, user, and misc.
The second main component is registering default values for
data in each of the levels. This functionality is OPTIONAL
and must be explicitly enabled when creating the Config object
using the kwarg `force_registration=True`.
Basic Usage:
Creating a Config object:
Use the `Config.get_conf()` class method to create new
Config objects.
See the `Config.get_conf()` documentation for more
information.
Registering Default Values (optional):
You can register default values for data at all levels
EXCEPT misc.
Simply pass in the key/value pairs as keyword arguments to
the respective function.
e.g.: conf_obj.register_global(enabled=True)
conf_obj.register_guild(likes_red=True)
Retrieving data by attributes:
Since I registered the "enabled" key in the previous example
at the global level I can now do:
conf_obj.enabled()
which will retrieve the current value of the "enabled"
key, making use of the default of "True". I can also do
the same for the guild key "likes_red":
conf_obj.guild(guild_obj).likes_red()
If I elected to not register default values, you can provide them
when you try to access the key:
conf_obj.no_default(default=True)
However if you do not provide a default and you do not register
defaults, accessing the attribute will return "None".
Saving data:
This is accomplished by using the `set` function available at
every level.
e.g.: conf_obj.set("enabled", False)
conf_obj.guild(guild_obj).set("likes_red", False)
If `force_registration` was enabled when the config object
was created you will only be allowed to save keys that you
have registered.
Misc data is special, use `conf.misc()` and `conf.set_misc(value)`
respectively.
"""
def __getattr__(self, key) -> Callable:
"""
Until I've got a better way to do this I'm just gonna fake __call__
:param key:
:return: lambda function with kwarg
"""
return self._get_value_from_key(key)
def _get_value_from_key(self, key) -> Callable:
try:
default = self.defaults[self.collection][key]
except KeyError as e:
if self.force_registration:
raise AttributeError("Key '{}' not registered!".format(key)) from e
default = None
self.curr_key = key
if self.collection != "MEMBER":
ret = lambda default=default: self.driver_getmap[self.collection](
self.cog_name, self.uuid, self.collection_uuid, key,
default=default)
else:
mid, sid = self.collection_uuid
ret = lambda default=default: self.driver.get_member(
self.cog_name, self.uuid, mid, sid, key,
default=default)
partial[k] = {}
partial = partial[k]
return ret
def get(self, key, default=None):
@staticmethod
def _update_defaults(to_add: dict, _partial: dict):
"""
Included as an alternative to registering defaults.
:param key:
:param default:
:return:
This tries to update the defaults dictionary with the nested
partial dict generated by _get_defaults_dict. This WILL
throw an error if you try to have both a value and a group
registered under the same name.
:param to_add:
:param _partial:
:return:
"""
for k, v in to_add.items():
val_is_dict = isinstance(v, dict)
if k in _partial:
existing_is_dict = isinstance(_partial[k], dict)
if val_is_dict != existing_is_dict:
# != is XOR
raise KeyError("You cannot register a Group and a Value under"
" the same name.")
if val_is_dict:
Config._update_defaults(v, _partial=_partial[k])
else:
_partial[k] = v
else:
_partial[k] = v
if default is not None:
return self._get_value_from_key(key)(default)
else:
return self._get_value_from_key(key)()
def _register_default(self, key: str, **kwargs):
if key not in self.defaults:
self.defaults[key] = {}
async def set(self, key, value):
# Notice to future developers:
# This code was commented to allow users to set keys without having to register them.
# That being said, if they try to get keys without registering them
# things will blow up. I do highly recommend enforcing the key registration.
data = deepcopy(kwargs)
if key in self.unsettable_keys or key in self.invalid_keys:
raise KeyError("Restricted key name, please use another.")
for k, v in data.items():
to_add = self._get_defaults_dict(k, v)
self._update_defaults(to_add, self.defaults[key])
if self.force_registration and key not in self.defaults[self.collection]:
raise AttributeError("Key '{}' not registered!".format(key))
def register_global(self, **kwargs):
self._register_default(self.GLOBAL, **kwargs)
if not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
def register_guild(self, **kwargs):
self._register_default(self.GUILD, **kwargs)
if self.collection == "GLOBAL":
await self.driver.set_global(self.cog_name, self.uuid, key, value)
elif self.collection == "MEMBER":
mid, sid = self.collection_uuid
await self.driver.set_member(self.cog_name, self.uuid, mid, sid,
key, value)
elif self.collection in self.driver_setmap:
func = self.driver_setmap[self.collection]
await func(self.cog_name, self.uuid, self.collection_uuid, key, value)
def register_channel(self, **kwargs):
# We may need to add a voice channel category later
self._register_default(self.CHANNEL, **kwargs)
async def clear(self):
await self.driver_setmap[self.collection](
self.cog_name, self.uuid, self.collection_uuid, None, None,
clear=True)
def register_role(self, **kwargs):
self._register_default(self.ROLE, **kwargs)
def guild(self, guild):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "GUILD"
new.collection_uuid = guild.id
new._driver = None
return new
def register_user(self, **kwargs):
self._register_default(self.USER, **kwargs)
def channel(self, channel):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "CHANNEL"
new.collection_uuid = channel.id
new._driver = None
return new
def register_member(self, **kwargs):
self._register_default(self.MEMBER, **kwargs)
def role(self, role):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "ROLE"
new.collection_uuid = role.id
new._driver = None
return new
def _get_base_group(self, key: str, *identifiers: str,
group_class=Group) -> Group:
# noinspection PyTypeChecker
return group_class(
identifiers=(self.unique_identifier, key) + identifiers,
defaults=self.defaults.get(key, {}),
spawner=self.spawner,
force_registration=self.force_registration
)
def member(self, member):
guild = member.guild
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "MEMBER"
new.collection_uuid = (member.id, guild.id)
new._driver = None
return new
def guild(self, guild: discord.Guild) -> Group:
return self._get_base_group(self.GUILD, guild.id)
def channel(self, channel: discord.TextChannel) -> Group:
return self._get_base_group(self.CHANNEL, channel.id)
def role(self, role: discord.Role) -> Group:
return self._get_base_group(self.ROLE, role.id)
def user(self, user: discord.User) -> Group:
return self._get_base_group(self.USER, user.id)
def member(self, member: discord.Member) -> MemberGroup:
return self._get_base_group(self.MEMBER, member.guild.id, member.id,
group_class=MemberGroup)
def user(self, user):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "USER"
new.collection_uuid = user.id
new._driver = None
return new

View File

@@ -98,7 +98,7 @@ class Core:
@commands.guild_only()
async def adminrole(self, ctx, *, role: discord.Role):
"""Sets the admin role for this server"""
await ctx.bot.db.guild(ctx.guild).set("admin_role", role.id)
await ctx.bot.db.guild(ctx.guild).admin_role.set(role.id)
await ctx.send("The admin role for this server has been set.")
@_set.command()
@@ -106,7 +106,7 @@ class Core:
@commands.guild_only()
async def modrole(self, ctx, *, role: discord.Role):
"""Sets the mod role for this server"""
await ctx.bot.db.guild(ctx.guild).set("mod_role", role.id)
await ctx.bot.db.guild(ctx.guild).mod_role.set(role.id)
await ctx.send("The mod role for this server has been set.")
@_set.command()
@@ -225,7 +225,7 @@ class Core:
await ctx.bot.send_cmd_help(ctx)
return
prefixes = sorted(prefixes, reverse=True)
await ctx.bot.db.set("prefix", prefixes)
await ctx.bot.db.prefix.set(prefixes)
await ctx.send("Prefix set.")
@_set.command(aliases=["serverprefixes"])
@@ -234,11 +234,11 @@ class Core:
async def serverprefix(self, ctx, *prefixes):
"""Sets Red's server prefix(es)"""
if not prefixes:
await ctx.bot.db.guild(ctx.guild).set("prefix", [])
await ctx.bot.db.guild(ctx.guild).prefix.set([])
await ctx.send("Server prefixes have been reset.")
return
prefixes = sorted(prefixes, reverse=True)
await ctx.bot.db.guild(ctx.guild).set("prefix", prefixes)
await ctx.bot.db.guild(ctx.guild).prefix.set(prefixes)
await ctx.send("Prefix set.")
@_set.command()
@@ -276,7 +276,7 @@ class Core:
else:
if message.content.strip() == token:
self.owner.reset_cooldown(ctx)
await ctx.bot.db.set("owner", ctx.author.id)
await ctx.bot.db.owner.set(ctx.author.id)
ctx.bot.owner_id = ctx.author.id
await ctx.send("You have been set as owner.")
else:

View File

@@ -1,45 +1,12 @@
from typing import Tuple
class BaseDriver:
def get_global(self, cog_name, ident, collection_id, key, *, default=None):
raise NotImplementedError()
def get_driver(self):
raise NotImplementedError
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
raise NotImplementedError()
def get(self, identifiers: Tuple[str]):
raise NotImplementedError
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
raise NotImplementedError()
def get_role(self, cog_name, ident, role_id, key, *, default=None):
raise NotImplementedError()
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
default=None):
raise NotImplementedError()
def get_user(self, cog_name, ident, user_id, key, *, default=None):
raise NotImplementedError()
def get_misc(self, cog_name, ident, *, default=None):
raise NotImplementedError()
async def set_global(self, cog_name, ident, key, value, clear=False):
raise NotImplementedError()
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
raise NotImplementedError()
async def set_channel(self, cog_name, ident, channel_id, key, value,
clear=False):
raise NotImplementedError()
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
raise NotImplementedError()
async def set_member(self, cog_name, ident, user_id, guild_id, key, value,
clear=False):
raise NotImplementedError()
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
raise NotImplementedError()
async def set_misc(self, cog_name, ident, value, clear=False):
raise NotImplementedError()
async def set(self, identifiers: Tuple[str], value):
raise NotImplementedError

View File

@@ -1,13 +1,15 @@
from typing import Tuple
from core.drivers.red_base import BaseDriver
from core.json_io import JsonIO
import os
from .red_base import BaseDriver
from pathlib import Path
class JSON(BaseDriver):
def __init__(self, cog_name, *args, data_path_override: Path=None,
file_name_override: str="settings.json", **kwargs):
def __init__(self, cog_name, *, data_path_override: Path=None,
file_name_override: str="settings.json"):
super().__init__()
self.cog_name = cog_name
self.file_name = file_name_override
if data_path_override:
@@ -25,111 +27,23 @@ class JSON(BaseDriver):
self.data = self.jsonIO._load_json()
except FileNotFoundError:
self.data = {}
def maybe_add_ident(self, ident: str):
if ident in self.data:
return
self.data[ident] = {}
for k in ("GLOBAL", "GUILD", "CHANNEL", "ROLE", "MEMBER", "USER"):
if k not in self.data[ident]:
self.data[ident][k] = {}
self.jsonIO._save_json(self.data)
def get_global(self, cog_name, ident, _, key, *, default=None):
return self.data[ident]["GLOBAL"].get(key, default)
def get_driver(self):
return self
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
guilddata = self.data[ident]["GUILD"].get(str(guild_id), {})
return guilddata.get(key, default)
def get(self, identifiers: Tuple[str]):
partial = self.data
for i in identifiers:
partial = partial[i]
return partial
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
channeldata = self.data[ident]["CHANNEL"].get(str(channel_id), {})
return channeldata.get(key, default)
async def set(self, identifiers, value):
partial = self.data
for i in identifiers[:-1]:
if i not in partial:
partial[i] = {}
partial = partial[i]
def get_role(self, cog_name, ident, role_id, key, *, default=None):
roledata = self.data[ident]["ROLE"].get(str(role_id), {})
return roledata.get(key, default)
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
default=None):
userdata = self.data[ident]["MEMBER"].get(str(user_id), {})
guilddata = userdata.get(str(guild_id), {})
return guilddata.get(key, default)
def get_user(self, cog_name, ident, user_id, key, *, default=None):
userdata = self.data[ident]["USER"].get(str(user_id), {})
return userdata.get(key, default)
async def set_global(self, cog_name, ident, key, value, clear=False):
if clear:
self.data[ident]["GLOBAL"] = {}
else:
self.data[ident]["GLOBAL"][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
guild_id = str(guild_id)
if clear:
self.data[ident]["GUILD"][guild_id] = {}
else:
try:
self.data[ident]["GUILD"][guild_id][key] = value
except KeyError:
self.data[ident]["GUILD"][guild_id] = {}
self.data[ident]["GUILD"][guild_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_channel(self, cog_name, ident, channel_id, key, value, clear=False):
channel_id = str(channel_id)
if clear:
self.data[ident]["CHANNEL"][channel_id] = {}
else:
try:
self.data[ident]["CHANNEL"][channel_id][key] = value
except KeyError:
self.data[ident]["CHANNEL"][channel_id] = {}
self.data[ident]["CHANNEL"][channel_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
role_id = str(role_id)
if clear:
self.data[ident]["ROLE"][role_id] = {}
else:
try:
self.data[ident]["ROLE"][role_id][key] = value
except KeyError:
self.data[ident]["ROLE"][role_id] = {}
self.data[ident]["ROLE"][role_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_member(self, cog_name, ident, user_id, guild_id, key, value, clear=False):
user_id = str(user_id)
guild_id = str(guild_id)
if clear:
self.data[ident]["MEMBER"][user_id] = {}
else:
try:
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
except KeyError:
if user_id not in self.data[ident]["MEMBER"]:
self.data[ident]["MEMBER"][user_id] = {}
if guild_id not in self.data[ident]["MEMBER"][user_id]:
self.data[ident]["MEMBER"][user_id][guild_id] = {}
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
user_id = str(user_id)
if clear:
self.data[ident]["USER"][user_id] = {}
else:
try:
self.data[ident]["USER"][user_id][key] = value
except KeyError:
self.data[ident]["USER"][user_id] = {}
self.data[ident]["USER"][user_id][key] = value
partial[identifiers[-1]] = value
await self.jsonIO._threadsafe_save_json(self.data)