Skip to content

Commit 7cee132

Browse files
authored
ref(spans): Pin redis sharding to partition index (#93209)
We observe performance degradation of the `process-spans` consumer with an increasing number of Redis shards. By using the partition index instead of the trace ID as sharding key, we pin each partition to exactly one Redis shard, which should reduce this effect. The downside is that a single hot partition can no longer be spread across multiple Redis shards. At the current time, we operate 4-8 partitions per shard, so we do not expect that this becomes an issue in the short time.
1 parent 3096a2b commit 7cee132

File tree

5 files changed

+80
-45
lines changed

5 files changed

+80
-45
lines changed

src/sentry/scripts/spans/add-buffer.lua

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@ ARGS:
1414
1515
]]--
1616

17-
local project_and_trace = KEYS[1]
17+
local partition = KEYS[1]
1818

19-
local num_spans = ARGV[1]
20-
local parent_span_id = ARGV[2]
21-
local has_root_span = ARGV[3] == "true"
22-
local set_timeout = tonumber(ARGV[4])
19+
local project_and_trace = ARGV[1]
20+
local num_spans = ARGV[2]
21+
local parent_span_id = ARGV[3]
22+
local has_root_span = ARGV[4] == "true"
23+
local set_timeout = tonumber(ARGV[5])
24+
local LAST_ARG = 5 -- last index of ARGV that is not a span_id
2325

2426
local set_span_id = parent_span_id
2527
local redirect_depth = 0
2628

27-
local main_redirect_key = string.format("span-buf:sr:{%s}", project_and_trace)
29+
local main_redirect_key = string.format("span-buf:sr:{%s}:%s", partition, project_and_trace)
2830

2931
for i = 0, 10000 do -- theoretically this limit means that segment trees of depth 10k may not be joined together correctly.
3032
local new_set_span = redis.call("hget", main_redirect_key, set_span_id)
@@ -36,10 +38,10 @@ for i = 0, 10000 do -- theoretically this limit means that segment trees of dep
3638
set_span_id = new_set_span
3739
end
3840

39-
local set_key = string.format("span-buf:s:{%s}:%s", project_and_trace, set_span_id)
40-
local parent_key = string.format("span-buf:s:{%s}:%s", project_and_trace, parent_span_id)
41+
local set_key = string.format("span-buf:s:{%s}:%s:%s", partition, project_and_trace, set_span_id)
42+
local parent_key = string.format("span-buf:s:{%s}:%s:%s", partition, project_and_trace, parent_span_id)
4143

42-
local has_root_span_key = string.format("span-buf:hrs:%s", set_key)
44+
local has_root_span_key = string.format("span-buf:hrs:{%s}:%s", partition, set_key)
4345
has_root_span = has_root_span or redis.call("get", has_root_span_key) == "1"
4446
if has_root_span then
4547
redis.call("setex", has_root_span_key, set_timeout, "1")
@@ -52,15 +54,15 @@ if set_span_id ~= parent_span_id and redis.call("scard", parent_key) > 0 then
5254
table.insert(sunionstore_args, parent_key)
5355
end
5456

55-
for i = 5, num_spans + 4 do
57+
for i = LAST_ARG + 1, num_spans + LAST_ARG do
5658
local span_id = ARGV[i]
5759
local is_root_span = span_id == parent_span_id
5860

5961
table.insert(hset_args, span_id)
6062
table.insert(hset_args, set_span_id)
6163

6264
if not is_root_span then
63-
local span_key = string.format("span-buf:s:{%s}:%s", project_and_trace, span_id)
65+
local span_key = string.format("span-buf:s:{%s}:%s:%s", partition, project_and_trace, span_id)
6466
table.insert(sunionstore_args, span_key)
6567
end
6668
end

src/sentry/spans/buffer.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080

8181
# SegmentKey is an internal identifier used by the redis buffer that is also
8282
# directly used as raw redis key. the format is
83-
# "span-buf:s:{project_id:trace_id}:span_id", and the type is bytes because our
83+
# "span-buf:s:{partition}:project_id:trace_id:span_id", and the type is bytes because our
8484
# redis client is bytes.
8585
#
8686
# The segment ID in the Kafka protocol is only the span ID.
@@ -92,16 +92,17 @@
9292

9393

9494
def _segment_key_to_span_id(segment_key: SegmentKey) -> bytes:
95-
return parse_segment_key(segment_key)[2]
95+
return parse_segment_key(segment_key)[3]
9696

9797

98-
def parse_segment_key(segment_key: SegmentKey) -> tuple[bytes, bytes, bytes]:
98+
def parse_segment_key(segment_key: SegmentKey) -> tuple[int, bytes, bytes, bytes]:
9999
segment_key_parts = segment_key.split(b":")
100-
project_id = segment_key_parts[2][1:]
101-
trace_id = segment_key_parts[3][:-1]
102-
span_id = segment_key_parts[4]
100+
partition = int(segment_key_parts[2][1:-1])
101+
project_id = segment_key_parts[3]
102+
trace_id = segment_key_parts[4]
103+
span_id = segment_key_parts[5]
103104

