1
Fork 0
mirror of https://github.com/allthingslinux/tux.git synced 2024-10-02 16:43:12 +00:00

feat(poetry.toml): add poetry.toml file to create virtual environments within the project directory

refactor(dev.py): rearrange error handling methods for better readability and maintainability
refactor(avatar.py): extract common logic into send_avatar method to reduce code duplication
refactor(remindme.py): simplify database access by directly using the reminder table
feat(case.py, note.py, reminder.py): add ensure_guild_exists method to check and create guild if not exists
refactor(case.py, note.py, reminder.py): call ensure_guild_exists before creating a new entry to ensure guild exists

feat(snippet.py): add ensure_guild_exists method to check and create guild if not exists
refactor(snippet.py): modify create_snippet method to call ensure_guild_exists before creating a snippet, ensuring guild existence
This commit is contained in:
kzndotsh 2024-07-02 19:19:07 +00:00
parent fe57c02b99
commit 53a7493345
8 changed files with 107 additions and 53 deletions

2
poetry.toml Normal file
View file

@ -0,0 +1,2 @@
[virtualenvs]
in-project = true

View file

@ -50,6 +50,13 @@ class Dev(commands.Cog):
await self.bot.tree.sync(guild=ctx.guild) await self.bot.tree.sync(guild=ctx.guild)
await ctx.reply("Application command tree synced.") await ctx.reply("Application command tree synced.")
@sync_tree.error
async def sync_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
if isinstance(error, commands.MissingRequiredArgument):
await ctx.send(f"Please specify a guild to sync application commands to. {error}")
else:
logger.error(f"Error syncing application commands: {error}")
@commands.has_guild_permissions(administrator=True) @commands.has_guild_permissions(administrator=True)
@dev.command( @dev.command(
name="clear_tree", name="clear_tree",
@ -131,6 +138,22 @@ class Dev(commands.Cog):
await ctx.send(f"Cog {cog} loaded.") await ctx.send(f"Cog {cog} loaded.")
logger.info(f"Cog {cog} loaded.") logger.info(f"Cog {cog} loaded.")
@load_cog.error
async def load_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
if isinstance(error, commands.MissingRequiredArgument):
await ctx.send(f"Please specify an cog to load. {error}")
elif isinstance(error, commands.ExtensionAlreadyLoaded):
await ctx.send(f"The specified cog is already loaded. {error}")
elif isinstance(error, commands.ExtensionNotFound):
await ctx.send(f"The specified cog is not found. {error}")
elif isinstance(error, commands.ExtensionFailed):
await ctx.send(f"Failed to load cog: {error}")
elif isinstance(error, commands.NoEntryPointError):
await ctx.send(f"The specified cog does not have a setup function. {error}")
else:
await ctx.send(f"Failed to load cog: {error}")
logger.error(f"Failed to load cog: {error}")
@commands.has_guild_permissions(administrator=True) @commands.has_guild_permissions(administrator=True)
@dev.command( @dev.command(
name="unload_cog", name="unload_cog",
@ -169,6 +192,15 @@ class Dev(commands.Cog):
logger.info(f"Cog {cog} unloaded.") logger.info(f"Cog {cog} unloaded.")
await ctx.send(f"Cog {cog} unloaded.") await ctx.send(f"Cog {cog} unloaded.")
@unload_cog.error
async def unload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
if isinstance(error, commands.MissingRequiredArgument):
await ctx.send(f"Please specify an extension to unload. {error}")
elif isinstance(error, commands.ExtensionNotLoaded):
await ctx.send(f"That cog is not loaded. {error}")
else:
logger.error(f"Error unloading cog: {error}")
@commands.has_guild_permissions(administrator=True) @commands.has_guild_permissions(administrator=True)
@dev.command( @dev.command(
name="reload_cog", name="reload_cog",
@ -208,14 +240,6 @@ class Dev(commands.Cog):
await ctx.send(f"Cog {cog} reloaded.") await ctx.send(f"Cog {cog} reloaded.")
logger.info(f"Cog {cog} reloaded.") logger.info(f"Cog {cog} reloaded.")
@sync_tree.error
async def sync_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
if isinstance(error, commands.MissingRequiredArgument):
await ctx.send(f"Please specify a guild to sync application commands to. {error}")
else:
logger.error(f"Error syncing application commands: {error}")
@reload_cog.error @reload_cog.error
async def reload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None: async def reload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
if isinstance(error, commands.MissingRequiredArgument): if isinstance(error, commands.MissingRequiredArgument):
@ -226,31 +250,6 @@ class Dev(commands.Cog):
await ctx.send(f"Error reloading cog: {error}") await ctx.send(f"Error reloading cog: {error}")
logger.error(f"Error reloading cog: {error}") logger.error(f"Error reloading cog: {error}")
@unload_cog.error
async def unload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
if isinstance(error, commands.MissingRequiredArgument):
await ctx.send(f"Please specify an extension to unload. {error}")
elif isinstance(error, commands.ExtensionNotLoaded):
await ctx.send(f"That cog is not loaded. {error}")
else:
logger.error(f"Error unloading cog: {error}")
@load_cog.error
async def load_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
if isinstance(error, commands.MissingRequiredArgument):
await ctx.send(f"Please specify an cog to load. {error}")
elif isinstance(error, commands.ExtensionAlreadyLoaded):
await ctx.send(f"The specified cog is already loaded. {error}")
elif isinstance(error, commands.ExtensionNotFound):
await ctx.send(f"The specified cog is not found. {error}")
elif isinstance(error, commands.ExtensionFailed):
await ctx.send(f"Failed to load cog: {error}")
elif isinstance(error, commands.NoEntryPointError):
await ctx.send(f"The specified cog does not have a setup function. {error}")
else:
await ctx.send(f"Failed to load cog: {error}")
logger.error(f"Failed to load cog: {error}")
async def setup(bot: commands.Bot) -> None: async def setup(bot: commands.Bot) -> None:
await bot.add_cog(Dev(bot)) await bot.add_cog(Dev(bot))

View file

@ -28,15 +28,7 @@ class Avatar(commands.Cog):
member : discord.Member member : discord.Member
The member to get the avatar of. The member to get the avatar of.
""" """
guild_avatar = member.guild_avatar.url if member.guild_avatar else None await self.send_avatar(ctx, member)
profile_avatar = member.avatar.url if member.avatar else None
files = [await self.create_avatar_file(avatar) for avatar in [guild_avatar, profile_avatar] if avatar]
if files:
await ctx.reply(files=files)
else:
await ctx.reply("Member has no avatar.")
@app_commands.command(name="avatar", description="Get the avatar of a member.") @app_commands.command(name="avatar", description="Get the avatar of a member.")
@app_commands.describe(member="The member to get the avatar of.") @app_commands.describe(member="The member to get the avatar of.")
@ -51,15 +43,40 @@ class Avatar(commands.Cog):
member : discord.Member member : discord.Member
The member to get the avatar of. The member to get the avatar of.
""" """
await self.send_avatar(interaction, member)
async def send_avatar(
self,
source: commands.Context[commands.Bot] | discord.Interaction,
member: discord.Member,
) -> None:
"""
Send the avatar of a member.
Parameters
----------
source : commands.Context[commands.Bot] | discord.Interaction
The source object for sending the message.
member : discord.Member
The member to get the avatar of.
"""
guild_avatar = member.guild_avatar.url if member.guild_avatar else None guild_avatar = member.guild_avatar.url if member.guild_avatar else None
profile_avatar = member.avatar.url if member.avatar else None profile_avatar = member.avatar.url if member.avatar else None
files = [await self.create_avatar_file(avatar) for avatar in [guild_avatar, profile_avatar] if avatar] files = [await self.create_avatar_file(avatar) for avatar in [guild_avatar, profile_avatar] if avatar]
if files: if files:
await interaction.response.send_message(files=files) if isinstance(source, discord.Interaction):
await source.response.send_message(files=files)
else:
await source.reply(files=files)
else: else:
await interaction.response.send_message(content="Member has no avatar.") message = "Member has no avatar."
if isinstance(source, discord.Interaction):
await source.response.send_message(content=message)
else:
await source.reply(content=message)
@staticmethod @staticmethod
async def create_avatar_file(url: str) -> discord.File: async def create_avatar_file(url: str) -> discord.File:

View file

@ -34,7 +34,7 @@ def get_closest_reminder(reminders: list[Reminder]) -> Reminder | None:
class RemindMe(commands.Cog): class RemindMe(commands.Cog):
def __init__(self, bot: commands.Bot) -> None: def __init__(self, bot: commands.Bot) -> None:
self.bot = bot self.bot = bot
self.db_controller = DatabaseController() self.db = DatabaseController().reminder
self.bot.loop.create_task(self.update()) self.bot.loop.create_task(self.update())
async def send_reminders(self, reminder: Reminder) -> None: async def send_reminders(self, reminder: Reminder) -> None:
@ -88,7 +88,7 @@ class RemindMe(commands.Cog):
logger.error(f"Failed to send reminder to {reminder.reminder_user_id}, user not found.") logger.error(f"Failed to send reminder to {reminder.reminder_user_id}, user not found.")
# Delete the reminder after sending # Delete the reminder after sending
await self.db_controller.reminders.delete_reminder_by_id(reminder.reminder_id) await self.db.delete_reminder_by_id(reminder.reminder_id)
# wait for a second so that the reminder is deleted before checking for more reminders # wait for a second so that the reminder is deleted before checking for more reminders
# who knows if this works, it seems to # who knows if this works, it seems to
@ -120,7 +120,7 @@ class RemindMe(commands.Cog):
try: try:
# Get all reminders # Get all reminders
reminders = await self.db_controller.reminders.get_all_reminders() reminders = await self.db.get_all_reminders()
# Get the closest reminder # Get the closest reminder
closest_reminder = get_closest_reminder(reminders) closest_reminder = get_closest_reminder(reminders)
@ -170,7 +170,7 @@ class RemindMe(commands.Cog):
seconds = datetime.datetime.now(datetime.UTC) + datetime.timedelta(seconds=seconds) seconds = datetime.datetime.now(datetime.UTC) + datetime.timedelta(seconds=seconds)
try: try:
await self.db_controller.reminders.insert_reminder( await self.db.insert_reminder(
reminder_user_id=interaction.user.id, reminder_user_id=interaction.user.id,
reminder_content=reminder, reminder_content=reminder,
reminder_expires_at=seconds, reminder_expires_at=seconds,

View file

@ -1,13 +1,20 @@
from datetime import datetime from datetime import datetime
from prisma.enums import CaseType from prisma.enums import CaseType
from prisma.models import Case from prisma.models import Case, Guild
from tux.database.client import db from tux.database.client import db
class CaseController: class CaseController:
def __init__(self): def __init__(self):
self.table = db.case self.table = db.case
self.guild_table = db.guild
async def ensure_guild_exists(self, guild_id: int) -> Guild | None:
guild = await self.guild_table.find_first(where={"guild_id": guild_id})
if guild is None:
return await self.guild_table.create(data={"guild_id": guild_id})
return guild
async def get_all_cases(self) -> list[Case]: async def get_all_cases(self) -> list[Case]:
return await self.table.find_many() return await self.table.find_many()
@ -24,13 +31,15 @@ class CaseController:
case_reason: str, case_reason: str,
case_expires_at: datetime | None = None, case_expires_at: datetime | None = None,
) -> Case | None: ) -> Case | None:
await self.ensure_guild_exists(guild_id)
return await self.table.create( return await self.table.create(
data={ data={
"guild_id": guild_id, "guild_id": guild_id,
"case_target_id": case_target_id, "case_target_id": case_target_id,
"case_moderator_id": case_moderator_id, "case_moderator_id": case_moderator_id,
"case_reason": case_reason,
"case_type": case_type, "case_type": case_type,
"case_reason": case_reason,
"case_expires_at": case_expires_at, "case_expires_at": case_expires_at,
}, },
) )

View file

@ -1,10 +1,17 @@
from prisma.models import Note from prisma.models import Guild, Note
from tux.database.client import db from tux.database.client import db
class NoteController: class NoteController:
def __init__(self): def __init__(self):
self.table = db.note self.table = db.note
self.guild_table = db.guild
async def ensure_guild_exists(self, guild_id: int) -> Guild | None:
guild = await self.guild_table.find_first(where={"guild_id": guild_id})
if guild is None:
return await self.guild_table.create(data={"guild_id": guild_id})
return guild
async def get_all_notes(self) -> list[Note]: async def get_all_notes(self) -> list[Note]:
return await self.table.find_many() return await self.table.find_many()
@ -19,6 +26,8 @@ class NoteController:
note_content: str, note_content: str,
guild_id: int, guild_id: int,
) -> Note: ) -> Note:
await self.ensure_guild_exists(guild_id)
return await self.table.create( return await self.table.create(
data={ data={
"note_target_id": note_target_id, "note_target_id": note_target_id,

View file

@ -1,12 +1,19 @@
from datetime import datetime from datetime import datetime
from prisma.models import Reminder from prisma.models import Guild, Reminder
from tux.database.client import db from tux.database.client import db
class ReminderController: class ReminderController:
def __init__(self) -> None: def __init__(self) -> None:
self.table = db.reminder self.table = db.reminder
self.guild_table = db.guild
async def ensure_guild_exists(self, guild_id: int) -> Guild | None:
guild = await self.guild_table.find_first(where={"guild_id": guild_id})
if guild is None:
return await self.guild_table.create(data={"guild_id": guild_id})
return guild
async def get_all_reminders(self) -> list[Reminder]: async def get_all_reminders(self) -> list[Reminder]:
return await self.table.find_many() return await self.table.find_many()
@ -22,6 +29,8 @@ class ReminderController:
reminder_channel_id: int, reminder_channel_id: int,
guild_id: int, guild_id: int,
) -> Reminder: ) -> Reminder:
await self.ensure_guild_exists(guild_id)
return await self.table.create( return await self.table.create(
data={ data={
"reminder_user_id": reminder_user_id, "reminder_user_id": reminder_user_id,

View file

@ -1,12 +1,19 @@
import datetime import datetime
from prisma.models import Snippet from prisma.models import Guild, Snippet
from tux.database.client import db from tux.database.client import db
class SnippetController: class SnippetController:
def __init__(self) -> None: def __init__(self) -> None:
self.table = db.snippet self.table = db.snippet
self.guild_table = db.guild
async def ensure_guild_exists(self, guild_id: int) -> Guild | None:
guild = await self.guild_table.find_first(where={"guild_id": guild_id})
if guild is None:
return await self.guild_table.create(data={"guild_id": guild_id})
return guild
async def get_all_snippets(self) -> list[Snippet]: async def get_all_snippets(self) -> list[Snippet]:
return await self.table.find_many() return await self.table.find_many()
@ -36,6 +43,8 @@ class SnippetController:
snippet_user_id: int, snippet_user_id: int,
guild_id: int, guild_id: int,
) -> Snippet: ) -> Snippet:
await self.ensure_guild_exists(guild_id)
return await self.table.create( return await self.table.create(
data={ data={
"snippet_name": snippet_name, "snippet_name": snippet_name,