Skip to content

Commit 262370e

Browse files
committed
Add skipping duplicate partitions
1 parent 276bb8c commit 262370e

File tree

2 files changed

+165
-37
lines changed

2 files changed

+165
-37
lines changed

airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import logging
77
import threading
88
import time
9-
from collections import OrderedDict, deque
9+
from collections import OrderedDict
1010
from copy import deepcopy
1111
from datetime import timedelta
12-
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional
12+
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional
1313

1414
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
1515
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
@@ -66,8 +66,8 @@ class ConcurrentPerPartitionCursor(Cursor):
6666
_GLOBAL_STATE_KEY = "state"
6767
_PERPARTITION_STATE_KEY = "states"
6868
_IS_PARTITION_DUPLICATION_LOGGED = False
69-
_KEY = 0
70-
_VALUE = 1
69+
_PARENT_STATE = 0
70+
_GENERATION_SEQUENCE = 1
7171

7272
def __init__(
7373
self,
@@ -105,12 +105,12 @@ def __init__(
105105
self._parent_state: Optional[StreamState] = None
106106

107107
# Tracks when the last slice for partition is emitted
108-
self._finished_partitions: set[str] = set()
109-
# Used to track the sequence numbers of open partitions
110-
self._open_seqs: deque[int] = deque()
111-
self._next_seq: int = 0
112-
# Dictionary to map partition keys to their sequence numbers
113-
self._seq_by_partition: dict[str, int] = {}
108+
self._partitions_done_generating_stream_slices: set[str] = set()
109+
# Used to track the index of partitions that are not closed yet
110+
self._processing_partitions_indexes: List[int] = list()
111+
self._generated_partitions_count: int = 0
112+
# Dictionary to map partition keys to their index
113+
self._partition_key_to_index: dict[str, int] = {}
114114

115115
self._lock = threading.Lock()
116116
self._lookback_window: int = 0
@@ -167,7 +167,7 @@ def close_partition(self, partition: Partition) -> None:
167167
self._cursor_per_partition[partition_key].close_partition(partition=partition)
168168
cursor = self._cursor_per_partition[partition_key]
169169
if (
170-
partition_key in self._finished_partitions
170+
partition_key in self._partitions_done_generating_stream_slices
171171
and self._semaphore_per_partition[partition_key]._value == 0
172172
):
173173
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])
@@ -188,7 +188,10 @@ def _check_and_update_parent_state(self) -> None:
188188
)
189189

190190
# if any partition that started <= candidate_seq is still open, we must wait
191-
if self._open_seqs and self._open_seqs[0] <= candidate_seq:
191+
if (
192+
self._processing_partitions_indexes
193+
and self._processing_partitions_indexes[0] <= candidate_seq
194+
):
192195
break
193196

