diff --git a/.gitignore b/.gitignore index e104f19b..c1db31ec 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ secrets.env **/__pycache__ test-file.py /src/cogs/TestCog.py -.vscode \ No newline at end of file +.vscode +.idea \ No newline at end of file diff --git a/README.md b/README.md index 4a7f7d13..c3477f67 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,8 @@ TWITCH_CLIENT_ID= TWITCH_CLIENT_SECRET= GOOGLE_API= ENABLE_MUSIC=TRUE +PGADMIN_DEFAULT_EMAIL=example@email.com +PGADMIN_DEFAULT_PASSWORD=changeme ```
diff --git a/docker-compose.yml b/docker-compose.yml index 48f8f57c..3707df9b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,5 +23,14 @@ services: - ENABLE_TWITCH=False restart: unless-stopped + pg_admin: + image: "dpage/pgadmin4" + restart: unless-stopped + env_file: secrets.env + environment: + - PGADMIN_LISTEN_PORT=80 + ports: + - "8080:80" + volumes: db_data: diff --git a/src/esportsbot/base_functions.py b/src/esportsbot/base_functions.py index 37c94e58..89ea9acb 100644 --- a/src/esportsbot/base_functions.py +++ b/src/esportsbot/base_functions.py @@ -1,10 +1,11 @@ -from .db_gateway import db_gateway +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Voicemaster_master, Voicemaster_slave, Guild_info async def send_to_log_channel(self, guild_id, msg): - db_logging_call = db_gateway().get('guild_info', params={'guild_id': guild_id}) - if db_logging_call and db_logging_call[0]['log_channel_id']: - await self.bot.get_channel(db_logging_call[0]['log_channel_id']).send(msg) + db_logging_call = DBGatewayActions().get(Guild_info, guild_id=guild_id) + if db_logging_call and db_logging_call.log_channel_id is not None: + await self.bot.get_channel(db_logging_call.log_channel_id).send(msg) def role_id_from_mention(pre_clean_data: str) -> int: @@ -48,10 +49,10 @@ def user_id_from_mention(pre_clean_data: str) -> int: def get_whether_in_vm_master(guild_id, channel_id): - in_master = db_gateway().get('voicemaster_master', params={'guild_id': guild_id, 'channel_id': channel_id}) + in_master = DBGatewayActions().get(Voicemaster_master, guild_id=guild_id, channel_id=channel_id) return bool(in_master) def get_whether_in_vm_slave(guild_id, channel_id): - in_slave = db_gateway().get('voicemaster_slave', params={'guild_id': guild_id, 'channel_id': channel_id}) + in_slave = DBGatewayActions().get(Voicemaster_slave, guild_id=guild_id, channel_id=channel_id) return bool(in_slave) diff --git a/src/esportsbot/bot.py b/src/esportsbot/bot.py index e2050889..ee72b47d 100644 --- a/src/esportsbot/bot.py +++ b/src/esportsbot/bot.py @@ -1,9 +1,10 @@ from typing import Dict, Any -from dotenv import load_dotenv -from . import lib -from .base_functions import get_whether_in_vm_master, get_whether_in_vm_slave -from .generate_schema import generate_schema -from .db_gateway import db_gateway +from esportsbot import lib +from esportsbot.base_functions import get_whether_in_vm_master, get_whether_in_vm_slave + +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Guild_info, Voicemaster_slave, Pingable_roles + from discord.ext import commands from discord.ext.commands import CommandNotFound, MissingRequiredArgument from discord.ext.commands.context import Context @@ -22,30 +23,12 @@ # EsportsBot client instance client = lib.client.instance() -# TODO -client.remove_command('help') - - -def make_guild_init_data(guild: discord.Guild) -> Dict[str, Any]: - """Construct default data for a guild database registration. - - :param discord.Guild guild: The guild to be registered - :return: A dictionary with default guild attributes, including the guild ID - :rtype: Dict[str, Any] - """ - return { - 'guild_id': guild.id, - 'num_running_polls': 0, - 'role_ping_cooldown_seconds': int(DEFAULT_ROLE_PING_COOLDOWN.total_seconds()), - "pingme_create_threshold": DEFAULT_PINGME_CREATE_THRESHOLD, - "pingme_create_poll_length_seconds": int(DEFAULT_PINGME_CREATE_POLL_LENGTH.total_seconds()) - } async def send_to_log_channel(guild_id, msg): - db_logging_call = db_gateway().get('guild_info', params={'guild_id': guild_id}) - if db_logging_call and db_logging_call[0]['log_channel_id']: - await client.get_channel(db_logging_call[0]['log_channel_id']).send(msg) + db_logging_call = DBGatewayActions().get(Guild_info, guild_id=guild_id) + if db_logging_call and db_logging_call.log_channel_id is not None: + await client.get_channel(db_logging_call.log_channel_id).send(msg) @client.event @@ -63,21 +46,32 @@ async def on_ready(): @client.event async def on_guild_join(guild): print(f"Joined the guild: {guild.name}") - db_gateway().insert('guild_info', params=make_guild_init_data(guild)) + DBGatewayActions().create( + Guild_info( + guild_id=guild.id, + num_running_polls=0, + role_ping_cooldown_seconds=int(DEFAULT_ROLE_PING_COOLDOWN.total_seconds()), + pingme_create_threshold=DEFAULT_PINGME_CREATE_THRESHOLD, + pingme_create_poll_length_seconds=int(DEFAULT_PINGME_CREATE_POLL_LENGTH.total_seconds()) + ) + ) @client.event async def on_guild_remove(guild): - print(f"Left the guild: {guild.name}") - db_gateway().delete('guild_info', where_params={'guild_id': guild.id}) + guild_from_db = DBGatewayActions().get(Guild_info, guild_id=guild.id) + if guild_from_db: + DBGatewayActions().delete(guild_from_db) + print(f"Left the guild: {guild.name}") @client.event async def on_member_join(member): - default_role_exists = db_gateway().get('guild_info', params={'guild_id': member.guild.id}) + guild = DBGatewayActions().get(Guild_info, guild_id=member.guild.id) + default_role_exists = guild.default_role_id is not None - if default_role_exists[0]['default_role_id']: - default_role = member.guild.get_role(default_role_exists[0]['default_role_id']) + if default_role_exists: + default_role = member.guild.get_role(guild.default_role_id) await member.add_roles(default_role) await send_to_log_channel( member.guild.id, @@ -93,41 +87,27 @@ async def on_voice_state_update(member, before, after): after_channel_id = after.channel.id if after.channel != None else False if before_channel_id and get_whether_in_vm_slave(member.guild.id, before_channel_id): + vm_slave = DBGatewayActions().get(Voicemaster_slave, guild_id=member.guild.id, channel_id=before_channel_id) # If you were in a slave VM VC if not before.channel.members: # Nobody else in VC await before.channel.delete() - db_gateway().delete( - 'voicemaster_slave', - where_params={ - 'guild_id': member.guild.id, - 'channel_id': before_channel_id - } - ) + DBGatewayActions().delete(vm_slave) await send_to_log_channel(member.guild.id, f"{member.mention} has deleted a VM slave") else: # Still others in VC await before.channel.edit(name=f"{before.channel.members[0].display_name}'s VC") - db_gateway().update( - 'voicemaster_slave', - set_params={'owner_id': before.channel.members[0].id}, - where_params={ - 'guild_id': member.guild.id, - 'channel_id': before_channel_id - } - ) + vm_slave.owner_id = before.channel.members[0].id + DBGatewayActions().update(vm_slave) elif after_channel_id and get_whether_in_vm_master(member.guild.id, after_channel_id): # Moved into a master VM VC slave_channel_name = f"{member.display_name}'s VC" new_slave_channel = await member.guild.create_voice_channel(slave_channel_name, category=after.channel.category) - db_gateway().insert( - 'voicemaster_slave', - params={ - 'guild_id': member.guild.id, - 'channel_id': new_slave_channel.id, - 'owner_id': member.id, - 'locked': False, - } + DBGatewayActions().create( + Voicemaster_slave(guild_id=member.guild.id, + channel_id=new_slave_channel.id, + owner_id=member.id, + locked=False) ) await member.move_to(new_slave_channel) await send_to_log_channel(member.guild.id, f"{member.mention} has created a VM slave") @@ -262,7 +242,7 @@ async def on_message(message): # Handle music channel messages guild_id = message.guild.id music_channel_in_db = client.MUSIC_CHANNELS.get(guild_id) - if music_channel_in_db == message.channel.id: + if music_channel_in_db: # The message was in a music channel and a song should be found music_cog_instance = client.cogs.get('MusicCog') await music_cog_instance.on_message_handle(message) @@ -284,11 +264,19 @@ async def on_message(message): @client.command() @commands.has_permissions(administrator=True) async def initialsetup(ctx): - already_in_db = db_gateway().get('guild_info', params={'guild_id': ctx.author.guild.id}) + already_in_db = DBGatewayActions().get(Guild_info, guild_id=ctx.author.guild.id) if already_in_db: await ctx.channel.send("This server is already set up") else: - db_gateway().insert('guild_info', make_guild_init_data(ctx.guild)) + DBGatewayActions().create( + Guild_info( + guild_id=ctx.author.guild.id, + num_running_polls=0, + role_ping_cooldown_seconds=int(DEFAULT_ROLE_PING_COOLDOWN.total_seconds()), + pingme_create_threshold=DEFAULT_PINGME_CREATE_THRESHOLD, + pingme_create_poll_length_seconds=int(DEFAULT_PINGME_CREATE_POLL_LENGTH.total_seconds()) + ) + ) await ctx.channel.send("This server has now been initialised") @@ -298,9 +286,9 @@ async def on_guild_role_delete(role: discord.Role): :param Role role: The role which was removed """ - db = db_gateway() - if db.get("pingable_roles", {"role_id": role.id}): - db.delete("pingable_roles", {"role_id": role.id}) + pingable_role = DBGatewayActions().get(Pingable_roles, role_id=role.id) + if pingable_role: + DBGatewayActions().delete(pingable_role) logEmbed = discord.Embed() logEmbed.set_author(icon_url=client.user.avatar_url_as(size=64), name="Admin Log") logEmbed.set_footer(text=datetime.now().strftime("%m/%d/%Y, %H:%M:%S")) @@ -311,11 +299,9 @@ async def on_guild_role_delete(role: discord.Role): def launch(): - load_dotenv() + TOKEN = os.getenv('DISCORD_TOKEN') - # Generate Database Schema - generate_schema() client.update_music_channels() client.load_extension('esportsbot.cogs.VoicemasterCog') diff --git a/src/esportsbot/cogs/AdminCog.py b/src/esportsbot/cogs/AdminCog.py index 9238f442..67d1bdfd 100644 --- a/src/esportsbot/cogs/AdminCog.py +++ b/src/esportsbot/cogs/AdminCog.py @@ -1,7 +1,6 @@ import toml from discord.ext import commands -from ..db_gateway import db_gateway -from ..base_functions import send_to_log_channel +from esportsbot.base_functions import send_to_log_channel class AdminCog(commands.Cog): diff --git a/src/esportsbot/cogs/DefaultRoleCog.py b/src/esportsbot/cogs/DefaultRoleCog.py index 86e470c8..f73073a9 100644 --- a/src/esportsbot/cogs/DefaultRoleCog.py +++ b/src/esportsbot/cogs/DefaultRoleCog.py @@ -1,8 +1,8 @@ import toml from discord.ext import commands -from ..db_gateway import db_gateway -from ..base_functions import role_id_from_mention -from ..base_functions import send_to_log_channel +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Guild_info +from esportsbot.base_functions import role_id_from_mention, send_to_log_channel class DefaultRoleCog(commands.Cog): @@ -15,11 +15,9 @@ def __init__(self, bot): async def setdefaultrole(self, ctx, given_role_id=None): cleaned_role_id = role_id_from_mention(given_role_id) if given_role_id else False if cleaned_role_id: - db_gateway().update( - 'guild_info', - set_params={'default_role_id': cleaned_role_id}, - where_params={'guild_id': ctx.author.guild.id} - ) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.author.guild.id) + guild.default_role_id = cleaned_role_id + DBGatewayActions().update(guild) await ctx.channel.send(self.STRINGS['default_role_set'].format(role_id=cleaned_role_id)) default_role = ctx.author.guild.get_role(cleaned_role_id) await send_to_log_channel( @@ -34,24 +32,23 @@ async def setdefaultrole(self, ctx, given_role_id=None): @commands.command() @commands.has_permissions(administrator=True) async def getdefaultrole(self, ctx): - default_role_exists = db_gateway().get('guild_info', params={'guild_id': ctx.author.guild.id}) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.author.guild.id) + default_role_exists = guild.default_role_id is not None - if default_role_exists[0]['default_role_id']: - await ctx.channel.send(self.STRINGS['default_role_get'].format(role_id=default_role_exists[0]['default_role_id'])) + if default_role_exists: + await ctx.channel.send(self.STRINGS['default_role_get'].format(role_id=guild.default_role_id)) else: await ctx.channel.send(self.STRINGS['default_role_missing']) @commands.command() @commands.has_permissions(administrator=True) async def removedefaultrole(self, ctx): - default_role_exists = db_gateway().get('guild_info', params={'guild_id': ctx.author.guild.id}) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.author.guild.id) + default_role_exists = guild.default_role_id is not None - if default_role_exists[0]['default_role_id']: - db_gateway().update( - 'guild_info', - set_params={'default_role_id': 'NULL'}, - where_params={'guild_id': ctx.author.guild.id} - ) + if default_role_exists: + guild.default_role_id = None + DBGatewayActions().update(guild) await ctx.channel.send(self.STRINGS['default_role_removed']) await send_to_log_channel( self, diff --git a/src/esportsbot/cogs/EventCategoriesCog.py b/src/esportsbot/cogs/EventCategoriesCog.py index 7e04efa0..913f660c 100644 --- a/src/esportsbot/cogs/EventCategoriesCog.py +++ b/src/esportsbot/cogs/EventCategoriesCog.py @@ -3,11 +3,12 @@ from discord.ext import commands from discord.ext.commands.context import Context from discord import PartialMessage, Forbidden, PermissionOverwrite, RawReactionActionEvent, Colour, Embed -from ..db_gateway import db_gateway +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Guild_info, Event_categories import asyncio -from .. import lib -from ..lib.client import EsportsBot, StringTable -from ..reactionMenus.reactionRoleMenu import ReactionRoleMenu, ReactionRoleMenuOption +from esportsbot import lib +from esportsbot.lib.client import EsportsBot, StringTable +from esportsbot.reactionMenus.reactionRoleMenu import ReactionRoleMenu, ReactionRoleMenuOption # Permissions overrides assigned to the shared role in closed event signin channels CLOSED_EVENT_SIGNIN_CHANNEL_SHARED_PERMS = PermissionOverwrite( @@ -74,28 +75,26 @@ async def getGuildEventSettings(self, ctx: Context, eventName: str) -> Tuple[dic :return: A tuple with the guild and event db entries if the guild has a shared role and an event named eventName, () otherwise :rtype: Tuple[dict, dict] if the guild has a shared role and an event named eventName, Tuple[] otherwise """ - db = db_gateway() - guildData = db.get("guild_info", params={"guild_id": ctx.guild.id})[0] - if not guildData["shared_role_id"]: + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + no_shared_role = guild.shared_role_id is None + if no_shared_role: await ctx.message.reply(self.STRINGS['no_shared_role'].format(command_prefix=self.bot.command_prefix)) else: - eventData = db.get("event_categories", params={"guild_id": ctx.guild.id, "event_name": eventName})[0] + eventData = DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName) if not eventData: - if not (allEvents := db.get("event_categories", params={"guild_id": ctx.guild.id})): + if not (allEvents := DBGatewayActions().list(Event_categories, guild_id=ctx.guild.id)): await ctx.message.reply(self.STRINGS['no_event_categories']) else: await ctx.message.reply( - self.STRINGS['unrecognised_event'].format( - events=", ".join(e["event_name"].title() for e in allEvents) - ) + self.STRINGS['unrecognised_event'].format(events=", ".join(e.event_name.title() for e in allEvents)) ) else: - return (guildData, eventData) + return (guild, eventData) return () @commands.command( name="open-event", - usage="", + usage="open-event ", help="Reveal the signin channel for the named event channel." ) @commands.has_permissions(administrator=True) @@ -110,12 +109,12 @@ async def admin_cmd_open_event(self, ctx: Context, *, args): elif allData := await self.getGuildEventSettings(ctx, args.lower()): guildData, eventData = allData eventName = args.lower() - signinMenu = self.bot.reactionMenus[eventData["signin_menu_id"]] + signinMenu = self.bot.reactionMenus[eventData.signin_menu_id] eventChannel = signinMenu.msg.channel if not eventChannel.permissions_for(ctx.guild.me).manage_permissions: await ctx.send(self.STRINGS['no_channel_edit_perms'].format(channel_id=eventChannel.id)) else: - sharedRole = ctx.guild.get_role(guildData["shared_role_id"]) + sharedRole = ctx.guild.get_role(guildData.shared_role_id) if not eventChannel.overwrites_for(sharedRole).read_messages: reason = self.STRINGS['event_channel_open_reason'].format( author=ctx.author.name, @@ -137,7 +136,7 @@ async def admin_cmd_open_event(self, ctx: Context, *, args): @commands.command( name="close-event", - usage="", + usage="close-event ", help="Hide the signin channel for the named event, reset the signin menu, and remove the event's role from users." ) @commands.has_permissions(administrator=True) @@ -152,7 +151,7 @@ async def admin_cmd_close_event(self, ctx: Context, *, args): await ctx.message.reply(self.STRINGS['request_event_name']) elif allData := await self.getGuildEventSettings(ctx, args.lower()): guildData, eventData = allData - signinMenu = self.bot.reactionMenus[eventData["signin_menu_id"]] + signinMenu = self.bot.reactionMenus[eventData.signin_menu_id] eventChannel = signinMenu.msg.channel myPerms = eventChannel.permissions_for(ctx.guild.me) if not myPerms.manage_permissions: @@ -160,7 +159,7 @@ async def admin_cmd_close_event(self, ctx: Context, *, args): elif not myPerms.manage_roles: await ctx.send(self.STRINGS['no_role_edit_perms']) else: - eventRole = ctx.guild.get_role(eventData["role_id"]) + eventRole = ctx.guild.get_role(eventData.role_id) if eventRole.position >= ctx.guild.self_role.position: await ctx.send( self.STRINGS['role_edit_perms_bad_order'].format( @@ -170,7 +169,7 @@ async def admin_cmd_close_event(self, ctx: Context, *, args): ) else: eventName = args.lower() - sharedRole = ctx.guild.get_role(guildData["shared_role_id"]) + sharedRole = ctx.guild.get_role(guildData.shared_role_id) channelEdited = eventChannel.overwrites_for(sharedRole).read_messages usersEdited = len(eventRole.members) # signinMenu.msg = await signinMenu.msg.channel.fetch_message(signinMenu.msg.id) @@ -241,7 +240,7 @@ async def admin_cmd_close_event(self, ctx: Context, *, args): @commands.command( name="set-event-signin-menu", - usage=" ", + usage="set-event-signin-menu ", help="Change the event signin menu to use with `open-event` and `close-event`." ) @commands.has_permissions(administrator=True) @@ -262,18 +261,17 @@ async def admin_cmd_set_event_signin_menu(self, ctx: Context, *, args: str): await ctx.send(self.STRINGS['unrecognised_menu_id'].format(menu_id=menuID)) else: eventName = args[len(menuID) + 1:].lower() - db = db_gateway() - if not (eventData := db.get("event_categories", {"guild_id": ctx.guild.id, "event_name": eventName})): - if not (allEvents := db.get("event_categoriesevent_channels", params={"guild_id": ctx.guild.id})): + if not (eventData := DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName)): + if not (allEvents := DBGatewayActions().list(Event_categories, guild_id=ctx.guild.id)): await ctx.message.reply(self.STRINGS['no_event_categories']) else: await ctx.message.reply( self.STRINGS['unrecognised_event'].format( - events=", ".join(e["event_name"].title() for e in allEvents) + events=", ".join(e.event_name.title() for e in allEvents) ) ) else: - eventRole = ctx.guild.get_role(eventData["role_id"]) + eventRole = ctx.guild.get_role(eventData.role_id) menu = self.bot.reactionMenus[int(menuID)] if not isinstance(menu, ReactionRoleMenu): await ctx.message.reply( @@ -287,14 +285,13 @@ async def admin_cmd_set_event_signin_menu(self, ctx: Context, *, args: str): self.STRINGS['invalid_signin_menu'].format(role_name=eventRole.name if eventRole else 'event') ) else: - db.update( - 'event_categories', - set_params={"signin_menu_id": menu.msg.id}, - where_params={ - "guild_id": ctx.guild.id, - "event_name": eventName - } + event_category = DBGatewayActions().get( + Event_categories, + guild_id=ctx.guild.id, + event_name=eventName ) + event_category.signin_menu_id = menu.msg.id + DBGatewayActions().update(event_category) await ctx.send( self.STRINGS['success_menu'].format(event_name=eventName.title, menu_url=menu.msg.jump_url) @@ -309,7 +306,7 @@ async def admin_cmd_set_event_signin_menu(self, ctx: Context, *, args: str): @commands.command( name="set-shared-role", - usage="", + usage="set-shared-role ", help= "Change the role to admit/deny into *all* event signin menus. This should NOT be the same as any event role. Role can be given as either a mention or an ID." ) @@ -332,11 +329,9 @@ async def admin_cmd_set_shared_role(self, ctx: Context, *, args: str): if role is None: await ctx.send(self.STRINGS['unrecognised_role']) else: - db_gateway().update( - 'guild_info', - set_params={"shared_role_id": roleID}, - where_params={"guild_id": ctx.guild.id} - ) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + guild.shared_role_id = roleID + DBGatewayActions().update(guild) await ctx.send(self.STRINGS['success_shared_role'].format(role_name=role.name)) await self.bot.adminLog( ctx.message, @@ -348,7 +343,7 @@ async def admin_cmd_set_shared_role(self, ctx: Context, *, args: str): @commands.command( name="set-event-role", - usage="", + usage="set-event-role ", help= "Change the role to remove during `close-event`. This should NOT be the same as your shared role. Role can be given as either a mention or an ID." ) @@ -372,25 +367,19 @@ async def admin_cmd_set_event_role(self, ctx: Context, *, args: str): await ctx.send(self.STRINGS['unrecognised_role']) else: eventName = args[len(roleStr) + 1:].lower() - db = db_gateway() - if not db.get("event_categories", {"guild_id": ctx.guild.id, "event_name": eventName}): - if not (allEvents := db.get("event_categories", params={"guild_id": ctx.guild.id})): + if not DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName): + if not (allEvents := DBGatewayActions().list(Event_categories, guild_id=ctx.guild.id)): await ctx.message.reply(self.STRINGS['no_event_categories']) else: await ctx.message.reply( self.STRINGS['unrecognised_event'].format( - events=", ".join(e["event_name"].title() for e in allEvents) + events=", ".join(e.event_name.title() for e in allEvents) ) ) else: - db.update( - 'event_categories', - set_params={"role_id": roleID}, - where_params={ - "guild_id": ctx.guild.id, - "event_name": eventName - } - ) + event_category = DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName) + event_category.role_id = roleID + DBGatewayActions().update(event_category) await ctx.send( self.STRINGS['success_event_role'].format(event_name=eventName.title(), role_name=role.name) @@ -405,7 +394,7 @@ async def admin_cmd_set_event_role(self, ctx: Context, *, args: str): @commands.command( name="register-event-category", - usage=" ", + usage="register-event-category ", help= "Register an existing event category, menu, and role, for use with `open-event` and `close-event`. This does not setup permissions for the category or channels." ) @@ -437,19 +426,17 @@ async def admin_cmd_register_event_category(self, ctx: Context, *, args: str): await ctx.send(self.STRINGS['unrecognised_role']) else: eventName = args[len(roleStr) + len(menuIDStr) + 2:].lower() - db = db_gateway() - if db.get("event_categories", {"guild_id": ctx.guild.id, "event_name": eventName}): + if DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName): await ctx.message.reply(self.STRINGS['event_exists'].format(event_name=eventName.title())) else: menu = self.bot.reactionMenus[int(menuIDStr)] - db.insert( - 'event_categories', - { - "guild_id": ctx.guild.id, - "event_name": eventName, - "role_id": roleID, - "signin_menu_id": menu.msg.id - } + DBGatewayActions().create( + Event_categories( + guild_id=ctx.guild.id, + event_name=eventName, + role_id=roleID, + signin_menu_id=menu.msg.id + ) ) await ctx.send(self.STRINGS['success_event_category'].format(event_name=eventName.title())) admin_message = self.STRINGS['admin_existing_event_registered'][1].format( @@ -465,7 +452,7 @@ async def admin_cmd_register_event_category(self, ctx: Context, *, args: str): @commands.command( name="create-event-category", - usage="", + usage="create-event-category ", help= "Create a new event category with a signin channel and menu, event role, general channel and correct permissions, and automatically register them for use with `open-event` and `close-event`." ) @@ -479,16 +466,15 @@ async def admin_cmd_create_event_category(self, ctx: Context, *, args: str): if not args: await ctx.send(self.STRINGS['request_event_name']) else: - db = db_gateway() eventName = args.lower() - if db.get("event_categories", {"guild_id": ctx.guild.id, "event_name": eventName}): + if DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName): await ctx.message.reply(self.STRINGS['event_exists'].format(event_name=eventName.title())) else: - guildData = db.get("guild_info", {"guild_id": ctx.guild.id})[0] - if not guildData["shared_role_id"]: + guildData = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + if not guildData.shared_role_id: await ctx.message.reply(self.STRINGS['no_shared_role'].format(command_prefix=self.bot.command_prefix)) else: - if not (sharedRole := ctx.guild.get_role(guildData["shared_role_id"])): + if not (sharedRole := ctx.guild.get_role(guildData.shared_role_id)): await ctx.message.reply( self.STRINGS['missing_shared_role'].format(command_prefix=self.bot.command_prefix) ) @@ -551,7 +537,7 @@ def emojiSelectorCheck(data: RawReactionActionEvent) -> bool: category=newCategory, overwrites=categoryOverwrites ) - signinMenuMsg = await signinChannel.send("​") + signinMenuMsg = await signinChannel.send(embed=Embed()) signinMenu = ReactionRoleMenu( signinMenuMsg, self.bot, @@ -562,14 +548,13 @@ def emojiSelectorCheck(data: RawReactionActionEvent) -> bool: ) await signinMenu.updateMessage() self.bot.reactionMenus.add(signinMenu) - db.insert( - 'event_categories', - { - "guild_id": ctx.guild.id, - "event_name": eventName, - "role_id": eventRole.id, - "signin_menu_id": signinMenuMsg.id - } + DBGatewayActions().create( + Event_categories( + guild_id=ctx.guild.id, + event_name=eventName, + role_id=eventRole.id, + signin_menu_id=signinMenuMsg.id + ) ) await ctx.send( self.STRINGS['success_event'].format( @@ -579,8 +564,7 @@ def emojiSelectorCheck(data: RawReactionActionEvent) -> bool: shared_role_name=sharedRole.name, command_prefix=self.bot.command_prefix, event_name=eventName, - event_general_mention=eventGeneral.mention, - event_role_mention=eventRole.mention + event_general_mention=eventGeneral.mention ) ) admin_message = self.STRINGS['admin_event_category_updated'][1].format( @@ -596,7 +580,7 @@ def emojiSelectorCheck(data: RawReactionActionEvent) -> bool: @commands.command( name="unregister-event-category", - usage="", + usage="unregister-event-category ", help= "Unregister an event category and role so that it can no longer be used with `open-event` and `close-event`, but without deleting the channels." ) @@ -611,26 +595,24 @@ async def admin_cmd_unregister_event_category(self, ctx: Context, *, args: str): if not args: await ctx.send(self.STRINGS['request_event_name']) else: - db = db_gateway() eventName = args.lower() - if not db.get("event_categories", {"guild_id": ctx.guild.id, "event_name": eventName}): - if not (allEvents := db.get("event_categories", params={"guild_id": ctx.guild.id})): + event_category = DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName) + if not event_category: + if not (allEvents := DBGatewayActions().list(Event_categories, guild_id=ctx.guild.id)): await ctx.message.reply(self.STRINGS['no_event_categories']) else: await ctx.message.reply( - self.STRINGS['unrecognised_event'].format( - events=", ".join(e["event_name"].title() for e in allEvents) - ) + self.STRINGS['unrecognised_event'].format(events=", ".join(e.event_name.title() for e in allEvents)) ) else: - db.delete("event_categories", {"guild_id": ctx.guild.id, "event_name": eventName}) + DBGatewayActions().delete(event_category) await ctx.message.reply(self.STRINGS['success_event_role_unregister'].format(event_title=eventName.title())) admin_message = self.STRINGS['admin_event_category_unregistered'].format(event_title=eventName.title()) await self.bot.adminLog(ctx.message, {self.STRINGS['admin_event_category_unregistered'][0]: admin_message}) @commands.command( name="delete-event-category", - usage="", + usage="delete-event-category ", help="Delete an event category and its role and channels from the server." ) @commands.has_permissions(administrator=True) @@ -643,22 +625,19 @@ async def admin_cmd_delete_event_category(self, ctx: Context, *, args: str): if not args: await ctx.send(self.STRINGS['request_event_name']) else: - db = db_gateway() eventName = args.lower() - if not (eventData := db.get("event_categories", {"guild_id": ctx.guild.id, "event_name": eventName})): - if not (allEvents := db.get("event_categories", params={"guild_id": ctx.guild.id})): + if not (eventData := DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName)): + if not (allEvents := DBGatewayActions().list(Event_categories, guild_id=ctx.guild.id)): await ctx.message.reply(self.STRINGS['no_event_categories']) else: await ctx.message.reply( - self.STRINGS['unrecognised_event'].format( - events=", ".join(e["event_name"].title() for e in allEvents) - ) + self.STRINGS['unrecognised_event'].format(events=", ".join(e.event_name.title() for e in allEvents)) ) else: - signinMenuID = eventData[0]["signin_menu_id"] + signinMenuID = eventData.signin_menu_id eventCategory = self.bot.reactionMenus[signinMenuID].msg.channel.category numChannels = len(eventCategory.channels) - eventRole = ctx.guild.get_role(eventData[0]["role_id"]) + eventRole = ctx.guild.get_role(eventData.role_id) confirmMsg = await ctx.message.reply( self.STRINGS['react_delete_confirm'].format( event_title=eventName.title(), @@ -695,12 +674,13 @@ def confirmCheck(data: RawReactionActionEvent) -> bool: deletionTasks.add(asyncio.create_task(currentCategory.delete(reason=deletionReason))) deletionTasks.add(asyncio.create_task(eventCategory.delete(reason=deletionReason))) await asyncio.wait(deletionTasks) - db.delete("event_categories", {"guild_id": ctx.guild.id, "event_name": eventName}) + event_category = DBGatewayActions().get(Event_categories, guild_id=ctx.guild.id, event_name=eventName) + DBGatewayActions().delete(event_category) await ctx.message.reply(self.STRINGS['success_event_deleted'].format(event_title=eventName.title())) admin_message = self.STRINGS['admin_event_category_deleted'][1].format( event_title=eventName.title(), num_channels=numChannels - ) + (f"\nRole deleted: #{eventData[0]['role_id']!s}" if eventData[0]['role_id'] else "") + ) + (f"\nRole deleted: #{eventData.role_id}" if eventData.role_id else "") await self.bot.adminLog(ctx.message, {self.STRINGS['admin_event_category_deleted'][0]: admin_message}) diff --git a/src/esportsbot/cogs/LogChannelCog.py b/src/esportsbot/cogs/LogChannelCog.py index 3b874fc3..492a9d92 100644 --- a/src/esportsbot/cogs/LogChannelCog.py +++ b/src/esportsbot/cogs/LogChannelCog.py @@ -1,8 +1,8 @@ import toml from discord.ext import commands -from ..db_gateway import db_gateway -from ..base_functions import channel_id_from_mention -from ..base_functions import send_to_log_channel +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Guild_info +from esportsbot.base_functions import channel_id_from_mention, send_to_log_channel class LogChannelCog(commands.Cog): @@ -14,14 +14,12 @@ def __init__(self, bot): @commands.has_permissions(administrator=True) async def setlogchannel(self, ctx, given_channel_id=None): cleaned_channel_id = channel_id_from_mention(given_channel_id) if given_channel_id else ctx.channel.id - log_channel_exists = db_gateway().get('guild_info', params={'guild_id': ctx.author.guild.id}) - if bool(log_channel_exists): - if log_channel_exists[0]['log_channel_id'] != cleaned_channel_id: - db_gateway().update( - 'guild_info', - set_params={'log_channel_id': cleaned_channel_id}, - where_params={'guild_id': ctx.author.guild.id} - ) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.author.guild.id) + log_channel_exists = guild.log_channel_id is not None + if log_channel_exists: + if guild.log_channel_id != cleaned_channel_id: + guild.log_channel_id = cleaned_channel_id + DBGatewayActions().update(guild) await ctx.channel.send(self.STRINGS["channel_set"].format(channel_id=cleaned_channel_id)) await send_to_log_channel( self, @@ -31,7 +29,8 @@ async def setlogchannel(self, ctx, given_channel_id=None): else: await ctx.channel.send(self.STRINGS["channel_set_already"]) else: - db_gateway().insert('guild_info', params={'guild_id': ctx.author.guild.id, 'log_channel_id': cleaned_channel_id}) + guild.log_channel_id = cleaned_channel_id + DBGatewayActions().update(guild) await ctx.channel.send(self.STRINGS["channel_set"].format(channel_id=cleaned_channel_id)) await send_to_log_channel( self, @@ -42,24 +41,23 @@ async def setlogchannel(self, ctx, given_channel_id=None): @commands.command() @commands.has_permissions(administrator=True) async def getlogchannel(self, ctx): - log_channel_exists = db_gateway().get('guild_info', params={'guild_id': ctx.author.guild.id}) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.author.guild.id) + log_channel_exists = guild.log_channel_id is not None - if (channel_id := log_channel_exists[0]['log_channel_id']) is not None: - await ctx.channel.send(self.STRINGS["channel_get"].format(channel_id=channel_id)) + if log_channel_exists: + await ctx.channel.send(self.STRINGS["channel_get"].format(channel_id=guild.log_channel_id)) else: await ctx.channel.send(self.STRINGS["channel_get_notfound"]) @commands.command() @commands.has_permissions(administrator=True) async def removelogchannel(self, ctx): - log_channel_exists = db_gateway().get('guild_info', params={'guild_id': ctx.author.guild.id}) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.author.guild.id) + log_channel_exists = guild.log_channel_id is not None - if log_channel_exists[0]['log_channel_id']: - db_gateway().update( - 'guild_info', - set_params={'log_channel_id': 'NULL'}, - where_params={'guild_id': ctx.author.guild.id} - ) + if log_channel_exists: + guild.log_channel_id = None + DBGatewayActions().update(guild) await ctx.channel.send(self.STRINGS["channel_removed"]) else: await ctx.channel.send(self.STRINGS["channel_get_notfound"]) diff --git a/src/esportsbot/cogs/MenusCog.py b/src/esportsbot/cogs/MenusCog.py index 4fdde171..5c9f5a08 100644 --- a/src/esportsbot/cogs/MenusCog.py +++ b/src/esportsbot/cogs/MenusCog.py @@ -1,10 +1,11 @@ from discord.ext import commands from discord.ext.commands.context import Context from discord import Embed -from ..db_gateway import db_gateway -from .. import lib -from ..lib.client import EsportsBot -from ..reactionMenus import reactionRoleMenu, reactionPollMenu +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Guild_info +from esportsbot import lib +from esportsbot.lib.client import EsportsBot +from esportsbot.reactionMenus import reactionRoleMenu, reactionPollMenu from datetime import timedelta # Maximum number of polls which can be running at once in a given guild, for performance @@ -350,7 +351,8 @@ async def cmd_poll(self, ctx: Context, *, args: str): :param Context ctx: A context summarising the message which called this command :param str args: a string containing the poll configuration as defined in this method's docstring """ - currentPollsNum = db_gateway().get('guild_info', params={'guild_id': ctx.author.guild.id})[0]['num_running_polls'] - 1 + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.author.guild.id) + currentPollsNum = (guild.num_running_polls) - 1 if currentPollsNum >= MAX_POLLS_PER_GUILD: await ctx.message.reply("This server already has " + str(currentPollsNum) \ + " polls running! Please wait for one to finish before starting another.") @@ -447,14 +449,16 @@ async def cmd_poll(self, ctx: Context, *, args: str): + lib.timeUtil.td_format_noYM(timeoutTD) + ".") # Update guild polls counter - runningPolls = db_gateway().get("guild_info", {"guild_id": ctx.guild.id})[0]["num_running_polls"] - db_gateway().update("guild_info", {"num_running_polls": runningPolls + 1}, {"guild_id": ctx.guild.id}) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + guild.num_running_polls += 1 + DBGatewayActions().update(guild) await menu.doMenu() # Allow the creation of another poll - runningPolls = db_gateway().get("guild_info", {"guild_id": ctx.guild.id})[0]["num_running_polls"] - db_gateway().update("guild_info", {"num_running_polls": runningPolls - 1}, {"guild_id": ctx.guild.id}) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + guild.num_running_polls -= 1 + DBGatewayActions().update(guild) await reactionPollMenu.showPollResults(menu) diff --git a/src/esportsbot/cogs/MusicCog.py b/src/esportsbot/cogs/MusicCog.py index c27d2092..2069033c 100644 --- a/src/esportsbot/cogs/MusicCog.py +++ b/src/esportsbot/cogs/MusicCog.py @@ -14,17 +14,18 @@ from discord.ext import commands, tasks from discord.ext.commands import Context -from ..base_functions import channel_id_from_mention -from ..db_gateway import db_gateway -from ..lib.client import EsportsBot +from esportsbot.base_functions import channel_id_from_mention +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Music_channels +from esportsbot.lib.client import EsportsBot import googleapiclient.discovery from urllib.parse import parse_qs, urlparse from random import shuffle -from ..lib.discordUtil import send_timed_message -from ..lib.stringTyping import strIsInt +from esportsbot.lib.discordUtil import send_timed_message +from esportsbot.lib.stringTyping import strIsInt class EmbedColours: @@ -80,7 +81,7 @@ def __init__(self, bot: EsportsBot, max_search_results=100): self.__check_loops_alive() - self.__db_accessor = db_gateway() + self.__db_accessor = DBGatewayActions() self.user_strings: dict = bot.STRINGS["music"] @@ -132,24 +133,16 @@ async def setmusicchannel(self, ctx: Context, args: str = None, given_channel_id await send_timed_message(ctx.channel, embed=message, timer=30) return False - current_channel_for_guild = self.__db_accessor.get('music_channels', params={'guild_id': ctx.guild.id}) + current_channel_for_guild = self.__db_accessor.get(Music_channels, guild_id=ctx.guild.id) - if len(current_channel_for_guild) > 0: + if current_channel_for_guild: # There is already a channel set.. update - self.__db_accessor.update( - 'music_channels', - set_params={'channel_id': cleaned_channel_id}, - where_params={'guild_id': ctx.guild.id} - ) + music_channel = self.__db_accessor.get(Music_channels, guild_id=ctx.guild.id) + music_channel.channel_id = cleaned_channel_id + self.__db_accessor.update(music_channel) else: # No channel for guild.. insert - self.__db_accessor.insert( - 'music_channels', - params={ - 'channel_id': int(cleaned_channel_id), - 'guild_id': int(ctx.guild.id) - } - ) + self.__db_accessor.create(Music_channels(guild_id=ctx.guild.id, channel_id=cleaned_channel_id)) await self.__setup_channel(ctx, int(cleaned_channel_id), args) self._bot.update_music_channels() @@ -166,11 +159,11 @@ async def getmusicchannel(self, ctx: Context) -> Message: :rtype: discord.Message """ - current_channel_for_guild = self.__db_accessor.get('music_channels', params={'guild_id': ctx.guild.id}) + current_channel_for_guild = self.__db_accessor.get(Music_channels, guild_id=ctx.guild.id) - if current_channel_for_guild and current_channel_for_guild[0].get('channel_id'): + if current_channel_for_guild and current_channel_for_guild.channel_id: # If the music channel has been set in the guild - id_as_channel = ctx.guild.get_channel(current_channel_for_guild[0].get('channel_id')) + id_as_channel = ctx.guild.get_channel(current_channel_for_guild.channel_id) message = self.user_strings["music_channel_get"].format(music_channel=id_as_channel.mention) return await ctx.channel.send(message) else: @@ -188,15 +181,15 @@ async def resetmusicchannel(self, ctx: Context) -> Message: :rtype: discord.Message """ - current_channel_for_guild = self.__db_accessor.get('music_channels', params={'guild_id': ctx.guild.id}) + current_channel_for_guild = self.__db_accessor.get(Music_channels, guild_id=ctx.guild.id) - if current_channel_for_guild and current_channel_for_guild[0].get('channel_id'): + if current_channel_for_guild and current_channel_for_guild.channel_id: # If the music channel has been set for the guild - channel_id = current_channel_for_guild[0].get('channel_id') + channel_id = current_channel_for_guild.channel_id await self.__setup_channel(ctx, arg='-c', channel_id=channel_id) - channel = self._bot.get_channel(current_channel_for_guild[0].get('channel_id')) + channel = self._bot.get_channel(channel_id) if channel is None: - channel = self._bot.fetch_channel(current_channel_for_guild[0].get('channel_id')) + channel = self._bot.fetch_channel(channel_id) message = self.user_strings["music_channel_reset"].format(music_channel=channel.mention) return await ctx.channel.send(message) else: @@ -502,8 +495,8 @@ async def listqueue(self, ctx: Context) -> str: return "" # We don't want the song channel to be filled with the queue as it already shows it - music_channel_in_db = self.__db_accessor.get('music_channels', params={'guild_id': ctx.guild.id}) - if ctx.message.channel.id == music_channel_in_db[0].get('channel_id'): + music_channel_in_db = self.__db_accessor.get(Music_channels, guild_id=ctx.guild.id) + if ctx.message.channel.id == music_channel_in_db.channel_id: # Message is in the songs channel message_title = self.user_strings["music_channel_wrong_channel"].format(command_option="cannot") await send_timed_message( @@ -679,17 +672,10 @@ async def __setup_channel(self, ctx: Context, channel_id: int, arg: str): default_queue_message = await channel_instance.send(EMPTY_QUEUE_MESSAGE) default_preview_message = await channel_instance.send(embed=temp_default_preview) - self.__db_accessor.update( - 'music_channels', - set_params={'queue_message_id': int(default_queue_message.id)}, - where_params={'guild_id': ctx.author.guild.id} - ) - - self.__db_accessor.update( - 'music_channels', - set_params={'preview_message_id': int(default_preview_message.id)}, - where_params={'guild_id': ctx.author.guild.id} - ) + music_channel = self.__db_accessor.get(Music_channels, guild_id=ctx.author.guild.id) + music_channel.queue_message_id = default_queue_message.id + music_channel.preview_message_id = default_preview_message.id + self.__db_accessor.update(music_channel) async def __remove_active_channel(self, guild_id: int) -> bool: """ @@ -964,17 +950,17 @@ async def __update_channel_messages(self, guild_id: int): :rtype: NoneType """ - guild_db_data = self.__db_accessor.get('music_channels', params={'guild_id': guild_id})[0] + guild_db_data = self.__db_accessor.get(Music_channels, guild_id=guild_id) # Get the ids of the queue and preview messages - queue_message_id = guild_db_data.get('queue_message_id') - preview_message_id = guild_db_data.get('preview_message_id') + queue_message_id = guild_db_data.queue_message_id + preview_message_id = guild_db_data.preview_message_id # Create the updated messages queue_message = self.__make_updated_queue_message(guild_id) preview_message = self.__make_update_preview_message(guild_id) - music_channel_id = guild_db_data.get('channel_id') + music_channel_id = guild_db_data.channel_id # Get the music channel id as a discord.TextChannel object music_channel_instance = self._bot.get_channel(music_channel_id) if music_channel_instance is None: @@ -1124,8 +1110,8 @@ async def __check_valid_user_vc(self, ctx: Context) -> bool: :rtype: bool """ - music_channel_in_db = self.__db_accessor.get('music_channels', params={'guild_id': ctx.guild.id}) - if ctx.message.channel.id != music_channel_in_db[0].get('channel_id'): + music_channel_in_db = self.__db_accessor.get(Music_channels, guild_id=ctx.guild.id) + if ctx.message.channel.id != music_channel_in_db.channel_id: # Message is not in the songs channel message_title = self.user_strings["music_channel_wrong_channel"].format(command_option="can only") await send_timed_message( diff --git a/src/esportsbot/cogs/PingablesCog.py b/src/esportsbot/cogs/PingablesCog.py index d0ddd450..42aa4d8c 100644 --- a/src/esportsbot/cogs/PingablesCog.py +++ b/src/esportsbot/cogs/PingablesCog.py @@ -2,11 +2,12 @@ import asyncio from discord.ext import commands from discord.ext.commands.context import Context -from ..db_gateway import db_gateway -from ..lib.client import EsportsBot -from .. import lib +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Pingable_roles, Guild_info +from esportsbot.lib.client import EsportsBot +from esportsbot import lib from datetime import timedelta -from ..reactionMenus import reactionPollMenu +from esportsbot.reactionMenus import reactionPollMenu # The default role colour for pingable roles DEFAULT_PINGABLE_COLOUR = 0x15e012 # green @@ -54,7 +55,11 @@ def __init__(self, bot: "EsportsBot"): """ self.bot: "EsportsBot" = bot - @commands.group(name="pingme", help="Get and create custom, cooldown-limited, pingable roles.", invoke_without_command=True) + @commands.group( + name="pingme", + help="Get and create custom, cooldown-limited, pingable roles.", + invoke_without_command=True + ) async def pingme(self, ctx: Context): """Non-functional command, for heirarchical command grouping. @@ -63,11 +68,7 @@ async def pingme(self, ctx: Context): """ pass - @pingme.command( - name="register", - usage="<@role> ", - help="Convert an existing role into a !pingme role" - ) + @pingme.command(name="register", usage="<@role> ", help="Convert an existing role into a !pingme role") @commands.has_permissions(administrator=True) async def admin_cmd_add_pingable_role(self, ctx: Context, *, args: str): """Admin command: Register an existing role for use with pingme. The role defaults to pingable (i.e not on cooldown) @@ -85,25 +86,23 @@ async def admin_cmd_add_pingable_role(self, ctx: Context, *, args: str): else: roleName = args[len(argsSplit[0]) + 1:].lower() role = ctx.message.role_mentions[0] - db = db_gateway() - if db.get("pingable_roles", {"role_id": role.id}): + if DBGatewayActions().get(Pingable_roles, role_id=role.id): await ctx.message.reply("that role is already pingable!") - elif db.get("pingable_roles", {"name": roleName}): - await ctx.message.reply("A `!pingme` role already exists with the name '" + roleName + "'!") + elif DBGatewayActions().get(Pingable_roles, name=roleName): + await ctx.message.reply(f"A `!pingme` role already exists with the name '{roleName}'!") else: - db.insert( - "pingable_roles", - { - "guild_id": ctx.guild.id, - "role_id": role.id, - "on_cooldown": False, - "last_ping": -1, - "ping_count": 0, - "monthly_ping_count": 0, - "creator_id": ctx.author.id, - "colour": DEFAULT_PINGABLE_COLOUR, - "name": roleName - } + DBGatewayActions().create( + Pingable_roles( + guild_id=ctx.guild.id, + role_id=role.id, + on_cooldown=False, + last_ping=-1, + ping_count=0, + monthly_ping_count=0, + creator_id=ctx.author.id, + colour=DEFAULT_PINGABLE_COLOUR, + name=roleName + ) ) if not role.mentionable: await role.edit(mentionable=True, colour=discord.Colour.green(), reason="setting up new pingable role") @@ -111,7 +110,7 @@ async def admin_cmd_add_pingable_role(self, ctx: Context, *, args: str): await self.bot.adminLog( ctx.message, {"New !pingme Role Registered", - "Name: " + roleName + "\nRole: " + role.mention} + f"Name: {roleName}\nRole: {role.mention}"} ) @pingme.command( @@ -129,14 +128,14 @@ async def admin_cmd_remove_pingable_role(self, ctx: Context): if len(ctx.message.role_mentions) != 1: await ctx.message.reply("Please give one role mention!") else: - db = db_gateway() role = ctx.message.role_mentions[0] - if not db.get("pingable_roles", {"role_id": role.id}): + pingable_role = DBGatewayActions().get(Pingable_roles, role_id=role.id) + if not pingable_role: await ctx.message.reply("that role is not pingable!") else: - db.delete("pingable_roles", {"role_id": role.id}) + DBGatewayActions().delete(pingable_role) await ctx.message.reply("✅ Role successfully unregistered for `!pingme`.") - await self.bot.adminLog(ctx.message, {"!pingme Role Unregistered", "Role: " + role.mention}) + await self.bot.adminLog(ctx.message, {"!pingme Role Unregistered", f"Role: {role.mention}"}) @pingme.command(name="delete", usage="<@role>", help="Delete a !pingme role from the server") @commands.has_permissions(administrator=True) @@ -149,15 +148,15 @@ async def admin_cmd_delete_pingable_role(self, ctx: Context): if len(ctx.message.role_mentions) != 1: await ctx.message.reply("Please give one role mention!") else: - db = db_gateway() role = ctx.message.role_mentions[0] - if not db.get("pingable_roles", {"role_id": role.id}): + pingable_role = DBGatewayActions().get(Pingable_roles, role_id=role.id) + if not pingable_role: await ctx.message.reply("that role is not pingable!") else: - db.delete("pingable_roles", {"role_id": role.id}) + DBGatewayActions().delete(pingable_role) await role.delete(reason="role deletion requested via admin command") await ctx.message.reply("The role as been deleted!") - await self.bot.adminLog(ctx.message, {"!pingme Role Deleted", "Name: " + role.name + "\nID: " + str(role.id)}) + await self.bot.adminLog(ctx.message, {"!pingme Role Deleted", f"Name: {role.name}\nID: {role.id}"}) @pingme.command( name="reset-cooldown", @@ -175,21 +174,22 @@ async def admin_cmd_reset_role_ping_cooldown(self, ctx: Context): await ctx.message.reply("please mention one role") else: role = ctx.message.role_mentions[0] - db = db_gateway() - roleData = db.get("pingable_roles", {"role_id": role.id}) + roleData = DBGatewayActions().get(Pingable_roles, role_id=role.id) if not roleData: await ctx.message.reply("that role is not pingable!") - elif not roleData[0]["on_cooldown"]: + elif not roleData.on_cooldown: await ctx.message.reply("that role is not on cooldown!") else: - db.update("pingable_roles", {"on_cooldown": False}, {"role_id": role.id}) + pingable_role = DBGatewayActions().get(Pingable_roles, role_id=role.id) + pingable_role.on_cooldown = False + DBGatewayActions().update(pingable_role) if not role.mentionable: await role.edit( mentionable=True, colour=discord.Colour.green(), - reason="manual cooldown reset by user " + str(ctx.author.name) + "#" + str(ctx.author.id) + reason=f"manual cooldown reset by user {ctx.author.name}#{ctx.author.id}" ) - await ctx.message.reply("The " + role.name + " role is now pingable again!") + await ctx.message.reply(f"The {role.name} role is now pingable again!") await self.bot.adminLog(ctx.message, {"Ping Cooldown Manually Reset For !pingme Role": role.mention}) @pingme.command( @@ -222,22 +222,20 @@ async def admin_cmd_set_role_ping_cooldown(self, ctx: Context, *, args: str): for timeName in ["days", "hours", "minutes", "seconds"]: if timeName in kwArgs: if not lib.stringTyping.strIsInt(kwArgs[timeName]) or int(kwArgs[timeName]) < 1: - await ctx.message.reply(":x: Invalid number of " + timeName + "!") + await ctx.message.reply(f":x: Invalid number of {timeName}!") return timeoutDict[timeName] = int(kwArgs[timeName]) timeoutTD = lib.timeUtil.timeDeltaFromDict(timeoutDict) if timeoutTD > MAX_ROLE_PING_TIMEOUT: - await ctx.message.reply(":x: The maximum ping cooldown is " + lib.timeUtil.td_format_noYM(MAX_ROLE_PING_TIMEOUT)) + await ctx.message.reply(f":x: The maximum ping cooldown is {lib.timeUtil.td_format_noYM(MAX_ROLE_PING_TIMEOUT)}") return - db_gateway().update( - "guild_info", - {"role_ping_cooldown_seconds": int(timeoutTD.total_seconds())}, - {"guild_id": ctx.guild.id} - ) - await ctx.message.reply("Cooldown for !pingme roles now updated to " + lib.timeUtil.td_format_noYM(timeoutTD) + "!") + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + guild.role_ping_cooldown_seconds = int(timeoutTD.total_seconds()) + DBGatewayActions().update(guild) + await ctx.message.reply(f"Cooldown for !pingme roles now updated to {lib.timeUtil.td_format_noYM(timeoutTD)}!") await self.bot.adminLog( ctx.message, {"Cooldown For !pingme Role Pings Updated": lib.timeUtil.td_format_noYM(timeoutTD)} @@ -261,14 +259,16 @@ async def admin_cmd_set_pingme_create_threshold(self, ctx: Context, *, args: str await ctx.message.reply(":x: Invalid threshold! It must be a number.") elif int(args) < 1 or int(args) > MAX_PINGME_CREATE_THRESHOLD: await ctx.message.reply( - ":x: Invalid threshold! It must be between 1 and " + str(MAX_PINGME_CREATE_THRESHOLD) + ", inclusive." + f":x: Invalid threshold! It must be between 1 and {MAX_PINGME_CREATE_THRESHOLD}, inclusive." ) else: - db_gateway().update("guild_info", {"pingme_create_threshold": int(args)}, {"guild_id": ctx.guild.id}) - await ctx.message.reply("✅ Minimum votes for `!pingme create` successfully updated to " + args + " users.") + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + guild.pingme_create_threshold = int(args) + DBGatewayActions().update(guild) + await ctx.message.reply(f"✅ Minimum votes for `!pingme create` successfully updated to {args} users.") await self.bot.adminLog( ctx.message, - {"Votes Required For !pingme create Updated": "Minimum votes for new roles: " + args} + {f"Votes Required For !pingme create Updated": "Minimum votes for new roles: {args}"} ) @pingme.command( @@ -301,7 +301,7 @@ async def admin_cmd_set_pingme_create_poll_length(self, ctx: Context, *, args: s for timeName in ["days", "hours", "minutes", "seconds"]: if timeName in kwArgs: if not lib.stringTyping.strIsInt(kwArgs[timeName]) or int(kwArgs[timeName]) < 1: - await ctx.message.reply(":x: Invalid number of " + timeName + "!") + await ctx.message.reply(f":x: Invalid number of {timeName}!") return timeoutDict[timeName] = int(kwArgs[timeName]) @@ -309,17 +309,15 @@ async def admin_cmd_set_pingme_create_poll_length(self, ctx: Context, *, args: s timeoutTD = lib.timeUtil.timeDeltaFromDict(timeoutDict) if timeoutTD > MAX_PINGME_CREATE_POLL_LENGTH: await ctx.message.reply( - ":x: The maximum `!pingme create` poll length is " + lib.timeUtil.td_format_noYM(MAX_ROLE_PING_TIMEOUT) + f":x: The maximum `!pingme create` poll length is {lib.timeUtil.td_format_noYM(MAX_ROLE_PING_TIMEOUT)}" ) return - db_gateway().update( - "guild_info", - {"pingme_create_poll_length_seconds": int(timeoutTD.total_seconds())}, - {"guild_id": ctx.guild.id} - ) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + guild.pingme_create_poll_length_seconds = int(timeoutTD.total_seconds()) + DBGatewayActions().update(guild) await ctx.message.reply( - "✅ Poll length for `!pingme create` successfully updated to " + lib.timeUtil.td_format_noYM(timeoutTD) + "." + f"✅ Poll length for `!pingme create` successfully updated to {lib.timeUtil.td_format_noYM(timeoutTD)}." ) await self.bot.adminLog( ctx.message, @@ -344,20 +342,21 @@ async def admin_cmd_set_pingme_role_emoji(self, ctx: Context, *, args: str): elif not lib.emotes.strIsUnicodeEmoji(args): await ctx.message.reply(":x: Invalid emoji! Make sure it's a built in one rather than custom.") else: - db = db_gateway() - db.update("guild_info", {"pingme_role_emoji": args}, {"guild_id": ctx.guild.id}) - rolesData = db.get("pingable_roles", {"guild_id": ctx.guild.id}) + guild = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + guild.pingme_role_emoji = args + DBGatewayActions().update(guild) + rolesData = DBGatewayActions().list(Pingable_roles, guild_id=ctx.guild.id) if rolesData: - progressMsg = await ctx.send("Renaming " + str(len(rolesData)) + " roles... ⏳") + progressMsg = await ctx.send(f"Renaming {len(rolesData)} roles... ⏳") renamerTasks = set() for roleData in rolesData: - renamerTasks.add(changePingablePrefix(args, ctx.guild, roleData["role_id"], roleData["name"])) + renamerTasks.add(changePingablePrefix(args, ctx.guild, roleData.role_id, roleData.name)) await asyncio.wait(renamerTasks) - await progressMsg.edit(content="Renaming " + str(len(rolesData)) + " roles... ✅") + await progressMsg.edit(content=f"Renaming {len(rolesData)} roles... ✅") - await ctx.message.reply("Emoji prefix for `!pingme create` roles now updated to " + args + "!") - await self.bot.adminLog(ctx.message, {"Emoji Prefix For !pingme roles Updated": "New emoji: " + args}) + await ctx.message.reply(f"Emoji prefix for `!pingme create` roles now updated to {args}!") + await self.bot.adminLog(ctx.message, {"Emoji Prefix For !pingme roles Updated": f"New emoji: {args}"}) @pingme.command( name="remove-role-emoji", @@ -372,30 +371,26 @@ async def admin_cmd_remove_pingme_role_emoji(self, ctx: Context): :param Context ctx: A context summarising the message which called this command :param str args: ignored """ - db = db_gateway() - guildData = db.get("guild_info", {"guild_id": ctx.guild.id}) - if guildData["pingme_role_emoji"] is None: + guildData = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + if guildData.pingme_role_emoji is None: await ctx.message.reply(":x: There is no `!pingme` role emoji set!") else: - db.update("guild_info", {"pingme_role_emoji": None}, {"guild_id": ctx.guild.id}) - rolesData = db.get("pingable_roles", {"guild_id": ctx.guild.id}) + guildData.pingme_role_emoji = None + DBGatewayActions().update(guildData) + rolesData = DBGatewayActions().list(Pingable_roles, guild_id=ctx.guild.id) if rolesData: - progressMsg = await ctx.send("Renaming " + str(len(rolesData)) + " roles... ⏳") + progressMsg = await ctx.send(f"Renaming {len(rolesData)} roles... ⏳") renamerTasks = set() for roleData in rolesData: - renamerTasks.add(changePingablePrefix("", ctx.guild, roleData["role_id"], roleData["name"])) + renamerTasks.add(changePingablePrefix("", ctx.guild, roleData.role_id, roleData.name)) await asyncio.wait(renamerTasks) - await progressMsg.edit(content="Renaming " + str(len(rolesData)) + " roles... ✅") + await progressMsg.edit(content=f"Renaming {len(rolesData)} roles... ✅") await ctx.message.reply("Emoji prefix for `!pingme create` roles has been removed!") await self.bot.adminLog(ctx.message, {"Emoji Prefix For !pingme roles Removed": "‎"}) - @pingme.command( - name="create", - usage="", - help="Start a poll for the creation of a new !pingme role" - ) + @pingme.command(name="create", usage="", help="Start a poll for the creation of a new !pingme role") async def pingme_create(self, ctx: Context, *, args: str): """User command: Trigger a poll for the creation of a new pingme role with the given name. If the guild's configured minimum number of votes is reached, then the role will be created automatically. If the poll @@ -408,54 +403,47 @@ async def pingme_create(self, ctx: Context, *, args: str): if not args: await ctx.message.reply(":x: Please give the name of your new role!") else: - db = db_gateway() - roleData = db.get("pingable_roles", {"name": args.lower()}) - if roleData and roleData[0]["guild_id"] == ctx.guild.id: + roleData = DBGatewayActions().get(Pingable_roles, name=args.lower()) + if roleData and roleData.guild_id == ctx.guild.id: await ctx.message.reply(":x: A `!pingme` role already exists with that name!") else: pollMsg = await ctx.send("‎") - guildData = db.get("guild_info", {"guild_id": ctx.guild.id})[0] - requiredVotes = guildData["pingme_create_threshold"] + guildData = DBGatewayActions().get(Guild_info, guild_id=ctx.guild.id) + requiredVotes = guildData.pingme_create_threshold rolePoll = reactionPollMenu.InlineSingleOptionPollMenu( pollMsg, - guildData["pingme_create_poll_length_seconds"], + guildData.pingme_create_poll_length_seconds, requiredVotes, pollStarter=ctx.author, - authorName=ctx.author.display_name + " wants to make a new !pingme role!", - desc="Name: " + args + "\nRequired votes: " + str(requiredVotes) - + "\n\nReact if you want the role to be created!", - footerTxt="This menu will expire in " - + lib.timeUtil.td_format_noYM(timedelta(seconds=guildData["pingme_create_poll_length_seconds"])) + "." + authorName=f"{ctx.author.display_name} wants to make a new !pingme role!", + desc=f"Name: {args}\nRequired votes: {requiredVotes}\n\nReact if you want the role to be created!", + footerTxt= + f"This menu will expire in {lib.timeUtil.td_format_noYM(timedelta(seconds=guildData.pingme_create_poll_length_seconds))}." ) await rolePoll.doMenu() if rolePoll.yesesReceived >= requiredVotes: - roleName = (guildData["pingme_role_emoji"] - + args.title()) if guildData["pingme_role_emoji"] else args.title() + roleName = (guildData.pingme_role_emoji + args.title()) if guildData.pingme_role_emoji else args.title() newRole = await ctx.guild.create_role( name=roleName, colour=DEFAULT_PINGABLE_COLOUR, mentionable=True, reason="New !pingme role creation requested via poll" ) - db.insert( - "pingable_roles", - { - "name": args.lower(), - "guild_id": ctx.guild.id, - "role_id": newRole.id, - "on_cooldown": False, - "last_ping": -1, - "ping_count": 0, - "monthly_ping_count": 0, - "creator_id": ctx.author.id, - "colour": DEFAULT_PINGABLE_COLOUR - } + DBGatewayActions().create( + Pingable_roles( + name=args.lower(), + guild_id=ctx.guild.id, + role_id=newRole.id, + on_cooldown=False, + last_ping=-1, + ping_count=0, + monthly_ping_count=0, + creator_id=ctx.author.id, + colour=DEFAULT_PINGABLE_COLOUR + ) ) await ctx.message.reply("✅ The role has been created! Get it with `!pingme for " + args.lower() + "`") - await self.bot.adminLog( - pollMsg, - {"New !pingme Role Created": "Role: " + newRole.mention + "\nName: " + args} - ) + await self.bot.adminLog(pollMsg, {"New !pingme Role Created": f"Role: {newRole.mention}\nName: {args}"}) else: await pollMsg.reply( ctx.author.mention + " The role has not been created, as the poll did not receive enough votes." @@ -477,19 +465,19 @@ async def pingme_for(self, ctx: Context, *, args: str): if not args: await ctx.message.reply(":x: Please give the name of the role you would like!") else: - roleData = db_gateway().get("pingable_roles", {"name": args.lower()}) - if not roleData or roleData[0]["guild_id"] != ctx.guild.id: + roleData = DBGatewayActions().get(Pingable_roles, name=args.lower()) + if not roleData or roleData.guild_id != ctx.guild.id: await ctx.message.reply(":x: Unrecognised role name!") else: - role = ctx.guild.get_role(roleData[0]["role_id"]) + role = ctx.guild.get_role(roleData.role_id) if role is None: await ctx.message.reply(":x: I couldn't find the role! Please contact an administrator.") elif role in ctx.author.roles: await ctx.author.remove_roles(role, reason="User unsubscribed from !pingme role via command") - await ctx.message.reply("✅ You removed the " + role.name + " role!") + await ctx.message.reply(f"✅ You removed the {role.name} role!") else: await ctx.author.add_roles(role, reason="User subscribed to !pingme role via command") - await ctx.message.reply("✅ You got the " + role.name + " role!") + await ctx.message.reply(f"✅ You got the {role.name} role!") @pingme.command(name="list", help="List all available `!pingme` roles") async def pingme_list(self, ctx: Context): @@ -499,7 +487,7 @@ async def pingme_list(self, ctx: Context): :param Context ctx: A context summarising the message which called this command :param str args: ignored """ - allRolesData = db_gateway().get("pingable_roles", {"guild_id": ctx.guild.id}) + allRolesData = DBGatewayActions().list(Pingable_roles, guild_id=ctx.guild.id) if not allRolesData: await ctx.message.reply( f":x: This guild has no `!pingme` roles! Make a new one with `{self.bot.command_prefix}pingme create`." @@ -508,19 +496,18 @@ async def pingme_list(self, ctx: Context): reportEmbed = discord.Embed(title="All !pingme Roles", desc=ctx.guild.name) reportEmbed.colour = discord.Colour.random() reportEmbed.set_thumbnail(url=self.bot.user.avatar_url_as(size=128)) - for roleData in sorted(allRolesData, key=lambda x: x["ping_count"], reverse=True): - if r := ctx.guild.get(roleData["role_id"]): + for roleData in sorted(allRolesData, key=lambda x: x.ping_count, reverse=True): + if r := ctx.guild.get(roleData.role_id): reportEmbed.add_field( - name=roleData["name"].title(), - value="<@&" + str(roleData["role_id"]) + ">\nCreated by: <@" + str(roleData["creator_id"]) - + ">\nTotal pings: " + str(roleData["ping_count"]) + "\nTotal members: " + str(len(r.members)) + name=roleData.name.title(), + value= + f"<@&{roleData.role_id}>\nCreated by: <@{roleData.creator_id}>\nTotal pings: {roleData.ping_count}\nTotal members: {len(r.members)}" ) else: - await self.bot.adminLog(ctx.message, - { - "Unknown !Pingme Role", - f"Failed to find the '{roleData['name'].title()}' role. Was it deleted?" - } + await self.bot.adminLog( + ctx.message, + {"Unknown !Pingme Role", + f"Failed to find the '{roleData.name.title()}' role. Was it deleted?"} ) await ctx.reply(embed=reportEmbed) @@ -531,14 +518,13 @@ async def pingme_clear(self, ctx: Context): :param Context ctx: A context summarising the message which called this command :param str args: ignored """ - db = db_gateway() rolesToRemove = [] for role in ctx.author.roles: - if db.get("pingable_roles", {"role_id": role.id}): + if DBGatewayActions().get(Pingable_roles, role_id=role.id): rolesToRemove.append(role) if rolesToRemove: await ctx.author.remove_roles(*rolesToRemove, reason="User unsubscribed from !pingme role via command") - await ctx.message.reply("✅ You unsubscribed from " + str(len(rolesToRemove)) + " roles!") + await ctx.message.reply(f"✅ You unsubscribed from {len(rolesToRemove)} roles!") else: await ctx.message.reply(":x: You are not subsribed to any `!pingme` roles!") diff --git a/src/esportsbot/cogs/TwitchIntegrationCog.py b/src/esportsbot/cogs/TwitchIntegrationCog.py deleted file mode 100644 index 6be87d57..00000000 --- a/src/esportsbot/cogs/TwitchIntegrationCog.py +++ /dev/null @@ -1,313 +0,0 @@ -from discord.ext import commands, tasks -from ..db_gateway import db_gateway -from ..base_functions import channel_id_from_mention -import requests -import aiohttp -import asyncio -import time -import os - - -class TwitchIntegrationCog(commands.Cog): - def __init__(self, bot): - self.bot = bot - self.twitch_handler = TwitchAPIHandler() - self.live_checker.start() - - def cog_unload(self): - self.live_checker.cancel() - - @commands.command() - async def addtwitch(self, ctx, twitch_handle=None, announce_channel=None): - if twitch_handle is not None and announce_channel is not None: - # Check if Twitch channel has already been added - twitch_in_db = db_gateway().get( - 'twitch_info', - params={ - 'guild_id': ctx.author.guild.id, - 'twitch_handle': twitch_handle.lower() - } - ) - cleaned_channel_id = channel_id_from_mention(announce_channel) - channel_mention = "<#" + str(cleaned_channel_id) + ">" - if not twitch_in_db: - # Check user exists - user_exists = bool(self.twitch_handler.request_user(twitch_handle)) - if user_exists: - # Get live status of the channel - live_status = bool(self.twitch_handler.request_data([twitch_handle])) - # Insert Twitch channel into DB - db_gateway().insert( - 'twitch_info', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': cleaned_channel_id, - 'twitch_handle': twitch_handle.lower(), - 'currently_live': live_status - } - ) - await ctx.channel.send( - f"{twitch_handle} is valid and has been added, their notifications will be placed in {channel_mention}" - ) - else: - await ctx.channel.send(f"{twitch_handle} is not a valid Twitch handle") - else: - await ctx.channel.send(f"{twitch_handle} is already configured to {channel_mention}") - else: - await ctx.channel.send("You need to provide a Twitch handle and a channel") - - @commands.command() - async def addcustomtwitch(self, ctx, twitch_handle=None, announce_channel=None, custom_message=None): - if None not in (twitch_handle, announce_channel, custom_message): - # Check if Twitch channel has already been added - twitch_in_db = db_gateway().get( - 'twitch_info', - params={ - 'guild_id': ctx.author.guild.id, - 'twitch_handle': twitch_handle.lower() - } - ) - cleaned_channel_id = channel_id_from_mention(announce_channel) - channel_mention = "<#" + str(cleaned_channel_id) + ">" - if not twitch_in_db: - # Check user exists - user_exists = bool(self.twitch_handler.request_user(twitch_handle)) - if user_exists: - # Get live status of the channel - live_status = bool(self.twitch_handler.request_data([twitch_handle])) - # Insert Twitch channel into DB - db_gateway().insert( - 'twitch_info', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': cleaned_channel_id, - 'twitch_handle': twitch_handle.lower(), - 'currently_live': live_status, - 'custom_message': custom_message - } - ) - await ctx.channel.send( - f"{twitch_handle} is valid and has been added, their notifications will be placed in {channel_mention}" - ) - sample_message = custom_message.format( - handle="TwitchHandle", - game="Game/Genre", - link="StreamLink", - title="Title" - ) - await ctx.channel.send(f"Sample custom message below\n {sample_message}") - else: - await ctx.channel.send(f"{twitch_handle} is not a valid Twitch handle") - else: - await ctx.channel.send(f"{twitch_handle} is already configured to {channel_mention}") - else: - await ctx.channel.send("You need to provide a Twitch handle, text channel and custom message") - - @commands.command() - async def editcustomtwitch(self, ctx, twitch_handle=None, custom_message=None): - if twitch_handle is not None and custom_message is not None: - # Check if Twitch channel has already been added - twitch_in_db = db_gateway().get( - 'twitch_info', - params={ - 'guild_id': ctx.author.guild.id, - 'twitch_handle': twitch_handle.lower() - } - ) - if twitch_in_db: - # Make DB edit - db_gateway().update( - 'twitch_info', - set_params={'custom_message': custom_message}, - where_params={ - 'guild_id': ctx.author.guild.id, - 'twitch_handle': twitch_handle.lower() - } - ) - sample_message = custom_message.format( - handle="TwitchHandle", - game="Game/Genre", - link="StreamLink", - title="Title" - ) - await ctx.channel.send(f"Sample custom message below\n {sample_message}") - else: - await ctx.channel.send("That Twitch handle is not configured in this server") - else: - await ctx.channel.send("You need to provide a Twitch handle, text channel and custom message") - - @commands.command() - async def edittwitch(self, ctx, twitch_handle=None, announce_channel=None): - if twitch_handle is not None and announce_channel is not None: - # Check if Twitch channel has already been added - twitch_in_db = db_gateway().get( - 'twitch_info', - params={ - 'guild_id': ctx.author.guild.id, - 'twitch_handle': twitch_handle.lower() - } - ) - cleaned_channel_id = channel_id_from_mention(announce_channel) - channel_mention = "<#" + str(cleaned_channel_id) + ">" - if twitch_in_db: - # Make DB edit - db_gateway().update( - 'twitch_info', - set_params={'channel_id': cleaned_channel_id}, - where_params={ - 'guild_id': ctx.author.guild.id, - 'twitch_handle': twitch_handle.lower() - } - ) - await ctx.channel.send(f"Changed the alerts for {twitch_handle} to {channel_mention}") - else: - await ctx.channel.send("The Twitch user mentioned is not configured in this server") - else: - await ctx.channel.send("You need to provide a Twitch handle and a channel") - - @commands.command() - async def removetwitch(self, ctx, twitch_handle=None): - if twitch_handle is not None: - # Entered a Twitter handle - twitch_handle = twitch_handle.lower() - handle_exists = db_gateway().get( - 'twitch_info', - params={ - 'guild_id': ctx.author.guild.id, - 'twitch_handle': twitch_handle.lower() - } - ) - if handle_exists: - # Handle exists - db_gateway().delete( - 'twitch_info', - where_params={ - 'guild_id': ctx.author.guild.id, - 'twitch_handle': twitch_handle.lower() - } - ) - await ctx.channel.send(f"Alerts for {twitch_handle} have been removed from this server") - else: - await ctx.channel.send("Entered Twitch handle is not configured in this server") - else: - await ctx.channel.send("You need to provide a Twitch handle") - - @commands.command() - async def removealltwitch(self, ctx): - db_gateway().delete('twitch_info', where_params={'guild_id': ctx.author.guild.id}) - await ctx.channel.send("Removed all Twitch alerts from this server") - - @commands.command() - async def getalltwitch(self, ctx): - returned_val = db_gateway().get('twitch_info', params={'guild_id': ctx.author.guild.id}) - all_handles = "** **\n__**Twitch Alerts**__\n" - for each in returned_val: - channel_mention = "<#" + str(each['channel_id']) + ">" - all_handles += f"{each['twitch_handle']} is set to alert in {channel_mention}\n" - await ctx.channel.send(all_handles) - - @tasks.loop(seconds=50) - async def live_checker(self): - print('TWITCH: Retrieving current statuses') - time_taken = await self.get_and_compare_statuses(True) - print(f'TWITCH: Retrieved current statuses in {time_taken}s') - - @live_checker.before_loop - async def before_live_checker(self): - print('TWITCH: Waiting until bot is ready') - await self.bot.wait_until_ready() - print('TWITCH: Updating current statuses') - time_taken = await self.get_and_compare_statuses(False) - print(f'TWITCH: Updated current statuses in {time_taken}s') - - async def get_and_compare_statuses(self, alert): - start_time = time.time() - all_twitch_handles = db_gateway().pure_return('SELECT DISTINCT twitch_handle FROM "twitch_info"') - if all_twitch_handles: - # Create list of all twitch handles in the database - twitch_handle_arr = list(x['twitch_handle'] for x in all_twitch_handles) - # Create dict consisting of twitch handles and live statuses - twitch_status_dict = dict() - all_twitch_statuses = db_gateway().pure_return('SELECT DISTINCT twitch_handle, currently_live FROM "twitch_info"') - for twitch_user in all_twitch_statuses: - twitch_status_dict[twitch_user['twitch_handle']] = twitch_user['currently_live'] - # Query Twitch to receive array of all live users - returned_data = self.twitch_handler.request_data(twitch_handle_arr) - # Loop through all users comparing them to the live list - for twitch_handle in twitch_handle_arr: - # if any(obj['user_name'].lower() == twitch_handle for obj in returned_data): - handle_live = (next((obj for obj in returned_data if obj['user_name'].lower() == twitch_handle), False)) - print(handle_live) - if handle_live: - # User is live - if not twitch_status_dict[f'{twitch_handle}']: - # User was not live before but now is - db_gateway().update( - 'twitch_info', - set_params={'currently_live': True}, - where_params={'twitch_handle': twitch_handle.lower()} - ) - if alert: - # Grab all channels to be alerted - all_channels = db_gateway().get('twitch_info', params={'twitch_handle': twitch_handle.lower()}) - for each in all_channels: - # Send alert to specified channel to each['channel_id'] - custom_message = each['custom_message'].format( - handle=handle_live['user_name'], - game=handle_live['game_name'], - link=f"https://twitch.tv/{handle_live['user_name']}", - title=handle_live['title'] - ) if each[ - 'custom_message' - ] != '' else f"{handle_live['user_name']} has just gone live with {handle_live['game_name']}, check them out here: https://twitch.tv/{handle_live['user_name']}" - await self.bot.get_channel(each['channel_id']).send(custom_message) - else: - # User is not live - db_gateway().update( - 'twitch_info', - set_params={'currently_live': False}, - where_params={'twitch_handle': twitch_handle.lower()} - ) - return round(time.time() - start_time, 3) - - -class TwitchAPIHandler: - def __init__(self): - self.client_id = os.getenv('TWITCH_CLIENT_ID') - self.client_secret = os.getenv('TWITCH_CLIENT_SECRET') - self.params = {'client_id': self.client_id, 'client_secret': self.client_secret, 'grant_type': 'client_credentials'} - self.token = None - - def base_headers(self): - return {'Authorization': f'Bearer {self.token.get("access_token")}', 'Client-ID': self.client_id} - - def generate_new_oauth(self): - OAuthURL = 'https://id.twitch.tv/oauth2/token' - params = {'client_id': self.client_id, 'client_secret': self.client_secret, 'grant_type': 'client_credentials'} - oauth_response = requests.post(OAuthURL, params) - if oauth_response.status_code == 200: - oauth_response_json = oauth_response.json() - oauth_response_json['expires_in'] += time.time() - self.token = oauth_response_json - print("TWITCH: Generated new OAuth token") - return self.token - - def request_data(self, twitch_handles): - if self.token is None or self.token['expires_in'] < time.time(): - self.generate_new_oauth() - data_url = 'https://api.twitch.tv/helix/streams?' - data_url = data_url + "user_login=" + ("&user_login=".join(twitch_handles)) - data_response = requests.get(data_url, headers=self.base_headers(), params=self.params) - return data_response.json()['data'] - - def request_user(self, twitch_handle): - if self.token is None or self.token['expires_in'] < time.time(): - self.generate_new_oauth() - data_url = f'https://api.twitch.tv/helix/users?login={twitch_handle}' - #data_url = data_url+"user_login="+("&user_login=".join(twitch_handles)) - data_response = requests.get(data_url, headers=self.base_headers(), params=self.params) - return data_response.json()['data'] - - -def setup(bot): - bot.add_cog(TwitchIntegrationCog(bot)) diff --git a/src/esportsbot/cogs/TwitterIntegrationCog.py b/src/esportsbot/cogs/TwitterIntegrationCog.py deleted file mode 100644 index 0c1c9074..00000000 --- a/src/esportsbot/cogs/TwitterIntegrationCog.py +++ /dev/null @@ -1,184 +0,0 @@ -from discord.ext import tasks, commands -from ..db_gateway import db_gateway -from ..base_functions import channel_id_from_mention -import snscrape.modules.twitter as sntwitter -import time - - -class TwitterIntegrationCog(commands.Cog): - def __init__(self, bot): - self.bot = bot - self.tweet_checker.start() - - def cog_unload(self): - self.tweet_checker.cancel() - - @commands.command() - @commands.has_permissions(administrator=True) - async def addtwitter(self, ctx, twitter_handle=None, announce_channel=None): - if twitter_handle is not None and announce_channel is not None: - if (twitter_handle.replace('_', '')).isalnum(): - twitter_in_db = db_gateway().get( - 'twitter_info', - params={ - 'guild_id': ctx.author.guild.id, - 'twitter_handle': twitter_handle.lower() - } - ) - if not bool(twitter_in_db): - cleaned_channel_id = channel_id_from_mention(announce_channel) - channel_mention = "<#" + str(cleaned_channel_id) + ">" - previous_tweet_id = self.get_tweets(twitter_handle)[0]['id'] - db_gateway().insert( - 'twitter_info', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': cleaned_channel_id, - 'twitter_handle': twitter_handle.lower(), - 'previous_tweet_id': previous_tweet_id - } - ) - await ctx.channel.send( - f"{twitter_handle} is valid and has been added, their Tweets will be placed in {channel_mention}" - ) - else: - await ctx.channel.send( - f"{twitter_handle} is already configured to output to <#{str(twitter_in_db['channel_id'])}>" - ) - else: - await ctx.channel.send("You need to provide a correct Twitter handle") - else: - await ctx.channel.send("You need to provide a Twitter handle and a channel") - - @commands.command() - @commands.has_permissions(administrator=True) - async def removetwitter(self, ctx, twitter_handle=None): - if twitter_handle is not None: - if (twitter_handle.replace('_', '')).isalnum(): - twitter_in_db = db_gateway().get( - 'twitter_info', - params={ - 'guild_id': ctx.author.guild.id, - 'twitter_handle': twitter_handle.lower() - } - ) - if bool(twitter_in_db): - db_gateway().delete( - 'twitter_info', - where_params={ - 'guild_id': ctx.author.guild.id, - 'twitter_handle': twitter_handle.lower() - } - ) - await ctx.channel.send(f"Removed alerts for @{twitter_handle}") - else: - await ctx.channel.send(f"No alerts set for @{twitter_handle}") - else: - await ctx.channel.send("You need to provide a correct Twitter handle") - else: - await ctx.channel.send("You need to provide a Twitter handle") - - @commands.command() - @commands.has_permissions(administrator=True) - async def changetwitterchannel(self, ctx, twitter_handle=None, announce_channel=None): - if twitter_handle is not None and announce_channel is not None: - if (twitter_handle.replace('_', '')).isalnum(): - twitter_in_db = db_gateway().get( - 'twitter_info', - params={ - 'guild_id': ctx.author.guild.id, - 'twitter_handle': twitter_handle.lower() - } - ) - if bool(twitter_in_db): - # In DB - cleaned_channel_id = channel_id_from_mention(announce_channel) - channel_mention = "<#" + str(cleaned_channel_id) + ">" - db_gateway().update( - 'twitter_info', - set_params={'channel_id': cleaned_channel_id}, - where_params={ - 'guild_id': ctx.author.guild.id, - 'twitter_handle': twitter_handle.lower() - } - ) - await ctx.channel.send(f"{twitter_handle} has been updated and will now notify in {channel_mention}") - else: - # Not set up - await ctx.channel.send(f"{twitter_handle} is not configured in this server") - else: - await ctx.channel.send("You need to provide a correct Twitter handle") - else: - await ctx.channel.send("You need to provide a Twitter handle and a channel") - - @commands.command() - @commands.has_permissions(administrator=True) - async def getalltwitters(self, ctx): - all_guild_twitters = db_gateway().get('twitter_info', params={'guild_id': ctx.author.guild.id}) - if all_guild_twitters: - all_twitters_str = str() - for twitter in all_guild_twitters: - all_twitters_str += f"{twitter['twitter_handle']} is set to notify in <#{str(twitter['channel_id'])}>\n" - - await ctx.channel.send(f"Current Twitters set in this server:\n{all_twitters_str}") - else: - await ctx.channel.send("No Twitters have currently been set in this server") - - def get_tweets(self, given_username, tweet_number=1): - # Using TwitterSearchScraper to scrape data and append tweets to list - tweets_list = list() - for index, tweet_data in enumerate(sntwitter.TwitterSearchScraper(f'from:{given_username}').get_items()): - #tweet_is_reply = True if tweet_data.content[0] == '@' else False - if tweet_data.content[0] != '@': - tweets_list.append({'id': tweet_data.id, 'content': tweet_data.content, 'link': str(tweet_data)}) - if len(tweets_list) == tweet_number: - break - return tweets_list - - # https://discordpy.readthedocs.io/en/latest/ext/tasks/ - - @tasks.loop(seconds=50) - async def tweet_checker(self): - start_time = time.time() - print("** Checking all saved handles **") - returned_val = db_gateway().getall('twitter_info') - for each in returned_val: - single_tweet = self.get_tweets(each['twitter_handle'], 1) - if single_tweet[0]['id'] == each['previous_tweet_id']: - print(f"{each['twitter_handle']} - Same") - else: - print(f"{each['twitter_handle']} - Different") - await self.bot.get_channel( - each['channel_id'] - ).send(f"@{each['twitter_handle']} has just tweeted! Link - {single_tweet[0]['link']}") - db_gateway().update( - 'twitter_info', - set_params={'previous_tweet_id': int(single_tweet[0]['id'])}, - where_params={ - 'guild_id': each['guild_id'], - 'twitter_handle': each['twitter_handle'] - } - ) - end_time = time.time() - print(f'Checking tweets took: {round(end_time-start_time, 3)}s') - - @tweet_checker.before_loop - async def before_tweet_checker(self): - print('Updating tweets in DB') - returned_val = db_gateway().getall('twitter_info') - for each in returned_val: - single_tweet = self.get_tweets(each['twitter_handle'], 1) - db_gateway().update( - 'twitter_info', - set_params={'previous_tweet_id': int(single_tweet[0]['id'])}, - where_params={ - 'guild_id': each['guild_id'], - 'twitter_handle': each['twitter_handle'] - } - ) - print('Waiting on bot to become ready before start Twitter cog') - await self.bot.wait_until_ready() - - -def setup(bot): - bot.add_cog(TwitterIntegrationCog(bot)) diff --git a/src/esportsbot/cogs/VoicemasterCog.py b/src/esportsbot/cogs/VoicemasterCog.py index 391e3140..715d3d30 100644 --- a/src/esportsbot/cogs/VoicemasterCog.py +++ b/src/esportsbot/cogs/VoicemasterCog.py @@ -1,6 +1,7 @@ from discord.ext import commands -from ..db_gateway import db_gateway -from ..base_functions import send_to_log_channel +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Voicemaster_master, Voicemaster_slave +from esportsbot.base_functions import send_to_log_channel class VoicemasterCog(commands.Cog): @@ -14,31 +15,13 @@ async def setvmmaster(self, ctx, given_channel_id=None): is_a_valid_id = given_channel_id and given_channel_id.isdigit() and len(given_channel_id) == 18 if is_a_valid_id: - is_a_master = db_gateway().get( - 'voicemaster_master', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': given_channel_id - } - ) + is_a_master = DBGatewayActions().get(Voicemaster_master, guild_id=ctx.author.guild.id, channel_id=given_channel_id) is_voice_channel = hasattr(self.bot.get_channel(int(given_channel_id)), 'voice_states') - is_a_slave = db_gateway().get( - 'voicemaster_slave', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': given_channel_id - } - ) + is_a_slave = DBGatewayActions().get(Voicemaster_slave, guild_id=ctx.author.guild.id, channel_id=given_channel_id) if is_voice_channel and not (is_a_master or is_a_slave): # Not currently a Master and is voice channel, add it - db_gateway().insert( - 'voicemaster_master', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': given_channel_id - } - ) + DBGatewayActions().create(Voicemaster_master(guild_id=ctx.author.guild.id, channel_id=given_channel_id)) await ctx.channel.send("This VC has now been set as a VM master") new_vm_master_channel = self.bot.get_channel(int(given_channel_id)) await send_to_log_channel( @@ -66,12 +49,12 @@ async def setvmmaster(self, ctx, given_channel_id=None): @commands.command() @commands.has_permissions(administrator=True) async def getvmmasters(self, ctx): - master_vm_exists = db_gateway().get('voicemaster_master', params={'guild_id': ctx.author.guild.id}) + master_vm_exists = DBGatewayActions().list(Voicemaster_master, guild_id=ctx.author.guild.id) if master_vm_exists: master_vm_str = str() for record in master_vm_exists: - master_vm_str += f"{self.bot.get_channel(record['channel_id']).name} - {str(record['channel_id'])}\n" + master_vm_str += f"{self.bot.get_channel(record.channel_id).name} - {record.channel_id}\n" await ctx.channel.send(self.STRINGS['show_current_vcs'].format(master_vms=master_vm_str)) else: await ctx.channel.send(self.STRINGS['error_no_vms']) @@ -80,18 +63,24 @@ async def getvmmasters(self, ctx): @commands.has_permissions(administrator=True) async def removevmmaster(self, ctx, given_channel_id=None): if given_channel_id: - channel_exists = db_gateway().get( - 'voicemaster_master', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': given_channel_id - } + channel_exists = DBGatewayActions().get( + Voicemaster_master, + guild_id=ctx.author.guild.id, + channel_id=given_channel_id ) if channel_exists: - db_gateway().delete('voicemaster_master', where_params={ - 'guild_id': ctx.author.guild.id, 'channel_id': given_channel_id}) + DBGatewayActions().delete(channel_exists) await ctx.channel.send(self.STRINGS['success_vm_unset']) - await send_to_log_channel(self, ctx.author.guild.id, self.STRINGS['log_vm_master_removed'].format(mention=ctx.author.guild.id, channel_name=new_vm_master_channel.name, channel_id=new_vm_master_channel.id)) + removed_vm_master = self.bot.get_channel(given_channel_id) + await send_to_log_channel( + self, + ctx.author.guild.id, + self.STRINGS['log_vm_master_removed'].format( + mention=ctx.author.guild.id, + channel_name=removed_vm_master.name, + channel_id=removed_vm_master.id + ) + ) else: await ctx.channel.send(self.STRINGS['error_not_vm']) else: @@ -100,50 +89,52 @@ async def removevmmaster(self, ctx, given_channel_id=None): @commands.command() @commands.has_permissions(administrator=True) async def removeallmasters(self, ctx): - all_vm_masters = db_gateway().get('voicemaster_master', params={'guild_id': ctx.author.guild.id}) + all_vm_masters = DBGatewayActions().list(Voicemaster_master, guild_id=ctx.author.guild.id) for vm_master in all_vm_masters: - db_gateway().delete('voicemaster_master', where_params={ - 'channel_id': vm_master['channel_id']}) + DBGatewayActions().delete(vm_master) await ctx.channel.send(self.STRINGS['success_vm_masters_cleared']) - await send_to_log_channel(self, ctx.author.guild.id, self.STRINGS['log_vm_masters_cleared'].format(mention=ctx.author.mention)) + await send_to_log_channel( + self, + ctx.author.guild.id, + self.STRINGS['log_vm_masters_cleared'].format(mention=ctx.author.mention) + ) @commands.command() @commands.has_permissions(administrator=True) async def killallslaves(self, ctx): - all_vm_slaves = db_gateway().get('voicemaster_slave', params={'guild_id': ctx.author.guild.id}) + all_vm_slaves = DBGatewayActions().list(Voicemaster_slave, guild_id=ctx.author.guild.id) for vm_slave in all_vm_slaves: - vm_slave_channel = self.bot.get_channel(vm_slave['channel_id']) + vm_slave_channel = self.bot.get_channel(vm_slave.channel_id) if vm_slave_channel: await vm_slave_channel.delete() - db_gateway().delete('voicemaster_slave', where_params={ - 'channel_id': vm_slave['channel_id']}) + DBGatewayActions().delete(vm_slave) await ctx.channel.send(self.STRINGS['success_vm_slaves_cleared']) - await send_to_log_channel(self, ctx.author.guild.id, self.STRINGS['log_vm_slaves_cleared'].format(mention=ctx.author.mention)) + await send_to_log_channel( + self, + ctx.author.guild.id, + self.STRINGS['log_vm_slaves_cleared'].format(mention=ctx.author.mention) + ) @commands.command() async def lockvm(self, ctx): - in_vm_slave = db_gateway().get( - 'voicemaster_slave', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': ctx.author.voice.channel.id - } + in_vm_slave = DBGatewayActions().get( + Voicemaster_slave, + guild_id=ctx.author.guild.id, + channel_id=ctx.author.voice.channel.id ) if in_vm_slave: - if in_vm_slave[0]['owner_id'] == ctx.author.id: - if not in_vm_slave[0]['locked']: - db_gateway().update( - 'voicemaster_slave', - set_params={'locked': True}, - where_params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': ctx.author.voice.channel.id - } - ) + if in_vm_slave.owner_id == ctx.author.id: + if not in_vm_slave.locked: + in_vm_slave.locked = True + DBGatewayActions().update(in_vm_slave) await ctx.author.voice.channel.edit(user_limit=len(ctx.author.voice.channel.members)) await ctx.channel.send(self.STRINGS['success_slave_locked']) - await send_to_log_channel(self, ctx.author.guild.id, self.STRINGS['log_slave_locked'].format(mention=ctx.author.mention)) + await send_to_log_channel( + self, + ctx.author.guild.id, + self.STRINGS['log_slave_locked'].format(mention=ctx.author.mention) + ) else: await ctx.channel.send(self.STRINGS['error_already_locked']) else: @@ -153,27 +144,23 @@ async def lockvm(self, ctx): @commands.command() async def unlockvm(self, ctx): - in_vm_slave = db_gateway().get( - 'voicemaster_slave', - params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': ctx.author.voice.channel.id - } + in_vm_slave = DBGatewayActions().get( + Voicemaster_slave, + guild_id=ctx.author.guild.id, + channel_id=ctx.author.voice.channel.id ) if in_vm_slave: - if in_vm_slave[0]['owner_id'] == ctx.author.id: - if in_vm_slave[0]['locked']: - db_gateway().update( - 'voicemaster_slave', - set_params={'locked': False}, - where_params={ - 'guild_id': ctx.author.guild.id, - 'channel_id': ctx.author.voice.channel.id - } - ) + if in_vm_slave.owner_id == ctx.author.id: + if in_vm_slave.locked: + in_vm_slave.locked = False + DBGatewayActions().update(in_vm_slave) await ctx.author.voice.channel.edit(user_limit=0) - await send_to_log_channel(self, ctx.author.guild.id, self.STRINGS['log_slave_unlocked'].format(mention=ctx.author.mention)) + await send_to_log_channel( + self, + ctx.author.guild.id, + self.STRINGS['log_slave_unlocked'].format(mention=ctx.author.mention) + ) else: await ctx.channel.send(self.STRINGS['error_already_unlocked']) else: diff --git a/src/esportsbot/db_gateway.py b/src/esportsbot/db_gateway.py index d6f0e49c..5915e103 100644 --- a/src/esportsbot/db_gateway.py +++ b/src/esportsbot/db_gateway.py @@ -1,148 +1,96 @@ -import psycopg2 -from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +from sqlalchemy import create_engine, inspect +from sqlalchemy.orm import sessionmaker +from esportsbot.models import * import os -from .lib.exceptions import print_exception_trace - - -class db_connection(): - def __init__(self, database=None): - self.conn = psycopg2.connect( - host=os.getenv('PG_HOST'), - database=os.getenv('PG_DATABASE') if database is None else database, - user=os.getenv('PG_USER'), - password=os.getenv('PG_PWD') - ) - self.conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - self.cur = self.conn.cursor() - - def commit_query(self, query): - self.cur.execute(query) - self.conn.commit() - - def return_query(self, query): - self.cur.execute(query) - columns = [desc[0] for desc in self.cur.description] - real_dict = [dict(zip(columns, row)) for row in self.cur.fetchall()] - self.conn.commit() - return real_dict - - def close(self): - self.cur.close() - self.conn.close() - - -class db_gateway(): - @staticmethod - def get_param_insert_str(params): - key_string = str() - val_string = str() - for key, val in params.items(): - key_string += f"{key}, " - val_string += f"'{val}', " - return (key_string[:-2], val_string[:-2]) - - @staticmethod - def get_param_select_str(params): - key_val_string = str() - for key, val in params.items(): - if val == 'NULL': - key_val_string += f"{key}={val} AND " - else: - key_val_string += f"{key}='{val}' AND " - return key_val_string[:-5] - - def insert(self, table, params): - # Example usage: - # db_gateway().insert('voicemaster', params={'guild_id': '11111131111111111', - # 'owner_id': '222222222222222222', - # 'channel_id': '333333333333333333' - # }) - try: - db = db_connection() - query_vals = self.get_param_insert_str(params) - query_string = f'INSERT INTO {table}({query_vals[0]}) VALUES ({query_vals[1]})' - db.commit_query(query_string) - db.close() - return True - except Exception as err: - print_exception_trace(err) - raise RuntimeError('Error occurred using INSERT') from err - - def get(self, table, params): - # Example usage: - # returned_val = db_gateway().get('voicemaster', params={ - # 'channel_id': '333333333333333333' - # }) - try: - db = db_connection() - query_string = f'SELECT * FROM {table} WHERE {self.get_param_select_str(params)}' - returned_data = db.return_query(query_string) - db.close() - return returned_data - except Exception as err: - print_exception_trace(err) - raise RuntimeError('Error occurred using SELECT') from err +from dotenv import load_dotenv + +load_dotenv() + +db_string = f"postgresql://{os.getenv('PG_USER')}:{os.getenv('PG_PWD')}@{os.getenv('PG_HOST')}:5432/{os.getenv('PG_DATABASE')}" + +db = create_engine(db_string) + +Session = sessionmaker(db) +session = Session() + +base.metadata.create_all(db) - def getall(self, table): - # Example usage: - # returned_val = db_gateway().getall('voicemaster') +print("[DATABASE] - Models created") + + +class DBGatewayActions(): + """ + Base class for handling database queries + """ + def list(self, db_model, **args): + """ + Method to return a list of results that suit the model criteria + + Args: + db_model (database_model): [The model to query in the database] + **args (model_attributes): [The attributes specified for the query] + + Returns: + [list]: [Returns a list of all models that fit the input models criteria] + """ try: - db = db_connection() - query_string = f'SELECT * FROM {table}' - returned_data = db.return_query(query_string) - db.close() - return returned_data + query = session.query(db_model).filter_by(**args).all() + return query except Exception as err: - print_exception_trace(err) - raise RuntimeError('Error occurred using SELECT ALL') from err + raise Exception(f"Error occured when using list - {err}") + + def get(self, db_model, **args): + """ + Method to return a record that suits the model criteria - def update(self, table, set_params, where_params): - # Example usage: - # db_gateway().update('loggingchannel', set_params={'guild_id': '44'}, where_params={'channel_id': '795761577705078808'}) + Args: + db_model (database_model): [The model to query in the database] + **args (model_attributes): [The attributes specified for the query] + + Returns: + [list]: [Returns a list of all models that fit the input models criteria] + """ try: - db = db_connection() - query_string = f'UPDATE {table} SET {self.get_param_select_str(set_params)} WHERE {self.get_param_select_str(where_params)}' - db.commit_query(query_string) - db.close() - return True + query = session.query(db_model).filter_by(**args).all() + return query[0] if query != [] else query except Exception as err: - print_exception_trace(err) - raise RuntimeError('Error occurred using UPDATE') from err + raise Exception(f"Error occured when using get - {err}") + + def update(self, model): + """ + Method for updating a record in the database - def delete(self, table, where_params): - # Example usage: - # db_gateway().delete('loggingchannel', where_params={'guild_id': 44}) + Args: + model (database_model): [A class that contains the necessary information for an entry] + """ try: - db = db_connection() - query_string = f'DELETE FROM {table} WHERE {self.get_param_select_str(where_params)}' - db.commit_query(query_string) - db.close() - return True - # return query_string + session.add(model) + session.commit() except Exception as err: - print_exception_trace(err) - raise RuntimeError('Error occurred using DELETE') from err + raise Exception(f"Error occured when using update - {err}") - def pure_return(self, sql_query, database=None): - # Example usage: - # db_gateway().pure("SELECT * FROM 'guild_info'"") + def delete(self, model): + """ + Method for deleting a record from the database + + Args: + model (database_model): [A class that contains the necessary information for an entry] + """ try: - db = db_connection(database) - returned_data = db.return_query(sql_query) - db.close() - return returned_data + session.delete(model) + session.commit() except Exception as err: - print_exception_trace(err) - raise RuntimeError('Error occurred using PURE') from err + raise Exception(f"Error occured when using delete - {err}") + + def create(self, model): + """ + Method for adding a record to the database - def pure_query(self, sql_query, database=None): - # Example usage: - # db_gateway().pure('SELECT * FROM 'guild_info'') + Args: + model (database_model): [A class that contains the necessary information for an entry] + """ try: - db = db_connection(database) - returned_data = db.commit_query(sql_query) - db.close() - return returned_data + session.add(model) + session.commit() except Exception as err: - print_exception_trace(err) - raise RuntimeError('Error occurred using PURE') from err + raise Exception(f"Error occured when using create - {err}") \ No newline at end of file diff --git a/src/esportsbot/generate_schema.py b/src/esportsbot/generate_schema.py deleted file mode 100644 index 3985dbce..00000000 --- a/src/esportsbot/generate_schema.py +++ /dev/null @@ -1,226 +0,0 @@ -from .db_gateway import db_gateway - - -def generate_schema(): - # Does the esportsbot DB exist? - esportsbot_exists = db_gateway().pure_return( - "SELECT datname FROM pg_catalog.pg_database WHERE lower(datname) = lower('esportsbot')", - "postgres" - ) - if not esportsbot_exists: - # Esportsbot DB doesn't exist - db_gateway().pure_query("CREATE DATABASE esportsbot", "postgres") - - # Does the guild_id table exist? - guild_id_exists = db_gateway( - ).pure_return("SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'guild_info'") - if not guild_id_exists: - # Does not exist - query_string = """ - CREATE TABLE guild_info( - guild_id bigint NOT NULL, - log_channel_id bigint, - default_role_id bigint, - num_running_polls int NOT NULL, - role_ping_cooldown_seconds bigint NOT NULL, - pingme_create_threshold int NOT NULL, - pingme_create_poll_length_seconds bigint NOT NULL, - pingme_role_emoji text, - shared_role_id bigint - ); - ALTER TABLE ONLY guild_info - ADD CONSTRAINT loggingchannel_pkey PRIMARY KEY(guild_id); - """ - db_gateway().pure_query(query_string) - - # Does the pingable_roles table exist? - pingable_roles_exists = db_gateway().pure_return( - "SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'pingable_roles'" - ) - if not pingable_roles_exists: - # Does not exist - query_string = """ - CREATE TABLE pingable_roles( - name text NOT NULL, - guild_id bigint NOT NULL, - role_id bigint NOT NULL, - on_cooldown boolean NOT NULL, - last_ping float NOT NULL, - ping_count int NOT NULL, - monthly_ping_count int NOT NULL, - creator_id bigint NOT NULL, - colour int NOT NULL - ); - ALTER TABLE ONLY pingable_roles - ADD CONSTRAINT roleid_pkey PRIMARY KEY(role_id); - ALTER TABLE ONLY pingable_roles - ADD CONSTRAINT guildid_fkey FOREIGN KEY(guild_id) REFERENCES guild_info (guild_id); - """ - db_gateway().pure_query(query_string) - - # Does the event_categories exist? - event_categories_exists = db_gateway().pure_return( - "SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'event_categories'" - ) - if not event_categories_exists: - # Does not exist - query_string = """ - CREATE TABLE event_categories( - guild_id bigint NOT NULL, - event_name text NOT NULL, - role_id bigint NOT NULL, - signin_menu_id bigint NOT NULL - ); - ALTER TABLE ONLY event_categories - ADD CONSTRAINT eventname_pkey PRIMARY KEY(guild_id, event_name); - ALTER TABLE ONLY event_categories - ADD CONSTRAINT guildid_fkey FOREIGN KEY(guild_id) REFERENCES guild_info(guild_id); - """ - db_gateway().pure_query(query_string) - - # Does the reaction_menus table exist? - reaction_menus_exists = db_gateway().pure_return( - "SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'reaction_menus'" - ) - if not reaction_menus_exists: - # Does not exist - query_string = """ - CREATE TABLE reaction_menus( - message_id bigint NOT NULL, - menu jsonb - ); - ALTER TABLE ONLY reaction_menus - ADD CONSTRAINT menu_pkey PRIMARY KEY(message_id); - """ - db_gateway().pure_query(query_string) - - # Does the voicemaster_master table exist? - voicemaster_master_exists = db_gateway().pure_return( - "SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'voicemaster_master'" - ) - if not voicemaster_master_exists: - # Does not exist - query_string = """ - CREATE TABLE voicemaster_master ( - master_id bigint NOT NULL, - guild_id bigint NOT NULL, - channel_id bigint NOT NULL - ); - ALTER TABLE voicemaster_master ALTER COLUMN master_id ADD GENERATED ALWAYS AS IDENTITY ( - SEQUENCE NAME voicemaster_master_master_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1 - ); - ALTER TABLE ONLY voicemaster_master - ADD CONSTRAINT voicemaster_master_pkey PRIMARY KEY (master_id); - """ - db_gateway().pure_query(query_string) - - # Does the voicemaster_slave table exist? - voicemaster_slave_exists = db_gateway().pure_return( - "SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'voicemaster_slave'" - ) - if not voicemaster_slave_exists: - # Does not exist - query_string = """ - CREATE TABLE voicemaster_slave ( - vc_id bigint NOT NULL, - guild_id bigint NOT NULL, - channel_id bigint NOT NULL, - owner_id bigint NOT NULL, - locked boolean NOT NULL - ); - ALTER TABLE voicemaster_slave ALTER COLUMN vc_id ADD GENERATED ALWAYS AS IDENTITY ( - SEQUENCE NAME voicemaster_vc_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1 - ); - ALTER TABLE ONLY voicemaster_slave - ADD CONSTRAINT voicemaster_pkey PRIMARY KEY (vc_id); - """ - db_gateway().pure_query(query_string) - - # Does the twitch_info table exist? - twitch_info_exists = db_gateway( - ).pure_return("SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'twitch_info'") - if not twitch_info_exists: - # Does not exist - query_string = """ - CREATE TABLE public.twitch_info( - id bigint NOT NULL, - guild_id bigint NOT NULL, - channel_id bigint NOT NULL, - twitch_handle character varying NOT NULL, - currently_live boolean NOT NULL, - custom_message character varying - ); - ALTER TABLE public.twitch_info ALTER COLUMN id ADD GENERATED ALWAYS AS IDENTITY( - SEQUENCE NAME public.twitch_info_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1 - ); - ALTER TABLE ONLY public.twitch_info - ADD CONSTRAINT twitch_info_pkey PRIMARY KEY(id); - """ - db_gateway().pure_query(query_string) - - # Does the twitter_info table exist? - twitter_info_exists = db_gateway( - ).pure_return("SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'twitter_info'") - if not twitter_info_exists: - # Does not exist - query_string = """ - CREATE TABLE public.twitter_info( - id bigint NOT NULL, - guild_id bigint NOT NULL, - channel_id bigint NOT NULL, - twitter_handle character varying NOT NULL, - previous_tweet_id bigint NOT NULL - ); - ALTER TABLE public.twitter_info ALTER COLUMN id ADD GENERATED ALWAYS AS IDENTITY( - SEQUENCE NAME public.twitter_info_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1 - ); - ALTER TABLE ONLY public.twitter_info - ADD CONSTRAINT twitter_info_pkey PRIMARY KEY(id); - """ - db_gateway().pure_query(query_string) - - # Does the music_channels_info table exist? - music_channels_info_exists = db_gateway().pure_return( - "SELECT true::BOOLEAN FROM pg_catalog.pg_tables WHERE schemaname = 'public' AND tablename = 'music_channels'" - ) - if not music_channels_info_exists: - query_string = """ - CREATE TABLE public.music_channels( - id bigint NOT NULL, - guild_id bigint NOT NULL, - channel_id bigint NOT NULL, - queue_message_id bigint, - preview_message_id bigint - ); - ALTER TABLE public.music_channels ALTER COLUMN id ADD GENERATED ALWAYS AS IDENTITY ( - SEQUENCE NAME public.music_channels_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1 - ); - ALTER TABLE ONLY public.music_channels - ADD CONSTRAINT music_channels_pkey PRIMARY KEY(id); - """ - db_gateway().pure_query(query_string) diff --git a/src/esportsbot/lib/client.py b/src/esportsbot/lib/client.py index ca479e39..55b81c05 100644 --- a/src/esportsbot/lib/client.py +++ b/src/esportsbot/lib/client.py @@ -1,10 +1,11 @@ from types import FrameType, FunctionType from discord.ext import commands, tasks from discord import Intents, Embed, Message, Colour, Role -from ..reactionMenus.reactionMenuDB import ReactionMenuDB -from ..reactionMenus import reactionMenu -from ..db_gateway import db_gateway -from . import exceptions +from esportsbot.reactionMenus.reactionMenuDB import ReactionMenuDB +from esportsbot.reactionMenus import reactionMenu +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Music_channels, Pingable_roles, Guild_info, Reaction_menus +from esportsbot.lib import exceptions from typing import Dict, MutableMapping, Set, Union, List from datetime import datetime, timedelta import os @@ -12,8 +13,8 @@ import asyncio import toml -from .exceptions import UnrecognisedReactionMenuMessage -from .emotes import Emote +from esportsbot.lib.exceptions import UnrecognisedReactionMenuMessage +from esportsbot.lib.emotes import Emote # Type alias to be used for user facing strings. Allows for multi-level tables. StringTable = MutableMapping[str, Union[str, "StringTable"]] @@ -49,9 +50,9 @@ def __init__(self, command_prefix: str, unknownCommandEmoji: Emote, userStringsF def update_music_channels(self): self.MUSIC_CHANNELS = {} - temp_channels = db_gateway().pure_return("SELECT guild_id, channel_id FROM music_channels") + temp_channels = DBGatewayActions().list(Music_channels) for item in temp_channels: - self.MUSIC_CHANNELS[item.get("guild_id")] = item.get("channel_id") + self.MUSIC_CHANNELS[item.guild_id] = item.channel_id return self.MUSIC_CHANNELS def interruptReceived(self, signum: signal.Signals, frame: FrameType): @@ -78,12 +79,12 @@ async def rolePingCooldown(self, role: Role, cooldownSeconds: int): :param int cooldownSeconds: The number of seconds to wait asynchronously before updating role """ await asyncio.sleep(cooldownSeconds) - db = db_gateway() - roleData = db.get("pingable_roles", {"role_id": role.id}) - if roleData and roleData[0]["on_cooldown"]: - db.update('pingable_roles', set_params={'on_cooldown': False}, where_params={'role_id': role.id}) + roleData = DBGatewayActions().get(Pingable_roles, role_id=role.id) + if roleData and roleData.on_cooldown: + roleData.on_cooldown = False + DBGatewayActions().update(roleData) if role.guild.get_role(role.id) is not None: - await role.edit(mentionable=True, colour=roleData[0]["colour"], reason="role ping cooldown complete") + await role.edit(mentionable=True, colour=roleData.colour, reason="role ping cooldown complete") @tasks.loop(hours=24) async def monthlyPingablesReport(self): @@ -100,35 +101,35 @@ async def monthlyPingablesReport(self): baseEmbed.colour = Colour.random() baseEmbed.set_thumbnail(url=self.user.avatar_url_as(size=128)) baseEmbed.set_footer(text=datetime.now().strftime("%m/%d/%Y")) - db = db_gateway() - for guildData in db.getall("guild_info"): - pingableRoles = db.get("pingable_roles", {"guild_id": guildData["guild_id"]}) + for guildData in DBGatewayActions().list(Guild_info): + pingableRoles = DBGatewayActions().list(Pingable_roles, guild_id=guildData.guild_id) if pingableRoles: - guild = self.get_guild(guildData["guild_id"]) + guild = self.get_guild(guildData.guild_id) if guild is None: print( - "[Esportsbot.monthlyPingablesReport] Unknown guild id in guild_info table: #" - + str(guildData["guild_id"]) + f"[Esportsbot.monthlyPingablesReport] Unknown guild id in guild_info table: #{guildData.guild_id}" ) - elif guildData["log_channel_id"] is not None: + elif guildData.log_channel_id is not None: reportEmbed = baseEmbed.copy() rolesAdded = False for roleData in pingableRoles: - role = guild.get_role(roleData["role_id"]) + role = guild.get_role(roleData.role_id) if role is None: - print("[Esportsbot.monthlyPingablesReport] Unknown pingable role id in pingable_roles table. Removing from the table: role #" \ - + str(roleData["role_id"]) + " in guild #" + str(guildData["guild_id"])) - db.delete("pingable_roles", {"role_id": roleData["role_id"]}) + print( + f"[Esportsbot.monthlyPingablesReport] Unknown pingable role id in pingable_roles table. Removing from the table: role #{roleData.role_id} in guild #{guildData.guild_id}" + ) + DBGatewayActions().delete(roleData) else: reportEmbed.add_field( name=role.name, - value=role.mention + "\n" + str(roleData["monthly_ping_count"]) + " pings" + value=f"{role.mention}\n{roleData.monthly_ping_count} pings" ) - db.update("pingable_roles", {"monthly_ping_count": 0}, {"role_id": role.id}) + roleData.monthly_ping_count = 0 + DBGatewayActions().update(roleData) rolesAdded = True if rolesAdded: loggingTasks.add( - asyncio.create_task(guild.get_channel(guildData['log_channel_id']).send(embed=reportEmbed)) + asyncio.create_task(guild.get_channel(guildData.log_channel_id).send(embed=reportEmbed)) ) if loggingTasks: @@ -142,16 +143,15 @@ async def init(self): which were interrupted by bot shutdown. This method must be called upon bot.on_ready, since these tasks cannot be performed synchronously during EsportsBot.__init__. """ - db = db_gateway() if not self.reactionMenus.initializing: raise RuntimeError("This bot's ReactionMenuDB has already been initialized.") try: - menusData = db.getall('reaction_menus') + menusData = DBGatewayActions().list(Reaction_menus) except Exception as e: print("failed to load menus from SQL", e) raise e for menuData in menusData: - msgID, menuDict = menuData['message_id'], menuData['menu'] + msgID, menuDict = menuData.message_id, menuData.menu if 'type' in menuDict: if reactionMenu.isSaveableMenuTypeName(menuDict['type']): try: @@ -160,15 +160,13 @@ async def init(self): menuDict) ) except UnrecognisedReactionMenuMessage: - print( - "Unrecognised message for " + menuDict['type'] + ", removing from the database: " - + str(menuDict["msg"]) - ) - db.delete('reaction_menus', where_params={'message_id': msgID}) + print(f"Unrecognised message for {menuDict['type']}, removing from the database: {menuDict['msg']}") + reaction_menu = DBGatewayActions().get(Reaction_menus, message_id=msgID) + DBGatewayActions().delete(reaction_menu) else: print("Non saveable menu in database:", msgID, menuDict["type"]) else: - print("no type for menu " + str(msgID)) + print(f"no type for menu {msgID}") self.reactionMenus.initializing = False if "UNKNOWN_COMMAND_EMOJI" in os.environ: @@ -176,22 +174,23 @@ async def init(self): now = datetime.now() roleUpdateTasks = set() - for guildData in db.getall("guild_info"): - guild = self.get_guild(guildData["guild_id"]) + for guildData in DBGatewayActions().list(Guild_info): + guild = self.get_guild(guildData.guild_id) if guild is None: - print("[Esportsbot.init] Unknown guild id in guild_info table: #" + str(guildData["guild_id"])) + print(f"[Esportsbot.init] Unknown guild id in guild_info table: #{guildData.guild_id}") else: - guildPingCooldown = timedelta(seconds=guildData["role_ping_cooldown_seconds"]) - for roleData in db.get("pingable_roles", {"guild_id": guildData["guild_id"]}): - role = guild.get_role(roleData["role_id"]) + guildPingCooldown = timedelta(seconds=guildData.role_ping_cooldown_seconds) + for roleData in DBGatewayActions().list(Pingable_roles, guild_id=guildData.guild_id): + role = guild.get_role(roleData.role_id) if role is None: - print("[Esportsbot.init] Unknown pingable role id in pingable_roles table. Removing from the table: role #" \ - + str(roleData["role_id"]) + " in guild #" + str(guildData["guild_id"])) - db.delete("pingable_roles", {"role_id": roleData["role_id"]}) + print( + f"[Esportsbot.init] Unknown pingable role id in pingable_roles table. Removing from the table: role #{roleData.role_id} in guild #{guildData.guild_id}" + ) + DBGatewayActions().delete(roleData) else: remainingCooldown = max( 0, - int((datetime.fromtimestamp(roleData["last_ping"]) + guildPingCooldown - now).total_seconds()) + int((datetime.fromtimestamp(roleData.last_ping) + guildPingCooldown - now).total_seconds()) ) roleUpdateTasks.add(asyncio.create_task(self.rolePingCooldown(role, remainingCooldown))) @@ -220,8 +219,8 @@ async def adminLog(self, message: Message, actions: Dict[str, str], *args, guild raise ValueError("Must give at least one of message or guildID") else: guildID = message.guild.id - db_logging_call = db_gateway().get('guild_info', params={'guild_id': guildID}) - if db_logging_call and db_logging_call[0]['log_channel_id']: + db_logging_call = DBGatewayActions().get(Guild_info, guild_id=guildID) + if db_logging_call and db_logging_call.log_channel_id: if "embed" not in kwargs: if message is None: logEmbed = Embed(description="Responsible user unknown. Check the server's audit log.") @@ -238,7 +237,7 @@ async def adminLog(self, message: Message, actions: Dict[str, str], *args, guild for aTitle, aDesc in actions.items(): logEmbed.add_field(name=str(aTitle), value=str(aDesc), inline=False) kwargs["embed"] = logEmbed - await self.get_channel(db_logging_call[0]['log_channel_id']).send(*args, **kwargs) + await self.get_channel(db_logging_call.log_channel_id).send(*args, **kwargs) def handleRoleMentions(self, message: Message) -> Set[asyncio.Task]: """Handle !pingme behaviour for the given message. @@ -249,13 +248,12 @@ def handleRoleMentions(self, message: Message) -> Set[asyncio.Task]: :return: A potentially empty set of already scheduled tasks handling role ping cooldown :rtype: Set[asyncio.Task] """ - db = db_gateway() - guildInfo = db.get('guild_info', params={'guild_id': message.guild.id}) + guildInfo = DBGatewayActions().get(Guild_info, guild_id=message.guild.id) roleUpdateTasks = set() if guildInfo: for role in message.role_mentions: - roleData = db.get('pingable_roles', params={'role_id': role.id}) - if roleData and not roleData[0]["on_cooldown"]: + roleData = DBGatewayActions().get(Pingable_roles, role_id=role.id) + if roleData and not roleData.on_cooldown: roleUpdateTasks.add( asyncio.create_task( role.edit( @@ -265,30 +263,23 @@ def handleRoleMentions(self, message: Message) -> Set[asyncio.Task]: ) ) ) - db.update('pingable_roles', {'on_cooldown': True}, {'role_id': role.id}) - db.update('pingable_roles', {"last_ping": datetime.now().timestamp()}, {'role_id': role.id}) - db.update('pingable_roles', {"ping_count": roleData[0]["ping_count"] + 1}, {'role_id': role.id}) - db.update( - 'pingable_roles', - {"monthly_ping_count": roleData[0]["monthly_ping_count"] + 1}, - {'role_id': role.id} - ) - roleUpdateTasks.add( - asyncio.create_task(self.rolePingCooldown(role, - guildInfo[0]["role_ping_cooldown_seconds"])) - ) + roleData.on_cooldown = True + roleData.last_ping = datetime.now().timestamp() + roleData.ping_count = roleData.ping_count + 1 + roleData.monthly_ping_count = roleData.monthly_ping_count + 1 + DBGatewayActions().update(roleData) + + roleUpdateTasks.add(asyncio.create_task(self.rolePingCooldown(role, guildInfo.role_ping_cooldown_seconds))) roleUpdateTasks.add( asyncio.create_task( self.adminLog( message, - {"!pingme Role Pinged": "Role: " + role.mention + "\nUser: " + message.author.mention} + {"!pingme Role Pinged": f"Role: {role.mention}\nUser: {message.author.mention}"} ) ) ) return roleUpdateTasks - - async def multiWaitFor(self, eventTypes: List[str], timeout: int, check: FunctionType = None): """Performs discord.Client.wait_for, but with multiple possible event types. @@ -337,7 +328,8 @@ def instance() -> EsportsBot: intents = Intents.default() intents.members = True _instance = EsportsBot( - os.environ.get("COMMAND_PREFIX", "!"), + os.environ.get("COMMAND_PREFIX", + "!"), Emote.fromStr("⁉"), "esportsbot/user_strings.toml", intents=intents diff --git a/src/esportsbot/models.py b/src/esportsbot/models.py new file mode 100644 index 00000000..46f3c28b --- /dev/null +++ b/src/esportsbot/models.py @@ -0,0 +1,89 @@ +from sqlalchemy import Column, String, BigInteger, Boolean, Float, ForeignKey +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.declarative import declarative_base + +base = declarative_base() + + +class Guild_info(base): + __tablename__ = 'guild_info' + guild_id = Column(BigInteger, primary_key=True, nullable=False) + log_channel_id = Column(BigInteger, nullable=True) + default_role_id = Column(BigInteger, nullable=True) + num_running_polls = Column(BigInteger, nullable=False) + role_ping_cooldown_seconds = Column(BigInteger, nullable=False) + pingme_create_threshold = Column(BigInteger, nullable=False) + pingme_create_poll_length_seconds = Column(BigInteger, nullable=False) + pingme_role_emoji = Column(String, nullable=True) + shared_role_id = Column(BigInteger, nullable=True) + + +class Pingable_roles(base): + __tablename__ = 'pingable_roles' + name = Column(String, nullable=False) + guild_id = Column(BigInteger, ForeignKey("guild_info.guild_id"), nullable=False) + role_id = Column(BigInteger, primary_key=True, nullable=False) + on_cooldown = Column(Boolean, nullable=False) + last_ping = Column(Float, nullable=False) + ping_count = Column(BigInteger, nullable=False) + monthly_ping = Column(BigInteger, nullable=False) + creator_id = Column(BigInteger, nullable=False) + colour = Column(BigInteger, nullable=False) + + +class Event_categories(base): + __tablename__ = 'event_categories' + guild_id = Column(BigInteger, ForeignKey("guild_info.guild_id"), primary_key=True, nullable=False) + event_name = Column(String, primary_key=True, nullable=False) + role_id = Column(BigInteger, nullable=False) + signin_menu = Column(BigInteger, nullable=False) + + +class Reaction_menus(base): + __tablename__ = 'reaction_menus' + message_id = Column(BigInteger, primary_key=True, nullable=False) + menu = Column(JSONB, nullable=False) + + +class Voicemaster_master(base): + __tablename__ = 'voicemaster_master' + master_id = Column(BigInteger, primary_key=True, autoincrement=True, nullable=False) + guild_id = Column(BigInteger, nullable=False) + channel_id = Column(BigInteger, nullable=False) + + +class Voicemaster_slave(base): + __tablename__ = 'voicemaster_slave' + vc_id = Column(BigInteger, primary_key=True, autoincrement=True, nullable=False) + guild_id = Column(BigInteger, nullable=False) + channel_id = Column(BigInteger, nullable=False) + owner_id = Column(BigInteger, nullable=False) + locked = Column(Boolean, nullable=False) + + +class Twitch_info(base): + __tablename__ = 'twitch_info' + id = Column(BigInteger, primary_key=True, autoincrement=True, nullable=False) + guild_id = Column(BigInteger, nullable=False) + channel_id = Column(BigInteger, nullable=False) + twitch_handle = Column(String, nullable=False) + currently_live = Column(Boolean, nullable=False) + custom_message = Column(String, nullable=False) + # Will most likely change after Benji switch + + +class Twitter_info(base): + __tablename__ = 'twitter_info' + id = Column(BigInteger, primary_key=True, autoincrement=True, nullable=False) + guild_id = Column(BigInteger, nullable=False) + twitter_user_id = Column(String, nullable=False) + twitter_handle = Column(String, nullable=False) + + +class Music_channels(base): + __tablename__ = 'music_channels' + id = Column(BigInteger, primary_key=True, autoincrement=True, nullable=False) + guild_id = Column(BigInteger, nullable=False) + channel_id = Column(BigInteger, nullable=False) + queue_message_id = Column(BigInteger, nullable=False) + preview_message_id = Column(BigInteger, nullable=False) diff --git a/src/esportsbot/reactionMenus/reactionMenu.py b/src/esportsbot/reactionMenus/reactionMenu.py index 495df480..821d0eac 100644 --- a/src/esportsbot/reactionMenus/reactionMenu.py +++ b/src/esportsbot/reactionMenus/reactionMenu.py @@ -7,7 +7,7 @@ import inspect from discord import Embed, Colour, NotFound, HTTPException, Forbidden, Member, User, Message, Role, RawReactionActionEvent, Client -from .. import lib +from esportsbot import lib from abc import abstractmethod from typing import Union, Dict, List, Any import asyncio diff --git a/src/esportsbot/reactionMenus/reactionMenuDB.py b/src/esportsbot/reactionMenus/reactionMenuDB.py index 6551efe3..ef6635c0 100644 --- a/src/esportsbot/reactionMenus/reactionMenuDB.py +++ b/src/esportsbot/reactionMenus/reactionMenuDB.py @@ -6,9 +6,10 @@ """ from typing import Union -from .reactionMenu import ReactionMenu, isSaveableMenuInstance +from esportsbot.reactionMenus.reactionMenu import ReactionMenu, isSaveableMenuInstance from psycopg2.extras import Json -from ..db_gateway import db_gateway +from esportsbot.db_gateway import DBGatewayActions +from esportsbot.models import Reaction_menus class ReactionMenuDB(dict): @@ -66,22 +67,18 @@ def __setitem__(self, menuID: int, menu: ReactionMenu) -> None: """ if menuID != menu.msg.id: raise ValueError( - "Attempted to register a menu with key " + str(menuID) + ", but the message ID for the given menu is " - + str(menu.msg.id) + f"Attempted to register a menu with key {menuID}, but the message ID for the given menu is {menu.msg.id}" ) if menu.msg.id in self: - raise KeyError("A menu is already registered with the given ID: " + str(menu.msg.id)) + raise KeyError(f"A menu is already registered with the given ID: {menu.msg.id}") super().__setitem__(menuID, menu) if not self.initializing and isSaveableMenuInstance(menu): - db_gateway().insert( - 'reaction_menus', - params={ - 'message_id': menu.msg.id, - 'menu': str(Json(menu.toDict())).lstrip("'").rstrip("'") - } + DBGatewayActions().create( + Reaction_menus(message_id=menu.msg.id, + menu=str(Json(menu.toDict())).lstrip("'").rstrip("'")) ) def __delitem__(self, menu: Union[ReactionMenu, int]) -> None: @@ -94,16 +91,17 @@ def __delitem__(self, menu: Union[ReactionMenu, int]) -> None: """ if isinstance(menu, int): if menu not in self: - raise KeyError("No menu is registered with the given ID: " + str(menu)) + raise KeyError(f"No menu is registered with the given ID: {menu}") menu = self[menu] elif menu.msg.id not in self: - raise KeyError("The given menu is not registered: " + str(menu.msg.id)) + raise KeyError(f"The given menu is not registered: {menu.msg.id}") super().__delitem__(menu.msg.id) if isSaveableMenuInstance(menu): - db_gateway().delete('reaction_menus', where_params={'message_id': menu.msg.id}) + reaction_menu = DBGatewayActions().get(Reaction_menus, message_id=menu.msg.id) + DBGatewayActions().delete(reaction_menu) def add(self, menu: ReactionMenu): """Register a ReactionMenu with the database, and save to SQL. @@ -128,7 +126,7 @@ def removeID(self, menuID: int): :raise KeyError: When the given menu is not registered """ if menuID not in self: - raise KeyError("No menu is registered with the given ID: " + str(menuID)) + raise KeyError(f"No menu is registered with the given ID: {menuID}") self.remove(self[menuID]) def updateDB(self, menu: ReactionMenu): @@ -140,8 +138,6 @@ def updateDB(self, menu: ReactionMenu): if menu.msg.id not in self: raise KeyError("The given menu is not registered: " + str(menu.msg.id)) if isSaveableMenuInstance(menu): - db_gateway().update( - 'reaction_menus', - set_params={'menu': str(Json(menu.toDict())).lstrip("'").rstrip("'")}, - where_params={'message_id': menu.msg.id} - ) + reaction_menu = DBGatewayActions().get(Reaction_menus, message_id=menu.msg.id) + reaction_menu.menu = str(Json(menu.toDict())).lstrip("'").rstrip("'") + DBGatewayActions().update(reaction_menu) diff --git a/src/esportsbot/reactionMenus/reactionPollMenu.py b/src/esportsbot/reactionMenus/reactionPollMenu.py index 4a135293..9a44562f 100644 --- a/src/esportsbot/reactionMenus/reactionPollMenu.py +++ b/src/esportsbot/reactionMenus/reactionPollMenu.py @@ -311,8 +311,12 @@ def __init__( authorName=authorName ) - - def reactionClosesMenu(self, reactPL: Union[RawReactionActionEvent, RawMessageDeleteEvent, RawBulkMessageDeleteEvent]) -> bool: + def reactionClosesMenu( + self, + reactPL: Union[RawReactionActionEvent, + RawMessageDeleteEvent, + RawBulkMessageDeleteEvent] + ) -> bool: """An InlineReactionMenu override which checks the number of yes votes received. :param discord.RawReactionActionEvent reactPL: The raw payload representing the reaction addition or removal @@ -328,7 +332,7 @@ def reactionClosesMenu(self, reactPL: Union[RawReactionActionEvent, RawMessageDe if self.msg.id in reactPL.message_ids: raise lib.exceptions.UnrecognisedReactionMenuMessage(self.msg.guild.id, self.msg.channel.id, self.msg.id) return False - + try: if reactPL.message_id == self.msg.id and reactPL.user_id != lib.client.instance().user.id and \ lib.emotes.Emote.fromPartial(reactPL.emoji, rejectInvalid=True) == self.yesOption.emoji: @@ -342,7 +346,6 @@ def reactionClosesMenu(self, reactPL: Union[RawReactionActionEvent, RawMessageDe except lib.exceptions.UnrecognisedCustomEmoji: return False - async def doMenu(self) -> None: """Overload that also handles reaction removal to allow for correct vote counting, and message deletion. This overload does not return anything, unlike the original method, which returns the emotes reacted with @@ -351,7 +354,10 @@ async def doMenu(self) -> None: await self.updateMessage() try: await lib.client.instance().multiWaitFor( - ["raw_reaction_add", "raw_reaction_remove", "raw_message_delete", "raw_bulk_message_delete"], + ["raw_reaction_add", + "raw_reaction_remove", + "raw_message_delete", + "raw_bulk_message_delete"], check=self.reactionClosesMenu, timeout=self.timeoutSeconds ) diff --git a/src/esportsbot/reactionMenus/reactionRoleMenu.py b/src/esportsbot/reactionMenus/reactionRoleMenu.py index f6a489f4..ec0902b1 100644 --- a/src/esportsbot/reactionMenus/reactionRoleMenu.py +++ b/src/esportsbot/reactionMenus/reactionRoleMenu.py @@ -4,8 +4,8 @@ .. codeauthor:: Trimatix """ -from . import reactionMenu -from .. import lib +from esportsbot.reactionMenus import reactionMenu +from esportsbot import lib from discord import Colour, Guild, Role, Message, User, Client, Member, PartialMessage from typing import List, Union, Dict diff --git a/src/requirements.txt b/src/requirements.txt index 113ff4bc..86ece21e 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,12 +1,12 @@ psycopg2-binary>=2.8 -discord.py +sqlalchemy +discord.py[voice] python-dotenv emoji -git+git://github.com/JustAnotherArchivist/snscrape.git#egg=snscrape lxml google-api-python-client youtube-dl youtube-search-python PyNaCl aiohttp[speedups] -toml +toml \ No newline at end of file