104-
return project_id, trace_id, span_id
105+
return partition, project_id, trace_id, span_id
105106

106107

107108
def get_redis_client() -> RedisCluster[bytes] | StrictRedis[bytes]:
@@ -113,6 +114,7 @@ def get_redis_client() -> RedisCluster[bytes] | StrictRedis[bytes]:
113114

114115
# NamedTuples are faster to construct than dataclasses
115116
class Span(NamedTuple):
117+
partition: int
116118
trace_id: str
117119
span_id: str
118120
parent_span_id: str | None
@@ -153,8 +155,8 @@ def client(self) -> RedisCluster[bytes] | StrictRedis[bytes]:
153155
def __reduce__(self):
154156
return (SpansBuffer, (self.assigned_shards,))
155157

156-
def _get_span_key(self, project_and_trace: str, span_id: str) -> bytes:
157-
return f"span-buf:s:{{{project_and_trace}}}:{span_id}".encode("ascii")
158+
def _get_span_key(self, partition: int, project_and_trace: str, span_id: str) -> bytes:
159+
return f"span-buf:s:{{{partition}}}:{project_and_trace}:{span_id}".encode("ascii")
158160

159161
def process_spans(self, spans: Sequence[Span], now: int):
160162
"""
@@ -176,8 +178,8 @@ def process_spans(self, spans: Sequence[Span], now: int):
176178
trees = self._group_by_parent(spans)
177179

178180
with self.client.pipeline(transaction=False) as p:
179-
for (project_and_trace, parent_span_id), subsegment in trees.items():
180-
set_key = self._get_span_key(project_and_trace, parent_span_id)
181+
for (partition, project_and_trace, parent_span_id), subsegment in trees.items():
182+
set_key = self._get_span_key(partition, project_and_trace, parent_span_id)
181183
p.sadd(set_key, *[span.payload for span in subsegment])
182184

183185
p.execute()
@@ -189,11 +191,12 @@ def process_spans(self, spans: Sequence[Span], now: int):
189191
add_buffer_sha = self._ensure_script()
190192

191193
with self.client.pipeline(transaction=False) as p:
192-
for (project_and_trace, parent_span_id), subsegment in trees.items():
194+
for (partition, project_and_trace, parent_span_id), subsegment in trees.items():
193195
p.execute_command(
194196
"EVALSHA",
195197
add_buffer_sha,
196198
1,
199+
partition,
197200
project_and_trace,
198201
len(subsegment),
199202
parent_span_id,
@@ -203,7 +206,7 @@ def process_spans(self, spans: Sequence[Span], now: int):
203206
)
204207

205208
is_root_span_count += sum(span.is_segment_span for span in subsegment)
206-
result_meta.append((project_and_trace, parent_span_id))
209+
result_meta.append((partition, project_and_trace, parent_span_id))
207210

208211
results = p.execute()
209212

@@ -213,14 +216,10 @@ def process_spans(self, spans: Sequence[Span], now: int):
213216

214217
assert len(result_meta) == len(results)
215218

216-
for (project_and_trace, parent_span_id), result in zip(result_meta, results):
219+
for (partition, project_and_trace, parent_span_id), result in zip(result_meta, results):
217220
redirect_depth, set_key, has_root_span = result
218221

219-
shard = self.assigned_shards[
220-
int(project_and_trace.split(":")[1], 16) % len(self.assigned_shards)
221-
]
222-
queue_key = self._get_queue_key(shard)
223-
222+
queue_key = self._get_queue_key(partition)
224223
min_redirect_depth = min(min_redirect_depth, redirect_depth)
225224
max_redirect_depth = max(max_redirect_depth, redirect_depth)
226225

@@ -235,10 +234,11 @@ def process_spans(self, spans: Sequence[Span], now: int):
235234
zadd_items = queue_adds.setdefault(queue_key, {})
236235
zadd_items[set_key] = now + offset
237236

238-
subsegment_spans = trees[project_and_trace, parent_span_id]
237+
subsegment_spans = trees[partition, project_and_trace, parent_span_id]
239238
delete_set = queue_deletes.setdefault(queue_key, set())
240239
delete_set.update(
241-
self._get_span_key(project_and_trace, span.span_id) for span in subsegment_spans
240+
self._get_span_key(partition, project_and_trace, span.span_id)
241+
for span in subsegment_spans
242242
)
243243
delete_set.discard(set_key)
244244

@@ -271,7 +271,7 @@ def _ensure_script(self):
271271
def _get_queue_key(self, shard: int) -> bytes:
272272
return f"span-buf:q:{shard}".encode("ascii")
273273

274-
def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[str, str], list[Span]]:
274+
def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[int, str, str], list[Span]]:
275275
"""
276276
Groups partial trees of spans by their top-most parent span ID in the
277277
provided list. The result is a dictionary where the keys identify a
@@ -282,7 +282,7 @@ def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[str, str], list[
282282
:return: Dictionary of grouped spans. The key is a tuple of
283283
the `project_and_trace`, and the `parent_span_id`.
284284
"""
285-
trees: dict[tuple[str, str], list[Span]] = {}
285+
trees: dict[tuple[int, str, str], list[Span]] = {}
286286
redirects: dict[str, dict[str, str]] = {}
287287

288288
for span in spans:
@@ -293,9 +293,9 @@ def _group_by_parent(self, spans: Sequence[Span]) -> dict[tuple[str, str], list[
293293
while redirect := trace_redirects.get(parent):
294294
parent = redirect
295295

296-
subsegment = trees.setdefault((project_and_trace, parent), [])
296+
subsegment = trees.setdefault((span.partition, project_and_trace, parent), [])
297297
if parent != span.span_id:
298-
subsegment.extend(trees.pop((project_and_trace, span.span_id), []))
298+
subsegment.extend(trees.pop((span.partition, project_and_trace, span.span_id), []))
299299
trace_redirects[span.span_id] = parent
300300
subsegment.append(span)
301301

@@ -471,8 +471,12 @@ def done_flush_segments(self, segment_keys: dict[SegmentKey, FlushedSegment]):
471471
p.delete(hrs_key)
472472
p.unlink(segment_key)
473473

474-
project_id, trace_id, _ = parse_segment_key(segment_key)
475-
redirect_map_key = b"span-buf:sr:{%s:%s}" % (project_id, trace_id)
474+
partition, project_id, trace_id, _ = parse_segment_key(segment_key)
475+
redirect_map_key = b"span-buf:sr:{%d}:%s:%s" % (
476+
partition,
477+
project_id,
478+
trace_id,
479+
)
476480
p.zrem(flushed_segment.queue_key, segment_key)
477481

478482
for span_batch in itertools.batched(flushed_segment.spans, 100):

src/sentry/spans/consumers/process/factory.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def create_with_partitions(
6464
committer = CommitOffsets(commit)
6565

6666
buffer = SpansBuffer(assigned_shards=[p.index for p in partitions])
67+
first_partition = next((p.index for p in partitions), 0)
6768

6869
# patch onto self just for testing
6970
flusher: ProcessingStrategy[FilteredPayload | int]
@@ -75,7 +76,7 @@ def create_with_partitions(
7576

7677
if self.num_processes != 1:
7778
run_task = run_task_with_multiprocessing(
78-
function=partial(process_batch, buffer),
79+
function=partial(process_batch, buffer, first_partition),
7980
next_step=flusher,
8081
max_batch_size=self.max_batch_size,
8182
max_batch_time=self.max_batch_time,
@@ -85,7 +86,7 @@ def create_with_partitions(
8586
)
8687
else:
8788
run_task = RunTask(
88-
function=partial(process_batch, buffer),
89+
function=partial(process_batch, buffer, first_partition),
8990
next_step=flusher,
9091
)
9192

@@ -119,7 +120,9 @@ def shutdown(self) -> None:
119120

120121

121122
def process_batch(
122-
buffer: SpansBuffer, values: Message[ValuesBatch[tuple[int, KafkaPayload]]]
123+
buffer: SpansBuffer,
124+
first_partition: int,
125+
values: Message[ValuesBatch[tuple[int, KafkaPayload]]],
123126
) -> int:
124127
min_timestamp = None
125128
spans = []
@@ -130,10 +133,9 @@ def process_batch(
130133

131134
val = rapidjson.loads(payload.value)
132135

133-
partition_id = None
134-
136+
partition_id: int = first_partition
135137
if len(value.committable) == 1:
136-
partition_id = value.committable[next(iter(value.committable))]
138+
partition_id = next(iter(value.committable)).index
137139

138140
if killswitches.killswitch_matches_context(
139141
"spans.drop-in-buffer",
@@ -147,6 +149,7 @@ def process_batch(
147149
continue
148150

149151
span = Span(
152+
partition=partition_id,
150153
trace_id=val["trace_id"],
151154
span_id=val["span_id"],
152155
parent_span_id=val.get("parent_span_id"),

tests/sentry/spans/consumers/process/test_flusher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,31 @@ def append(msg):
3939

4040
spans = [
4141
Span(
42+
partition=0,
4243
payload=_payload(b"a" * 16),
4344
trace_id=trace_id,
4445
span_id="a" * 16,
4546
parent_span_id="b" * 16,
4647
project_id=1,
4748
),
4849
Span(
50+
partition=0,
4951
payload=_payload(b"d" * 16),
5052
trace_id=trace_id,
5153
span_id="d" * 16,
5254
parent_span_id="b" * 16,
5355
project_id=1,
5456
),
5557
Span(
58+
partition=0,
5659
payload=_payload(b"c" * 16),
5760
trace_id=trace_id,
5861
span_id="c" * 16,
5962
parent_span_id="b" * 16,
6063
project_id=1,
6164
),
6265
Span(
66+
partition=0,
6367
payload=_payload(b"b" * 16),
6468
trace_id=trace_id,
6569
span_id="b" * 16,

0 commit comments

Comments
 (0)