194197
# safe to pop
@@ -273,20 +276,21 @@ def _generate_slices_from_partition(
273276
if not self._IS_PARTITION_DUPLICATION_LOGGED:
274277
logger.warning(f"Partition duplication detected for stream {self._stream_name}")
275278
self._IS_PARTITION_DUPLICATION_LOGGED = True
279+
return
276280
else:
277281
self._semaphore_per_partition[partition_key] = threading.Semaphore(0)
278282

279283
with self._lock:
280-
seq = self._next_seq
281-
self._next_seq += 1
282-
self._open_seqs.append(seq)
283-
self._seq_by_partition[partition_key] = seq
284+
seq = self._generated_partitions_count
285+
self._generated_partitions_count += 1
286+
self._processing_partitions_indexes.append(seq)
287+
self._partition_key_to_index[partition_key] = seq
284288

285289
if (
286290
len(self._partition_parent_state_map) == 0
287291
or self._partition_parent_state_map[
288292
next(reversed(self._partition_parent_state_map))
289-
][0]
293+
][self._PARENT_STATE]
290294
!= parent_state
291295
):
292296
self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq)
@@ -297,7 +301,7 @@ def _generate_slices_from_partition(
297301
):
298302
self._semaphore_per_partition[partition_key].release()
299303
if is_last_slice:
300-
self._finished_partitions.add(partition_key)
304+
self._partitions_done_generating_stream_slices.add(partition_key)
301305
yield StreamSlice(
302306
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
303307
)
@@ -327,7 +331,7 @@ def _ensure_partition_limit(self) -> None:
327331
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
328332
# Try removing finished partitions first
329333
for partition_key in list(self._cursor_per_partition.keys()):
330-
if partition_key not in self._seq_by_partition:
334+
if partition_key not in self._partition_key_to_index:
331335
oldest_partition = self._cursor_per_partition.pop(
332336
partition_key
333337
) # Remove the oldest partition
@@ -466,16 +470,16 @@ def _cleanup_if_done(self, partition_key: str) -> None:
466470
cursor, semaphore, flag inside `_finished_partitions`
467471
"""
468472
if not (
469-
partition_key in self._finished_partitions
473+
partition_key in self._partitions_done_generating_stream_slices
470474
and self._semaphore_per_partition[partition_key]._value == 0
471475
):
472476
return
473477

474478
self._semaphore_per_partition.pop(partition_key, None)
475-
self._finished_partitions.discard(partition_key)
479+
self._partitions_done_generating_stream_slices.discard(partition_key)
476480

477-
seq = self._seq_by_partition.pop(partition_key)
478-
self._open_seqs.remove(seq)
481+
seq = self._partition_key_to_index.pop(partition_key)
482+
self._processing_partitions_indexes.remove(seq)
479483

480484
logger.debug(f"Partition {partition_key} fully processed and cleaned up.")
481485

unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py

Lines changed: 138 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3438,9 +3438,9 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update():
34383438
assert mock_cursor_2.stream_slices.call_count == 1 # Called once for each partition
34393439

34403440
assert len(cursor._semaphore_per_partition) == 1
3441-
assert len(cursor._finished_partitions) == 1
3442-
assert len(cursor._open_seqs) == 1
3443-
assert len(cursor._seq_by_partition) == 1
3441+
assert len(cursor._partitions_done_generating_stream_slices) == 1
3442+
assert len(cursor._processing_partitions_indexes) == 1
3443+
assert len(cursor._partition_key_to_index) == 1
34443444

34453445

34463446
def test_given_unfinished_last_parent_partition_with_partial_parent_state_update():
@@ -3526,9 +3526,9 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update
35263526
assert mock_cursor_2.stream_slices.call_count == 1 # Called once for each partition
35273527

35283528
assert len(cursor._semaphore_per_partition) == 1
3529-
assert len(cursor._finished_partitions) == 1
3530-
assert len(cursor._open_seqs) == 1
3531-
assert len(cursor._seq_by_partition) == 1
3529+
assert len(cursor._partitions_done_generating_stream_slices) == 1
3530+
assert len(cursor._processing_partitions_indexes) == 1
3531+
assert len(cursor._partition_key_to_index) == 1
35323532

35333533

35343534
def test_given_all_partitions_finished_when_close_partition_then_final_state_emitted():
@@ -3606,9 +3606,9 @@ def test_given_all_partitions_finished_when_close_partition_then_final_state_emi
36063606

36073607
# Checks that all internal variables are cleaned up
36083608
assert len(cursor._semaphore_per_partition) == 0
3609-
assert len(cursor._finished_partitions) == 0
3610-
assert len(cursor._open_seqs) == 0
3611-
assert len(cursor._seq_by_partition) == 0
3609+
assert len(cursor._partitions_done_generating_stream_slices) == 0
3610+
assert len(cursor._processing_partitions_indexes) == 0
3611+
assert len(cursor._partition_key_to_index) == 0
36123612

36133613

36143614
def test_given_partition_limit_exceeded_when_close_partition_then_switch_to_global_cursor():
@@ -3727,8 +3727,8 @@ def test_semaphore_cleanup():
37273727
# Verify initial state
37283728
assert len(cursor._semaphore_per_partition) == 2
37293729
assert len(cursor._partition_parent_state_map) == 2
3730-
assert len(cursor._open_seqs) == 2
3731-
assert len(cursor._seq_by_partition) == 2
3730+
assert len(cursor._processing_partitions_indexes) == 2
3731+
assert len(cursor._partition_key_to_index) == 2
37323732
assert cursor._partition_parent_state_map['{"id":"1"}'][0] == {"parent": {"state": "state1"}}
37333733
assert cursor._partition_parent_state_map['{"id":"2"}'][0] == {"parent": {"state": "state2"}}
37343734

@@ -3737,10 +3737,10 @@ def test_semaphore_cleanup():
37373737
cursor.close_partition(DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), s))
37383738

37393739
# Check state after closing partitions
3740-
assert len(cursor._finished_partitions) == 0
3740+
assert len(cursor._partitions_done_generating_stream_slices) == 0
37413741
assert len(cursor._semaphore_per_partition) == 0
3742-
assert len(cursor._open_seqs) == 0
3743-
assert len(cursor._seq_by_partition) == 0
3742+
assert len(cursor._processing_partitions_indexes) == 0
3743+
assert len(cursor._partition_key_to_index) == 0
37443744
assert len(cursor._partition_parent_state_map) == 0 # All parent states should be popped
37453745
assert cursor._parent_state == {"parent": {"state": "state2"}} # Last parent state
37463746

@@ -3788,3 +3788,127 @@ def test_given_global_state_when_read_then_state_is_not_per_partition() -> None:
37883788
"use_global_cursor": True, # ensures that it is running the Concurrent CDK version as this is not populated in the declarative implementation
37893789
}, # this state does have per partition which would be under `states`
37903790
)
3791+
3792+
3793+
def _make_inner_cursor(ts: str) -> MagicMock:
3794+
"""Return an inner cursor that yields exactly one slice and has a proper state."""
3795+
inner = MagicMock()
3796+
inner.stream_slices.return_value = iter([{"dummy": "slice"}])
3797+
inner.state = {"updated_at": ts}
3798+
inner.close_partition.return_value = None
3799+
inner.observe.return_value = None
3800+
return inner
3801+
3802+
3803+
def test_duplicate_partition_after_cleanup():
3804+
inner_cursors = [
3805+
_make_inner_cursor("2024-01-01T00:00:00Z"), # for first "1"
3806+
_make_inner_cursor("2024-01-02T00:00:00Z"), # for "2"
3807+
_make_inner_cursor("2024-01-03T00:00:00Z"), # for second "1"
3808+
]
3809+
cursor_factory_mock = MagicMock()
3810+
cursor_factory_mock.create.side_effect = inner_cursors
3811+
3812+
converter = CustomFormatConcurrentStreamStateConverter(
3813+
datetime_format="%Y-%m-%dT%H:%M:%SZ",
3814+
input_datetime_formats=["%Y-%m-%dT%H:%M:%SZ"],
3815+
is_sequential_state=True,
3816+
cursor_granularity=timedelta(0),
3817+
)
3818+
3819+
cursor = ConcurrentPerPartitionCursor(
3820+
cursor_factory=cursor_factory_mock,
3821+
partition_router=MagicMock(),
3822+
stream_name="dup_stream",
3823+
stream_namespace=None,
3824+
stream_state={},
3825+
message_repository=MagicMock(),
3826+
connector_state_manager=MagicMock(),
3827+
connector_state_converter=converter,
3828+
cursor_field=CursorField(cursor_field_key="updated_at"),
3829+
)
3830+
3831+
cursor.DEFAULT_MAX_PARTITIONS_NUMBER = 1
3832+
3833+
# ── Partition sequence: 1 → 2 → 1 ──────────────────────────────────
3834+
partitions = [
3835+
StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}),
3836+
StreamSlice(partition={"id": "2"}, cursor_slice={}, extra_fields={}),
3837+
StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}),
3838+
]
3839+
pr = cursor._partition_router
3840+
pr.stream_slices.return_value = iter(partitions)
3841+
pr.get_stream_state.return_value = {}
3842+
3843+
# Iterate lazily so that the first "1" gets cleaned before
3844+
# the second "1" arrives.
3845+
slice_gen = cursor.stream_slices()
3846+
3847+
first_1 = next(slice_gen)
3848+
cursor.close_partition(
3849+
DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), first_1)
3850+
)
3851+
3852+
two = next(slice_gen)
3853+
cursor.close_partition(DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), two))
3854+
3855+
second_1 = next(slice_gen)
3856+
cursor.close_partition(
3857+
DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), second_1)
3858+
)
3859+
3860+
assert cursor._IS_PARTITION_DUPLICATION_LOGGED is False # No duplicate detected
3861+
assert len(cursor._semaphore_per_partition) == 0
3862+
assert len(cursor._processing_partitions_indexes) == 0
3863+
assert len(cursor._partition_key_to_index) == 0
3864+
3865+
3866+
def test_duplicate_partition_while_processing():
3867+
inner_cursors = [
3868+
_make_inner_cursor("2024-01-01T00:00:00Z"), # first “1”
3869+
_make_inner_cursor("2024-01-02T00:00:00Z"), # “2”
3870+
_make_inner_cursor("2024-01-03T00:00:00Z"), # for second "1"
3871+
]
3872+
3873+
factory = MagicMock()
3874+
factory.create.side_effect = inner_cursors
3875+
3876+
cursor = ConcurrentPerPartitionCursor(
3877+
cursor_factory=factory,
3878+
partition_router=MagicMock(),
3879+
stream_name="dup_stream",
3880+
stream_namespace=None,
3881+
stream_state={},
3882+
message_repository=MagicMock(),
3883+
connector_state_manager=MagicMock(),
3884+
connector_state_converter=MagicMock(),
3885+
cursor_field=CursorField(cursor_field_key="updated_at"),
3886+
)
3887+
3888+
partitions = [
3889+
StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}),
3890+
StreamSlice(partition={"id": "2"}, cursor_slice={}, extra_fields={}),
3891+
StreamSlice(partition={"id": "1"}, cursor_slice={}, extra_fields={}),
3892+
]
3893+
pr = cursor._partition_router
3894+
pr.stream_slices.return_value = iter(partitions)
3895+
pr.get_stream_state.return_value = {}
3896+
3897+
generated = list(cursor.stream_slices())
3898+
# Only “1” and “2” emitted – duplicate “1” skipped
3899+
assert len(generated) == 2
3900+
3901+
# Close “2” first
3902+
cursor.close_partition(
3903+
DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), generated[1])
3904+
)
3905+
# Now close the initial “1”
3906+
cursor.close_partition(
3907+
DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), generated[0])
3908+
)
3909+
3910+
assert cursor._IS_PARTITION_DUPLICATION_LOGGED is True # warning emitted
3911+
assert len(cursor._cursor_per_partition) == 2
3912+
assert len(cursor._semaphore_per_partition) == 0
3913+
assert len(cursor._processing_partitions_indexes) == 0
3914+
assert len(cursor._partition_key_to_index) == 0

0 commit comments

Comments
 (0)