diff --git a/.gitignore b/.gitignore index 8989f962..109173f2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ *.log DiscordBot/tokens.json TABot/tokens.json +venv/ +.DS_Store +.history +gcp_key.json diff --git a/DiscordBot/bot.py b/DiscordBot/bot.py index ec5dddb6..57830148 100644 --- a/DiscordBot/bot.py +++ b/DiscordBot/bot.py @@ -6,8 +6,13 @@ import logging import re import requests +from google.oauth2 import service_account +from google.auth.transport.requests import Request +from enum import Enum from report import Report +from report_queue import SubmittedReport, PriorityReportQueue import pdb +from moderate import ModeratorReview # Set up logging to the console logger = logging.getLogger('discord') @@ -26,6 +31,24 @@ discord_token = tokens['discord'] +MOD_TODO_START = "---------------------------\nTODO" +MODERATE_KEYWORD = "moderate" + +NUM_QUEUE_LEVELS = 3 + +CLASSIFIER_URL = "placeholder/classify" +GCP_SERVICE_ACCOUNT_TOKEN_FILE = "gcp_key.json" # key that allows our discord bot to run the classifier +credentials = service_account.IDTokenCredentials.from_service_account_file( + GCP_SERVICE_ACCOUNT_TOKEN_FILE, + target_audience=CLASSIFIER_URL +) + +class ConversationState(Enum): + NOFLOW = 0 + REPORTING = 1 + MODERATING = 2 + + class ModBot(discord.Client): def __init__(self): intents = discord.Intents.default() @@ -34,6 +57,11 @@ def __init__(self): self.group_num = None self.mod_channels = {} # Map from guild to the mod channel id for that guild self.reports = {} # Map from user IDs to the state of their report + self.moderations = {} + self.report_id_counter = 0 + # should equal the number of distinct priorities defined in Report.get_priority + self.report_queue = PriorityReportQueue(NUM_QUEUE_LEVELS, ["Imminent physical/mental harm", "Imminent financial/property harm", "Non-imminent"]) + self.conversationState = 0 async def on_ready(self): print(f'{self.user.name} has connected to Discord! It is these guilds:') @@ -70,45 +98,225 @@ async def on_message(self, message): else: await self.handle_dm(message) + async def handle_dm(self, message): # Handle a help message if message.content == Report.HELP_KEYWORD: reply = "Use the `report` command to begin the reporting process.\n" reply += "Use the `cancel` command to cancel the report process.\n" + reply += "Use the `moderate` command to begin the moderation process.\n" await message.channel.send(reply) return + + if message.content.startswith(Report.START_KEYWORD) or self.conversationState == ConversationState.REPORTING: + self.conversationState = ConversationState.REPORTING + await self.handle_report(message) + elif message.content.startswith(MODERATE_KEYWORD) or self.conversationState == ConversationState.MODERATING: + self.conversationState = ConversationState.MODERATING + await self.handle_moderation(message) + + async def handle_report(self, message): author_id = message.author.id responses = [] - # Only respond to messages if they're part of a reporting flow - if author_id not in self.reports and not message.content.startswith(Report.START_KEYWORD): - return - # If we don't currently have an active report for this user, add one if author_id not in self.reports: self.reports[author_id] = Report(self) - # Let the report class handle this message; forward all the messages it returns to uss + # If we are starting a report responses = await self.reports[author_id].handle_message(message) - for r in responses: - await message.channel.send(r) + + ## report.py updates state, and below, we route our response based on that state + + + if self.reports[author_id].is_awaiting_message(): + for r in responses: + await message.channel.send(r) + + if self.reports[author_id].is_awaiting_reason(): + for r in responses: + await message.channel.send(r) + + if self.reports[author_id].is_awaiting_misinformation_type(): + for r in responses: + await message.channel.send(r) + + if self.reports[author_id].is_awaiting_political_misinformation_type(): + for r in responses: + await message.channel.send(r) + + if self.reports[author_id].is_awaiting_healthl_misinformation_type(): + for r in responses: + await message.channel.send(r) + + if self.reports[author_id].is_awaiting_harmful_content_status(): + for r in responses: + await message.channel.send(r) + + if self.reports[author_id].is_awaiting_filter_action(): + for r in responses: + await message.channel.send(r) + + # if self.reports[author_id].harm_identified(): + # reply = responses[0] + # harm = responses[1] + # if harm: + # # TODO escalate (or simulate it) + # print("Escalating report") + # await message.channel.send(reply) + + # if self.reports[author_id].block_step(): + # reply = responses[0] + # block = responses[1] + # if block: + # # TODO block user (or simulate it) + # print("Blocking user") + # await message.channel.send(reply) # If the report is complete or cancelled, remove it from our map - if self.reports[author_id].report_complete(): + if self.reports[author_id].is_report_complete(): + + for r in responses: + await message.channel.send(r) + + if not self.reports[author_id].is_cancelled(): + + reported_author = self.reports[author_id].get_reported_author() + reported_content = self.reports[author_id].get_reported_content() + report_type = self.reports[author_id].get_report_type() + misinfo_type = self.reports[author_id].get_misinfo_type() + misinfo_subtype = self.reports[author_id].get_misinfo_subtype() + imminent = self.reports[author_id].get_imminent() + priority = self.reports[author_id].get_priority() + id = self.report_id_counter + self.report_id_counter += 1 + reported_message = self.reports[author_id].get_reported_message() + + # Put the report in the mod channel + message_guild_id = self.reports[author_id].get_message_guild_id() + mod_channel = self.mod_channels[message_guild_id] + # todo are we worried about code injection via author name or content? + report_info_msg = "Report ID: " + str(id) + "\n" + report_info_msg += "User " + message.author.name + " reported user " + str(reported_author) + "'s message.\n" + # report_info_msg += "Here is the message: \n```" + str(reported_content) + "\n```" + report_info_msg += "Category: " + str(report_type) + " > " + str(misinfo_type) + " > " + str(misinfo_subtype) + "\n" + if imminent: + report_info_msg += "URGENT: Imminent " + imminent + " harm reported." + submitted_report = SubmittedReport(id, reported_message, reported_author, reported_content, report_type, misinfo_type, misinfo_subtype, imminent, message_guild_id, priority) + self.report_queue.enqueue(submitted_report) + + await mod_channel.send(report_info_msg) + + # remove self.reports.pop(author_id) + self.conversationState = ConversationState.NOFLOW + + # ------ starter code relevant to MILESTONE 3: -------------- + # scores = self.eval_text(message.content) + # await mod_channel.send(self.code_format(scores)) + #------------------------------------------------- + + async def handle_moderation(self, message): + + author_id = message.author.id + + if author_id not in self.moderations and self.report_queue.is_empty(): + await message.channel.send("No pending reports.") + self.conversationState = ConversationState.NOFLOW + return + + if author_id not in self.moderations: + try: + next_report = self.report_queue.dequeue() + except IndexError: + await message.channel.send("No pending reports.") + self.conversationState = ConversationState.NOFLOW + return + review = ModeratorReview() + review.original_report = next_report + review.original_priority = next_report.priority + review.report_type = next_report.report_type + review.misinfo_type = next_report.misinfo_type + review.misinfo_subtype = next_report.subtype + review.imminent = next_report.imminent + review.reported_author_metadata = f"User: {next_report.author}" + review.reported_content_metadata = f"Msg: \"{next_report.content}\"" + review.message_guild_id = next_report.message_guild_id + review.reported_message = next_report.reported_message + self.moderations[author_id] = review + preview = self.report_queue.display_one(next_report, showContent=False) + if preview: + await message.channel.send(f"```{preview}```") + + review = self.moderations[author_id] + + responses = await review.handle_message(message) + for r in responses: + await message.channel.send(r) + + if review.is_review_complete(): + + if self.moderations[author_id].action_taken in ["Allowed", "Removed"]: + # Put the verdict in the mod channel + mod_channel = self.mod_channels[self.moderations[author_id].message_guild_id] + # todo are we worried about code injection via author name or content? + mod_info_msg = "Report ID: " + str(id) + "\n" + mod_info_msg += "has been moderated.\n" + mod_info_msg += "Verdict: " + self.moderations[author_id].action_taken + ".\n" + await mod_channel.send(mod_info_msg) + if self.moderations[author_id].action_taken == "Removed": + await review.reported_message.add_reaction("❌") + + elif self.moderations[author_id].action_taken in ["Skipped", "Escalated"]: + original_report = self.moderations[author_id].original_report + self.report_queue.enqueue(original_report) + + + self.moderations.pop(author_id, None) + self.conversationState = ConversationState.NOFLOW async def handle_channel_message(self, message): - # Only handle messages sent in the "group-#" channel - if not message.channel.name == f'group-{self.group_num}': + if not message.channel.name in [f'group-{self.group_num}', f'group-{self.group_num}-mod']: return - # Forward the message to the mod channel - mod_channel = self.mod_channels[message.guild.id] - await mod_channel.send(f'Forwarded message:\n{message.author.name}: "{message.content}"') - scores = self.eval_text(message.content) - await mod_channel.send(self.code_format(scores)) + # moderator commands + if message.channel.name == f'group-{self.group_num}-mod': + if message.content == "report summary": + await message.channel.send(self.report_queue.summary()) + elif message.content.startswith("report display"): + if "showcontent" in message.content: + await message.channel.send(self.report_queue.display(showContent=True)) + else: + await message.channel.send(self.report_queue.display()) + + + # ----- teddy: for milestone 3, send every msg to classifier/llm ----------------------- + # TODO figure out api call for our classifier (in gcp) and send it and wait for a response + # TODO uncomment and edit this code below + # credentials.refresh(Request()) + # token = credentials.token + # headers = { + # "Authorization": f"Bearer {token}" + # } + # payload = {"message": message.content} + # try: + # response = requests.post(CLASSIFIER_URL, headers=headers, json=payload) + # result = response.json() + + # classification = result.get("classification") + # confidence = result.get("confidence_score") + # TODO replace this line with sending to LLM to fill out report info + # await message.channel.send( + # f"Classification: {classification}, Confidence: {confidence:.2f}" + # ) + + # except Exception as e: + # await message.channel.send("Error classifying message.") + # print(e) + + return def eval_text(self, message): '''' @@ -124,8 +332,22 @@ def code_format(self, text): evaluated, insert your code here for formatting the string to be shown in the mod channel. ''' + #teddy: not sure if we need this function return "Evaluated: '" + text+ "'" + + # def process_response(self, responses): + + # reply = responses["reply"] + # if not isinstance(reply, str): # just in case i forget brackets in report.py + # reply = [reply] + # del responses["reply"] + + # for key, value in responses.items(): # go through data (not including reply) + # if key not in self.current_report: # don't allow overwriting + # self.current_report[key] = value + + # return reply client = ModBot() -client.run(discord_token) \ No newline at end of file +client.run(discord_token) diff --git a/DiscordBot/moderate.py b/DiscordBot/moderate.py new file mode 100644 index 00000000..c5e107ff --- /dev/null +++ b/DiscordBot/moderate.py @@ -0,0 +1,179 @@ +from enum import Enum, auto + +class ModState(Enum): + MOD_START = auto() + AWAITING_DECISION = auto() + AWAITING_SKIP_REASON = auto() + AWAITING_SUMMARY_CONFIRM = auto() + AWAITING_ACTION = auto() + REVIEW_COMPLETE = auto() + +class ModeratorReview: + def __init__(self): + self.state = ModState.MOD_START + + self.message_guild_id = None + + self.original_report = None + + self.original_priority = None + + self.report_type = None + self.misinfo_type = None + self.misinfo_subtype = None + self.imminent = None + self.filter = False + + self.reported_author_metadata = None + self.reported_content_metadata = None + self.reported_message = None + + self.skip_reason = None + self.action_taken = None + + async def handle_message(self, message): + if self.state == ModState.MOD_START: + self.state = ModState.AWAITING_DECISION + return [ + "New reported content available.", + "Would you like to review it now?", + "Type `yes` to begin review, or `skip` to pass." + ] + + if self.state == ModState.AWAITING_DECISION: + if message.content.lower() == "yes": + self.state = ModState.AWAITING_SUMMARY_CONFIRM + reply = "This content was reported as " + self.report_type + ".\n" + reply += "Misinfo category: " + str(self.misinfo_type) + " - " + str(self.misinfo_subtype) + "\n" + reply += "Here is the relevant wikipedia article: https://en.wikipedia.org/wiki/Misinformation" + if self.imminent: + reply += "Potential imminent harm: " + self.imminent + "\n" + if self.filter: + reply += "User requested filtering/blocking.\n" + reply += "Author metadata: " + str(self.reported_author_metadata) + "\n" + reply += "Content metadata: " + str(self.reported_content_metadata) + "\n\n" + reply += "Type any key to continue." + return [reply] + + elif message.content.lower() == "skip": + self.state = ModState.AWAITING_SKIP_REASON + return [ + "Please select a reason for skipping:", + "1. Personal reasons", + "2. Bias/Conflict of interest (recusal)", + "3. Requires escalation" + ] + else: + return ["Invalid response. Type `yes` or `skip`."] + + if self.state == ModState.AWAITING_SKIP_REASON: + reasons = { + "1": "Personal reasons", + "2": "Bias/Conflict of interest (recusal)", + "3": "Requires escalation" + } + if message.content in reasons: + self.skip_reason = reasons[message.content] + self.state = ModState.REVIEW_COMPLETE + self.action_taken = "Skipped" + return [f"You skipped this review due to: {self.skip_reason}.", "Returning to queue."] + else: + return ["Please choose a valid skip reason: 1, 2, or 3."] + + if self.state == ModState.AWAITING_SUMMARY_CONFIRM: + self.state = ModState.AWAITING_ACTION + return [ + "What action would you like to take on this content?", + "1. Remove content", + "2. Allow content", + "3. Uncertain (Escalate)" + ] + + if self.state == ModState.AWAITING_ACTION: + if message.content == "1": + self.action_taken = "Removed" + print("TODO ACTUALLY REMOVE MESSAGE") + self.state = ModState.REVIEW_COMPLETE + return ["Content has been removed. Review complete."] + elif message.content == "2": + self.action_taken = "Allowed" + self.state = ModState.REVIEW_COMPLETE + return ["Content has been allowed. Review complete."] + elif message.content == "3": + self.action_taken = "Escalated" + self.state = ModState.REVIEW_COMPLETE + return [f"You escalated this review due to uncertainty.", "Returning to queue."] + else: + return ["Invalid action. Type 1 to Remove, 2 to Allow, or 3 to Escalate."] + + return [] + + def get_message_guild_id(self): + return self.message_guild_id + + def get_report_type(self): + return self.report_type + + def get_misinfo_type(self): + return self.misinfo_type + + def get_misinfo_subtype(self): + return self.misinfo_subtype + + def get_imminent(self): + return self.imminent + + def get_filter(self): + return self.filter + + def get_reported_author_metadata(self): + return self.reported_author_metadata + + def get_reported_content_metadata(self): + return self.reported_content_metadata + + def get_skip_reason(self): + return self.skip_reason + + def get_action_taken(self): + return self.action_taken + + def get_state(self): + return self.state + + def set_report_info(self, report): + self.report_type = report.get_report_type() + self.misinfo_type = report.get_misinfo_type() + self.misinfo_subtype = report.get_misinfo_subtype() + self.imminent = report.get_imminent() + self.filter = report.get_filter() + + def set_metadata(self, author_meta, content_meta, primer=None): + self.reported_author_metadata = author_meta + self.reported_content_metadata = content_meta + + def get_priority(self): + if self.imminent in ["physical", "mental"]: + return 0 + elif self.imminent == "financial": + return 1 + else: + return 2 + + def is_review_complete(self): + return self.state == ModState.REVIEW_COMPLETE + + def is_mod_start(self): + return self.state == ModState.MOD_START + + def is_awaiting_decision(self): + return self.state == ModState.AWAITING_DECISION + + def is_awaiting_skip_reason(self): + return self.state == ModState.AWAITING_SKIP_REASON + + def is_awaiting_summary_confirm(self): + return self.state == ModState.AWAITING_SUMMARY_CONFIRM + + def is_awaiting_action(self): + return self.state == ModState.AWAITING_ACTION diff --git a/DiscordBot/report.py b/DiscordBot/report.py index d2bba994..2e0fe147 100644 --- a/DiscordBot/report.py +++ b/DiscordBot/report.py @@ -5,7 +5,14 @@ class State(Enum): REPORT_START = auto() AWAITING_MESSAGE = auto() - MESSAGE_IDENTIFIED = auto() + + AWAITING_REASON = auto() + AWAITING_DISINFORMATION_TYPE = auto() + AWAITING_POLITICAL_DISINFORMATION_TYPE =auto() + AWAITING_HEALTHL_DISINFORMATION_TYPE =auto() + AWAITING_FILTER_ACTION = auto() + AWAITING_HARMFUL_CONTENT_STATUS = auto() + REPORT_COMPLETE = auto() class Report: @@ -16,7 +23,18 @@ class Report: def __init__(self, client): self.state = State.REPORT_START self.client = client - self.message = None + self.reported_message = None + + self.cancelled = False + + self.message_guild_id = None + self.reported_author = None + self.reported_content = None + self.report_type = None + self.misinfo_type = None + self.misinfo_subtype = None + self.filter = False + self.imminent = None async def handle_message(self, message): ''' @@ -27,7 +45,10 @@ async def handle_message(self, message): if message.content == self.CANCEL_KEYWORD: self.state = State.REPORT_COMPLETE + self.cancelled = True return ["Report cancelled."] + else: + self.cancelled = False if self.state == State.REPORT_START: reply = "Thank you for starting the reporting process. " @@ -36,37 +57,395 @@ async def handle_message(self, message): reply += "You can obtain this link by right-clicking the message and clicking `Copy Message Link`." self.state = State.AWAITING_MESSAGE return [reply] + else: + if message.content == self.START_KEYWORD: + reply = "You currently have an active report open, the status is " + self.state.name + ". " + reply += "Please continue this report or say `cancel` to cancel.\n" + return [reply] if self.state == State.AWAITING_MESSAGE: # Parse out the three ID strings from the message link m = re.search('/(\d+)/(\d+)/(\d+)', message.content) if not m: return ["I'm sorry, I couldn't read that link. Please try again or say `cancel` to cancel."] + guild = self.client.get_guild(int(m.group(1))) if not guild: return ["I cannot accept reports of messages from guilds that I'm not in. Please have the guild owner add me to the guild and try again."] + channel = guild.get_channel(int(m.group(2))) if not channel: return ["It seems this channel was deleted or never existed. Please try again or say `cancel` to cancel."] + try: message = await channel.fetch_message(int(m.group(3))) + self.reported_message = message except discord.errors.NotFound: return ["It seems this message was deleted or never existed. Please try again or say `cancel` to cancel."] - # Here we've found the message - it's up to you to decide what to do next! - self.state = State.MESSAGE_IDENTIFIED - return ["I found this message:", "```" + message.author.name + ": " + message.content + "```", \ - "This is all I know how to do right now - it's up to you to build out the rest of my reporting flow!"] + self.state = State.AWAITING_REASON + + # add guild ID so we know where to send the moderation todo + self.message_guild_id = message.guild.id + + self.reported_author = message.author.name + self.reported_content = message.content + + reply = "I found this message:```" + message.author.name + ": " + message.content + "```\n" + reply += "Please select the reason for reporting this message by typing the corresponding number:\n" + reply += "1. Misinformation\n" + reply += "2. Other\n" + return [reply] - if self.state == State.MESSAGE_IDENTIFIED: - return [""] + if self.state == State.AWAITING_REASON: + # Process user's report reason - return [] + if message.content == "1": + # Handling misinformation + self.report_type = "Misinformation" + self.state = State.AWAITING_DISINFORMATION_TYPE - def report_complete(self): - return self.state == State.REPORT_COMPLETE - + reply = "You have selected " + self.report_type + ".\n" + reply += "Please select the type of misinformation by typing the corresponding number:\n" + reply += "1. Political Misinformation\n" + reply += "2. Health Misinformation\n" + reply += "3. Other Misinformation\n" + return [reply] + + elif message.content == "2" : + # Handling Other Abuse types + self.report_type = "Other" + # self.misinfo_type = "[out of scope of project]" + # self.misinfo_subtype = "[out of scope of project]" + self.state = State.REPORT_COMPLETE + # return [ + # "Thank you for reporting " + self.report_type + " content.", + # "Our content moderation team will review the message and take action which may result in content or account removal." + # ] + reply = "Thank you for reporting " + self.report_type + " content.\n" + reply += "Our content moderation team will review the message and take action which may result in content or account removal.\n" + return [reply] + + # elif message.content == "3" : + # # Handling Harassment + # self.report_type = "Harassment" + # self.misinfo_type = "[out of scope of project]" + # self.misinfo_subtype = "[out of scope of project]" + # self.state = State.REPORT_COMPLETE + # return [ + # "Thank you for reporting " + self.report_type + " content.", + # "Our content moderation team will review the message and take action which may result in content or account removal." + # ] + + + # elif message.content == "4" : + # # Handling Spam + # self.report_type = "Spam" + # self.misinfo_type = "[out of scope of project]" + # self.misinfo_subtype = "[out of scope of project]" + # self.state = State.REPORT_COMPLETE + # return [ + # "Thank you for reporting " + self.report_type + " content", + # "Our content moderation team will review the message and take action which may result in content or account removal." + # ] + + else: + # Handling wrong report reason + reply = "Kindly enter a valid report reason by selecting the correponding number:\n" + reply += "1. Misinformation\n" + reply += "2. Other\n" + reply += "Please try again or say `cancel` to cancel.\n" + return [reply] + + if self.state == State.AWAITING_DISINFORMATION_TYPE : + # Process Misinformation options + + if message.content == "1": + # Handle political misinformation + self.state = State.AWAITING_POLITICAL_DISINFORMATION_TYPE + self.misinfo_type = "Political Misinformation" + reply = "You have selected " + self.misinfo_type + ".\n" + reply += "Please select the type of political Misinformation by typing the corresponding number:\n" + reply += "1. Election/Campaign Misinformation\n" + reply += "2. Government/Civic Services\n" + reply += "3. Manipulated Photos/Video\n" + reply += "4. Other\n" + return [reply] + + elif message.content == "2" : + # Handle Health Misinformation + self.state = State.AWAITING_HEALTHL_DISINFORMATION_TYPE + self.misinfo_type = "Health Misinformation" + reply = "You have selected " + self.misinfo_type + ".\n" + reply += "Please select the type of health misinformation by typing the corresponding number:\n" + reply += "1. Vaccines\n" + reply += "2. Cures and Treatments\n" + reply += "3. Mental Health\n" + reply += "4. Other\n" + return [reply] + + + elif message.content == "3" : + # Handle other Misinformation + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + self.misinfo_type = "Other Misinformation" + self.misinfo_subtype = "[out of scope of project]" + reply = "You have selected " + self.misinfo_type + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + else : + # Handling wrong misinformation type + reply = "Kindly enter a valid misinformation type by selecting the correponding number:\n" + reply += "1. Political Misinformation\n" + reply += "2. Health Misinformation\n" + reply += "3. Other Misinformation\n" + reply += "Please try again or say `cancel` to cancel.\n" + return [reply] + if self.state == State.AWAITING_POLITICAL_DISINFORMATION_TYPE : + # Process political misinformation options + if message.content == "1": + # Handling Election/Campaign Misinformation + self.misinfo_subtype = "Election/Campaign Misinformation" + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + reply = "You have selected " + self.misinfo_subtype + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + elif message.content == "2": + # Handling Government/Civic Services + self.misinfo_subtype = "Government/Civic Services" + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + reply = "You have selected " + self.misinfo_subtype + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + elif message.content == "3": + # Handling Manipulated Photos/Video + self.misinfo_subtype = "Manipulated Photos/Video" + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + reply = "You have selected " + self.misinfo_subtype + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + elif message.content == "4": + # Handling Other + self.misinfo_subtype = "Other" + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + reply = "You have selected " + self.misinfo_subtype + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + else : + # Handling wrong political misinformation type + reply = "Please select the type of political Misinformation by typing the corresponding number:\n" + reply += "1. Election/Campaign Misinformation\n" + reply += "2. Government/Civic Services\n" + reply += "3. Manipulated Photos/Video\n" + reply += "4. Other\n" + reply += "Please try again or say `cancel` to cancel." + return [reply] + + if self.state == State.AWAITING_HEALTHL_DISINFORMATION_TYPE: + # Process health misinformation options + + if message.content == "1": + # Handling Vaccines + self.misinfo_subtype = "Vaccines" + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + reply = "You have selected " + self.misinfo_subtype + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + elif message.content == "2": + # Handling Cures and Treatments + self.misinfo_subtype = "Cures and Treatments" + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + reply = "You have selected " + self.misinfo_subtype + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + elif message.content == "3": + # Handling Mental Health + self.misinfo_subtype = "Mental Health" + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + reply = "You have selected " + self.misinfo_subtype + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + elif message.content == "4": + # Handling Other + self.misinfo_subtype = "Other" + self.state = State.AWAITING_HARMFUL_CONTENT_STATUS + reply = "You have selected " + self.misinfo_subtype + ".\n" + reply += "Could this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + return [reply] + + else : + # Handling wrong health misinformation type + reply = "Please select the type of health Misinformation by typing the corresponding number:\n" + reply += "1. Vaccines\n" + reply += "2. Cures and Treatments\n" + reply += "3. Mental Health\n" + reply += "4. Other\n" + reply += "Please try again or say `cancel` to cancel." + return [reply] + + if self.state == State.AWAITING_HARMFUL_CONTENT_STATUS: + # Handle decision making on whether content is harmful + + if message.content == "1" : + # No harmful content + self.state = State.AWAITING_FILTER_ACTION + reply = "Please indicate if you would like to block content from this account on your feed. Select the correponding number:\n" + reply += "1. No \n" + reply += "2. Yes \n" + return [reply] + + elif message.content in ["2", "3", "4"] : + # Harmful content + harm_dict = { + "2": "physical", + "3": "mental", + "4": "financial" + } + self.imminent = harm_dict[message.content] + self.state = State.AWAITING_FILTER_ACTION + reply = "Thank you. Our team has been notified.\n" + reply += "Please indicate if you would like to block content from this account on your feed. Select the correponding number:\n" + reply += "1. No \n" + reply += "2. Yes \n" + return [reply] + + else: + # Handle wrong response to harmful prompt + reply = "Kindly indicate if this content likely cause imminent harm to people or public safety? Select the correponding number:\n" + reply += "1. No.\n" + reply += "2. Yes, physical harm.\n" + reply += "3. Yes, mental harm.\n" + reply += "4. Yes, financial or property harm.\n" + reply += "Please try again or say `cancel` to cancel." + return [reply] + + + if self.state == State.AWAITING_FILTER_ACTION: + # Handling responses to filter account content + + if message.content == "1": + # Handle no content filtering action + self.state = State.REPORT_COMPLETE + reply = "Thank you for reporting " + self.report_type + " content.\n" + reply += "Our content moderation team will review the message and take action which may result in content or account removal.\n" + return [reply] + + elif message.content == "2": + # Handle content filtering action + self.filter = True + self.state = State.REPORT_COMPLETE + reply = "Thank you for reporting " + self.report_type + " content.\n" + reply += "Our content moderation team will review the message and take action which may result in content or account removal.\n" + return [reply] + + else: + # wrong option for account filtering prompt + reply = "Would you like to filter content from this account on your feed? Select the correponding number:\n" + reply += "1. Yes\n" + reply += "2. No\n" + reply += "Please try again or say `cancel` to cancel." + return [reply] + +# if self.state == State.BLOCK_STEP: +# # if user wants to block then block +# user_wants_to_block = True +# return [user_wants_to_block] + + return {} + + #getters for state + def get_message_guild_id(self): + return self.message_guild_id + def get_reported_author(self): + return self.reported_author + def get_reported_content(self): + return self.reported_content + def get_report_type(self): + return self.report_type + def get_misinfo_type(self): + return self.misinfo_type + def get_misinfo_subtype(self): + return self.misinfo_subtype + def get_imminent(self): + return self.imminent + def get_priority(self): # defining priorities, can be changed + if self.imminent in ["physical", "mental"]: + return 0 + elif self.imminent == "financial": + return 1 + else: + return 2 + def get_filter(self): + return self.filter + def get_reported_message(self): + return self.reported_message + def is_report_start(self): + return self.state == State.REPORT_START + def is_awaiting_message(self): + return self.state == State.AWAITING_MESSAGE + def is_awaiting_reason(self): + return self.state == State.AWAITING_REASON + def is_awaiting_misinformation_type(self): + return self.state == State.AWAITING_DISINFORMATION_TYPE + def is_awaiting_political_misinformation_type(self): + return self.state == State.AWAITING_POLITICAL_DISINFORMATION_TYPE + def is_awaiting_healthl_misinformation_type(self): + return self.state == State.AWAITING_HEALTHL_DISINFORMATION_TYPE + def is_awaiting_harmful_content_status(self): + return self.state == State.AWAITING_HARMFUL_CONTENT_STATUS + def is_awaiting_filter_action(self): + return self.state == State.AWAITING_FILTER_ACTION + # def block_step(self): + # return self.state == State.BLOCK_STEP + def is_report_complete(self): + return self.state == State.REPORT_COMPLETE + def is_cancelled(self): + return self.cancelled diff --git a/DiscordBot/report_queue.py b/DiscordBot/report_queue.py new file mode 100644 index 00000000..85918ca5 --- /dev/null +++ b/DiscordBot/report_queue.py @@ -0,0 +1,77 @@ +from collections import deque + +class SubmittedReport: + def __init__(self, id, reported_message, author, content, report_type, misinfo_type, misinfo_subtype, imminent, message_guild_id, priority): + self.author = author + self.id = id + self.reported_message = reported_message + self.content = content + self.report_type = report_type + self.misinfo_type = misinfo_type + self.subtype = misinfo_subtype + self.imminent = imminent + self.message_guild_id = message_guild_id + self.priority = priority + +class PriorityReportQueue: + def __init__(self, num_levels, queue_names): + self.num_queues = num_levels + self.queue_names = queue_names + self.queues = [deque() for _ in range(num_levels)] + + def enqueue(self, report): + if not (0 <= report.priority < len(self.queues)): + raise ValueError("Invalid priority level") + self.queues[report.priority].append(report) + + def dequeue(self): + for queue in self.queues: + if queue: + return queue.popleft() + raise IndexError("All queues are empty") + + def is_empty(self): + return all(len(q) == 0 for q in self.queues) + + def __getitem__(self, priority): + return list(self.queues[priority]) + + def summary(self): + out = "```" + out += "Priority | Queue Name | # Reports\n" + out += "-" * 58 + "\n" + total = 0 + for i in range(self.num_queues): + queue = self.queues[i] + out += f"{i:^8} | {self.queue_names[i]:<35} | {len(queue):^9}\n" + total += len(queue) + out += "-" * 58 + "\n" + out += f"Total pending reports: {total}\n" + out += "```" + return out + + def display_one(self, report, showContent=False): + output = ( + f" Report ID: {report.id}\n" + f" Author: {report.author}\n" + f" Type: {report.misinfo_type}\n" + f" Subtype: {report.subtype}\n" + f" Imminent: {report.imminent}\n" + ) + if showContent: + output += f" Content: `{report.content}`\n" + return output + + def display(self, showContent=False): + output = "" + for i in range(self.num_queues): + output += f"--- Priority {i}: {self.queue_names[i]} ---\n" + queue = self.queues[i] + if not queue: + output += " (No reports)\n" + else: + for idx, report in enumerate(queue): + output += f" [{idx+1}]\n" + output += self.display_one(report, showContent) + return output.strip() + diff --git a/LLM/.gitignore b/LLM/.gitignore new file mode 100644 index 00000000..12ef38ef --- /dev/null +++ b/LLM/.gitignore @@ -0,0 +1,2 @@ +api_key.txt +test.py \ No newline at end of file diff --git a/LLM/LLM_reports.py b/LLM/LLM_reports.py new file mode 100644 index 00000000..3e4f91f9 --- /dev/null +++ b/LLM/LLM_reports.py @@ -0,0 +1,308 @@ +from google import genai +from google.genai import types +import re + +# Load API key from a text file +try: + with open("api_key.txt", "r") as f: + api_key = f.read().strip() + client = genai.Client(api_key=api_key) +except FileNotFoundError: + print("Error: API key file not found. Create 'api_key.txt' with your API key.") + exit(1) +except Exception as e: + print(f"Error loading API key: {e}") + exit(1) + + +def call_gemini(sys_instruction, content): + try : + response = client.models.generate_content( + model= "gemini-2.0-flash", + config=types.GenerateContentConfig( + system_instruction= sys_instruction), + contents= content + ) + + # print(f"LLM output is: {response.text}") + return response.text + + except Exception as e : + print(f"Error connecting to LLM: {e}") + return None + + + + + +# Function to invoke report generation +def LLM_report(message_content, classifier_label, confidence_score,metadata, reporter_info = 'Classifier'): + + # Dictionary for keeping track of report details + report_details = { + 'message_guild_id' : f"{metadata.get('message_guild_id')}", + 'classifier_label' : classifier_label, + 'confidence_score' : confidence_score, + 'reported_author' : f"{metadata.get('message_author')}", + 'reported_content' : message_content, + 'report_type' : None, + 'misinfo_type' : None, + 'misinfo_subtype': None, + 'imminent' : None, + 'filter' : False, + 'LLM_recommendation' : None + } + + # Perform initial Classification + report_type_response = call_report_type(message_content, classifier_label, confidence_score,metadata) + report_type_response = report_type_response[0] + # report_type_response = re.search(r'(\d+)',report_type_response) + print(f"Report type response is: {report_type_response}") + # Update misinfo_type in report details + if report_type_response in ["1", "2"] : + report_details['report_type'] = "Misinformation" if report_type_response == "1" else "other" + + # Initiate userflow for misiniformation + if report_type_response == "1" : + + # Call to classify type of misinformation + misinfo_type_response = call_misinfo_type(message_content) + misinfo_type_response = misinfo_type_response[0] + + + #================== Decision logic for Misinformation Type Response ================== + + # Political Misinfo + if misinfo_type_response == "1" : + report_details ['misinfo_type'] = "Political Misinformation" + + # Call to classify political misinfo subtype + pol_misinfo_subtype_response = call_pol_misinfo_subtype(message_content) + pol_misinfo_subtype_response = pol_misinfo_subtype_response[0] + + #=============== Decision logic for Political misinfo subtye response =============== + if pol_misinfo_subtype_response == "1": + report_details['misinfo_subtype'] = 'Election/Campaign Misinformation' + + elif pol_misinfo_subtype_response == "2": + report_details['misinfo_subtype'] = 'Government/Civic Services' + + elif pol_misinfo_subtype_response == "3": + report_details['misinfo_subtype'] = 'Manipulated Photos/Video' + + elif pol_misinfo_subtype_response == "4": + report_details['misinfo_subtype'] = 'Other' + + # Health Misinfo + elif misinfo_type_response == "2" : + report_details ['misinfo_type'] = "Health Misinformation" + + # Call to classify health misinfo subtype + health_misinfo_subtype_response = call_health_misinfo_subtype(message_content) + health_misinfo_subtype_response = health_misinfo_subtype_response[0] + + #=============== Decision logic for Health misinfo subtye response =============== + if health_misinfo_subtype_response == "1": + report_details['misinfo_subtype'] = 'Vaccines' + + elif health_misinfo_subtype_response == "2": + report_details['misinfo_subtype'] = 'Cures and Treatments' + + elif health_misinfo_subtype_response == "3": + report_details['misinfo_subtype'] = 'Mental Health' + + elif health_misinfo_subtype_response == "4": + report_details['misinfo_subtype'] = 'Oher' + + + elif misinfo_type_response == "3" : + report_details ['misinfo_type'] = "Other Misinformation" + report_details['misinfo_subtype'] = 'Other' + + + # Initiate userflow for Harmful content + imminent_response = call_imminent(message_content) + imminent_response = imminent_response[0] + + #================== Decision logic for Imminent Harm Response ================== + + if imminent_response == "2": + report_details['imminent'] = 'physical' + + elif imminent_response == "3": + report_details['imminent'] = 'mental' + + elif imminent_response == "4": + report_details['imminent'] = 'financial or property' + + """ + Discussion : Not sure if to factor in the filter flag since this is detected automatically and + not specific to a particular user's feed + """ + + # Initiate userflow for LLM Recommendation + recommendation_response = call_recommedation(message_content, report_details['imminent'], report_details['confidence_score']) + report_details['LLM_recommendation'] = recommendation_response + + # Think about logic for instances where LLM returns non option value + + return report_details + + + + +def call_report_type(message_content, classifier_label, confidence_score,metadata): + # Step 1: Initial classification - Misinformation or Other + print("====Step 1: Initial classification - Misinformation or Other===") + print(f"Message: {message_content}") + + system_instruction = f""" + You are a trust & safety expert content moderator for a social media platform who has been assigned to generate a user + report for a post that has been flagged by the platform's classifier. + """ + + content = f""" + + Message Content : {message_content}, + + Initial Classification from the Automated Post Classifier: + - Label : {classifier_label}, + - Confidence : {confidence_score}, + + Metadata : + - Hashtags : {metadata.get('hashtags', 'Unkown')}, + - Previous Violation Count : {metadata.get('violation count', '0')} + + Validate the classifier's decision by selecting a category: + 1. Misinformation + 2. Other inappropriate content + + Respond with ONLY the number (1 or 2). + """ + + + return call_gemini (system_instruction, content) + +def call_misinfo_type (message_content): + # Step 2: Type of Misinformation + print("====Step 2: Misinformation type ===") + + system_instruction = f""" + You are a misinformation trust & safety expert content moderator for a social media platform who has been assigned to analyze content reported + as misinformation. + """ + + content = f""" + Message Content: {message_content} + Please select the type of misinformation: + 1. Political Misinformation + 2. Health Misinformation + 3. Other Misinformation + + Respond with ONLY the number (1-3). + """ + + return call_gemini(system_instruction, content) + + + +def call_pol_misinfo_subtype(message_content): + # Step 3a. Type of Political Misinformation + print("====Step 3a. Type of Political Misinformation ===") + + system_instruction = f""" + You are a political trust & safety expert content moderator for a social media platform who has been assigned to analyze content reported + as political misinformation. + """ + + content = f""" + Message Content: {message_content} + Classify the type of political misinformation which the message falls under : + 1. Election/Campaign Misinformation + 2. Government/Civic Services + 3. Manipulated Photos/Video + 4. Other political misinformation + + Respond with ONLY the number (1-4). + """ + + return call_gemini(system_instruction, content) + + + +def call_health_misinfo_subtype(message_content): + # Step 3b. Type of Health Misinformation + print("====Step 3b. Type of Health Misinformation ===") + + system_instruction = f""" + You are a health trust & safety expert content moderator for a social media platform who has been assigned to analyze content reported + as health misinformation. + """ + + content = f""" + Message Content: {message_content} + Classify the type of health misinformation which the message falls under : + 1. Vaccines + 2. Cures and Treatments + 3. Mental Health + 4. Other health misinformation + + Respond with ONLY the number (1-4). + """ + + return call_gemini(system_instruction,content) + + +def call_imminent(message_content): + # Step 4: Imminent Harm + print("====Step 4: Imminent Harm===") + + system_instruction = f""" + You are a trust & safety expert content moderator for a social media platform who has been assigned to analyze content reported + and assess potential harm of the reported content. + """ + + content = f""" + Message Content: {message_content} + Could this content likely cause imminent harm to people or public safety? + 1. No + 2. Yes, physical harm + 3. Yes, mental harm + 4. Yes, financial or property harm + + Respond with ONLY the number (1-4). + """ + + return call_gemini(system_instruction, content) + + + +def call_recommedation(message_content, harm, score): + # Step 5: Recommendation + print("====Step 5: Recommendation===") + + system_instruction = f""" + You are a trust & safety expert content moderator for a social media platform who has been assigned to analyze content reported + and assess based on its potential harm, message content and confidence score, recommend an action which should be limited to the + options provided . + """ + + content = f""" + Message Content: {message_content} + Potential Harm Label : {harm} + Confidence Score : {score} + + Based on the message content, potential harm label and confidence score, which of the following do you recommend : + 1. Allow Content + 2. Remove Content + 3. Escalate to a human moderator + + Respond with ONLY one of these phrases: 'Allow Content', 'Remove Content', or 'Escalate to a human moderator' and in less than 80 words, + justify your recommendation. Adhere strictly to the word limit of 80. + """ + + return call_gemini(system_instruction,content) +# Recommendation based on filter, imminent harm, + + + diff --git a/LLM/__pycache__/LLM_reports.cpython-313.pyc b/LLM/__pycache__/LLM_reports.cpython-313.pyc new file mode 100644 index 00000000..cf3331b6 Binary files /dev/null and b/LLM/__pycache__/LLM_reports.cpython-313.pyc differ diff --git a/LLM/demo.py b/LLM/demo.py new file mode 100644 index 00000000..b993f540 --- /dev/null +++ b/LLM/demo.py @@ -0,0 +1,108 @@ +import sys +import os +from datetime import datetime +from LLM_reports import LLM_report + + +# Test cases +test_data_1 = { + "message_content": "BREAKING: Scientists confirm that COVID vaccines contain microchips to track your movements. Share before this gets deleted!", + "classifier_label": "potential_misinfo", + "confidence_score": 0.92, + "metadata": { + "message_guild_id": "123456789012345678", + "message_author": "HealthTruth88", + "hashtags": ["#vaccine", "#truth", "#covid"], + "violation count": 2 + }, + "reporter_info": "ModerationBot" +} + +test_data_2 = { + "message_content": "CONFIRMED: The election was rigged! Counting machines were programmed to switch votes. The evidence is being suppressed by the government!", + "classifier_label": "political_misinfo", + "confidence_score": 0.89, + "metadata": { + "message_guild_id": "876543210987654321", + "message_author": "PatriotTruth1776", + "hashtags": ["#election", "#fraud", "#truth"], + "violation count": 1 + }, + "reporter_info": "ElectionIntegrity" +} + +test_data_3 = { + "message_content": "Scientists have been hiding the truth: the Earth is flat and NASA has been faking space photos for decades!", + "classifier_label": "potential_misinfo", + "confidence_score": 0.75, + "metadata": { + "message_guild_id": "555666777888999000", + "message_author": "TruthSeeker42", + "hashtags": ["#flatearth", "#nasa", "#truth"], + "violation count": 0 + }, + "reporter_info": "ScienceChannel" +} + +test_data_4 = { + "message_content": "I've heard some people say that drinking lemon water every morning can boost your immune system by 50%. Interesting!", + "classifier_label": "potential_misinfo", + "confidence_score": 0.58, + "metadata": { + "message_guild_id": "111222333444555666", + "message_author": "WellnessJourney", + "hashtags": ["#health", "#wellness", "#immunity"], + "violation count": 0 + }, + "reporter_info": "HealthFacts" +} + +test_data_5 = { + "message_content": "According to the CDC website updated yesterday, vaccination rates have increased by 15% nationwide in the past month.", + "classifier_label": "potential_misinfo", + "confidence_score": 0.35, + "metadata": { + "message_guild_id": "999888777666555444", + "message_author": "NewsUpdates", + "hashtags": ["#covid", "#vaccination", "#data"], + "violation count": 0 + }, + "reporter_info": "FactChecker" +} + +test_data_6 = { + "message_content": "URGENT investment opportunity! Send $500 in Bitcoin to this wallet and I guarantee you'll receive $5000 back within 24 hours. This is a secret method banks don't want you to know!", + "classifier_label": "potential_scam", + "confidence_score": 0.95, + "metadata": { + "message_guild_id": "444333222111000999", + "message_author": "WealthMaker2023", + "hashtags": ["#investment", "#bitcoin", "#getrich"], + "violation count": 3 + }, + "reporter_info": "ScamAlert" +} + +if __name__ == "__main__": + print("Starting LLM Report Tests...") + print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + test_cases = [test_data_1, test_data_2, test_data_3, test_data_4, test_data_5, test_data_6] + + # Generate report details for sample flagged posts + for post in test_cases : + report_details = LLM_report(post["message_content"], + post["classifier_label"], + post["confidence_score"], + post["metadata"], + post["reporter_info"]) + + + print("\nResults:") + print(f"Report Type: {report_details.get('report_type', 'Not classified')}") + print(f"Misinfo Type: {report_details.get('misinfo_type', 'N/A')}") + print(f"Misinfo Subtype: {report_details.get('misinfo_subtype', 'N/A')}") + print(f"Imminent Harm: {report_details.get('imminent', 'None')}") + print(f"Filter Recommendation: {report_details.get('filter', 'No recommendation')}") + print(f"LLM Recommendation: {report_details.get('LLM_recommendation', 'No recommendation')}") + + print("\nTests complete!") \ No newline at end of file diff --git a/LLM/prompts.txt b/LLM/prompts.txt new file mode 100644 index 00000000..1c039181 --- /dev/null +++ b/LLM/prompts.txt @@ -0,0 +1,4 @@ +// Classifier + + +- Set test data to test at each func implementation \ No newline at end of file diff --git a/classifier_gcp/.gitignore b/classifier_gcp/.gitignore new file mode 100644 index 00000000..0d29fb0a --- /dev/null +++ b/classifier_gcp/.gitignore @@ -0,0 +1 @@ +gcp_key.json \ No newline at end of file diff --git a/classifier_gcp/Dockerfile b/classifier_gcp/Dockerfile new file mode 100644 index 00000000..4ee92d39 --- /dev/null +++ b/classifier_gcp/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.12 + +WORKDIR /app +COPY . . + +RUN pip install --no-cache-dir -r requirements.txt +CMD ["python", "main.py"] \ No newline at end of file diff --git a/classifier_gcp/README.md b/classifier_gcp/README.md new file mode 100644 index 00000000..592dc353 --- /dev/null +++ b/classifier_gcp/README.md @@ -0,0 +1,30 @@ +this folder will be for the code in GCP that runs our classifier + +- main.py: Flask server +- requirements.txt: dependencies +- Dockerfile: creates the python runtime for our code. I probably need to update this. also not sure what version of python to run but I assume 3.12 is fine. + +TODOs: +- once rhea gets the model trained, get it locally and then upload to gcp: +`gsutil cp path/to/your_model.pkl gs://pol-disinfo-classifier/` +- deploy cloud run instance: + ``` + gcloud builds submit --tag gcr.io/YOUR_PROJECT_ID/discord-classifier + gcloud run deploy discord-classifier \ + --image gcr.io/YOUR_PROJECT_ID/discord-classifier \ + --platform managed \ + --region us-central1 \ + --no-allow-unauthenticated + ``` +- use the public URL that is given for our code + + +we can curl the endpoint using this: +``` +curl -H "Authorization: Bearer $(gcloud auth print-identity-token)" \ + -H "Content-Type: application/json" \ + -X POST \ + -d '{"message": "this is a test"}' \ + https://your-service-url.a.run.app/classify + +``` \ No newline at end of file diff --git a/classifier_gcp/__pycache__/model.cpython-38.pyc b/classifier_gcp/__pycache__/model.cpython-38.pyc new file mode 100644 index 00000000..6aa10e84 Binary files /dev/null and b/classifier_gcp/__pycache__/model.cpython-38.pyc differ diff --git a/classifier_gcp/main.py b/classifier_gcp/main.py new file mode 100644 index 00000000..a84a867e --- /dev/null +++ b/classifier_gcp/main.py @@ -0,0 +1,57 @@ +from flask import Flask, request, jsonify +from pytorch_pretrained_bert import BertTokenizer, BertModel, BertConfig +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Import your model definition (or paste class directly) +from model import BertForSequenceClassification # or define directly below if easier + +app = Flask(__name__) + +# Load tokenizer +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + +# BERT config (must match your training config) +config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + +# Load model and weights +model = BertForSequenceClassification(num_labels=2) +model.load_state_dict(torch.load("bert_model_finetuned.pth", map_location=torch.device('cpu'))) +model.eval() + +def preprocess(text, max_len): + tokens = tokenizer.tokenize(text if text else "None") + tokens = tokens[:max_len] + token_ids = tokenizer.convert_tokens_to_ids(tokens) + padded = token_ids + [0] * (max_len - len(token_ids)) + return torch.tensor(padded).unsqueeze(0) # shape: (1, max_len) + +@app.route("/classify", methods=["POST"]) +def predict(): + data = request.json + + statement = data.get("message", "") + justification = data.get("justification", "") + metadata = data.get("metadata", "") + credit = data.get("credit_score", 0.5) + + input_ids1 = preprocess(statement, max_len=64) + input_ids2 = preprocess(justification, max_len=256) + input_ids3 = preprocess(metadata, max_len=32) + credit_tensor = torch.tensor([credit] * 2304).unsqueeze(0) # shape (1, 2304) + + with torch.no_grad(): + logits = model(input_ids1, input_ids2, input_ids3, credit_tensor) + probs = F.softmax(logits, dim=1) + confidence = probs[0][1].item() + prediction = "misinformation" if confidence > 0.5 else "factual" + + return jsonify({ + "classification": prediction, + "confidence_score": round(confidence, 4) + }) + +if __name__ == "__main__": + app.run(debug=True) \ No newline at end of file diff --git a/classifier_gcp/model.py b/classifier_gcp/model.py new file mode 100644 index 00000000..2ca814ad --- /dev/null +++ b/classifier_gcp/model.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +from pytorch_pretrained_bert import BertModel, BertConfig + +# Optional layer norm class (not currently used, but included for completeness) +class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +# Main classifier model class +class BertForSequenceClassification(nn.Module): + def __init__(self, num_labels=2): + super(BertForSequenceClassification, self).__init__() + self.num_labels = num_labels + self.config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + self.bert = BertModel.from_pretrained('bert-base-uncased') + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Linear(self.config.hidden_size * 3, num_labels) + nn.init.xavier_normal_(self.classifier.weight) + + def forward_once(self, input_ids, token_type_ids=None, attention_mask=None): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + return pooled_output + + def forward(self, input_ids1, input_ids2, input_ids3, credit_sc): + output1 = self.forward_once(input_ids1) + output2 = self.forward_once(input_ids2) + output3 = self.forward_once(input_ids3) + + out = torch.cat((output1, output2, output3), dim=1) + out = out + credit_sc # add credit score vector + + logits = self.classifier(out) + return logits \ No newline at end of file diff --git a/classifier_gcp/requirements.txt b/classifier_gcp/requirements.txt new file mode 100644 index 00000000..7c28ff58 --- /dev/null +++ b/classifier_gcp/requirements.txt @@ -0,0 +1,10 @@ +transformers +datasets +torch +pandas +Flask +joblib +scikit-learn +pytorch_pretrained_bert +requests +numpy \ No newline at end of file diff --git a/classifier_gcp/test_classifier.py b/classifier_gcp/test_classifier.py new file mode 100644 index 00000000..b9bc6d87 --- /dev/null +++ b/classifier_gcp/test_classifier.py @@ -0,0 +1,12 @@ +import requests + +data = { + "message": "Joe Biden banned all beef products in the US.", + "justification": "A claim on a partisan blog.", + "metadata": "Biden politics diet USA", + "credit_score": 0.5 +} + + +response = requests.post("http://127.0.0.1:5000/classify", json=data) +print(response.json()) \ No newline at end of file