Skip to content

ref(spans): Pin redis sharding to partition index #93209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/sentry/scripts/spans/add-buffer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ ARGS:

]]--

local project_and_trace = KEYS[1]
local partition = KEYS[1]

local num_spans = ARGV[1]
local parent_span_id = ARGV[2]
local has_root_span = ARGV[3] == "true"
local set_timeout = tonumber(ARGV[4])
local project_and_trace = ARGV[1]
local num_spans = ARGV[2]
local parent_span_id = ARGV[3]
local has_root_span = ARGV[4] == "true"
local set_timeout = tonumber(ARGV[5])
local LAST_ARG = 5 -- last index of ARGV that is not a span_id

local set_span_id = parent_span_id
local redirect_depth = 0

local main_redirect_key = string.format("span-buf:sr:{%s}", project_and_trace)
local main_redirect_key = string.format("span-buf:sr:{%s}:%s", partition, project_and_trace)

for i = 0, 10000 do -- theoretically this limit means that segment trees of depth 10k may not be joined together correctly.
local new_set_span = redis.call("hget", main_redirect_key, set_span_id)
Expand All @@ -36,10 +38,10 @@ for i = 0, 10000 do -- theoretically this limit means that segment trees of dep
set_span_id = new_set_span
end

local set_key = string.format("span-buf:s:{%s}:%s", project_and_trace, set_span_id)
local parent_key = string.format("span-buf:s:{%s}:%s", project_and_trace, parent_span_id)
local set_key = string.format("span-buf:s:{%s}:%s:%s", partition, project_and_trace, set_span_id)
local parent_key = string.format("span-buf:s:{%s}:%s:%s", partition, project_and_trace, parent_span_id)

local has_root_span_key = string.format("span-buf:hrs:%s", set_key)
local has_root_span_key = string.format("span-buf:hrs:{%s}:%s", partition, set_key)
has_root_span = has_root_span or redis.call("get", has_root_span_key) == "1"
if has_root_span then
redis.call("setex", has_root_span_key, set_timeout, "1")
Expand All @@ -52,15 +54,15 @@ if set_span_id ~= parent_span_id and redis.call("scard", parent_key) > 0 then
table.insert(sunionstore_args, parent_key)
end

for i = 5, num_spans + 4 do
for i = LAST_ARG + 1, num_spans + LAST_ARG do
local span_id = ARGV[i]
local is_root_span = span_id == parent_span_id

table.insert(hset_args, span_id)
table.insert(hset_args, set_span_id)

if not is_root_span then
local span_key = string.format("span-buf:s:{%s}:%s", project_and_trace, span_id)
local span_key = string.format("span-buf:s:{%s}:%s:%s", partition, project_and_trace, span_id)
table.insert(sunionstore_args, span_key)
end
end
Expand Down
58 changes: 31 additions & 27 deletions src/sentry/spans/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@

# SegmentKey is an internal identifier used by the redis buffer that is also
# directly used as raw redis key. the format is
# "span-buf:s:{project_id:trace_id}:span_id", and the type is bytes because our
# "span-buf:s:{partition}:project_id:trace_id:span_id", and the type is bytes because our
# redis client is bytes.
#
# The segment ID in the Kafka protocol is only the span ID.
Expand All @@ -92,16 +92,17 @@


def _segment_key_to_span_id(segment_key: SegmentKey) -> bytes:
return parse_segment_key(segment_key)[2]
return parse_segment_key(segment_key)[3]


def parse_segment_key(segment_key: SegmentKey) -> tuple[bytes, bytes, bytes]:
def parse_segment_key(segment_key: SegmentKey) -> tuple[int, bytes, bytes, bytes]:
segment_key_parts = segment_key.split(b":")
project_id = segment_key_parts[2][1:]
trace_id = segment_key_parts[3][:-1]
span_id = segment_key_parts[4]
partition = int(segment_key_parts[2][1:-1])
project_id = segment_key_parts[3]
trace_id = segment_key_parts[4]
span_id = segment_key_parts[5]

return project_id, trace_id, span_id
return partition, project_id, trace_id, span_id


def get_redis_client() -> RedisCluster[bytes] | StrictRedis[bytes]:
Expand All @@ -113,6 +114,7 @@ def get_redis_client() -> RedisCluster[bytes] | StrictRedis[bytes]:

