1
Fork 0
mirror of https://github.com/allthingslinux/tux.git synced 2024-10-02 16:43:12 +00:00

- [READABILITY] Reverted example config changes, and added a GIF_LIMITER category

- [READABILITY] Separated the GIF limiter message handler into several functions
- [PERFORMANCE] Convert GIF limiter settings dictionaries in constants.py instead of in the cog itself
- [BUG FIX] Added locks to prevent race conditions between the message handler and routine cleanup function
This commit is contained in:
rm-rf-omega 2024-09-26 12:05:46 +02:00
parent 11bd345c44
commit fd62c38845
4 changed files with 93 additions and 89 deletions

View file

@ -6,8 +6,9 @@ DEFAULT_PREFIX:
USER_IDS: USER_IDS:
SYSADMINS: SYSADMINS:
- 123456789012345 - 123456789012345679
BOT_OWNER: 123456789012345 - 123456789012345679
BOT_OWNER: 123456789012345679
TEMPVC_CATEGORY_ID: 123456789012345 TEMPVC_CATEGORY_ID: 123456789012345
TEMPVC_CHANNEL_ID: 123456789012345 TEMPVC_CHANNEL_ID: 123456789012345
@ -40,13 +41,13 @@ EMBED_ICONS:
TIMEOUT: "https://github.com/allthingslinux/tux/blob/main/assets/emojis/timeout.png?raw=true" TIMEOUT: "https://github.com/allthingslinux/tux/blob/main/assets/emojis/timeout.png?raw=true"
WARN: "https://github.com/allthingslinux/tux/blob/main/assets/emojis/warn.png?raw=true" WARN: "https://github.com/allthingslinux/tux/blob/main/assets/emojis/warn.png?raw=true"
RECENT_GIF_AGE: 60 GIF_LIMITER:
RECENT_GIF_AGE: 60
GIF_LIMIT_EXCLUDE: GIF_LIMIT_EXCLUDE:
- 123456789012345 - 123456789012345
GIF_LIMITS_USER: GIF_LIMITS_USER:
"123456789012345": 2 "123456789012345": 2
GIF_LIMITS_CHANNEL:
GIF_LIMITS_CHANNEL: "123456789012345": 3
"123456789012345": 3

View file

