From 53a7493345367901ded3911b04f10c02532e5515 Mon Sep 17 00:00:00 2001 From: kzndotsh Date: Tue, 2 Jul 2024 19:19:07 +0000 Subject: [PATCH] feat(poetry.toml): add poetry.toml file to create virtual environments within the project directory refactor(dev.py): rearrange error handling methods for better readability and maintainability refactor(avatar.py): extract common logic into send_avatar method to reduce code duplication refactor(remindme.py): simplify database access by directly using the reminder table feat(case.py, note.py, reminder.py): add ensure_guild_exists method to check and create guild if not exists refactor(case.py, note.py, reminder.py): call ensure_guild_exists before creating a new entry to ensure guild exists feat(snippet.py): add ensure_guild_exists method to check and create guild if not exists refactor(snippet.py): modify create_snippet method to call ensure_guild_exists before creating a snippet, ensuring guild existence --- poetry.toml | 2 + tux/cogs/admin/dev.py | 65 ++++++++++++++-------------- tux/cogs/utility/avatar.py | 39 ++++++++++++----- tux/cogs/utility/remindme.py | 8 ++-- tux/database/controllers/case.py | 13 +++++- tux/database/controllers/note.py | 11 ++++- tux/database/controllers/reminder.py | 11 ++++- tux/database/controllers/snippet.py | 11 ++++- 8 files changed, 107 insertions(+), 53 deletions(-) create mode 100644 poetry.toml diff --git a/poetry.toml b/poetry.toml new file mode 100644 index 0000000..ab1033b --- /dev/null +++ b/poetry.toml @@ -0,0 +1,2 @@ +[virtualenvs] +in-project = true diff --git a/tux/cogs/admin/dev.py b/tux/cogs/admin/dev.py index 868a387..8c86d17 100644 --- a/tux/cogs/admin/dev.py +++ b/tux/cogs/admin/dev.py @@ -50,6 +50,13 @@ class Dev(commands.Cog): await self.bot.tree.sync(guild=ctx.guild) await ctx.reply("Application command tree synced.") + @sync_tree.error + async def sync_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None: + if isinstance(error, commands.MissingRequiredArgument): + await ctx.send(f"Please specify a guild to sync application commands to. {error}") + else: + logger.error(f"Error syncing application commands: {error}") + @commands.has_guild_permissions(administrator=True) @dev.command( name="clear_tree", @@ -131,6 +138,22 @@ class Dev(commands.Cog): await ctx.send(f"Cog {cog} loaded.") logger.info(f"Cog {cog} loaded.") + @load_cog.error + async def load_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None: + if isinstance(error, commands.MissingRequiredArgument): + await ctx.send(f"Please specify an cog to load. {error}") + elif isinstance(error, commands.ExtensionAlreadyLoaded): + await ctx.send(f"The specified cog is already loaded. {error}") + elif isinstance(error, commands.ExtensionNotFound): + await ctx.send(f"The specified cog is not found. {error}") + elif isinstance(error, commands.ExtensionFailed): + await ctx.send(f"Failed to load cog: {error}") + elif isinstance(error, commands.NoEntryPointError): + await ctx.send(f"The specified cog does not have a setup function. {error}") + else: + await ctx.send(f"Failed to load cog: {error}") + logger.error(f"Failed to load cog: {error}") + @commands.has_guild_permissions(administrator=True) @dev.command( name="unload_cog", @@ -169,6 +192,15 @@ class Dev(commands.Cog): logger.info(f"Cog {cog} unloaded.") await ctx.send(f"Cog {cog} unloaded.") + @unload_cog.error + async def unload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None: + if isinstance(error, commands.MissingRequiredArgument): + await ctx.send(f"Please specify an extension to unload. {error}") + elif isinstance(error, commands.ExtensionNotLoaded): + await ctx.send(f"That cog is not loaded. {error}") + else: + logger.error(f"Error unloading cog: {error}") + @commands.has_guild_permissions(administrator=True) @dev.command( name="reload_cog", @@ -208,14 +240,6 @@ class Dev(commands.Cog): await ctx.send(f"Cog {cog} reloaded.") logger.info(f"Cog {cog} reloaded.") - @sync_tree.error - async def sync_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None: - if isinstance(error, commands.MissingRequiredArgument): - await ctx.send(f"Please specify a guild to sync application commands to. {error}") - - else: - logger.error(f"Error syncing application commands: {error}") - @reload_cog.error async def reload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None: if isinstance(error, commands.MissingRequiredArgument): @@ -226,31 +250,6 @@ class Dev(commands.Cog): await ctx.send(f"Error reloading cog: {error}") logger.error(f"Error reloading cog: {error}") - @unload_cog.error - async def unload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None: - if isinstance(error, commands.MissingRequiredArgument): - await ctx.send(f"Please specify an extension to unload. {error}") - elif isinstance(error, commands.ExtensionNotLoaded): - await ctx.send(f"That cog is not loaded. {error}") - else: - logger.error(f"Error unloading cog: {error}") - - @load_cog.error - async def load_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None: - if isinstance(error, commands.MissingRequiredArgument): - await ctx.send(f"Please specify an cog to load. {error}") - elif isinstance(error, commands.ExtensionAlreadyLoaded): - await ctx.send(f"The specified cog is already loaded. {error}") - elif isinstance(error, commands.ExtensionNotFound): - await ctx.send(f"The specified cog is not found. {error}") - elif isinstance(error, commands.ExtensionFailed): - await ctx.send(f"Failed to load cog: {error}") - elif isinstance(error, commands.NoEntryPointError): - await ctx.send(f"The specified cog does not have a setup function. {error}") - else: - await ctx.send(f"Failed to load cog: {error}") - logger.error(f"Failed to load cog: {error}") - async def setup(bot: commands.Bot) -> None: await bot.add_cog(Dev(bot)) diff --git a/tux/cogs/utility/avatar.py b/tux/cogs/utility/avatar.py index 77fcdf8..6d70d17 100644 --- a/tux/cogs/utility/avatar.py +++ b/tux/cogs/utility/avatar.py @@ -28,15 +28,7 @@ class Avatar(commands.Cog): member : discord.Member The member to get the avatar of. """ - guild_avatar = member.guild_avatar.url if member.guild_avatar else None - profile_avatar = member.avatar.url if member.avatar else None - - files = [await self.create_avatar_file(avatar) for avatar in [guild_avatar, profile_avatar] if avatar] - - if files: - await ctx.reply(files=files) - else: - await ctx.reply("Member has no avatar.") + await self.send_avatar(ctx, member) @app_commands.command(name="avatar", description="Get the avatar of a member.") @app_commands.describe(member="The member to get the avatar of.") @@ -51,15 +43,40 @@ class Avatar(commands.Cog): member : discord.Member The member to get the avatar of. """ + await self.send_avatar(interaction, member) + + async def send_avatar( + self, + source: commands.Context[commands.Bot] | discord.Interaction, + member: discord.Member, + ) -> None: + """ + Send the avatar of a member. + + Parameters + ---------- + source : commands.Context[commands.Bot] | discord.Interaction + The source object for sending the message. + member : discord.Member + The member to get the avatar of. + """ + guild_avatar = member.guild_avatar.url if member.guild_avatar else None profile_avatar = member.avatar.url if member.avatar else None files = [await self.create_avatar_file(avatar) for avatar in [guild_avatar, profile_avatar] if avatar] if files: - await interaction.response.send_message(files=files) + if isinstance(source, discord.Interaction): + await source.response.send_message(files=files) + else: + await source.reply(files=files) else: - await interaction.response.send_message(content="Member has no avatar.") + message = "Member has no avatar." + if isinstance(source, discord.Interaction): + await source.response.send_message(content=message) + else: + await source.reply(content=message) @staticmethod async def create_avatar_file(url: str) -> discord.File: diff --git a/tux/cogs/utility/remindme.py b/tux/cogs/utility/remindme.py index e47daee..b0516a7 100644 --- a/tux/cogs/utility/remindme.py +++ b/tux/cogs/utility/remindme.py @@ -34,7 +34,7 @@ def get_closest_reminder(reminders: list[Reminder]) -> Reminder | None: class RemindMe(commands.Cog): def __init__(self, bot: commands.Bot) -> None: self.bot = bot - self.db_controller = DatabaseController() + self.db = DatabaseController().reminder self.bot.loop.create_task(self.update()) async def send_reminders(self, reminder: Reminder) -> None: @@ -88,7 +88,7 @@ class RemindMe(commands.Cog): logger.error(f"Failed to send reminder to {reminder.reminder_user_id}, user not found.") # Delete the reminder after sending - await self.db_controller.reminders.delete_reminder_by_id(reminder.reminder_id) + await self.db.delete_reminder_by_id(reminder.reminder_id) # wait for a second so that the reminder is deleted before checking for more reminders # who knows if this works, it seems to @@ -120,7 +120,7 @@ class RemindMe(commands.Cog): try: # Get all reminders - reminders = await self.db_controller.reminders.get_all_reminders() + reminders = await self.db.get_all_reminders() # Get the closest reminder closest_reminder = get_closest_reminder(reminders) @@ -170,7 +170,7 @@ class RemindMe(commands.Cog): seconds = datetime.datetime.now(datetime.UTC) + datetime.timedelta(seconds=seconds) try: - await self.db_controller.reminders.insert_reminder( + await self.db.insert_reminder( reminder_user_id=interaction.user.id, reminder_content=reminder, reminder_expires_at=seconds, diff --git a/tux/database/controllers/case.py b/tux/database/controllers/case.py index 9b16db8..7431f46 100644 --- a/tux/database/controllers/case.py +++ b/tux/database/controllers/case.py @@ -1,13 +1,20 @@ from datetime import datetime from prisma.enums import CaseType -from prisma.models import Case +from prisma.models import Case, Guild from tux.database.client import db class CaseController: def __init__(self): self.table = db.case + self.guild_table = db.guild + + async def ensure_guild_exists(self, guild_id: int) -> Guild | None: + guild = await self.guild_table.find_first(where={"guild_id": guild_id}) + if guild is None: + return await self.guild_table.create(data={"guild_id": guild_id}) + return guild async def get_all_cases(self) -> list[Case]: return await self.table.find_many() @@ -24,13 +31,15 @@ class CaseController: case_reason: str, case_expires_at: datetime | None = None, ) -> Case | None: + await self.ensure_guild_exists(guild_id) + return await self.table.create( data={ "guild_id": guild_id, "case_target_id": case_target_id, "case_moderator_id": case_moderator_id, - "case_reason": case_reason, "case_type": case_type, + "case_reason": case_reason, "case_expires_at": case_expires_at, }, ) diff --git a/tux/database/controllers/note.py b/tux/database/controllers/note.py index 72e3bfd..4c257af 100644 --- a/tux/database/controllers/note.py +++ b/tux/database/controllers/note.py @@ -1,10 +1,17 @@ -from prisma.models import Note +from prisma.models import Guild, Note from tux.database.client import db class NoteController: def __init__(self): self.table = db.note + self.guild_table = db.guild + + async def ensure_guild_exists(self, guild_id: int) -> Guild | None: + guild = await self.guild_table.find_first(where={"guild_id": guild_id}) + if guild is None: + return await self.guild_table.create(data={"guild_id": guild_id}) + return guild async def get_all_notes(self) -> list[Note]: return await self.table.find_many() @@ -19,6 +26,8 @@ class NoteController: note_content: str, guild_id: int, ) -> Note: + await self.ensure_guild_exists(guild_id) + return await self.table.create( data={ "note_target_id": note_target_id, diff --git a/tux/database/controllers/reminder.py b/tux/database/controllers/reminder.py index b1ce50b..d0c73fa 100644 --- a/tux/database/controllers/reminder.py +++ b/tux/database/controllers/reminder.py @@ -1,12 +1,19 @@ from datetime import datetime -from prisma.models import Reminder +from prisma.models import Guild, Reminder from tux.database.client import db class ReminderController: def __init__(self) -> None: self.table = db.reminder + self.guild_table = db.guild + + async def ensure_guild_exists(self, guild_id: int) -> Guild | None: + guild = await self.guild_table.find_first(where={"guild_id": guild_id}) + if guild is None: + return await self.guild_table.create(data={"guild_id": guild_id}) + return guild async def get_all_reminders(self) -> list[Reminder]: return await self.table.find_many() @@ -22,6 +29,8 @@ class ReminderController: reminder_channel_id: int, guild_id: int, ) -> Reminder: + await self.ensure_guild_exists(guild_id) + return await self.table.create( data={ "reminder_user_id": reminder_user_id, diff --git a/tux/database/controllers/snippet.py b/tux/database/controllers/snippet.py index 5df17bc..ba2a826 100644 --- a/tux/database/controllers/snippet.py +++ b/tux/database/controllers/snippet.py @@ -1,12 +1,19 @@ import datetime -from prisma.models import Snippet +from prisma.models import Guild, Snippet from tux.database.client import db class SnippetController: def __init__(self) -> None: self.table = db.snippet + self.guild_table = db.guild + + async def ensure_guild_exists(self, guild_id: int) -> Guild | None: + guild = await self.guild_table.find_first(where={"guild_id": guild_id}) + if guild is None: + return await self.guild_table.create(data={"guild_id": guild_id}) + return guild async def get_all_snippets(self) -> list[Snippet]: return await self.table.find_many() @@ -36,6 +43,8 @@ class SnippetController: snippet_user_id: int, guild_id: int, ) -> Snippet: + await self.ensure_guild_exists(guild_id) + return await self.table.create( data={ "snippet_name": snippet_name,