|
6 | 6 | import logging
|
7 | 7 | import threading
|
8 | 8 | import time
|
9 |
| -from collections import OrderedDict |
| 9 | +from collections import deque, OrderedDict |
10 | 10 | from copy import deepcopy
|
11 | 11 | from datetime import timedelta
|
12 | 12 | from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional
|
@@ -99,9 +99,13 @@ def __init__(
|
99 | 99 | self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict()
|
100 | 100 |
|
101 | 101 | # Parent-state tracking: store each partition’s parent state in creation order
|
102 |
| - self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict() |
| 102 | + self._partition_parent_state_map: OrderedDict[str, tuple[Mapping[str, Any], int]] = OrderedDict() |
103 | 103 |
|
104 | 104 | self._finished_partitions: set[str] = set()
|
| 105 | + self._open_seqs: deque[int] = deque() |
| 106 | + self._next_seq: int = 0 |
| 107 | + self._seq_by_partition: dict[str, int] = {} |
| 108 | + |
105 | 109 | self._lock = threading.Lock()
|
106 | 110 | self._timer = Timer()
|
107 | 111 | self._new_global_cursor: Optional[StreamState] = None
|
@@ -162,55 +166,28 @@ def close_partition(self, partition: Partition) -> None:
|
162 | 166 | ):
|
163 | 167 | self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])
|
164 | 168 |
|
| 169 | + # Clean up the partition if it is fully processed |
| 170 | + self._cleanup_if_done(partition_key) |
| 171 | + |
165 | 172 | self._check_and_update_parent_state()
|
166 | 173 |
|
167 | 174 | self._emit_state_message()
|
168 | 175 |
|
169 | 176 | def _check_and_update_parent_state(self) -> None:
|
170 |
| - """ |
171 |
| - Pop the leftmost partition state from _partition_parent_state_map only if |
172 |
| - *all partitions* up to (and including) that partition key in _semaphore_per_partition |
173 |
| - are fully finished (i.e. in _finished_partitions and semaphore._value == 0). |
174 |
| - Additionally, delete finished semaphores with a value of 0 to free up memory, |
175 |
| - as they are only needed to track errors and completion status. |
176 |
| - """ |
177 | 177 | last_closed_state = None
|
178 | 178 |
|
179 | 179 | while self._partition_parent_state_map:
|
180 |
| - # Look at the earliest partition key in creation order |
181 |
| - earliest_key = next(iter(self._partition_parent_state_map)) |
182 |
| - |
183 |
| - # Verify ALL partitions from the left up to earliest_key are finished |
184 |
| - all_left_finished = True |
185 |
| - for p_key, sem in list( |
186 |
| - self._semaphore_per_partition.items() |
187 |
| - ): # Use list to allow modification during iteration |
188 |
| - # If any earlier partition is still not finished, we must stop |
189 |
| - if p_key not in self._finished_partitions or sem._value != 0: |
190 |
| - all_left_finished = False |
191 |
| - break |
192 |
| - # Once we've reached earliest_key in the semaphore order, we can stop checking |
193 |
| - if p_key == earliest_key: |
194 |
| - break |
195 |
| - |
196 |
| - # If the partitions up to earliest_key are not all finished, break the while-loop |
197 |
| - if not all_left_finished: |
198 |
| - break |
| 180 | + earliest_key, (candidate_state, candidate_seq) = \ |
| 181 | + next(iter(self._partition_parent_state_map.items())) |
199 | 182 |
|
200 |
| - # Pop the leftmost entry from parent-state map |
201 |
| - _, closed_parent_state = self._partition_parent_state_map.popitem(last=False) |
202 |
| - last_closed_state = closed_parent_state |
| 183 | + # if any partition that started <= candidate_seq is still open, we must wait |
| 184 | + if self._open_seqs and self._open_seqs[0] <= candidate_seq: |
| 185 | + break |
203 | 186 |
|
204 |
| - # Clean up finished semaphores with value 0 up to and including earliest_key |
205 |
| - for p_key in list(self._semaphore_per_partition.keys()): |
206 |
| - sem = self._semaphore_per_partition[p_key] |
207 |
| - if p_key in self._finished_partitions and sem._value == 0: |
208 |
| - del self._semaphore_per_partition[p_key] |
209 |
| - logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0") |
210 |
| - if p_key == earliest_key: |
211 |
| - break |
| 187 | + # safe to pop |
| 188 | + self._partition_parent_state_map.popitem(last=False) |
| 189 | + last_closed_state = candidate_state |
212 | 190 |
|
213 |
| - # Update _parent_state if we popped at least one partition |
214 | 191 | if last_closed_state is not None:
|
215 | 192 | self._parent_state = last_closed_state
|
216 | 193 |
|
@@ -293,14 +270,19 @@ def _generate_slices_from_partition(
|
293 | 270 | self._semaphore_per_partition[partition_key] = threading.Semaphore(0)
|
294 | 271 |
|
295 | 272 | with self._lock:
|
| 273 | + seq = self._next_seq |
| 274 | + self._next_seq += 1 |
| 275 | + self._open_seqs.append(seq) |
| 276 | + self._seq_by_partition[partition_key] = seq |
| 277 | + |
296 | 278 | if (
|
297 | 279 | len(self._partition_parent_state_map) == 0
|
298 | 280 | or self._partition_parent_state_map[
|
299 | 281 | next(reversed(self._partition_parent_state_map))
|
300 |
| - ] |
| 282 | + ][0] |
301 | 283 | != parent_state
|
302 | 284 | ):
|
303 |
| - self._partition_parent_state_map[partition_key] = deepcopy(parent_state) |
| 285 | + self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq) |
304 | 286 |
|
305 | 287 | for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
|
306 | 288 | cursor.stream_slices(),
|
@@ -338,10 +320,7 @@ def _ensure_partition_limit(self) -> None:
|
338 | 320 | while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
|
339 | 321 | # Try removing finished partitions first
|
340 | 322 | for partition_key in list(self._cursor_per_partition.keys()):
|
341 |
| - if partition_key in self._finished_partitions and ( |
342 |
| - partition_key not in self._semaphore_per_partition |
343 |
| - or self._semaphore_per_partition[partition_key]._value == 0 |
344 |
| - ): |
| 323 | + if partition_key not in self._seq_by_partition: |
345 | 324 | oldest_partition = self._cursor_per_partition.pop(
|
346 | 325 | partition_key
|
347 | 326 | ) # Remove the oldest partition
|
@@ -474,6 +453,25 @@ def _update_global_cursor(self, value: Any) -> None:
|
474 | 453 | ):
|
475 | 454 | self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)}
|
476 | 455 |
|
| 456 | + def _cleanup_if_done(self, partition_key: str) -> None: |
| 457 | + """ |
| 458 | + Free every in-memory structure that belonged to a completed partition: |
| 459 | + cursor, semaphore, flag inside `_finished_partitions` |
| 460 | + """ |
| 461 | + if not ( |
| 462 | + partition_key in self._finished_partitions |
| 463 | + and self._semaphore_per_partition[partition_key]._value == 0 |
| 464 | + ): |
| 465 | + return |
| 466 | + |
| 467 | + self._semaphore_per_partition.pop(partition_key, None) |
| 468 | + self._finished_partitions.discard(partition_key) |
| 469 | + |
| 470 | + seq = self._seq_by_partition.pop(partition_key) |
| 471 | + self._open_seqs.remove(seq) |
| 472 | + |
| 473 | + logger.debug(f"Partition {partition_key} fully processed and cleaned up.") |
| 474 | + |
477 | 475 | def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
|
478 | 476 | return self._partition_serializer.to_partition_key(partition)
|
479 | 477 |
|
|
0 commit comments