Skip to content

Commit 6320a0f

Browse files
committed
Add 'exclude_alert_channel_ids' int list field support to disable sending crossposting alerts to those channels
1 parent 945ba18 commit 6320a0f

File tree

1 file changed

+138
-74
lines changed

1 file changed

+138
-74
lines changed

pcbot/exts/anti_crosspost.py

Lines changed: 138 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
import discord
77
from discord.ext import commands
88
import snakecore
9-
from typing import TypedDict, Collection
9+
from typing import TypedDict, Collection, cast
1010
from collections import OrderedDict
1111
import logging
1212

1313
from ..base import BaseExtensionCog
1414

1515
# Define the type for the bot, supporting both Bot and AutoShardedBot from snakecore
1616
BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot
17+
MessageableGuildChannel = (
18+
discord.TextChannel | discord.VoiceChannel | discord.StageChannel | discord.Thread
19+
)
1720

1821
logger = logging.getLogger(__name__)
1922

@@ -55,13 +58,17 @@ async def crosspost_cmp(message: discord.Message, other: discord.Message) -> boo
5558
"""
5659
Compare two messages to determine if they are crossposts or duplicates.
5760
58-
Args:
59-
message (discord.Message): The first message to compare.
60-
other (discord.Message): The second message to compare.
61-
62-
Returns:
63-
bool: True if the messages are similar enough to be considered
64-
duplicates, otherwise False.
61+
Parameters
62+
----------
63+
message : discord.Message
64+
The first message to compare.
65+
other : discord.Message
66+
The second message to compare.
67+
68+
Returns
69+
-------
70+
bool
71+
True if the messages are similar enough to be considered duplicates, otherwise False.
6572
"""
6673

