1
Fork 0
mirror of https://github.com/wlinator/luminara.git synced 2024-10-02 18:23:12 +00:00

Add format.py and format slowmode durations

This commit is contained in:
wlinator 2024-08-29 05:19:09 -04:00
parent f87f8a0b39
commit bf05ae118c
6 changed files with 297 additions and 7 deletions

139
lib/format.py Normal file
View file

@ -0,0 +1,139 @@
import textwrap
import discord
from discord.ext import commands
from pytimeparse import parse # type: ignore
from lib import exceptions
from lib.const import CONST
from services.config_service import GuildConfig
def template(text: str, username: str, level: int | None = None) -> str:
"""
Replaces placeholders in the given text with actual values.
Args:
text (str): The template text containing placeholders.
username (str): The username to replace the "{user}" placeholder.
level (int | None, optional): The level to replace the "{level}" placeholder. Defaults to None.
Returns:
str: The formatted text with placeholders replaced by actual values.
"""
replacements: dict[str, str] = {
"{user}": username,
"{level}": str(level) if level else "",
}
for placeholder, value in replacements.items():
text = text.replace(placeholder, value)
return text
def shorten(text: str, width: int = 200) -> str:
"""
Shortens the input text to the specified width by adding a placeholder at the end if the text exceeds the width.
Args:
text (str): The text to be shortened.
width (int): The maximum width of the shortened text (default is 200).
Returns:
str: The shortened text.
Examples:
shortened_text = shorten("Lorem ipsum dolor sit amet", 10)
"""
return textwrap.shorten(text, width=width, placeholder="...")
def format_case_number(case_number: int) -> str:
"""
Formats a case number as a string with leading zeros if necessary.
Args:
case_number (int): The case number to format.
Returns:
str: The formatted case number as a string.
If the case number is less than 1000, it will be padded with leading zeros to three digits.
If the case number is 1000 or greater, it will be returned as a regular string.
Examples:
>>> format_case_number(1)
'001'
>>> format_case_number(42)
'042'
>>> format_case_number(999)
'999'
>>> format_case_number(1000)
'1000'
"""
return f"{case_number:03d}" if case_number < 1000 else str(case_number)
def get_prefix(ctx: commands.Context[commands.Bot]) -> str:
"""
Attempts to retrieve the prefix for the given guild context.
Args:
ctx (discord.ext.commands.Context): The context of the command invocation.
Returns:
str: The prefix for the guild. Defaults to "." if the guild or prefix is not found.
"""
try:
return GuildConfig.get_prefix(ctx.guild.id if ctx.guild else 0)
except (AttributeError, TypeError):
return "."
def get_invoked_name(ctx: commands.Context[commands.Bot]) -> str | None:
"""
Attempts to get the alias of the command used. If the user used a SlashCommand, return the command name.
Args:
ctx (discord.ext.commands.Context): The context of the command invocation.
Returns:
str: The alias or name of the invoked command.
"""
try:
return ctx.invoked_with
except (discord.app_commands.CommandInvokeError, AttributeError):
return ctx.command.name if ctx.command else None
def format_duration_to_seconds(duration: str) -> int:
"""
Formats a duration in seconds to a human-readable string.
"""
parsed_duration: int = parse(duration) # type: ignore
if isinstance(parsed_duration, int):
return parsed_duration
raise exceptions.LumiException(CONST.STRINGS["error_invalid_duration"].format(duration))
def format_seconds_to_duration_string(seconds: int) -> str:
"""
Formats a duration in seconds to a human-readable string.
Returns seconds if shorter than a minute.
"""
if seconds < 60:
return f"{seconds}s"
days = seconds // 86400
hours = (seconds % 86400) // 3600
minutes = (seconds % 3600) // 60
if days > 0:
return f"{days}d{hours}h" if hours > 0 else f"{days}d"
if hours > 0:
return f"{hours}h{minutes}m" if minutes > 0 else f"{hours}h"
return f"{minutes}m"

View file

@ -7,13 +7,15 @@ from loguru import logger
from lib.client import Luminara
from lib.const import CONST
from services.config_service import GuildConfig
logger.remove()
logger.add(sys.stdout, format=CONST.LOG_FORMAT, colorize=True, level=CONST.LOG_LEVEL)
async def get_prefix(bot, message):
return commands.when_mentioned_or(".")(bot, message)
async def get_prefix(bot: Luminara, message: discord.Message) -> list[str]:
extras = GuildConfig.get_prefix(message)
return commands.when_mentioned_or(*extras)(bot, message)
async def main() -> None:

View file

@ -4,6 +4,8 @@ import discord
from discord.ext import commands
from lib.const import CONST
from lib.exceptions import LumiException
from lib.format import format_duration_to_seconds
from ui.embeds import Builder
@ -32,11 +34,11 @@ class Slowmode(commands.Cog):
try:
channel = await commands.TextChannelConverter().convert(ctx, arg)
except commands.BadArgument:
with contextlib.suppress(ValueError):
duration = int(arg)
with contextlib.suppress(LumiException):
duration = format_duration_to_seconds(arg)
else:
with contextlib.suppress(ValueError):
duration = int(arg)
with contextlib.suppress(LumiException):
duration = format_duration_to_seconds(arg)
if not channel:
await ctx.send(CONST.STRINGS["slowmode_channel_not_found"])

