From bf05ae118cfa91ac504bced45edc322ac51b2366 Mon Sep 17 00:00:00 2001 From: wlinator Date: Thu, 29 Aug 2024 05:19:09 -0400 Subject: [PATCH] Add format.py and format slowmode durations --- lib/format.py | 139 +++++++++++++++++++++++++++++++++ main.py | 6 +- modules/moderation/slowmode.py | 10 ++- poetry.lock | 13 ++- pyproject.toml | 1 + services/config_service.py | 135 ++++++++++++++++++++++++++++++++ 6 files changed, 297 insertions(+), 7 deletions(-) create mode 100644 lib/format.py create mode 100644 services/config_service.py diff --git a/lib/format.py b/lib/format.py new file mode 100644 index 0000000..c6b8d7e --- /dev/null +++ b/lib/format.py @@ -0,0 +1,139 @@ +import textwrap + +import discord +from discord.ext import commands +from pytimeparse import parse # type: ignore + +from lib import exceptions +from lib.const import CONST +from services.config_service import GuildConfig + + +def template(text: str, username: str, level: int | None = None) -> str: + """ + Replaces placeholders in the given text with actual values. + + Args: + text (str): The template text containing placeholders. + username (str): The username to replace the "{user}" placeholder. + level (int | None, optional): The level to replace the "{level}" placeholder. Defaults to None. + + Returns: + str: The formatted text with placeholders replaced by actual values. + """ + replacements: dict[str, str] = { + "{user}": username, + "{level}": str(level) if level else "", + } + + for placeholder, value in replacements.items(): + text = text.replace(placeholder, value) + + return text + + +def shorten(text: str, width: int = 200) -> str: + """ + Shortens the input text to the specified width by adding a placeholder at the end if the text exceeds the width. + + Args: + text (str): The text to be shortened. + width (int): The maximum width of the shortened text (default is 200). + + Returns: + str: The shortened text. + + Examples: + shortened_text = shorten("Lorem ipsum dolor sit amet", 10) + """ + return textwrap.shorten(text, width=width, placeholder="...") + + +def format_case_number(case_number: int) -> str: + """ + Formats a case number as a string with leading zeros if necessary. + + Args: + case_number (int): The case number to format. + + Returns: + str: The formatted case number as a string. + If the case number is less than 1000, it will be padded with leading zeros to three digits. + If the case number is 1000 or greater, it will be returned as a regular string. + + Examples: + >>> format_case_number(1) + '001' + >>> format_case_number(42) + '042' + >>> format_case_number(999) + '999' + >>> format_case_number(1000) + '1000' + """ + return f"{case_number:03d}" if case_number < 1000 else str(case_number) + + +def get_prefix(ctx: commands.Context[commands.Bot]) -> str: + """ + Attempts to retrieve the prefix for the given guild context. + + Args: + ctx (discord.ext.commands.Context): The context of the command invocation. + + Returns: + str: The prefix for the guild. Defaults to "." if the guild or prefix is not found. + """ + try: + return GuildConfig.get_prefix(ctx.guild.id if ctx.guild else 0) + except (AttributeError, TypeError): + return "." + + +def get_invoked_name(ctx: commands.Context[commands.Bot]) -> str | None: + """ + Attempts to get the alias of the command used. If the user used a SlashCommand, return the command name. + + Args: + ctx (discord.ext.commands.Context): The context of the command invocation. + + Returns: + str: The alias or name of the invoked command. + """ + try: + return ctx.invoked_with + + except (discord.app_commands.CommandInvokeError, AttributeError): + return ctx.command.name if ctx.command else None + + +def format_duration_to_seconds(duration: str) -> int: + """ + Formats a duration in seconds to a human-readable string. + """ + parsed_duration: int = parse(duration) # type: ignore + + if isinstance(parsed_duration, int): + return parsed_duration + + raise exceptions.LumiException(CONST.STRINGS["error_invalid_duration"].format(duration)) + + +def format_seconds_to_duration_string(seconds: int) -> str: + """ + Formats a duration in seconds to a human-readable string. + Returns seconds if shorter than a minute. + """ + if seconds < 60: + return f"{seconds}s" + + days = seconds // 86400 + hours = (seconds % 86400) // 3600 + minutes = (seconds % 3600) // 60 + + if days > 0: + return f"{days}d{hours}h" if hours > 0 else f"{days}d" + if hours > 0: + return f"{hours}h{minutes}m" if minutes > 0 else f"{hours}h" + + return f"{minutes}m" diff --git a/main.py b/main.py index ce8bc98..ed4d50f 100644 --- a/main.py +++ b/main.py @@ -7,13 +7,15 @@ from loguru import logger from lib.client import Luminara from lib.const import CONST +from services.config_service import GuildConfig logger.remove() logger.add(sys.stdout, format=CONST.LOG_FORMAT, colorize=True, level=CONST.LOG_LEVEL) -async def get_prefix(bot, message): - return commands.when_mentioned_or(".")(bot, message) +async def get_prefix(bot: Luminara, message: discord.Message) -> list[str]: + extras = GuildConfig.get_prefix(message) + return commands.when_mentioned_or(*extras)(bot, message) async def main() -> None: diff --git a/modules/moderation/slowmode.py b/modules/moderation/slowmode.py index e213e30..a257b82 100644 --- a/modules/moderation/slowmode.py +++ b/modules/moderation/slowmode.py @@ -4,6 +4,8 @@ import discord from discord.ext import commands from lib.const import CONST +from lib.exceptions import LumiException +from lib.format import format_duration_to_seconds from ui.embeds import Builder @@ -32,11 +34,11 @@ class Slowmode(commands.Cog): try: channel = await commands.TextChannelConverter().convert(ctx, arg) except commands.BadArgument: - with contextlib.suppress(ValueError): - duration = int(arg) + with contextlib.suppress(LumiException): + duration = format_duration_to_seconds(arg) else: - with contextlib.suppress(ValueError): - duration = int(arg) + with contextlib.suppress(LumiException): + duration = format_duration_to_seconds(arg) if not channel: await ctx.send(CONST.STRINGS["slowmode_channel_not_found"]) diff --git a/poetry.lock b/poetry.lock index c61b3d2..40b5a73 100644 --- a/poetry.lock +++ b/poetry.lock @@ -972,6 +972,17 @@ nodeenv = ">=1.6.0" all = ["twine (>=3.4.1)"] dev = ["twine (>=3.4.1)"] +[[package]] +name = "pytimeparse" +version = "1.1.8" +description = "Time expression parser" +optional = false +python-versions = "*" +files = [ + {file = "pytimeparse-1.1.8-py2.py3-none-any.whl", hash = "sha256:04b7be6cc8bd9f5647a6325444926c3ac34ee6bc7e69da4367ba282f076036bd"}, + {file = "pytimeparse-1.1.8.tar.gz", hash = "sha256:e86136477be924d7e670646a98561957e8ca7308d44841e21f5ddea757556a0a"}, +] + [[package]] name = "pytz" version = "2024.1" @@ -1299,4 +1310,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "e6ab702cf6efc2ec25a9c033029869f9c4c3631e2e6063c4241153ea8c1f8e79" +content-hash = "4a7a75036f4de7e0126a8f6b058eb3deb52f710becf44e4da6bac4a0ea0a1a2f" diff --git a/pyproject.toml b/pyproject.toml index 810a6b9..eab4b62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ pyyaml = "^6.0.2" ruff = "^0.6.2" typing-extensions = "^4.12.2" pydantic = "^2.8.2" +pytimeparse = "^1.1.8" [build-system] build-backend = "poetry.core.masonry.api" diff --git a/services/config_service.py b/services/config_service.py new file mode 100644 index 0000000..172224a --- /dev/null +++ b/services/config_service.py @@ -0,0 +1,135 @@ +from typing import Any + +from db import database + + +class GuildConfig: + def __init__(self, guild_id: int) -> None: + self.guild_id: int = guild_id + self.birthday_channel_id: int | None = None + self.command_channel_id: int | None = None + self.intro_channel_id: int | None = None + self.welcome_channel_id: int | None = None + self.welcome_message: str | None = None + self.boost_channel_id: int | None = None + self.boost_message: str | None = None + self.boost_image_url: str | None = None + self.level_channel_id: int | None = None + self.level_message: str | None = None + self.level_message_type: int = 1 + + self.fetch_or_create_config() + + def fetch_or_create_config(self) -> None: + """ + Gets a Guild Config from the database or inserts a new row if it doesn't exist yet. + """ + query: str = """ + SELECT birthday_channel_id, command_channel_id, intro_channel_id, + welcome_channel_id, welcome_message, boost_channel_id, + boost_message, boost_image_url, level_channel_id, + level_message, level_message_type + FROM guild_config WHERE guild_id = %s + """ + + try: + self._extracted_from_fetch_or_create_config_14(query) + except (IndexError, TypeError): + # No record found for the specified guild_id + query = "INSERT INTO guild_config (guild_id) VALUES (%s)" + database.execute_query(query, (self.guild_id,)) + + # TODO Rename this here and in `fetch_or_create_config` + def _extracted_from_fetch_or_create_config_14(self, query: str) -> None: + result: tuple[Any, ...] = database.select_query(query, (self.guild_id,))[0] + ( + self.birthday_channel_id, + self.command_channel_id, + self.intro_channel_id, + self.welcome_channel_id, + self.welcome_message, + self.boost_channel_id, + self.boost_message, + self.boost_image_url, + self.level_channel_id, + self.level_message, + self.level_message_type, + ) = result + + def push(self) -> None: + query: str = """ + UPDATE guild_config + SET + birthday_channel_id = %s, + command_channel_id = %s, + intro_channel_id = %s, + welcome_channel_id = %s, + welcome_message = %s, + boost_channel_id = %s, + boost_message = %s, + boost_image_url = %s, + level_channel_id = %s, + level_message = %s, + level_message_type = %s + WHERE guild_id = %s; + """ + + database.execute_query( + query, + ( + self.birthday_channel_id, + self.command_channel_id, + self.intro_channel_id, + self.welcome_channel_id, + self.welcome_message, + self.boost_channel_id, + self.boost_message, + self.boost_image_url, + self.level_channel_id, + self.level_message, + self.level_message_type, + self.guild_id, + ), + ) + + @staticmethod + def get_prefix(message: Any) -> str: + """ + Gets the prefix from a given guild. + This function is done as static method to make the prefix fetch process faster. + """ + query: str = """ + SELECT prefix + FROM guild_config + WHERE guild_id = %s + """ + + prefix: str | None = database.select_query_one( + query, + (message.guild.id if message.guild else None,), + ) + + return prefix or "." + + @staticmethod + def get_prefix_from_guild_id(guild_id: int) -> str: + query: str = """ + SELECT prefix + FROM guild_config + WHERE guild_id = %s + """ + + return database.select_query_one(query, (guild_id,)) or "." + + @staticmethod + def set_prefix(guild_id: int, prefix: str) -> None: + """ + Sets the prefix for a given guild. + """ + query: str = """ + UPDATE guild_config + SET prefix = %s + WHERE guild_id = %s; + """ + + database.execute_query(query, (prefix, guild_id))