From bbd8fb62aa22ff2730897b8224bf94b8e6e41dfe Mon Sep 17 00:00:00 2001 From: petabyte-imo Date: Wed, 4 Sep 2024 17:48:13 +0100 Subject: [PATCH] Added a set_tempban_expired and get_expired_tempbans function and changed up tempban a bit --- prisma/schema.prisma | 20 ++------- tux/cogs/moderation/tempban.py | 70 ++++++++++++++++---------------- tux/database/controllers/case.py | 41 +++++++++++++++++++ 3 files changed, 81 insertions(+), 50 deletions(-) diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 7fd1a6d..0eece64 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -11,14 +11,13 @@ datasource db { } model Guild { - guild_id BigInt @id + guild_id BigInt @id @unique guild_joined_at DateTime? @default(now()) cases Case[] - snippets Snippet[] + guild_config GuildConfig? notes Note[] reminders Reminder[] - guild_config GuildConfig[] - AFK AFKModel[] + snippets Snippet[] Starboard Starboard? StarboardMessage StarboardMessage[] case_count BigInt @default(0) @@ -65,6 +64,7 @@ model Case { case_number BigInt? case_created_at DateTime? @default(now()) case_expires_at DateTime? + case_tempban_expired Boolean? @default(false) guild_id BigInt guild Guild @relation(fields: [guild_id], references: [guild_id]) @@ -115,18 +115,6 @@ model Reminder { @@index([reminder_id, guild_id]) } -model AFKModel { - member_id BigInt @id - nickname String - reason String - since DateTime @default(now()) - guild_id BigInt - guild Guild @relation(fields: [guild_id], references: [guild_id]) - - @@unique([member_id, guild_id]) - @@index([member_id]) -} - model Starboard { guild_id BigInt @id @unique starboard_channel_id BigInt diff --git a/tux/cogs/moderation/tempban.py b/tux/cogs/moderation/tempban.py index 1f788a8..b6009ca 100644 --- a/tux/cogs/moderation/tempban.py +++ b/tux/cogs/moderation/tempban.py @@ -54,7 +54,7 @@ class TempBan(ModerationCogBase): return moderator = ctx.author - duration = parse_time_string(f"{flags.expires_at}d") + duration = parse_time_string(f"{flags.expires_at}m") expires_at = datetime.now(UTC) + duration if not await self.check_conditions(ctx, member, moderator, "temporarily ban"): @@ -76,48 +76,50 @@ class TempBan(ModerationCogBase): case_reason=flags.reason, guild_id=ctx.guild.id, case_expires_at=expires_at, + case_tempban_expired=False ) await self.handle_case_response(ctx, CaseType.TEMPBAN, case.case_number, flags.reason, member) - @tasks.loop(minutes=30) + @tasks.loop(seconds=30) async def tempban_check(self) -> None: - # Fetch all guilds and fetch all tempbans for each guild's ID - guilds = await self.db.guild.get_all_guilds() - tempbans = [await self.db.case.get_all_cases_by_type(guild.guild_id, CaseType.TEMPBAN) for guild in guilds] - # Here, we have 3 nested for loops because for some odd reason, tempbans is a list of a list of lists, very confusing ikr + # Get all expired tempbans + expired_temp_bans = await self.db.case.get_expired_tempbans() + logger.debug(f"Checking {len(expired_temp_bans)} expired tempbans. {expired_temp_bans}") + for temp_ban in expired_temp_bans: + #Debug Print + logger.debug(f"Unbanning user with ID {temp_ban.case_user_id} | Case number {temp_ban.case_number}") + guild = self.bot.get_guild(temp_ban.guild_id) - for tempban in tempbans: - for cases in tempbans: - for case in cases: - # Check if the case has expired - if case.case_expires_at < datetime.now(UTC): - # Get the guild, if that doesnt work then fetch it instead + if guild is None: + #Debug Print + logger.debug(f"Fetching guild with ID {temp_ban.guild_id}") + try: + guild = await self.bot.fetch_guild(temp_ban.guild_id) - guild = self.bot.get_guild(case.guild_id) - if guild is None: - try: - guild = await self.bot.fetch_guild(case.guild_id) + except (discord.Forbidden, discord.HTTPException) as e: + logger.error( + f"Failed to unban user with ID {temp_ban.case_user_id} | Case number {temp_ban.case_number} | Issue: Failed to get guild with ID {temp_ban.guild_id}. {e}", + ) + return + else: + logger.debug(f"Found guild with ID {temp_ban.guild_id}") + try: + # Unban the user - except (discord.Forbidden, discord.HTTPException) as e: - logger.error( - f"Failed to unban user with ID {case.case_user_id} | Case number {case.case_number} | Issue: Failed to get guild with ID {case.guild_id}. {e}", - ) - return - else: - try: - # Unban the user + guild_bans = guild.bans() + async for ban_entry in guild_bans: + if ban_entry.user.id == temp_ban.case_user_id: + await guild.unban(ban_entry.user, reason="Tempban expired") + await self.db.case.set_tempban_expired(temp_ban.case_number, temp_ban.guild_id) + except (discord.Forbidden, discord.HTTPException) as e: + logger.error( + f"Faile+d to unban user with ID {temp_ban.case_user_id} | Case number {temp_ban.case_number} Issue: Failed to unban user. {e}", + ) + return + #Debug Print + logger.debug(f"Unbanned user with ID {temp_ban.case_user_id} | Case number {temp_ban.case_number}") - guild_bans = guild.bans() - async for ban_entry in guild_bans: - if ban_entry.user.id == case.case_user_id: - await guild.unban(ban_entry.user, reason="Tempban expired") - - except (discord.Forbidden, discord.HTTPException) as e: - logger.error( - f"Failed to unban user with ID {case.case_user_id} | Case number {case.case_number} Issue: Failed to unban user. {e}", - ) - return async def setup(bot: Tux) -> None: diff --git a/tux/database/controllers/case.py b/tux/database/controllers/case.py index b1bd373..46fe028 100644 --- a/tux/database/controllers/case.py +++ b/tux/database/controllers/case.py @@ -82,6 +82,7 @@ class CaseController: case_reason: str, case_user_roles: list[int] | None = None, case_expires_at: datetime | None = None, + case_tempban_expired: bool = False, ) -> Case: """ Insert a case into the database. @@ -102,6 +103,8 @@ class CaseController: The roles of the target of the case. case_expires_at : datetime | None The expiration date of the case. + case_tempban_expired : bool + Whether the tempban has expired (Use only for tempbans). Returns ------- @@ -121,6 +124,7 @@ class CaseController: "case_reason": case_reason, "case_expires_at": case_expires_at, "case_user_roles": case_user_roles if case_user_roles is not None else [], + "case_tempban_expired": case_tempban_expired, }, ) @@ -307,3 +311,40 @@ class CaseController: if case is not None: return await self.table.delete(where={"case_id": case.case_id}) return None + + async def get_expired_tempbans(self) -> list[Case]: + """ + Get all cases that have expired tempbans. + + Returns + ------- + list[Case] + A list of cases of the type in the guild. + """ + return await self.table.find_many( + where={ + "case_type": CaseType.TEMPBAN, + "case_expires_at": {"lt": datetime.now()}, + "case_tempban_expired": False + } + ) + async def set_tempban_expired(self, case_number: int, guild_id: int) -> Case | None: + """ + Set a tempban case as expired. + + Parameters + ---------- + case_number : int + The number of the case to delete. + guild_id : int + The ID of the guild to delete the case in. + + Returns + ------- + Case | None + The case if found and deleted, otherwise None. + """ + return await self.table.update( + where={"case_number": case_number, "guild_id": guild_id}, + data={"case_tempban_expired": True} + )