13
poetry.lock generated
View file

@ -972,6 +972,17 @@ nodeenv = ">=1.6.0"
all = ["twine (>=3.4.1)"]
dev = ["twine (>=3.4.1)"]
[[package]]
name = "pytimeparse"
version = "1.1.8"
description = "Time expression parser"
optional = false
python-versions = "*"
files = [
{file = "pytimeparse-1.1.8-py2.py3-none-any.whl", hash = "sha256:04b7be6cc8bd9f5647a6325444926c3ac34ee6bc7e69da4367ba282f076036bd"},
{file = "pytimeparse-1.1.8.tar.gz", hash = "sha256:e86136477be924d7e670646a98561957e8ca7308d44841e21f5ddea757556a0a"},
]
[[package]]
name = "pytz"
version = "2024.1"
@ -1299,4 +1310,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
content-hash = "e6ab702cf6efc2ec25a9c033029869f9c4c3631e2e6063c4241153ea8c1f8e79"
content-hash = "4a7a75036f4de7e0126a8f6b058eb3deb52f710becf44e4da6bac4a0ea0a1a2f"

View file

@ -25,6 +25,7 @@ pyyaml = "^6.0.2"
ruff = "^0.6.2"
typing-extensions = "^4.12.2"
pydantic = "^2.8.2"
pytimeparse = "^1.1.8"
[build-system]
build-backend = "poetry.core.masonry.api"

135
services/config_service.py Normal file
View file

@ -0,0 +1,135 @@
from typing import Any
from db import database
class GuildConfig:
def __init__(self, guild_id: int) -> None:
self.guild_id: int = guild_id
self.birthday_channel_id: int | None = None
self.command_channel_id: int | None = None
self.intro_channel_id: int | None = None
self.welcome_channel_id: int | None = None
self.welcome_message: str | None = None
self.boost_channel_id: int | None = None
self.boost_message: str | None = None
self.boost_image_url: str | None = None
self.level_channel_id: int | None = None
self.level_message: str | None = None
self.level_message_type: int = 1
self.fetch_or_create_config()
def fetch_or_create_config(self) -> None:
"""
Gets a Guild Config from the database or inserts a new row if it doesn't exist yet.
"""
query: str = """
SELECT birthday_channel_id, command_channel_id, intro_channel_id,
welcome_channel_id, welcome_message, boost_channel_id,
boost_message, boost_image_url, level_channel_id,
level_message, level_message_type
FROM guild_config WHERE guild_id = %s
"""
try:
self._extracted_from_fetch_or_create_config_14(query)
except (IndexError, TypeError):
# No record found for the specified guild_id
query = "INSERT INTO guild_config (guild_id) VALUES (%s)"
database.execute_query(query, (self.guild_id,))
# TODO Rename this here and in `fetch_or_create_config`
def _extracted_from_fetch_or_create_config_14(self, query: str) -> None:
result: tuple[Any, ...] = database.select_query(query, (self.guild_id,))[0]
(
self.birthday_channel_id,
self.command_channel_id,
self.intro_channel_id,
self.welcome_channel_id,
self.welcome_message,
self.boost_channel_id,
self.boost_message,
self.boost_image_url,
self.level_channel_id,
self.level_message,
self.level_message_type,
) = result
def push(self) -> None:
query: str = """
UPDATE guild_config
SET
birthday_channel_id = %s,
command_channel_id = %s,
intro_channel_id = %s,
welcome_channel_id = %s,
welcome_message = %s,
boost_channel_id = %s,
boost_message = %s,
boost_image_url = %s,
level_channel_id = %s,
level_message = %s,
level_message_type = %s
WHERE guild_id = %s;
"""
database.execute_query(
query,
(
self.birthday_channel_id,
self.command_channel_id,
self.intro_channel_id,
self.welcome_channel_id,
self.welcome_message,
self.boost_channel_id,
self.boost_message,
self.boost_image_url,
self.level_channel_id,
self.level_message,
self.level_message_type,
self.guild_id,
),
)
@staticmethod
def get_prefix(message: Any) -> str:
"""
Gets the prefix from a given guild.
This function is done as static method to make the prefix fetch process faster.
"""
query: str = """
SELECT prefix
FROM guild_config
WHERE guild_id = %s
"""
prefix: str | None = database.select_query_one(
query,
(message.guild.id if message.guild else None,),
)
return prefix or "."
@staticmethod
def get_prefix_from_guild_id(guild_id: int) -> str:
query: str = """
SELECT prefix
FROM guild_config
WHERE guild_id = %s
"""
return database.select_query_one(query, (guild_id,)) or "."
@staticmethod
def set_prefix(guild_id: int, prefix: str) -> None:
"""
Sets the prefix for a given guild.
"""
query: str = """
UPDATE guild_config
SET prefix = %s
WHERE guild_id = %s;
"""
database.execute_query(query, (prefix, guild_id))