6774
similarity_score = None
@@ -123,14 +130,15 @@ class UserCrosspostCache(TypedDict):
123130
"""
124131

125132
message_groups: list[list[discord.Message]]
126-
message_to_alert: dict[int, int] # Mapping from message ID to alert message ID
133+
message_to_alert_map: dict[int, int] # Mapping from message ID to alert message ID
127134

128135

129136
class AntiCrosspostCog(BaseExtensionCog, name="anti-crosspost"):
130137
def __init__(
131138
self,
132139
bot: BotT,
133140
channel_ids: Collection[int],
141+
exclude_alert_channel_ids: Collection[int] | None,
134142
crosspost_timedelta_threshold: int,
135143
same_channel_message_length_threshold: int,
136144
cross_channel_message_length_threshold: int,
@@ -141,21 +149,30 @@ def __init__(
141149
"""
142150
Initialize the AntiCrosspostCog.
143151
144-
Args:
145-
bot (BotT): The bot instance.
146-
channel_ids (Collection[int]): Collection of channel IDs to monitor.
147-
crosspost_timedelta_threshold (int): Minimum time difference between messages to not be considered crossposts.
148-
same_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
149-
if the messages are in the same channel.
150-
cross_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
151-
if the messages are in different channels.
152-
max_tracked_users (int): Maximum number of users to track.
153-
max_tracked_message_groups_per_user (int): Maximum number of message
154-
groups to track per user.
155-
theme_color (int | discord.Color): Theme color for the bot's responses.
152+
Parameters
153+
----------
154+
bot : BotT
155+
The bot instance.
156+
channel_ids : Collection[int]
157+
Collection of channel IDs to watch.
158+
exclude_alert_channel_ids : Collection[int] or None
159+
Collection of channel IDs to exclude from alerting.
160+
crosspost_timedelta_threshold : int
161+
Minimum time difference between messages to not be considered crossposts.
162+
same_channel_message_length_threshold : int
163+
Minimum length of a text-only message to be considered if the messages are in the same channel.
164+
cross_channel_message_length_threshold : int
165+
Minimum length of a text-only message to be considered if the messages are in different channels.
166+
max_tracked_users : int
167+
Maximum number of users to track.
168+
max_tracked_message_groups_per_user : int
169+
Maximum number of message groups to track per user.
170+
theme_color : int or discord.Color, optional
171+
Theme color for the bot's responses, by default 0.
156172
"""
157173
super().__init__(bot, theme_color)
158174
self.channel_ids = set(channel_ids)
175+
self.exclude_alert_channel_ids = set(exclude_alert_channel_ids or ())
159176
self.crossposting_cache: OrderedDict[int, UserCrosspostCache] = OrderedDict()
160177

161178
self.crosspost_timedelta_threshold = crosspost_timedelta_threshold
@@ -170,15 +187,9 @@ def __init__(
170187

171188
@commands.Cog.listener()
172189
async def on_message(self, message: discord.Message):
173-
"""
174-
Event listener for new messages.
175-
176-
Args:
177-
message (discord.Message): The message object.
178-
"""
179190
if (
180191
message.author.bot
181-
or not await self._is_watched_channel(message.channel) # type: ignore
192+
or not await self._check_channel(message.channel, self.channel_ids) # type: ignore
182193
or message.type != discord.MessageType.default
183194
or (
184195
message.content
@@ -205,14 +216,15 @@ async def on_message(self, message: discord.Message):
205216

206217
user_cache = self.crossposting_cache[user_id]
207218
if not any(len(group) > 1 for group in user_cache["message_groups"]):
219+
# Remove user from cache if they dont have any crossposts
208220
self.crossposting_cache.pop(user_id)
209221
logger.debug(f"Removed user {user_id} from cache to enforce size limit")
210222

211223
# Initialize cache for new users
212224
if message.author.id not in self.crossposting_cache:
213225
self.crossposting_cache[message.author.id] = UserCrosspostCache(
214226
message_groups=[[message]],
215-
message_to_alert={},
227+
message_to_alert_map={},
216228
)
217229
logger.debug(f"Initialized cache for new user {message.author.name}")
218230
else:
@@ -248,14 +260,40 @@ async def on_message(self, message: discord.Message):
248260
logger.debug(
249261
f"Found crosspost for user {message.author.name}, message URL {message.jump_url}!!!!!!!!!!"
250262
)
263+
alert_channel = cast(MessageableGuildChannel, message.channel)
264+
if self.exclude_alert_channel_ids and not await self._check_channel(
265+
alert_channel, deny=self.exclude_alert_channel_ids
266+
):
267+
# Attempt to find the next best channel to alert in
268+
print( [ msg.content for msg in messages[:-1] ])
269+
for message in reversed(messages[:-1]):
270+
alert_channel = cast(
271+
MessageableGuildChannel, message.channel
272+
)
273+
if await self._check_channel(
274+
alert_channel, deny=self.exclude_alert_channel_ids
275+
):
276+
break
277+
else:
278+
logger.debug(
279+
f"No allowed alerting channel for user {message.author.name} found"
280+
)
281+
break # Don't issue an alert if not possible
282+
283+
if message.id in user_cache["message_to_alert_map"]:
284+
logger.debug(
285+
f"Message {message.id} is already being alerted for user {message.author.name}"
286+
)
287+
break # Don't issue an alert if already alerted
251288

252289
try:
253-
alert_message = await message.reply(
290+
alert_message = await alert_channel.send(
254291
"This message is a recent crosspost/duplicate among the following messages: "
255292
+ ", ".join([m.jump_url for m in messages])
256-
+ ".\n\nPlease delete all duplicate messages."
293+
+ ".\n\nPlease delete all duplicate messages.",
294+
reference=message,
257295
)
258-
user_cache["message_to_alert"][
296+
user_cache["message_to_alert_map"][
259297
message.id
260298
] = alert_message.id
261299
logger.debug(
@@ -290,13 +328,7 @@ async def on_message(self, message: discord.Message):
290328

291329
@commands.Cog.listener()
292330
async def on_message_delete(self, message: discord.Message):
293-
"""
294-
Event listener for deleted messages.
295-
296-
Args:
297-
message (discord.Message): The message object.
298-
"""
299-
if not await self._is_watched_channel(message.channel): # type: ignore
331+
if not await self._check_channel(message.channel, self.channel_ids): # type: ignore
300332
return
301333

302334
if message.author.id not in self.crossposting_cache:
@@ -309,9 +341,9 @@ async def on_message_delete(self, message: discord.Message):
309341
for j in range(len(messages) - 1, -1, -1):
310342
if message.id == messages[j].id:
311343
del messages[j]
312-
if message.id in user_cache["message_to_alert"]:
344+
if message.id in user_cache["message_to_alert_map"]:
313345
stale_alert_message_ids.append(
314-
user_cache["message_to_alert"].pop(message.id)
346+
user_cache["message_to_alert_map"].pop(message.id)
315347
)
316348
logger.debug(
317349
f"Removed message {message.jump_url} from user {message.author.name}'s cache due to deletion"
@@ -320,9 +352,12 @@ async def on_message_delete(self, message: discord.Message):
320352

321353
# Mark last alert message for this crosspost group as stale if the group
322354
# has only one message
323-
if len(messages) == 1 and messages[0].id in user_cache["message_to_alert"]:
355+
if (
356+
len(messages) == 1
357+
and messages[0].id in user_cache["message_to_alert_map"]
358+
):
324359
stale_alert_message_ids.append(
325-
user_cache["message_to_alert"].pop(messages[0].id)
360+
user_cache["message_to_alert_map"].pop(messages[0].id)
326361
)
327362

328363
# Delete stale alert messages
@@ -337,29 +372,48 @@ async def on_message_delete(self, message: discord.Message):
337372
f"Failed to delete alert message ID {alert_message_id}: {e}"
338373
)
339374

340-
async def _is_watched_channel(self, channel: discord.abc.GuildChannel) -> bool:
375+
@staticmethod
376+
async def _check_channel(
377+
channel: discord.abc.GuildChannel | discord.Thread,
378+
allow: Collection[int] = (),
379+
deny: Collection[int] = (),
380+
) -> bool:
381+
"""
382+
Check if a guild channel or thread is allowed or denied for something based on the provided allow and deny lists.
383+
384+
Parameters
385+
----------
386+
channel : discord.abc.GuildChannel | discord.Thread
387+
The channel to check.
388+
allow : Collection[int], optional
389+
Collection of channel IDs to allow, by default ()
390+
deny : Collection[int], optional
391+
Collection of channel IDs to deny, by default ()
392+
393+
Returns
394+
-------
395+
bool: True if the channel is allowed, False if it is denied, and None if neither is allowed.
341396
"""
342-
Check if a channel is watched for crossposts based on the configured channel IDs.
343397

344-
Args:
345-
channel (discord.abc.GuildChannel): The channel to check.
398+
if not (allow or deny):
399+
raise ValueError("Either 'allow' or 'deny' must be provided")
400+
401+
result = False
346402

347-
Returns:
348-
bool: True if the channel is watched, otherwise False.
349-
"""
350403
if isinstance(channel, discord.abc.GuildChannel):
351404
# Check if the channel ID or category ID is in the monitored channel IDs
352-
if (
353-
channel.id in self.channel_ids
354-
or channel.category_id in self.channel_ids
355-
):
356-
return True
357-
358-
# If the channel is a thread, check if the parent or the parent's category ID is in the monitored channel IDs
359-
if isinstance(channel, discord.Thread):
360-
if channel.parent_id in self.channel_ids:
361-
return True
405+
result = (
406+
bool(allow) and (channel.id in allow or channel.category_id in allow)
407+
) or not (
408+
bool(deny) and (channel.id in deny or channel.category_id in deny)
409+
)
362410