@ -1,41 +1,20 @@
"""
This cog is a handler for GIF ratelimiting.
It keeps a list of GIF send times and routinely removes old times.
If a user posts a GIF, the message_handler function should be externally called.
It will delete the message if the user or channel quota is exceeded.
"""
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from time import time from time import time
import discord import discord
from discord import Message
from discord.ext import commands, tasks from discord.ext import commands, tasks
from loguru import logger
from tux.bot import Tux from tux.bot import Tux
from tux.utils.constants import CONST from tux.utils.constants import CONST
def convert_dict_str_to_int(original_dict: dict[str, int]) -> dict[int, int]:
"""Helper function required as YAML keys are str. Channel and user IDs are int."""
converted_dict: dict[int, int] = {}
for key, value in original_dict.items():
try:
int_key: int = int(key)
converted_dict[int_key] = value
except ValueError:
logger.error("An error occurred when loading the GIF ratelimiter configuration.")
return converted_dict
class GifLimiter(commands.Cog): class GifLimiter(commands.Cog):
"""Main class with GIF tracking and message handlers""" """
This class is a handler for GIF ratelimiting.
It keeps a list of GIF send times and routinely removes old times.
It will prevent people from posting GIFs if the quotas are exceeded.
"""
def __init__(self, bot: Tux) -> None: def __init__(self, bot: Tux) -> None:
self.bot = bot self.bot = bot
@ -43,9 +22,9 @@ class GifLimiter(commands.Cog):
self.recent_gif_age: int = CONST.RECENT_GIF_AGE self.recent_gif_age: int = CONST.RECENT_GIF_AGE
# Max number of GIFs sent recently in a channel # Max number of GIFs sent recently in a channel
self.channelwide_gif_limits: dict[int, int] = convert_dict_str_to_int(CONST.GIF_LIMITS_CHANNEL) self.channelwide_gif_limits: dict[int, int] = CONST.GIF_LIMITS_CHANNEL
# Max number of GIFs sent recently by a user to be able to post one in specified channels # Max number of GIFs sent recently by a user to be able to post one in specified channels
self.user_gif_limits: dict[int, int] = convert_dict_str_to_int(CONST.GIF_LIMITS) self.user_gif_limits: dict[int, int] = CONST.GIF_LIMITS
# list of channels in which not to count GIFs # list of channels in which not to count GIFs
self.gif_limit_exclude: list[int] = CONST.GIF_LIMIT_EXCLUDE self.gif_limit_exclude: list[int] = CONST.GIF_LIMIT_EXCLUDE
@ -57,68 +36,73 @@ class GifLimiter(commands.Cog):
# Channel ID, list of timestamps # Channel ID, list of timestamps
self.recent_gifs_by_channel: defaultdict[int, list[int]] = defaultdict(list) self.recent_gifs_by_channel: defaultdict[int, list[int]] = defaultdict(list)
async def delete_message(self, message: discord.Message, epilogue: str) -> None: # Lock to prevent race conditions
self.gif_lock = asyncio.Lock()
self.old_gif_remover.start()
async def _should_process_message(self, message: discord.Message) -> bool:
""" Checks if a message contains a GIF and was not sent in a blacklisted channel """
return not (len(message.embeds) == 0
or "gif" not in message.content.lower()
or message.channel.id in self.gif_limit_exclude)
async def _handle_gif_message(self, message: discord.Message) -> None:
""" Checks for ratelimit infringements """
async with self.gif_lock:
channel: int = message.channel.id
user: int = message.author.id
if (
channel in self.channelwide_gif_limits
and channel in self.recent_gifs_by_channel
and len(self.recent_gifs_by_channel[channel]) >= self.channelwide_gif_limits[channel]
):
await self._delete_message(message, "for channel")
return
if (
user in self.recent_gifs_by_user
and channel in self.user_gif_limits
and len(self.recent_gifs_by_user[user]) >= self.user_gif_limits[channel]
):
await self._delete_message(message, "for user")
return
# Add message to recent GIFs if it doesn't infringe on ratelimits
current_time: int = int(time())
self.recent_gifs_by_channel[channel].append(current_time)
self.recent_gifs_by_user[user].append(current_time)
async def _delete_message(self, message: discord.Message, epilogue: str) -> None:
""" """
Deletes the message passed as an argument, and sends a self-deleting message with the reason Deletes the message passed as an argument, and sends a self-deleting message with the reason
""" """
sent_message: Message = await message.channel.send(f"-# GIF ratelimit exceeded {epilogue}")
await message.delete() await message.delete()
await asyncio.sleep(3) await message.channel.send(f"-# GIF ratelimit exceeded {epilogue}", delete_after=3)
await sent_message.delete()
@commands.Cog.listener() @commands.Cog.listener()
async def on_message(self, message: discord.Message) -> None: async def on_message(self, message: discord.Message) -> None:
"""Checks for GIFs in every sent message""" """Checks for GIFs in every sent message"""
# Nothing to do if the message doesn't have a .gif embed, if (await self._should_process_message(message)):
# or if it was sent in a blacklisted channel await self._handle_gif_message(message)
if (
len(message.embeds) == 0
or "gif" not in message.content.lower()
or message.channel.id in self.gif_limit_exclude
):
return
channel: int = message.channel.id
user: int = message.author.id
# Check if the message infringes on any ratelimits
if (
channel in self.channelwide_gif_limits
and channel in self.recent_gifs_by_channel
and len(self.recent_gifs_by_channel[channel]) >= self.channelwide_gif_limits[channel]
):
await self.delete_message(message, "for channel")
return
if (
user in self.recent_gifs_by_user
and channel in self.user_gif_limits
and len(self.recent_gifs_by_user[user]) >= self.user_gif_limits[channel]
):
await self.delete_message(message, "for user")
return
# If it doesn't, add it to recent GIFs
current_time: int = int(time())
self.recent_gifs_by_channel[channel].append(current_time)
self.recent_gifs_by_user[user].append(current_time)
@tasks.loop(seconds=20) @tasks.loop(seconds=20)
async def old_gif_remover(self) -> None: async def old_gif_remover(self) -> None:
"""Regularly cleans old GIF timestamps""" """Regularly cleans old GIF timestamps"""
current_time: int = int(time()) current_time: int = int(time())
for channel_id, timestamps in self.recent_gifs_by_channel.items(): async with self.gif_lock:
self.recent_gifs_by_channel[channel_id] = [t for t in timestamps if current_time - t < self.recent_gif_age] for channel_id, timestamps in self.recent_gifs_by_channel.items():
self.recent_gifs_by_channel[channel_id] = [t for t in timestamps if current_time - t < self.recent_gif_age]
for user_id, timestamps in self.recent_gifs_by_user.items(): for user_id, timestamps in self.recent_gifs_by_user.items():
self.recent_gifs_by_user[user_id] = [t for t in timestamps if current_time - t < self.recent_gif_age] self.recent_gifs_by_user[user_id] = [t for t in timestamps if current_time - t < self.recent_gif_age]
# Delete user key if no GIF has recently been sent by them
if len(self.recent_gifs_by_user[user_id]) == 0:
del self.recent_gifs_by_user[user_id]
# Delete user key if no GIF has recently been sent by them
if len(self.recent_gifs_by_user[user_id]) == 0:
del self.recent_gifs_by_user[user_id]
async def setup(bot: Tux) -> None: async def setup(bot: Tux) -> None:
await bot.add_cog(GifLimiter(bot)) await bot.add_cog(GifLimiter(bot))

