mirror of
https://github.com/allthingslinux/tux.git
synced 2024-10-02 16:43:12 +00:00
feat(poetry.toml): add poetry.toml file to create virtual environments within the project directory
refactor(dev.py): rearrange error handling methods for better readability and maintainability refactor(avatar.py): extract common logic into send_avatar method to reduce code duplication refactor(remindme.py): simplify database access by directly using the reminder table feat(case.py, note.py, reminder.py): add ensure_guild_exists method to check and create guild if not exists refactor(case.py, note.py, reminder.py): call ensure_guild_exists before creating a new entry to ensure guild exists feat(snippet.py): add ensure_guild_exists method to check and create guild if not exists refactor(snippet.py): modify create_snippet method to call ensure_guild_exists before creating a snippet, ensuring guild existence
This commit is contained in:
parent
fe57c02b99
commit
53a7493345
8 changed files with 107 additions and 53 deletions
2
poetry.toml
Normal file
2
poetry.toml
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
[virtualenvs]
|
||||||
|
in-project = true
|
|
@ -50,6 +50,13 @@ class Dev(commands.Cog):
|
||||||
await self.bot.tree.sync(guild=ctx.guild)
|
await self.bot.tree.sync(guild=ctx.guild)
|
||||||
await ctx.reply("Application command tree synced.")
|
await ctx.reply("Application command tree synced.")
|
||||||
|
|
||||||
|
@sync_tree.error
|
||||||
|
async def sync_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
|
||||||
|
if isinstance(error, commands.MissingRequiredArgument):
|
||||||
|
await ctx.send(f"Please specify a guild to sync application commands to. {error}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Error syncing application commands: {error}")
|
||||||
|
|
||||||
@commands.has_guild_permissions(administrator=True)
|
@commands.has_guild_permissions(administrator=True)
|
||||||
@dev.command(
|
@dev.command(
|
||||||
name="clear_tree",
|
name="clear_tree",
|
||||||
|
@ -131,6 +138,22 @@ class Dev(commands.Cog):
|
||||||
await ctx.send(f"Cog {cog} loaded.")
|
await ctx.send(f"Cog {cog} loaded.")
|
||||||
logger.info(f"Cog {cog} loaded.")
|
logger.info(f"Cog {cog} loaded.")
|
||||||
|
|
||||||
|
@load_cog.error
|
||||||
|
async def load_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
|
||||||
|
if isinstance(error, commands.MissingRequiredArgument):
|
||||||
|
await ctx.send(f"Please specify an cog to load. {error}")
|
||||||
|
elif isinstance(error, commands.ExtensionAlreadyLoaded):
|
||||||
|
await ctx.send(f"The specified cog is already loaded. {error}")
|
||||||
|
elif isinstance(error, commands.ExtensionNotFound):
|
||||||
|
await ctx.send(f"The specified cog is not found. {error}")
|
||||||
|
elif isinstance(error, commands.ExtensionFailed):
|
||||||
|
await ctx.send(f"Failed to load cog: {error}")
|
||||||
|
elif isinstance(error, commands.NoEntryPointError):
|
||||||
|
await ctx.send(f"The specified cog does not have a setup function. {error}")
|
||||||
|
else:
|
||||||
|
await ctx.send(f"Failed to load cog: {error}")
|
||||||
|
logger.error(f"Failed to load cog: {error}")
|
||||||
|
|
||||||
@commands.has_guild_permissions(administrator=True)
|
@commands.has_guild_permissions(administrator=True)
|
||||||
@dev.command(
|
@dev.command(
|
||||||
name="unload_cog",
|
name="unload_cog",
|
||||||
|
@ -169,6 +192,15 @@ class Dev(commands.Cog):
|
||||||
logger.info(f"Cog {cog} unloaded.")
|
logger.info(f"Cog {cog} unloaded.")
|
||||||
await ctx.send(f"Cog {cog} unloaded.")
|
await ctx.send(f"Cog {cog} unloaded.")
|
||||||
|
|
||||||
|
@unload_cog.error
|
||||||
|
async def unload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
|
||||||
|
if isinstance(error, commands.MissingRequiredArgument):
|
||||||
|
await ctx.send(f"Please specify an extension to unload. {error}")
|
||||||
|
elif isinstance(error, commands.ExtensionNotLoaded):
|
||||||
|
await ctx.send(f"That cog is not loaded. {error}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Error unloading cog: {error}")
|
||||||
|
|
||||||
@commands.has_guild_permissions(administrator=True)
|
@commands.has_guild_permissions(administrator=True)
|
||||||
@dev.command(
|
@dev.command(
|
||||||
name="reload_cog",
|
name="reload_cog",
|
||||||
|
@ -208,14 +240,6 @@ class Dev(commands.Cog):
|
||||||
await ctx.send(f"Cog {cog} reloaded.")
|
await ctx.send(f"Cog {cog} reloaded.")
|
||||||
logger.info(f"Cog {cog} reloaded.")
|
logger.info(f"Cog {cog} reloaded.")
|
||||||
|
|
||||||
@sync_tree.error
|
|
||||||
async def sync_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
|
|
||||||
if isinstance(error, commands.MissingRequiredArgument):
|
|
||||||
await ctx.send(f"Please specify a guild to sync application commands to. {error}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.error(f"Error syncing application commands: {error}")
|
|
||||||
|
|
||||||
@reload_cog.error
|
@reload_cog.error
|
||||||
async def reload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
|
async def reload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
|
||||||
if isinstance(error, commands.MissingRequiredArgument):
|
if isinstance(error, commands.MissingRequiredArgument):
|
||||||
|
@ -226,31 +250,6 @@ class Dev(commands.Cog):
|
||||||
await ctx.send(f"Error reloading cog: {error}")
|
await ctx.send(f"Error reloading cog: {error}")
|
||||||
logger.error(f"Error reloading cog: {error}")
|
logger.error(f"Error reloading cog: {error}")
|
||||||
|
|
||||||
@unload_cog.error
|
|
||||||
async def unload_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
|
|
||||||
if isinstance(error, commands.MissingRequiredArgument):
|
|
||||||
await ctx.send(f"Please specify an extension to unload. {error}")
|
|
||||||
elif isinstance(error, commands.ExtensionNotLoaded):
|
|
||||||
await ctx.send(f"That cog is not loaded. {error}")
|
|
||||||
else:
|
|
||||||
logger.error(f"Error unloading cog: {error}")
|
|
||||||
|
|
||||||
@load_cog.error
|
|
||||||
async def load_error(self, ctx: commands.Context[commands.Bot], error: Exception) -> None:
|
|
||||||
if isinstance(error, commands.MissingRequiredArgument):
|
|
||||||
await ctx.send(f"Please specify an cog to load. {error}")
|
|
||||||
elif isinstance(error, commands.ExtensionAlreadyLoaded):
|
|
||||||
await ctx.send(f"The specified cog is already loaded. {error}")
|
|
||||||
elif isinstance(error, commands.ExtensionNotFound):
|
|
||||||
await ctx.send(f"The specified cog is not found. {error}")
|
|
||||||
elif isinstance(error, commands.ExtensionFailed):
|
|
||||||
await ctx.send(f"Failed to load cog: {error}")
|
|
||||||
elif isinstance(error, commands.NoEntryPointError):
|
|
||||||
await ctx.send(f"The specified cog does not have a setup function. {error}")
|
|
||||||
else:
|
|
||||||
await ctx.send(f"Failed to load cog: {error}")
|
|
||||||
logger.error(f"Failed to load cog: {error}")
|
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot: commands.Bot) -> None:
|
async def setup(bot: commands.Bot) -> None:
|
||||||
await bot.add_cog(Dev(bot))
|
await bot.add_cog(Dev(bot))
|
||||||
|
|
|
@ -28,15 +28,7 @@ class Avatar(commands.Cog):
|
||||||
member : discord.Member
|
member : discord.Member
|
||||||
The member to get the avatar of.
|
The member to get the avatar of.
|
||||||
"""
|
"""
|
||||||
guild_avatar = member.guild_avatar.url if member.guild_avatar else None
|
await self.send_avatar(ctx, member)
|
||||||
profile_avatar = member.avatar.url if member.avatar else None
|
|
||||||
|
|
||||||
files = [await self.create_avatar_file(avatar) for avatar in [guild_avatar, profile_avatar] if avatar]
|
|
||||||
|
|
||||||
if files:
|
|
||||||
await ctx.reply(files=files)
|
|
||||||
else:
|
|
||||||
await ctx.reply("Member has no avatar.")
|
|
||||||
|
|
||||||
@app_commands.command(name="avatar", description="Get the avatar of a member.")
|
@app_commands.command(name="avatar", description="Get the avatar of a member.")
|
||||||
@app_commands.describe(member="The member to get the avatar of.")
|
@app_commands.describe(member="The member to get the avatar of.")
|
||||||
|
@ -51,15 +43,40 @@ class Avatar(commands.Cog):
|
||||||
member : discord.Member
|
member : discord.Member
|
||||||
The member to get the avatar of.
|
The member to get the avatar of.
|
||||||
"""
|
"""
|
||||||
|
await self.send_avatar(interaction, member)
|
||||||
|
|
||||||
|
async def send_avatar(
|
||||||
|
self,
|
||||||
|
source: commands.Context[commands.Bot] | discord.Interaction,
|
||||||
|
member: discord.Member,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Send the avatar of a member.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
source : commands.Context[commands.Bot] | discord.Interaction
|
||||||
|
The source object for sending the message.
|
||||||
|
member : discord.Member
|
||||||
|
The member to get the avatar of.
|
||||||
|
"""
|
||||||
|
|
||||||
guild_avatar = member.guild_avatar.url if member.guild_avatar else None
|
guild_avatar = member.guild_avatar.url if member.guild_avatar else None
|
||||||
profile_avatar = member.avatar.url if member.avatar else None
|
profile_avatar = member.avatar.url if member.avatar else None
|
||||||
|
|
||||||
files = [await self.create_avatar_file(avatar) for avatar in [guild_avatar, profile_avatar] if avatar]
|
files = [await self.create_avatar_file(avatar) for avatar in [guild_avatar, profile_avatar] if avatar]
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
await interaction.response.send_message(files=files)
|
if isinstance(source, discord.Interaction):
|
||||||
|
await source.response.send_message(files=files)
|
||||||
|
else:
|
||||||
|
await source.reply(files=files)
|
||||||
else:
|
else:
|
||||||
await interaction.response.send_message(content="Member has no avatar.")
|
message = "Member has no avatar."
|
||||||
|
if isinstance(source, discord.Interaction):
|
||||||
|
await source.response.send_message(content=message)
|
||||||
|
else:
|
||||||
|
await source.reply(content=message)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_avatar_file(url: str) -> discord.File:
|
async def create_avatar_file(url: str) -> discord.File:
|
||||||
|
|
|
@ -34,7 +34,7 @@ def get_closest_reminder(reminders: list[Reminder]) -> Reminder | None:
|
||||||
class RemindMe(commands.Cog):
|
class RemindMe(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot) -> None:
|
def __init__(self, bot: commands.Bot) -> None:
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.db_controller = DatabaseController()
|
self.db = DatabaseController().reminder
|
||||||
self.bot.loop.create_task(self.update())
|
self.bot.loop.create_task(self.update())
|
||||||
|
|
||||||
async def send_reminders(self, reminder: Reminder) -> None:
|
async def send_reminders(self, reminder: Reminder) -> None:
|
||||||
|
@ -88,7 +88,7 @@ class RemindMe(commands.Cog):
|
||||||
logger.error(f"Failed to send reminder to {reminder.reminder_user_id}, user not found.")
|
logger.error(f"Failed to send reminder to {reminder.reminder_user_id}, user not found.")
|
||||||
|
|
||||||
# Delete the reminder after sending
|
# Delete the reminder after sending
|
||||||
await self.db_controller.reminders.delete_reminder_by_id(reminder.reminder_id)
|
await self.db.delete_reminder_by_id(reminder.reminder_id)
|
||||||
|
|
||||||
# wait for a second so that the reminder is deleted before checking for more reminders
|
# wait for a second so that the reminder is deleted before checking for more reminders
|
||||||
# who knows if this works, it seems to
|
# who knows if this works, it seems to
|
||||||
|
@ -120,7 +120,7 @@ class RemindMe(commands.Cog):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get all reminders
|
# Get all reminders
|
||||||
reminders = await self.db_controller.reminders.get_all_reminders()
|
reminders = await self.db.get_all_reminders()
|
||||||
# Get the closest reminder
|
# Get the closest reminder
|
||||||
closest_reminder = get_closest_reminder(reminders)
|
closest_reminder = get_closest_reminder(reminders)
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ class RemindMe(commands.Cog):
|
||||||
seconds = datetime.datetime.now(datetime.UTC) + datetime.timedelta(seconds=seconds)
|
seconds = datetime.datetime.now(datetime.UTC) + datetime.timedelta(seconds=seconds)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.db_controller.reminders.insert_reminder(
|
await self.db.insert_reminder(
|
||||||
reminder_user_id=interaction.user.id,
|
reminder_user_id=interaction.user.id,
|
||||||
reminder_content=reminder,
|
reminder_content=reminder,
|
||||||
reminder_expires_at=seconds,
|
reminder_expires_at=seconds,
|
||||||
|
|
|
@ -1,13 +1,20 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from prisma.enums import CaseType
|
from prisma.enums import CaseType
|
||||||
from prisma.models import Case
|
from prisma.models import Case, Guild
|
||||||
from tux.database.client import db
|
from tux.database.client import db
|
||||||
|
|
||||||
|
|
||||||
class CaseController:
|
class CaseController:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.table = db.case
|
self.table = db.case
|
||||||
|
self.guild_table = db.guild
|
||||||
|
|
||||||
|
async def ensure_guild_exists(self, guild_id: int) -> Guild | None:
|
||||||
|
guild = await self.guild_table.find_first(where={"guild_id": guild_id})
|
||||||
|
if guild is None:
|
||||||
|
return await self.guild_table.create(data={"guild_id": guild_id})
|
||||||
|
return guild
|
||||||
|
|
||||||
async def get_all_cases(self) -> list[Case]:
|
async def get_all_cases(self) -> list[Case]:
|
||||||
return await self.table.find_many()
|
return await self.table.find_many()
|
||||||
|
@ -24,13 +31,15 @@ class CaseController:
|
||||||
case_reason: str,
|
case_reason: str,
|
||||||
case_expires_at: datetime | None = None,
|
case_expires_at: datetime | None = None,
|
||||||
) -> Case | None:
|
) -> Case | None:
|
||||||
|
await self.ensure_guild_exists(guild_id)
|
||||||
|
|
||||||
return await self.table.create(
|
return await self.table.create(
|
||||||
data={
|
data={
|
||||||
"guild_id": guild_id,
|
"guild_id": guild_id,
|
||||||
"case_target_id": case_target_id,
|
"case_target_id": case_target_id,
|
||||||
"case_moderator_id": case_moderator_id,
|
"case_moderator_id": case_moderator_id,
|
||||||
"case_reason": case_reason,
|
|
||||||
"case_type": case_type,
|
"case_type": case_type,
|
||||||
|
"case_reason": case_reason,
|
||||||
"case_expires_at": case_expires_at,
|
"case_expires_at": case_expires_at,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,10 +1,17 @@
|
||||||
from prisma.models import Note
|
from prisma.models import Guild, Note
|
||||||
from tux.database.client import db
|
from tux.database.client import db
|
||||||
|
|
||||||
|
|
||||||
class NoteController:
|
class NoteController:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.table = db.note
|
self.table = db.note
|
||||||
|
self.guild_table = db.guild
|
||||||
|
|
||||||
|
async def ensure_guild_exists(self, guild_id: int) -> Guild | None:
|
||||||
|
guild = await self.guild_table.find_first(where={"guild_id": guild_id})
|
||||||
|
if guild is None:
|
||||||
|
return await self.guild_table.create(data={"guild_id": guild_id})
|
||||||
|
return guild
|
||||||
|
|
||||||
async def get_all_notes(self) -> list[Note]:
|
async def get_all_notes(self) -> list[Note]:
|
||||||
return await self.table.find_many()
|
return await self.table.find_many()
|
||||||
|
@ -19,6 +26,8 @@ class NoteController:
|
||||||
note_content: str,
|
note_content: str,
|
||||||
guild_id: int,
|
guild_id: int,
|
||||||
) -> Note:
|
) -> Note:
|
||||||
|
await self.ensure_guild_exists(guild_id)
|
||||||
|
|
||||||
return await self.table.create(
|
return await self.table.create(
|
||||||
data={
|
data={
|
||||||
"note_target_id": note_target_id,
|
"note_target_id": note_target_id,
|
||||||
|
|
|
@ -1,12 +1,19 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from prisma.models import Reminder
|
from prisma.models import Guild, Reminder
|
||||||
from tux.database.client import db
|
from tux.database.client import db
|
||||||
|
|
||||||
|
|
||||||
class ReminderController:
|
class ReminderController:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.table = db.reminder
|
self.table = db.reminder
|
||||||
|
self.guild_table = db.guild
|
||||||
|
|
||||||
|
async def ensure_guild_exists(self, guild_id: int) -> Guild | None:
|
||||||
|
guild = await self.guild_table.find_first(where={"guild_id": guild_id})
|
||||||
|
if guild is None:
|
||||||
|
return await self.guild_table.create(data={"guild_id": guild_id})
|
||||||
|
return guild
|
||||||
|
|
||||||
async def get_all_reminders(self) -> list[Reminder]:
|
async def get_all_reminders(self) -> list[Reminder]:
|
||||||
return await self.table.find_many()
|
return await self.table.find_many()
|
||||||
|
@ -22,6 +29,8 @@ class ReminderController:
|
||||||
reminder_channel_id: int,
|
reminder_channel_id: int,
|
||||||
guild_id: int,
|
guild_id: int,
|
||||||
) -> Reminder:
|
) -> Reminder:
|
||||||
|
await self.ensure_guild_exists(guild_id)
|
||||||
|
|
||||||
return await self.table.create(
|
return await self.table.create(
|
||||||
data={
|
data={
|
||||||
"reminder_user_id": reminder_user_id,
|
"reminder_user_id": reminder_user_id,
|
||||||
|
|
|
@ -1,12 +1,19 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from prisma.models import Snippet
|
from prisma.models import Guild, Snippet
|
||||||
from tux.database.client import db
|
from tux.database.client import db
|
||||||
|
|
||||||
|
|
||||||
class SnippetController:
|
class SnippetController:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.table = db.snippet
|
self.table = db.snippet
|
||||||
|
self.guild_table = db.guild
|
||||||
|
|
||||||
|
async def ensure_guild_exists(self, guild_id: int) -> Guild | None:
|
||||||
|
guild = await self.guild_table.find_first(where={"guild_id": guild_id})
|
||||||
|
if guild is None:
|
||||||
|
return await self.guild_table.create(data={"guild_id": guild_id})
|
||||||
|
return guild
|
||||||
|
|
||||||
async def get_all_snippets(self) -> list[Snippet]:
|
async def get_all_snippets(self) -> list[Snippet]:
|
||||||
return await self.table.find_many()
|
return await self.table.find_many()
|
||||||
|
@ -36,6 +43,8 @@ class SnippetController:
|
||||||
snippet_user_id: int,
|
snippet_user_id: int,
|
||||||
guild_id: int,
|
guild_id: int,
|
||||||
) -> Snippet:
|
) -> Snippet:
|
||||||
|
await self.ensure_guild_exists(guild_id)
|
||||||
|
|
||||||
return await self.table.create(
|
return await self.table.create(
|
||||||
data={
|
data={
|
||||||
"snippet_name": snippet_name,
|
"snippet_name": snippet_name,
|
||||||
|
|
Loading…
Reference in a new issue