411+
# If the channel is a thread, check if the parent or the parent's category ID is in the monitored channel IDs
412+
elif isinstance(channel, discord.Thread):
413+
if not (
414+
result := (bool(allow) and channel.parent_id in allow)
415+
or not (bool(deny) and channel.parent_id in deny)
416+
):
363417
try:
364418
parent = (
365419
channel.parent
@@ -369,16 +423,18 @@ async def _is_watched_channel(self, channel: discord.abc.GuildChannel) -> bool:
369423
except discord.NotFound:
370424
pass
371425
else:
372-
if parent and parent.category_id in self.channel_ids:
373-
return True
426+
result = (bool(allow) and parent.category_id in allow) or not (
427+
bool(deny) and parent.category_id in deny
428+
)
374429

375-
return False
430+
return result
376431

377432

378433
@snakecore.commands.decorators.with_config_kwargs
379434
async def setup(
380435
bot: BotT,
381436
channel_ids: Collection[int],
437+
exclude_alert_channel_ids: Collection[int] | None = None,
382438
max_tracked_users: int = 10,
383439
max_tracked_message_groups_per_user: int = 10,
384440
crosspost_timedelta_threshold: int = 86400,
@@ -389,22 +445,30 @@ async def setup(
389445
"""
390446
Setup function to add the AntiCrosspostCog to the bot.
391447
392-
Args:
393-
bot (BotT): The bot instance.
394-
channel_ids (Collection[int]): Collection of channel IDs to monitor.
395-
max_tracked_users (int): Maximum number of users to track.
396-
max_tracked_message_groups_per_user (int): Maximum number of message groups to track per user.
397-
crosspost_timedelta_threshold (int): Minimum time difference between messages to not be considered crossposts.
398-
same_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
399-
if the messages are in the same channel.
400-
cross_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
401-
if the messages are in different channels.
402-
theme_color (int | discord.Color): Theme color for the bot's responses.
448+
Parameters
449+
----------
450+
bot : BotT
451+
The bot instance.
452+
channel_ids : Collection[int]
453+
Collection of channel IDs to watch.
454+
exclude_alert_channel_ids : Collection[int] or None, optional
455+
Collection of channel IDs to exclude from alerting, by default None
456+
max_tracked_users : int, optional
457+
Maximum number of users to track, by default 10
458+
max_tracked_message_groups_per_user : int, optional
459+
Maximum number of message groups to track per user, by default 10
460+
crosspost_timedelta_threshold : int, optional
461+
Minimum time difference between messages to not be considered crossposts, by default 86400
462+
same_channel_message_length_threshold : int, optional
463+
Minimum length of a text-only message to be considered if the messages are in the same channel, by default 64
464+
cross_channel_message_length_threshold : int, optional
465+
Minimum length of a text-only message to be considered if the messages are in different channels, by default 16
403466
"""
404467
await bot.add_cog(
405468
AntiCrosspostCog(
406469
bot,
407470
channel_ids,
471+
exclude_alert_channel_ids,
408472
crosspost_timedelta_threshold,
409473
same_channel_message_length_threshold,
410474
cross_channel_message_length_threshold,

0 commit comments

Comments
 (0)