View file

@ -6,6 +6,8 @@ from typing import Final
import yaml import yaml
from dotenv import load_dotenv, set_key from dotenv import load_dotenv, set_key
from tux.utils.functions import convert_dict_str_to_int
load_dotenv(verbose=True) load_dotenv(verbose=True)
config_file = Path("config/settings.yml") config_file = Path("config/settings.yml")
@ -78,11 +80,11 @@ class Constants:
EMBED_ICONS: Final[dict[str, str]] = config["EMBED_ICONS"] EMBED_ICONS: Final[dict[str, str]] = config["EMBED_ICONS"]
# GIF ratelimit constants # GIF ratelimit constants
RECENT_GIF_AGE: Final[int] = config["RECENT_GIF_AGE"] RECENT_GIF_AGE: Final[int] = config["GIF_LIMITER"]["RECENT_GIF_AGE"]
GIF_LIMIT_EXCLUDE: Final[list[int]] = config["GIF_LIMIT_EXCLUDE"] GIF_LIMIT_EXCLUDE: Final[list[int]] = config["GIF_LIMITER"]["GIF_LIMIT_EXCLUDE"]
# Ideally would be int, int but YAML doesn't support integer keys
GIF_LIMITS: Final[dict[str, int]] = config["GIF_LIMITS_USER"] GIF_LIMITS: Final[dict[int, int]] = convert_dict_str_to_int(config["GIF_LIMITER"]["GIF_LIMITS_USER"])
GIF_LIMITS_CHANNEL: Final[dict[str, int]] = config["GIF_LIMITS_CHANNEL"] GIF_LIMITS_CHANNEL: Final[dict[int, int]] = convert_dict_str_to_int(config["GIF_LIMITER"]["GIF_LIMITS_CHANNEL"])
# Embed limit constants # Embed limit constants
EMBED_MAX_NAME_LENGTH = 256 EMBED_MAX_NAME_LENGTH = 256

View file

@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
from typing import Any from typing import Any
import discord import discord
from loguru import logger
harmful_command_pattern = r"(?:sudo\s+|doas\s+|run0\s+)?rm\s+(-[frR]*|--force|--recursive|--no-preserve-root|\s+)*([/\~]\s*|\*|/bin|/boot|/etc|/lib|/proc|/root|/sbin|/sys|/tmp|/usr|/var|/var/log|/network.|/system)(\s+--no-preserve-root|\s+\*)*|:\(\)\{ :|:& \};:" # noqa: RUF001 harmful_command_pattern = r"(?:sudo\s+|doas\s+|run0\s+)?rm\s+(-[frR]*|--force|--recursive|--no-preserve-root|\s+)*([/\~]\s*|\*|/bin|/boot|/etc|/lib|/proc|/root|/sbin|/sys|/tmp|/usr|/var|/var/log|/network.|/system)(\s+--no-preserve-root|\s+\*)*|:\(\)\{ :|:& \};:" # noqa: RUF001
@ -300,3 +301,19 @@ def extract_member_attrs(member: discord.Member) -> dict[str, Any]:
"status": member.status, "status": member.status,
"activity": member.activity, "activity": member.activity,
} }
def convert_dict_str_to_int(original_dict: dict[str, int]) -> dict[int, int]:
"""Helper function used for GIF Limiter constants.
Required as YAML keys are str. Channel and user IDs are int."""
converted_dict: dict[int, int] = {}
for key, value in original_dict.items():
try:
int_key: int = int(key)
converted_dict[int_key] = value
except ValueError:
logger.exception(f"An error occurred when loading the GIF ratelimiter configuration at key {key}")
return converted_dict