6
6
import discord
7
7
from discord .ext import commands
8
8
import snakecore
9
- from typing import TypedDict , Collection
9
+ from typing import TypedDict , Collection , cast
10
10
from collections import OrderedDict
11
11
import logging
12
12
13
13
from ..base import BaseExtensionCog
14
14
15
15
# Define the type for the bot, supporting both Bot and AutoShardedBot from snakecore
16
16
BotT = snakecore .commands .Bot | snakecore .commands .AutoShardedBot
17
+ MessageableGuildChannel = (
18
+ discord .TextChannel | discord .VoiceChannel | discord .StageChannel | discord .Thread
19
+ )
17
20
18
21
logger = logging .getLogger (__name__ )
19
22
@@ -55,13 +58,17 @@ async def crosspost_cmp(message: discord.Message, other: discord.Message) -> boo
55
58
"""
56
59
Compare two messages to determine if they are crossposts or duplicates.
57
60
58
- Args:
59
- message (discord.Message): The first message to compare.
60
- other (discord.Message): The second message to compare.
61
-
62
- Returns:
63
- bool: True if the messages are similar enough to be considered
64
- duplicates, otherwise False.
61
+ Parameters
62
+ ----------
63
+ message : discord.Message
64
+ The first message to compare.
65
+ other : discord.Message
66
+ The second message to compare.
67
+
68
+ Returns
69
+ -------
70
+ bool
71
+ True if the messages are similar enough to be considered duplicates, otherwise False.
65
72
"""
66
73
67
74
similarity_score = None
@@ -123,14 +130,15 @@ class UserCrosspostCache(TypedDict):
123
130
"""
124
131
125
132
message_groups : list [list [discord .Message ]]
126
- message_to_alert : dict [int , int ] # Mapping from message ID to alert message ID
133
+ message_to_alert_map : dict [int , int ] # Mapping from message ID to alert message ID
127
134
128
135
129
136
class AntiCrosspostCog (BaseExtensionCog , name = "anti-crosspost" ):
130
137
def __init__ (
131
138
self ,
132
139
bot : BotT ,
133
140
channel_ids : Collection [int ],
141
+ exclude_alert_channel_ids : Collection [int ] | None ,
134
142
crosspost_timedelta_threshold : int ,
135
143
same_channel_message_length_threshold : int ,
136
144
cross_channel_message_length_threshold : int ,
@@ -141,21 +149,30 @@ def __init__(
141
149
"""
142
150
Initialize the AntiCrosspostCog.
143
151
144
- Args:
145
- bot (BotT): The bot instance.
146
- channel_ids (Collection[int]): Collection of channel IDs to monitor.
147
- crosspost_timedelta_threshold (int): Minimum time difference between messages to not be considered crossposts.
148
- same_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
149
- if the messages are in the same channel.
150
- cross_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
151
- if the messages are in different channels.
152
- max_tracked_users (int): Maximum number of users to track.
153
- max_tracked_message_groups_per_user (int): Maximum number of message
154
- groups to track per user.
155
- theme_color (int | discord.Color): Theme color for the bot's responses.
152
+ Parameters
153
+ ----------
154
+ bot : BotT
155
+ The bot instance.
156
+ channel_ids : Collection[int]
157
+ Collection of channel IDs to watch.
158
+ exclude_alert_channel_ids : Collection[int] or None
159
+ Collection of channel IDs to exclude from alerting.
160
+ crosspost_timedelta_threshold : int
161
+ Minimum time difference between messages to not be considered crossposts.
162
+ same_channel_message_length_threshold : int
163
+ Minimum length of a text-only message to be considered if the messages are in the same channel.
164
+ cross_channel_message_length_threshold : int
165
+ Minimum length of a text-only message to be considered if the messages are in different channels.
166
+ max_tracked_users : int
167
+ Maximum number of users to track.
168
+ max_tracked_message_groups_per_user : int
169
+ Maximum number of message groups to track per user.
170
+ theme_color : int or discord.Color, optional
171
+ Theme color for the bot's responses, by default 0.
156
172
"""
157
173
super ().__init__ (bot , theme_color )
158
174
self .channel_ids = set (channel_ids )
175
+ self .exclude_alert_channel_ids = set (exclude_alert_channel_ids or ())
159
176
self .crossposting_cache : OrderedDict [int , UserCrosspostCache ] = OrderedDict ()
160
177
161
178
self .crosspost_timedelta_threshold = crosspost_timedelta_threshold
@@ -170,15 +187,9 @@ def __init__(
170
187
171
188
@commands .Cog .listener ()
172
189
async def on_message (self , message : discord .Message ):
173
- """
174
- Event listener for new messages.
175
-
176
- Args:
177
- message (discord.Message): The message object.
178
- """
179
190
if (
180
191
message .author .bot
181
- or not await self ._is_watched_channel (message .channel ) # type: ignore
192
+ or not await self ._check_channel (message .channel , self . channel_ids ) # type: ignore
182
193
or message .type != discord .MessageType .default
183
194
or (
184
195
message .content
@@ -205,14 +216,15 @@ async def on_message(self, message: discord.Message):
205
216
206
217
user_cache = self .crossposting_cache [user_id ]
207
218
if not any (len (group ) > 1 for group in user_cache ["message_groups" ]):
219
+ # Remove user from cache if they dont have any crossposts
208
220
self .crossposting_cache .pop (user_id )
209
221
logger .debug (f"Removed user { user_id } from cache to enforce size limit" )
210
222
211
223
# Initialize cache for new users
212
224
if message .author .id not in self .crossposting_cache :
213
225
self .crossposting_cache [message .author .id ] = UserCrosspostCache (
214
226
message_groups = [[message ]],
215
- message_to_alert = {},
227
+ message_to_alert_map = {},
216
228
)
217
229
logger .debug (f"Initialized cache for new user { message .author .name } " )
218
230
else :
@@ -248,14 +260,40 @@ async def on_message(self, message: discord.Message):
248
260
logger .debug (
249
261
f"Found crosspost for user { message .author .name } , message URL { message .jump_url } !!!!!!!!!!"
250
262
)
263
+ alert_channel = cast (MessageableGuildChannel , message .channel )
264
+ if self .exclude_alert_channel_ids and not await self ._check_channel (
265
+ alert_channel , deny = self .exclude_alert_channel_ids
266
+ ):
267
+ # Attempt to find the next best channel to alert in
268
+ print ( [ msg .content for msg in messages [:- 1 ] ])
269
+ for message in reversed (messages [:- 1 ]):
270
+ alert_channel = cast (
271
+ MessageableGuildChannel , message .channel
272
+ )
273
+ if await self ._check_channel (
274
+ alert_channel , deny = self .exclude_alert_channel_ids
275
+ ):
276
+ break
277
+ else :
278
+ logger .debug (
279
+ f"No allowed alerting channel for user { message .author .name } found"
280
+ )
281
+ break # Don't issue an alert if not possible
282
+
283
+ if message .id in user_cache ["message_to_alert_map" ]:
284
+ logger .debug (
285
+ f"Message { message .id } is already being alerted for user { message .author .name } "
286
+ )
287
+ break # Don't issue an alert if already alerted
251
288
252
289
try :
253
- alert_message = await message . reply (
290
+ alert_message = await alert_channel . send (
254
291
"This message is a recent crosspost/duplicate among the following messages: "
255
292
+ ", " .join ([m .jump_url for m in messages ])
256
- + ".\n \n Please delete all duplicate messages."
293
+ + ".\n \n Please delete all duplicate messages." ,
294
+ reference = message ,
257
295
)
258
- user_cache ["message_to_alert " ][
296
+ user_cache ["message_to_alert_map " ][
259
297
message .id
260
298
] = alert_message .id
261
299
logger .debug (
@@ -290,13 +328,7 @@ async def on_message(self, message: discord.Message):
290
328
291
329
@commands .Cog .listener ()
292
330
async def on_message_delete (self , message : discord .Message ):
293
- """
294
- Event listener for deleted messages.
295
-
296
- Args:
297
- message (discord.Message): The message object.
298
- """
299
- if not await self ._is_watched_channel (message .channel ): # type: ignore
331
+ if not await self ._check_channel (message .channel , self .channel_ids ): # type: ignore
300
332
return
301
333
302
334
if message .author .id not in self .crossposting_cache :
@@ -309,9 +341,9 @@ async def on_message_delete(self, message: discord.Message):
309
341
for j in range (len (messages ) - 1 , - 1 , - 1 ):
310
342
if message .id == messages [j ].id :
311
343
del messages [j ]
312
- if message .id in user_cache ["message_to_alert " ]:
344
+ if message .id in user_cache ["message_to_alert_map " ]:
313
345
stale_alert_message_ids .append (
314
- user_cache ["message_to_alert " ].pop (message .id )
346
+ user_cache ["message_to_alert_map " ].pop (message .id )
315
347
)
316
348
logger .debug (
317
349
f"Removed message { message .jump_url } from user { message .author .name } 's cache due to deletion"
@@ -320,9 +352,12 @@ async def on_message_delete(self, message: discord.Message):
320
352
321
353
# Mark last alert message for this crosspost group as stale if the group
322
354
# has only one message
323
- if len (messages ) == 1 and messages [0 ].id in user_cache ["message_to_alert" ]:
355
+ if (
356
+ len (messages ) == 1
357
+ and messages [0 ].id in user_cache ["message_to_alert_map" ]
358
+ ):
324
359
stale_alert_message_ids .append (
325
- user_cache ["message_to_alert " ].pop (messages [0 ].id )
360
+ user_cache ["message_to_alert_map " ].pop (messages [0 ].id )
326
361
)
327
362
328
363
# Delete stale alert messages
@@ -337,29 +372,48 @@ async def on_message_delete(self, message: discord.Message):
337
372
f"Failed to delete alert message ID { alert_message_id } : { e } "
338
373
)
339
374
340
- async def _is_watched_channel (self , channel : discord .abc .GuildChannel ) -> bool :
375
+ @staticmethod
376
+ async def _check_channel (
377
+ channel : discord .abc .GuildChannel | discord .Thread ,
378
+ allow : Collection [int ] = (),
379
+ deny : Collection [int ] = (),
380
+ ) -> bool :
381
+ """
382
+ Check if a guild channel or thread is allowed or denied for something based on the provided allow and deny lists.
383
+
384
+ Parameters
385
+ ----------
386
+ channel : discord.abc.GuildChannel | discord.Thread
387
+ The channel to check.
388
+ allow : Collection[int], optional
389
+ Collection of channel IDs to allow, by default ()
390
+ deny : Collection[int], optional
391
+ Collection of channel IDs to deny, by default ()
392
+
393
+ Returns
394
+ -------
395
+ bool: True if the channel is allowed, False if it is denied, and None if neither is allowed.
341
396
"""
342
- Check if a channel is watched for crossposts based on the configured channel IDs.
343
397
344
- Args:
345
- channel (discord.abc.GuildChannel): The channel to check.
398
+ if not (allow or deny ):
399
+ raise ValueError ("Either 'allow' or 'deny' must be provided" )
400
+
401
+ result = False
346
402
347
- Returns:
348
- bool: True if the channel is watched, otherwise False.
349
- """
350
403
if isinstance (channel , discord .abc .GuildChannel ):
351
404
# Check if the channel ID or category ID is in the monitored channel IDs
352
- if (
353
- channel .id in self .channel_ids
354
- or channel .category_id in self .channel_ids
355
- ):
356
- return True
357
-
358
- # If the channel is a thread, check if the parent or the parent's category ID is in the monitored channel IDs
359
- if isinstance (channel , discord .Thread ):
360
- if channel .parent_id in self .channel_ids :
361
- return True
405
+ result = (
406
+ bool (allow ) and (channel .id in allow or channel .category_id in allow )
407
+ ) or not (
408
+ bool (deny ) and (channel .id in deny or channel .category_id in deny )
409
+ )
362
410
411
+ # If the channel is a thread, check if the parent or the parent's category ID is in the monitored channel IDs
412
+ elif isinstance (channel , discord .Thread ):
413
+ if not (
414
+ result := (bool (allow ) and channel .parent_id in allow )
415
+ or not (bool (deny ) and channel .parent_id in deny )
416
+ ):
363
417
try :
364
418
parent = (
365
419
channel .parent
@@ -369,16 +423,18 @@ async def _is_watched_channel(self, channel: discord.abc.GuildChannel) -> bool:
369
423
except discord .NotFound :
370
424
pass
371
425
else :
372
- if parent and parent .category_id in self .channel_ids :
373
- return True
426
+ result = (bool (allow ) and parent .category_id in allow ) or not (
427
+ bool (deny ) and parent .category_id in deny
428
+ )
374
429
375
- return False
430
+ return result
376
431
377
432
378
433
@snakecore .commands .decorators .with_config_kwargs
379
434
async def setup (
380
435
bot : BotT ,
381
436
channel_ids : Collection [int ],
437
+ exclude_alert_channel_ids : Collection [int ] | None = None ,
382
438
max_tracked_users : int = 10 ,
383
439
max_tracked_message_groups_per_user : int = 10 ,
384
440
crosspost_timedelta_threshold : int = 86400 ,
@@ -389,22 +445,30 @@ async def setup(
389
445
"""
390
446
Setup function to add the AntiCrosspostCog to the bot.
391
447
392
- Args:
393
- bot (BotT): The bot instance.
394
- channel_ids (Collection[int]): Collection of channel IDs to monitor.
395
- max_tracked_users (int): Maximum number of users to track.
396
- max_tracked_message_groups_per_user (int): Maximum number of message groups to track per user.
397
- crosspost_timedelta_threshold (int): Minimum time difference between messages to not be considered crossposts.
398
- same_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
399
- if the messages are in the same channel.
400
- cross_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
401
- if the messages are in different channels.
402
- theme_color (int | discord.Color): Theme color for the bot's responses.
448
+ Parameters
449
+ ----------
450
+ bot : BotT
451
+ The bot instance.
452
+ channel_ids : Collection[int]
453
+ Collection of channel IDs to watch.
454
+ exclude_alert_channel_ids : Collection[int] or None, optional
455
+ Collection of channel IDs to exclude from alerting, by default None
456
+ max_tracked_users : int, optional
457
+ Maximum number of users to track, by default 10
458
+ max_tracked_message_groups_per_user : int, optional
459
+ Maximum number of message groups to track per user, by default 10
460
+ crosspost_timedelta_threshold : int, optional
461
+ Minimum time difference between messages to not be considered crossposts, by default 86400
462
+ same_channel_message_length_threshold : int, optional
463
+ Minimum length of a text-only message to be considered if the messages are in the same channel, by default 64
464
+ cross_channel_message_length_threshold : int, optional
465
+ Minimum length of a text-only message to be considered if the messages are in different channels, by default 16
403
466
"""
404
467
await bot .add_cog (
405
468
AntiCrosspostCog (
406
469
bot ,
407
470
channel_ids ,
471
+ exclude_alert_channel_ids ,
408
472
crosspost_timedelta_threshold ,
409
473
same_channel_message_length_threshold ,
410
474
cross_channel_message_length_threshold ,
0 commit comments