Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.

Commit a643cd1

Browse files
authored
Optimize signal ingestion pipeline and create perf-test cli util (#3337)
* Optimize signal ingestion pipeline and create perf-test cli util * Update src/dispatch/cli.py * address feedback
1 parent 3812486 commit a643cd1

File tree

4 files changed

+215
-30
lines changed

4 files changed

+215
-30
lines changed

src/dispatch/cli.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import logging
22
import os
3+
import time
34

45
import click
6+
import requests
57
import uvicorn
68
from dispatch import __version__, config
9+
from dispatch.config import DISPATCH_UI_URL
710
from dispatch.enums import UserRoles
811
from dispatch.plugin.models import PluginInstance
912

@@ -689,6 +692,104 @@ def start_tasks(tasks, exclude, eager):
689692
scheduler.start()
690693

691694

695+
@dispatch_scheduler.command("perf-test")
696+
@click.option("--num-instances", default=1000, help="Number of signal instances to send.")
697+
@click.option("--num-workers", default=1000, help="Number of threads to use.")
698+
@click.option(
699+
"--api-endpoint",
700+
default=f"{DISPATCH_UI_URL}/api/v1/default/signals/instances",
701+
required=True,
702+
help="API endpoint to send the signal instances.",
703+
)
704+
@click.option(
705+
"--api-token",
706+
required=True,
707+
help="API token to use.",
708+
)
709+
@click.option(
710+
"--project",
711+
default="Test",
712+
required=True,
713+
help="The Dispatch project to send the instances to",
714+
)
715+
def perf_test(
716+
num_instances: int, num_workers: int, api_endpoint: str, api_token: str, project: str
717+
) -> None:
718+
"""Performance testing utility for creating signal instances."""
719+
import concurrent.futures
720+
from fastapi import status
721+
722+
NUM_SIGNAL_INSTANCES = num_instances
723+
NUM_WORKERS = num_workers
724+
725+
session = requests.Session()
726+
session.headers.update(
727+
{
728+
"Content-Type": "application/json",
729+
"Authorization": f"Bearer {api_token}",
730+
}
731+
)
732+
start_time = time.time()
733+
734+
def _send_signal_instance(
735+
api_endpoint: str,
736+
api_token: str,
737+
session: requests.Session,
738+
signal_instance: dict[str, str],
739+
) -> None:
740+
try:
741+
r = session.post(
742+
api_endpoint,
743+
json=signal_instance,
744+
headers={
745+
"Content-Type": "application/json",
746+
"Authorization": f"Bearer {api_token}",
747+
},
748+
)
749+
log.info(f"Response: {r.json()}")
750+
if r.status_code == status.HTTP_401_UNAUTHORIZED:
751+
raise PermissionError(
752+
"Unauthorized. Please check your bearer token. You can find it in the Dev Tools under Request Headers -> Authorization."
753+
)
754+
755+
r.raise_for_status()
756+
757+
except requests.exceptions.RequestException as e:
758+
log.error(f"Unable to send finding. Reason: {e} Response: {r.json() if r else 'N/A'}")
759+
else:
760+
log.info(f"{signal_instance.get('raw', {}).get('id')} created succesfully")
761+
762+
def send_signal_instances(
763+
api_endpoint: str, api_token: str, signal_instances: list[dict[str, str]]
764+
):
765+
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
766+
futures = [
767+
executor.submit(
768+
_send_signal_instance,
769+
api_endpoint=api_endpoint,
770+
api_token=api_token,
771+
session=session,
772+
signal_instance=signal_instance,
773+
)
774+
for signal_instance in signal_instances
775+
]
776+
results = [future.result() for future in concurrent.futures.as_completed(futures)]
777+
778+
log.info(f"\nSent {len(results)} of {NUM_SIGNAL_INSTANCES} signal instances")
779+
780+
signal_instances = [
781+
{
782+
"project": {"name": project},
783+
"raw": {},
784+
},
785+
] * NUM_SIGNAL_INSTANCES
786+
787+
send_signal_instances(api_endpoint, api_token, signal_instances)
788+
789+
elapsed_time = time.time() - start_time
790+
click.echo(f"Elapsed time: {elapsed_time:.2f} seconds")
791+
792+
692793
@dispatch_cli.group("server")
693794
def dispatch_server():
694795
"""Container for all dispatch server commands."""

src/dispatch/signal/flows.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from datetime import timedelta
12
import logging
3+
from cachetools import TTLCache
24

35
from email_validator import validate_email, EmailNotValidError
46
from sqlalchemy.orm import Session
@@ -40,13 +42,13 @@ def signal_instance_create_flow(
4042
db_session.commit()
4143

4244
# we don't need to continue if a filter action took place
43-
if signal_service.filter_signal(db_session=db_session, signal_instance=signal_instance):
45+
if signal_service.filter_signal(
46+
db_session=db_session,
47+
signal_instance=signal_instance,
48+
):
4449
# If the signal was deduplicated, we can assume a case exists,
4550
# and we need to update the corresponding signal message
46-
if (
47-
signal_instance.filter_action == SignalFilterAction.deduplicate
48-
and signal_instance.case.signal_thread_ts # noqa
49-
):
51+
if _should_update_signal_message(signal_instance):
5052
update_signal_message(
5153
db_session=db_session,
5254
signal_instance=signal_instance,
@@ -226,3 +228,32 @@ def update_signal_message(db_session: Session, signal_instance: SignalInstance)
226228
db_session=db_session,
227229
thread_id=signal_instance.case.signal_thread_ts,
228230
)
231+
232+
233+
_last_nonupdated_signal_cache = TTLCache(maxsize=4, ttl=60)
234+
235+
236+
def _should_update_signal_message(signal_instance: SignalInstance) -> bool:
237+
"""
238+
Determine if the signal message should be updated based on the filter action and time since the last update.
239+
"""
240+
global _last_nonupdated_signal_cache
241+
242+
case_id = str(signal_instance.case_id)
243+
244+
if case_id not in _last_nonupdated_signal_cache:
245+
_last_nonupdated_signal_cache[case_id] = signal_instance
246+
return True
247+
248+
last_nonupdated_signal = _last_nonupdated_signal_cache[case_id]
249+
time_since_last_update = signal_instance.created_at - last_nonupdated_signal.created_at
250+
251+
if (
252+
signal_instance.filter_action == SignalFilterAction.deduplicate
253+
and signal_instance.case.signal_thread_ts # noqa
254+
and time_since_last_update >= timedelta(seconds=5) # noqa
255+
):
256+
_last_nonupdated_signal_cache[case_id] = signal_instance
257+
return True
258+
else:
259+
return False

src/dispatch/signal/scheduled.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@
44
:copyright: (c) 2022 by Netflix Inc., see AUTHORS for more
55
:license: Apache, see LICENSE for more details.
66
"""
7+
from datetime import datetime, timedelta, timezone
78
import logging
9+
import queue
10+
from sqlalchemy import asc
11+
from sqlalchemy.orm import scoped_session
12+
813
from schedule import every
9-
from dispatch.database.core import SessionLocal
14+
from dispatch.database.core import SessionLocal, sessionmaker, engine
1015
from dispatch.scheduler import scheduler
1116
from dispatch.project.models import Project
1217
from dispatch.plugin import service as plugin_service
1318
from dispatch.signal import flows as signal_flows
14-
from dispatch.decorators import scheduled_project_task
19+
from dispatch.decorators import scheduled_project_task, timer
1520
from dispatch.signal.models import SignalInstance
1621

1722
log = logging.getLogger(__name__)
@@ -47,27 +52,77 @@ def consume_signals(db_session: SessionLocal, project: Project):
4752
log.debug(signal_instance_data)
4853
log.exception(e)
4954

50-
if signal_instances:
51-
plugin.instance.delete()
5255

56+
@timer
57+
def process_signal_instance(db_session: SessionLocal, signal_instance: SignalInstance) -> None:
58+
try:
59+
signal_flows.signal_instance_create_flow(
60+
db_session=db_session,
61+
signal_instance_id=signal_instance.id,
62+
)
63+
except Exception as e:
64+
log.debug(signal_instance)
65+
log.exception(e)
66+
finally:
67+
db_session.close()
68+
69+
70+
MAX_SIGNAL_INSTANCES = 500
71+
signal_queue = queue.Queue(maxsize=MAX_SIGNAL_INSTANCES)
5372

73+
74+
@timer
5475
@scheduler.add(every(1).minutes, name="signal-process")
5576
@scheduled_project_task
5677
def process_signals(db_session: SessionLocal, project: Project):
57-
"""Processes signals and create cases if appropriate."""
78+
"""
79+
Process signals and create cases if appropriate.
80+
81+
This function processes signals within a given project, creating cases if necessary.
82+
It runs every minute, processing signals that meet certain criteria within the last 5 minutes.
83+
Signals are added to a queue for processing, and then each signal instance is processed.
84+
85+
Args:
86+
db_session: The database session used to query and update the database.
87+
project: The project for which the signals will be processed.
88+
89+
Notes:
90+
The function is decorated with three decorators:
91+
- scheduler.add: schedules the function to run every minute.
92+
- scheduled_project_task: ensures that the function is executed as a scheduled project task.
93+
94+
The function uses a queue to process signal instances in a first-in-first-out (FIFO) order
95+
This ensures that signals are processed in the order they were added to the queue.
96+
97+
A scoped session is used to create a new database session for each signal instance
98+
This ensures that each signal instance is processed using its own separate database connection,
99+
preventing potential issues with concurrent connections.
100+
"""
101+
one_hour_ago = datetime.now(timezone.utc) - timedelta(hours=1)
58102
signal_instances = (
59-
db_session.query(SignalInstance)
60-
.filter(SignalInstance.project_id == project.id)
61-
.filter(SignalInstance.filter_action == None) # noqa
62-
.filter(SignalInstance.case_id == None) # noqa
63-
).limit(100)
103+
(
104+
db_session.query(SignalInstance)
105+
.filter(SignalInstance.project_id == project.id)
106+
.filter(SignalInstance.filter_action == None) # noqa
107+
.filter(SignalInstance.case_id == None) # noqa
108+
.filter(SignalInstance.created_at >= one_hour_ago)
109+
)
110+
.order_by(asc(SignalInstance.created_at))
111+
.limit(MAX_SIGNAL_INSTANCES)
112+
)
113+
# Add each signal_instance to the queue for processing
64114
for signal_instance in signal_instances:
65-
log.info(f"Attempting to process the following signal: {signal_instance.id}")
66-
try:
67-
signal_flows.signal_instance_create_flow(
68-
db_session=db_session,
69-
signal_instance_id=signal_instance.id,
70-
)
71-
except Exception as e:
72-
log.debug(signal_instance)
73-
log.exception(e)
115+
signal_queue.put(signal_instance)
116+
117+
schema_engine = engine.execution_options(
118+
schema_translate_map={
119+
None: "dispatch_organization_default",
120+
}
121+
)
122+
session = scoped_session(sessionmaker(bind=schema_engine))
123+
124+
# Process each signal instance in the queue
125+
while not signal_queue.empty():
126+
signal_instance = signal_queue.get()
127+
db_session = session()
128+
process_signal_instance(db_session, signal_instance)

src/dispatch/signal/service.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional
44

55
from pydantic.error_wrappers import ErrorWrapper, ValidationError
6-
from sqlalchemy import asc
6+
from sqlalchemy import desc, asc
77
from sqlalchemy.orm import Session
88

99
from dispatch.auth.models import DispatchUser
@@ -552,13 +552,11 @@ def filter_signal(*, db_session: Session, signal_instance: SignalInstance) -> bo
552552
SignalInstance.signal_id == signal_instance.signal_id,
553553
SignalInstance.created_at >= default_dedup_window,
554554
SignalInstance.id != signal_instance.id,
555-
SignalInstance.case_id.isnot(None),
555+
SignalInstance.case_id.isnot(None), # noqa
556556
)
557-
.order_by(asc(SignalInstance.created_at))
558-
.all()
557+
.order_by(desc(SignalInstance.created_at))
559558
)
560-
561-
if default_dedup_query:
559+
if default_dedup_query.all():
562560
signal_instance.case_id = default_dedup_query[0].case_id
563561
signal_instance.filter_action = SignalFilterAction.deduplicate
564562
filtered = True

0 commit comments

Comments
 (0)