1
Fork 0
mirror of https://github.com/wlinator/luminara.git synced 2024-10-02 20:23:12 +00:00

Add mod module and warn command

This commit is contained in:
wlinator 2024-08-29 06:21:55 -04:00
parent 31c80b9b00
commit d68368a30e
7 changed files with 553 additions and 5 deletions

30
lib/actionable.py Normal file
View file

@ -0,0 +1,30 @@
import discord
from lib.const import CONST
from lib.exceptions import LumiException
async def async_actionable(
target: discord.Member,
invoker: discord.Member,
bot_user: discord.Member,
) -> None:
"""
Checks if the invoker and client have a higher role than the target user.
Args:
target: The member object of the target user.
invoker: The member object of the user who invoked the command.
bot_user: The discord.Bot.user object representing the bot itself.
Returns:
True if the client's highest role AND the invoker's highest role are higher than the target.
"""
if target == invoker:
raise LumiException(CONST.STRINGS["error_actionable_self"])
if target.top_role >= invoker.top_role and invoker != invoker.guild.owner:
raise LumiException(CONST.STRINGS["error_actionable_hierarchy_user"])
if target.top_role >= bot_user.top_role:
raise LumiException(CONST.STRINGS["error_actionable_hierarchy_bot"])

152
lib/case_handler.py Normal file
View file

@ -0,0 +1,152 @@
import discord
from discord.ext import commands
from loguru import logger
from lib.exceptions import LumiException
from services.case_service import CaseService
from services.modlog_service import ModLogService
from ui.cases import create_case_embed
case_service = CaseService()
modlog_service = ModLogService()
async def create_case(
ctx: commands.Context[commands.Bot],
target: discord.User,
action_type: str,
reason: str | None = None,
duration: int | None = None,
expires_at: str | None = None,
):
"""
Creates a new moderation case and logs it to the modlog channel if configured.
Args:
ctx: The context of the command invocation.
target (discord.User): The user who is the subject of the moderation action.
action_type (str): The type of moderation action (e.g., "ban", "kick", "warn").
reason (Optional[str]): The reason for the moderation action. Defaults to None.
duration (Optional[int]): The duration of the action in seconds, if applicable. Defaults to None.
expires_at (Optional[str]): The expiration date of the action, if applicable. Defaults to None.
Returns:
None
Raises:
Exception: If there's an error sending the case to the modlog channel.
This function performs the following steps:
1. Creates a new case in the database using the CaseService.
2. Logs the case creation using the logger.
3. If a modlog channel is configured, it sends an embed with the case details to that channel.
4. If the embed is successfully sent to the modlog channel, it updates the case with the message ID for later edits.
"""
if not ctx.guild:
raise LumiException
guild_id = ctx.guild.id
moderator_id = ctx.author.id
target_id = target.id
# Create the case
case_number: int = case_service.create_case(
guild_id=guild_id,
target_id=target_id,
moderator_id=moderator_id,
action_type=action_type,
reason=reason,
duration=duration,
expires_at=expires_at,
modlog_message_id=None,
)
logger.info(f"Created case {case_number} for {target.name} in guild {guild_id}")
if mod_log_channel_id := modlog_service.fetch_modlog_channel_id(guild_id):
try:
mod_log_channel = await commands.TextChannelConverter().convert(
ctx,
str(mod_log_channel_id),
)
embed: discord.Embed = create_case_embed(
ctx=ctx,
target=target,
case_number=case_number,
action_type=action_type,
reason=reason,
timestamp=None,
duration=duration,
)
message = await mod_log_channel.send(embed=embed)
# Update the case with the modlog_message_id
case_service.edit_case(
guild_id=guild_id,
case_number=case_number,
changes={"modlog_message_id": message.id},
)
except Exception as e:
logger.error(f"Failed to send case to modlog channel: {e}")
async def edit_case_modlog(
ctx: commands.Context[commands.Bot],
guild_id: int,
case_number: int,
new_reason: str,
) -> bool:
"""
Edits the reason for an existing case and updates the modlog message if it exists.
Args:
ctx: The context of the command invocation.
guild_id: The ID of the guild where the case exists.
case_number: The number of the case to edit.
new_reason: The new reason for the case.
Raises:
ValueError: If the case is not found.
Exception: If there's an error updating the modlog message.
"""
case = case_service.fetch_case_by_guild_and_number(guild_id, case_number)
if not case:
msg = f"Case {case_number} not found in guild {guild_id}"
raise ValueError(msg)
modlog_message_id = case.get("modlog_message_id")
if not modlog_message_id:
return False
mod_log_channel_id = modlog_service.fetch_modlog_channel_id(guild_id)
if not mod_log_channel_id:
return False
try:
mod_log_channel = await commands.TextChannelConverter().convert(
ctx,
str(mod_log_channel_id),
)
message = await mod_log_channel.fetch_message(modlog_message_id)
target = await commands.UserConverter().convert(ctx, str(case["target_id"]))
updated_embed: discord.Embed = create_case_embed(
ctx=ctx,
target=target,
case_number=case_number,
action_type=case["action_type"],
reason=new_reason,
timestamp=case["created_at"],
duration=case["duration"] or None,
)
await message.edit(embed=updated_embed)
logger.info(f"Updated case {case_number} in guild {guild_id}")
except Exception as e:
logger.error(f"Failed to update modlog message for case {case_number}: {e}")
return False
return True