# NamedTuples are faster to construct than dataclasses
class Span(NamedTuple):
partition: int
trace_id: str
span_id: str
parent_span_id: str | None
Expand Down Expand Up @@ -153,8 +155,8 @@ def client(self) -> RedisCluster[bytes] | StrictRedis[bytes]:
def __reduce__(self):
return (SpansBuffer, (self.assigned_shards,))

def _get_span_key(self, project_and_trace: str, span_id: str) -> bytes:
return f"span-buf:s:{{{project_and_trace}}}:{span_id}".encode("ascii")
def _get_span_key(self, partition: int, project_and_trace: str, span_id: str) -> bytes:
return f"span-buf:s:{{{partition}}}:{project_and_trace}:{span_id}".encode("ascii")

def process_spans(self, spans: Sequence[Span], now: int):
"""
Expand All @@ -176,8 +178,8 @@ def process_spans(self, spans: Sequence[Span], now: int):
trees = self._group_by_parent(spans)

with self.client.pipeline(transaction=False) as p:
for (project_and_trace, parent_span_id), subsegment in trees.items():
set_key = self._get_span_key(project_and_trace, parent_span_id)
for (partition, project_and_trace, parent_span_id), subsegment in trees.items():
set_key = self._get_span_key(partition, project_and_trace, parent_span_id)
p.sadd(set_key, *[span.payload for span in subsegment])

p.execute()
Expand All @@ -189,11 +191,12 @@ def process_spans(self, spans: Sequence[Span], now: int):
add_buffer_sha = self._ensure_script()

with self.client.pipeline(transaction=False) as p:
for (project_and_trace, parent_span_id), subsegment in trees.items():
for (partition, project_and_trace, parent_span_id), subsegment in trees.items():
p.execute_command(
"EVALSHA",
add_buffer_sha,
1,
partition,
project_and_trace,
len(subsegment),
parent_span_id,
Expand All @@ -203,7 +206,7 @@ def process_spans(self, spans: Sequence[Span], now: int):
)

is_root_span_count += sum(span.is_segment_span for span in subsegment)
result_meta.append((project_and_trace, parent_span_id))
result_meta.append((partition, project_and_trace, parent_span_id))

results = p.execute()

Expand All @@ -213,14 +216,10 @@ def process_spans(self, spans: Sequence[Span], now: int):

assert len(result_meta) == len(results)

for (project_and_trace, parent_span_id), result in zip(result_meta, results):
for (partition, project_and_trace, parent_span_id), result in zip(result_meta, results):
redirect_depth, set_key, has_root_span = result

shard = self.assigned_shards[
int(project_and_trace.split(":")[1], 16) % len(self.assigned_shards)
]
queue_key = self._get_queue_key(shard)

queue_key = self._get_queue_key(partition)
min_redirect_depth = min(min_redirect_depth, redirect_depth)
max_redirect_depth = max(max_redirect_depth, redirect_depth)

Expand All @@ -235,10 +234,11 @@ def process_spans(self, spans: Sequence[Span], now: int):
zadd_items = queue_adds.setdefault(queue_key, {})
zadd_items[set_key] = now + offset

subsegment_spans = trees[project_and_trace, parent_span_id]
subsegment_spans = trees[partition, project_and_trace, parent_span_id]
delete_set = queue_deletes.setdefault(queue_key, set())
delete_set.update(
self._get_span_key(project_and_trace, span.span_id) for span in subsegment_spans
self._get_span_key(partition, project_and_trace, span.span_id)
for span in subsegment_spans
)
delete_set.discard(set_key)

Expand Down Expand Up @@ -271,7 +271,7 @@ def _ensure_script(self):
def _get_queue_key(self, shard: int) -> bytes:
return f"span-buf:q:{shard}".encode("ascii")

def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[str, str], list[Span]]:
def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[int, str, str], list[Span]]:
"""
Groups partial trees of spans by their top-most parent span ID in the
provided list. The result is a dictionary where the keys identify a
Expand All @@ -282,7 +282,7 @@ def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[str, str], list[
:return: Dictionary of grouped spans. The key is a tuple of
the `project_and_trace`, and the `parent_span_id`.
"""
trees: dict[tuple[str, str], list[Span]] = {}
trees: dict[tuple[int, str, str], list[Span]] = {}
redirects: dict[str, dict[str, str]] = {}

for span in spans:
Expand All @@ -293,9 +293,9 @@ def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[str, str], list[
while redirect := trace_redirects.get(parent):
parent = redirect

subsegment = trees.setdefault((project_and_trace, parent), [])
subsegment = trees.setdefault((span.partition, project_and_trace, parent), [])
if parent != span.span_id:
subsegment.extend(trees.pop((project_and_trace, span.span_id), []))
subsegment.extend(trees.pop((span.partition, project_and_trace, span.span_id), []))
trace_redirects[span.span_id] = parent
subsegment.append(span)

Expand Down Expand Up @@ -471,8 +471,12 @@ def done_flush_segments(self, segment_keys: dict[SegmentKey, FlushedSegment]):
p.delete(hrs_key)
p.unlink(segment_key)

project_id, trace_id, _ = parse_segment_key(segment_key)
redirect_map_key = b"span-buf:sr:{%s:%s}" % (project_id, trace_id)
partition, project_id, trace_id, _ = parse_segment_key(segment_key)
redirect_map_key = b"span-buf:sr:{%d}:%s:%s" % (
partition,
project_id,
trace_id,
)
p.zrem(flushed_segment.queue_key, segment_key)

for span_batch in itertools.batched(flushed_segment.spans, 100):
Expand Down
15 changes: 9 additions & 6 deletions src/sentry/spans/consumers/process/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def create_with_partitions(
committer = CommitOffsets(commit)

buffer = SpansBuffer(assigned_shards=[p.index for p in partitions])
first_partition = next((p.index for p in partitions), 0)

# patch onto self just for testing
flusher: ProcessingStrategy[FilteredPayload | int]
Expand All @@ -75,7 +76,7 @@ def create_with_partitions(

if self.num_processes != 1:
run_task = run_task_with_multiprocessing(
function=partial(process_batch, buffer),
function=partial(process_batch, buffer, first_partition),
next_step=flusher,
max_batch_size=self.max_batch_size,
max_batch_time=self.max_batch_time,
Expand All @@ -85,7 +86,7 @@ def create_with_partitions(
)
else:
run_task = RunTask(
function=partial(process_batch, buffer),
function=partial(process_batch, buffer, first_partition),
next_step=flusher,
)

Expand Down Expand Up @@ -119,7 +120,9 @@ def shutdown(self) -> None:


def process_batch(
buffer: SpansBuffer, values: Message[ValuesBatch[tuple[int, KafkaPayload]]]
buffer: SpansBuffer,
first_partition: int,
values: Message[ValuesBatch[tuple[int, KafkaPayload]]],
) -> int:
min_timestamp = None
spans = []
Expand All @@ -130,10 +133,9 @@ def process_batch(

val = rapidjson.loads(payload.value)

partition_id = None

partition_id: int = first_partition
if len(value.committable) == 1:
partition_id = value.committable[next(iter(value.committable))]
partition_id = next(iter(value.committable)).index

if killswitches.killswitch_matches_context(
"spans.drop-in-buffer",
Expand All @@ -147,6 +149,7 @@ def process_batch(
continue

span = Span(
partition=partition_id,
trace_id=val["trace_id"],
span_id=val["span_id"],
parent_span_id=val.get("parent_span_id"),
Expand Down
4 changes: 4 additions & 0 deletions tests/sentry/spans/consumers/process/test_flusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,31 @@ def append(msg):

spans = [
Span(
partition=0,
payload=_payload(b"a" * 16),
trace_id=trace_id,
span_id="a" * 16,
parent_span_id="b" * 16,
project_id=1,
),
Span(
partition=0,
payload=_payload(b"d" * 16),
trace_id=trace_id,
span_id="d" * 16,
parent_span_id="b" * 16,
project_id=1,
),
Span(
partition=0,
payload=_payload(b"c" * 16),
trace_id=trace_id,
span_id="c" * 16,
parent_span_id="b" * 16,
project_id=1,
),
Span(
partition=0,
payload=_payload(b"b" * 16),
trace_id=trace_id,
span_id="b" * 16,
Expand Down
Loading
Loading