Skip to content

Commit 357a925

Browse files
committed
Add global cursor with fallback
1 parent a01c0b5 commit 357a925

File tree

4 files changed

+1086
-14
lines changed

4 files changed

+1086
-14
lines changed

airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
#
44
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
55
#
6+
import threading
67
import logging
78
from collections import OrderedDict
89
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional
910

1011
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
12+
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import iterate_with_last_flag_and_state, Timer
1113
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
1214
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
1315
from airbyte_cdk.sources.message import MessageRepository
@@ -77,6 +79,15 @@ def __init__(
7779
# The dict is ordered to ensure that once the maximum number of partitions is reached,
7880
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
7981
self._cursor_per_partition: OrderedDict[str, Cursor] = OrderedDict()
82+
self._state = {"states": []}
83+
self._semaphore_per_partition = OrderedDict()
84+
self._finished_partitions = set()
85+
self._lock = threading.Lock()
86+
self._timer = Timer()
87+
self._global_cursor = None
88+
self._new_global_cursor = None
89+
self._lookback_window = 0
90+
self._parent_state = None
8091
self._over_limit = 0
8192
self._partition_serializer = PerPartitionKeySerializer()
8293

@@ -91,7 +102,7 @@ def state(self) -> MutableMapping[str, Any]:
91102
states = []
92103
for partition_tuple, cursor in self._cursor_per_partition.items():
93104
cursor_state = cursor._connector_state_converter.convert_to_state_message(
94-
cursor._cursor_field, cursor.state
105+
self.cursor_field, cursor.state
95106
)
96107
if cursor_state:
97108
states.append(
@@ -101,16 +112,40 @@ def state(self) -> MutableMapping[str, Any]:
101112
}
102113
)
103114
state: dict[str, Any] = {"states": states}
115+
116+
state["state"] = self._global_cursor
117+
if self._lookback_window is not None:
118+
state["lookback_window"] = self._lookback_window
119+
if self._parent_state is not None:
120+
state["parent_state"] = self._parent_state
121+
print(state)
104122
return state
105123

106124
def close_partition(self, partition: Partition) -> None:
107-
self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)].close_partition_without_emit(partition=partition)
125+
print(f"Closing partition {self._to_partition_key(partition._stream_slice.partition)}")
126+
self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)].close_partition(partition=partition)
127+
with (self._lock):
128+
self._semaphore_per_partition[self._to_partition_key(partition._stream_slice.partition)].acquire()
129+
cursor = self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)]
130+
cursor_state = cursor._connector_state_converter.convert_to_state_message(
131+
cursor._cursor_field, cursor.state
132+
)
133+
print(f"State {cursor_state} {cursor.state}")
134+
if self._to_partition_key(partition._stream_slice.partition) in self._finished_partitions \
135+
and self._semaphore_per_partition[self._to_partition_key(partition._stream_slice.partition)]._value == 0:
136+
if self._new_global_cursor is None \
137+
or self._new_global_cursor[self.cursor_field.cursor_field_key] < cursor_state[self.cursor_field.cursor_field_key]:
138+
self._new_global_cursor = copy.deepcopy(cursor_state)
108139

109140
def ensure_at_least_one_state_emitted(self) -> None:
110141
"""
111142
The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
112143
called.
113144
"""
145+
if not any(semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items()):
146+
self._global_cursor = self._new_global_cursor
147+
self._lookback_window = self._timer.finish()
148+
self._parent_state = self._partition_router.get_stream_state()
114149
self._emit_state_message()
115150

116151
def _emit_state_message(self) -> None:
@@ -127,6 +162,7 @@ def _emit_state_message(self) -> None:
127162

128163
def stream_slices(self) -> Iterable[StreamSlice]:
129164
slices = self._partition_router.stream_slices()
165+
self._timer.start()
130166
for partition in slices:
131167
yield from self.generate_slices_from_partition(partition)
132168

@@ -143,8 +179,15 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str
143179
)
144180
cursor = self._create_cursor(partition_state)
145181
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
146-
147-
for cursor_slice in cursor.stream_slices():
182+
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = threading.Semaphore(0)
183+
184+
for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
185+
cursor.stream_slices(),
186+
lambda: None,
187+
):
188+
self._semaphore_per_partition[self._to_partition_key(partition.partition)].release()
189+
if is_last_slice:
190+
self._finished_partitions.add(self._to_partition_key(partition.partition))
148191
yield StreamSlice(
149192
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
150193
)
@@ -208,6 +251,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
208251
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
209252
self._create_cursor(state["cursor"])
210253
)
254+
self._semaphore_per_partition[self._to_partition_key(state["partition"])] = threading.Semaphore(0)
211255

212256
# set default state for missing partitions if it is per partition with fallback to global
213257
if "state" in stream_state:
@@ -217,6 +261,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
217261
self._partition_router.set_initial_state(stream_state)
218262

219263
def observe(self, record: Record) -> None:
264+
print(self._to_partition_key(record.associated_slice.partition), record)
220265
self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)].observe(record)
221266

222267
def _to_partition_key(self, partition: Mapping[str, Any]) -> str:

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@
376376
InMemoryMessageRepository,
377377
LogAppenderMessageRepositoryDecorator,
378378
MessageRepository,
379+
NoopMessageRepository,
379380
)
380381
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField
381382
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import (
@@ -773,6 +774,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
773774
stream_namespace: Optional[str],
774775
config: Config,
775776
stream_state: MutableMapping[str, Any],
777+
message_repository: Optional[MessageRepository] = None,
776778
**kwargs: Any,
777779
) -> ConcurrentCursor:
778780
component_type = component_definition.get("type")
@@ -908,7 +910,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
908910
stream_name=stream_name,
909911
stream_namespace=stream_namespace,
910912
stream_state=stream_state,
911-
message_repository=self._message_repository,
913+
message_repository=message_repository or self._message_repository,
912914
connector_state_manager=state_manager,
913915
connector_state_converter=connector_state_converter,
914916
cursor_field=cursor_field,
@@ -961,6 +963,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
961963
stream_name=stream_name,
962964
stream_namespace=stream_namespace,
963965
config=config,
966+
message_repository=NoopMessageRepository()
964967
)
965968
)
966969

airbyte_cdk/sources/streams/concurrent/cursor.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _get_concurrent_state(
226226
)
227227

228228
def observe(self, record: Record) -> None:
229+
print(f"Observing record: {record}")
229230
most_recent_cursor_value = self._most_recent_cursor_value_per_partition.get(
230231
record.associated_slice
231232
)
@@ -240,15 +241,6 @@ def observe(self, record: Record) -> None:
240241
def _extract_cursor_value(self, record: Record) -> Any:
241242
return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record))
242243

243-
def close_partition_without_emit(self, partition: Partition) -> None:
244-
slice_count_before = len(self.state.get("slices", []))
245-
self._add_slice_to_state(partition)
246-
if slice_count_before < len(
247-
self.state["slices"]
248-
): # only emit if at least one slice has been processed
249-
self._merge_partitions()
250-
self._has_closed_at_least_one_slice = True
251-
252244
def close_partition(self, partition: Partition) -> None:
253245
slice_count_before = len(self.state.get("slices", []))
254246
self._add_slice_to_state(partition)

0 commit comments

Comments
 (0)