View file

@ -114,12 +114,12 @@ def format_duration_to_seconds(duration: str) -> int:
if duration.isdigit(): if duration.isdigit():
return int(duration) return int(duration)
parsed_duration: int = parse(duration) # type: ignore try:
parsed_duration: int = parse(duration) # type: ignore
return max(0, parsed_duration)
if isinstance(parsed_duration, int): except Exception as e:
return parsed_duration raise exceptions.LumiException(CONST.STRINGS["error_invalid_duration"].format(duration)) from e
raise exceptions.LumiException(CONST.STRINGS["error_invalid_duration"].format(duration))
def format_seconds_to_duration_string(seconds: int) -> str: def format_seconds_to_duration_string(seconds: int) -> str:

View file

@ -0,0 +1,61 @@
import asyncio
from typing import cast
import discord
from discord.ext import commands
from lib.actionable import async_actionable
from lib.case_handler import create_case
from lib.const import CONST
from lib.exceptions import LumiException
from ui.embeds import Builder
class Warn(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
@commands.hybrid_command(name="warn", description="Warn a user")
@commands.has_permissions(manage_messages=True)
async def warn(self, ctx: commands.Context[commands.Bot], target: discord.Member, *, reason: str | None = None):
if not ctx.guild or not ctx.author or not ctx.bot.user:
raise LumiException
bot_member = await commands.MemberConverter().convert(ctx, str(ctx.bot.user))
await async_actionable(target, cast(discord.Member, ctx.author), bot_member)
output_reason = reason or CONST.STRINGS["mod_no_reason"]
dm_task = target.send(
embed=Builder.create_embed(
user_name=target.name,
author_text=CONST.STRINGS["mod_warned_author"],
description=CONST.STRINGS["mod_warn_dm"].format(
target.name,
ctx.guild.name,
output_reason,
),
hide_name_in_description=True,
),
)
respond_task = ctx.send(
embed=Builder.create_embed(
user_name=ctx.author.name,
author_text=CONST.STRINGS["mod_warned_author"],
description=CONST.STRINGS["mod_warned_user"].format(target.name),
),
)
create_case_task = create_case(ctx, cast(discord.User, target), "WARN", reason)
await asyncio.gather(
dm_task,
respond_task,
create_case_task,
return_exceptions=True,
)
async def setup(bot: commands.Bot) -> None:
await bot.add_cog(Warn(bot))

168
services/case_service.py Normal file
View file

@ -0,0 +1,168 @@
from typing import Any
from db.database import execute_query, select_query_dict, select_query_one
class CaseService:
def __init__(self) -> None:
pass
def create_case(
self,
guild_id: int,
target_id: int,
moderator_id: int,
action_type: str,
reason: str | None = None,
duration: int | None = None,
expires_at: str | None = None,
modlog_message_id: int | None = None,
) -> int:
# Get the next case number for the guild
query: str = """
SELECT IFNULL(MAX(case_number), 0) + 1
FROM cases
WHERE guild_id = %s
"""
case_number: int | None = select_query_one(query, (guild_id,))
if case_number is None:
msg: str = "Failed to retrieve the next case number."
raise ValueError(msg)
# Insert the new case
query: str = """
INSERT INTO cases (
guild_id, case_number, target_id, moderator_id, action_type, reason, duration, expires_at, modlog_message_id
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s
)
"""
execute_query(
query,
(
guild_id,
case_number,
target_id,
moderator_id,
action_type.upper(),
reason,
duration,
expires_at,
modlog_message_id,
),
)
return int(case_number)
def close_case(self, guild_id: int, case_number: int) -> None:
query: str = """
UPDATE cases
SET is_closed = TRUE, updated_at = CURRENT_TIMESTAMP
WHERE guild_id = %s AND case_number = %s
"""
execute_query(query, (guild_id, case_number))
def edit_case_reason(
self,
guild_id: int,
case_number: int,
new_reason: str | None = None,
) -> bool:
query: str = """
UPDATE cases
SET reason = COALESCE(%s, reason),
updated_at = CURRENT_TIMESTAMP
WHERE guild_id = %s AND case_number = %s
"""
execute_query(
query,
(
new_reason,
guild_id,
case_number,
),
)
return True
def edit_case(self, guild_id: int, case_number: int, changes: dict[str, Any]) -> None:
set_clause: str = ", ".join([f"{key} = %s" for key in changes])
query: str = f"""
UPDATE cases
SET {set_clause}, updated_at = CURRENT_TIMESTAMP
WHERE guild_id = %s AND case_number = %s
"""
execute_query(query, (*changes.values(), guild_id, case_number))
def _fetch_cases(self, query: str, params: tuple[Any, ...]) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = select_query_dict(query, params)
return results
def _fetch_single_case(self, query: str, params: tuple[Any, ...]) -> dict[str, Any] | None:
result = self._fetch_cases(query, params)
return result[0] if result else None
def fetch_case_by_id(self, case_id: int) -> dict[str, Any] | None:
query: str = """
SELECT * FROM cases
WHERE id = %s
LIMIT 1
"""
return self._fetch_single_case(query, (case_id,))
def fetch_case_by_guild_and_number(
self,
guild_id: int,
case_number: int,
) -> dict[str, Any] | None:
query: str = """
SELECT * FROM cases
WHERE guild_id = %s AND case_number = %s
ORDER BY case_number DESC
LIMIT 1
"""
return self._fetch_single_case(query, (guild_id, case_number))
def fetch_cases_by_guild(self, guild_id: int) -> list[dict[str, Any]]:
query: str = """
SELECT * FROM cases
WHERE guild_id = %s
ORDER BY case_number DESC
"""
return self._fetch_cases(query, (guild_id,))
def fetch_cases_by_target(
self,
guild_id: int,
target_id: int,
) -> list[dict[str, Any]]:
query: str = """
SELECT * FROM cases
WHERE guild_id = %s AND target_id = %s
ORDER BY case_number DESC
"""
return self._fetch_cases(query, (guild_id, target_id))
def fetch_cases_by_moderator(
self,
guild_id: int,
moderator_id: int,
) -> list[dict[str, Any]]:
query: str = """
SELECT * FROM cases
WHERE guild_id = %s AND moderator_id = %s
ORDER BY case_number DESC
"""
return self._fetch_cases(query, (guild_id, moderator_id))
def fetch_cases_by_action_type(
self,
guild_id: int,
action_type: str,
) -> list[dict[str, Any]]:
query: str = """
SELECT * FROM cases
WHERE guild_id = %s AND action_type = %s
ORDER BY case_number DESC
"""
return self._fetch_cases(query, (guild_id, action_type.upper()))

View file

@ -0,0 +1,30 @@
from db.database import execute_query, select_query_one
class ModLogService:
def __init__(self):
pass
def set_modlog_channel(self, guild_id: int, channel_id: int) -> None:
query: str = """
INSERT INTO mod_log (guild_id, channel_id, is_enabled)
VALUES (%s, %s, TRUE)
ON DUPLICATE KEY UPDATE channel_id = VALUES(channel_id), is_enabled = TRUE, updated_at = CURRENT_TIMESTAMP
"""
execute_query(query, (guild_id, channel_id))
def disable_modlog_channel(self, guild_id: int) -> None:
query: str = """
UPDATE mod_log
SET is_enabled = FALSE, updated_at = CURRENT_TIMESTAMP
WHERE guild_id = %s
"""
execute_query(query, (guild_id,))
def fetch_modlog_channel_id(self, guild_id: int) -> int | None:
query: str = """
SELECT channel_id FROM mod_log
WHERE guild_id = %s AND is_enabled = TRUE
"""
result = select_query_one(query, (guild_id,))
return result or None

107
ui/cases.py Normal file
View file

@ -0,0 +1,107 @@
import datetime
from typing import Any
import discord
from discord.ext import commands
from lib.const import CONST
from lib.format import format_case_number, format_seconds_to_duration_string
from ui.embeds import Builder
def create_case_embed(
ctx: commands.Context[commands.Bot],
target: discord.User,
case_number: int,
action_type: str,
reason: str | None,
timestamp: datetime.datetime | None = None,
duration: int | None = None,
) -> discord.Embed:
embed: discord.Embed = Builder.create_embed(
user_name=ctx.author.name,
author_text=CONST.STRINGS["case_new_case_author"],
thumbnail_url=target.display_avatar.url,
hide_name_in_description=True,
timestamp=timestamp,
)
embed.add_field(
name=CONST.STRINGS["case_case_field"],
value=CONST.STRINGS["case_case_field_value"].format(
format_case_number(case_number),
),
inline=True,
)
if not duration:
embed.add_field(
name=CONST.STRINGS["case_type_field"],
value=CONST.STRINGS["case_type_field_value"].format(
action_type.lower().capitalize(),
),
inline=True,
)
else:
embed.add_field(
name=CONST.STRINGS["case_type_field"],
value=CONST.STRINGS["case_type_field_value_with_duration"].format(
action_type.lower().capitalize(),
format_seconds_to_duration_string(duration),
),
inline=True,
)
embed.add_field(
name=CONST.STRINGS["case_moderator_field"],
value=CONST.STRINGS["case_moderator_field_value"].format(
ctx.author.name,
),
inline=True,
)
embed.add_field(
name=CONST.STRINGS["case_target_field"],
value=CONST.STRINGS["case_target_field_value"].format(target.name),
inline=False,
)
embed.add_field(
name=CONST.STRINGS["case_reason_field"],
value=CONST.STRINGS["case_reason_field_value"].format(
reason or CONST.STRINGS["mod_no_reason"],
),
inline=False,
)
return embed
def create_case_list_embed(
ctx: commands.Context[commands.Bot],
cases: list[dict[str, Any]],
author_text: str,
) -> discord.Embed:
embed: discord.Embed = Builder.create_embed(
user_name=ctx.author.name,
author_text=author_text,
hide_name_in_description=True,
)
for case in cases:
status_emoji = "" if case.get("is_closed") else ""
case_number = case.get("case_number", "N/A")
if isinstance(case_number, int):
case_number = format_case_number(case_number)
action_type = case.get("action_type", "Unknown")
timestamp = case.get("created_at", "Unknown")
if isinstance(timestamp, datetime.datetime):
formatted_timestamp = f"<t:{int(timestamp.timestamp())}:R>"
else:
formatted_timestamp = str(timestamp)
if embed.description is None:
embed.description = ""
embed.description += f"{status_emoji} `{case_number}` **[{action_type}]** {formatted_timestamp}\n"
return embed