Skip to content

Commit 35290eb

Browse files
committed
[ISSUE #10552] move stream slicer concept in concurrent CDK
1 parent 4aaf1e7 commit 35290eb

File tree

12 files changed

+384
-264
lines changed

12 files changed

+384
-264
lines changed

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
)
3131
from airbyte_cdk.sources.declarative.requesters import HttpRequester
3232
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
33+
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
34+
DeclarativePartitionFactory,
35+
StreamSlicerPartitionGenerator,
36+
)
3337
from airbyte_cdk.sources.declarative.transformations.add_fields import AddFields
3438
from airbyte_cdk.sources.declarative.types import ConnectionDefinition
3539
from airbyte_cdk.sources.source import TState
3640
from airbyte_cdk.sources.streams import Stream
3741
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
38-
from airbyte_cdk.sources.streams.concurrent.adapters import CursorPartitionGenerator
3942
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
4043
AlwaysAvailableAvailabilityStrategy,
4144
)
@@ -228,13 +231,15 @@ def _group_streams(
228231
)
229232
declarative_stream.retriever.cursor = None
230233

231-
partition_generator = CursorPartitionGenerator(
232-
stream=declarative_stream,
233-
message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory
234-
cursor=cursor,
235-
connector_state_converter=connector_state_converter,
236-
cursor_field=[cursor.cursor_field.cursor_field_key],
237-
slice_boundary_fields=cursor.slice_boundary_fields,
234+
235+
partition_generator = StreamSlicerPartitionGenerator(
236+
DeclarativePartitionFactory(
237+
declarative_stream.name,
238+
declarative_stream.get_json_schema(),
239+
declarative_stream.retriever,
240+
self.message_repository,
241+
),
242+
cursor,
238243
)
239244

240245
concurrent_streams.append(

airbyte_cdk/sources/declarative/manifest_declarative_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def resolved_manifest(self) -> Mapping[str, Any]:
9494
return self._source_config
9595

9696
@property
97-
def message_repository(self) -> Union[None, MessageRepository]:
97+
def message_repository(self) -> MessageRepository:
9898
return self._message_repository
9999

100100
@property
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
from typing import Iterable, Optional, Mapping, Any
4+
5+
from airbyte_cdk.sources.declarative.retrievers import Retriever
6+
from airbyte_cdk.sources.message import MessageRepository
7+
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
8+
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
9+
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
10+
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
11+
from airbyte_cdk.sources.types import StreamSlice
12+
from airbyte_cdk.utils.slice_hasher import SliceHasher
13+
14+
15+
class DeclarativePartitionFactory:
16+
def __init__(self, stream_name: str, json_schema: Mapping[str, Any], retriever: Retriever, message_repository: MessageRepository) -> None:
17+
self._stream_name = stream_name
18+
self._json_schema = json_schema
19+
self._retriever = retriever # FIXME: it should be a retriever_factory here to ensure that paginators and other classes don't share interal/class state
20+
self._message_repository = message_repository
21+
22+
def create(self, stream_slice: StreamSlice) -> Partition:
23+
return DeclarativePartition(
24+
self._stream_name,
25+
self._json_schema,
26+
self._retriever,
27+
self._message_repository,
28+
stream_slice,
29+
)
30+
31+
32+
class DeclarativePartition(Partition):
33+
def __init__(self, stream_name: str, json_schema: Mapping[str, Any], retriever: Retriever, message_repository: MessageRepository, stream_slice: StreamSlice):
34+
self._stream_name = stream_name
35+
self._json_schema = json_schema
36+
self._retriever = retriever
37+
self._message_repository = message_repository
38+
self._stream_slice = stream_slice
39+
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)
40+
41+
def read(self) -> Iterable[Record]:
42+
for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice):
43+
if isinstance(stream_data, Mapping):
44+
# TODO validate if this is necessary: self._stream.transformer.transform(data_to_return, self._stream.get_json_schema())
45+
yield Record(stream_data, self)
46+
else:
47+
self._message_repository.emit_message(stream_data)
48+
49+
def to_slice(self) -> Optional[Mapping[str, Any]]:
50+
return self._stream_slice
51+
52+
def stream_name(self) -> str:
53+
return self._stream_name
54+
55+
def __hash__(self) -> int:
56+
return self._hash
57+
58+
59+
class StreamSlicerPartitionGenerator(PartitionGenerator):
60+
def __init__(self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer) -> None:
61+
self._partition_factory = partition_factory
62+
self._stream_slicer = stream_slicer
63+
64+
def generate(self) -> Iterable[Partition]:
65+
for stream_slice in self._stream_slicer.stream_slices():
66+
yield self._partition_factory.create(stream_slice)

airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,15 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44

5-
from abc import abstractmethod
6-
from dataclasses import dataclass
7-
from typing import Iterable
5+
from abc import ABC
86

97
from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import (
108
RequestOptionsProvider,
119
)
12-
from airbyte_cdk.sources.types import StreamSlice
10+
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer as ConcurrentStreamSlicer
1311

1412

15-
@dataclass
16-
class StreamSlicer(RequestOptionsProvider):
13+
class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC):
1714
"""
1815
Slices the stream into a subset of records.
1916
Slices enable state checkpointing and data retrieval parallelization.
@@ -22,11 +19,4 @@ class StreamSlicer(RequestOptionsProvider):
2219
2320
See the stream slicing section of the docs for more information.
2421
"""
25-
26-
@abstractmethod
27-
def stream_slices(self) -> Iterable[StreamSlice]:
28-
"""
29-
Defines stream slices
30-
31-
:return: List of stream slices
32-
"""
22+
pass

airbyte_cdk/sources/streams/concurrent/adapters.py

Lines changed: 4 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
4848
from deprecated.classic import deprecated
4949

50+
from airbyte_cdk.utils.slice_hasher import SliceHasher
51+
5052
"""
5153
This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream
5254
"""
@@ -270,6 +272,7 @@ def __init__(
270272
self._sync_mode = sync_mode
271273
self._cursor_field = cursor_field
272274
self._state = state
275+
self._hash = SliceHasher.hash(self._stream.name, self._slice)
273276

274277
def read(self) -> Iterable[Record]:
275278
"""
@@ -309,12 +312,7 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
309312
return self._slice
310313

311314
def __hash__(self) -> int:
312-
if self._slice:
313-
# Convert the slice to a string so that it can be hashed
314-
s = json.dumps(self._slice, sort_keys=True, cls=SliceEncoder)
315-
return hash((self._stream.name, s))
316-
else:
317-
return hash(self._stream.name)
315+
return self._hash
318316

319317
def stream_name(self) -> str:
320318
return self._stream.name
@@ -363,83 +361,6 @@ def generate(self) -> Iterable[Partition]:
363361
)
364362

365363

366-
class CursorPartitionGenerator(PartitionGenerator):
367-
"""
368-
This class generates partitions using the concurrent cursor and iterates through state slices to generate partitions.
369-
370-
It is used when synchronizing a stream in incremental or full-refresh mode where state information is maintained
371-
across partitions. Each partition represents a subset of the stream's data and is determined by the cursor's state.
372-
"""
373-
374-
_START_BOUNDARY = 0
375-
_END_BOUNDARY = 1
376-
377-
def __init__(
378-
self,
379-
stream: Stream,
380-
message_repository: MessageRepository,
381-
cursor: Cursor,
382-
connector_state_converter: DateTimeStreamStateConverter,
383-
cursor_field: Optional[List[str]],
384-
slice_boundary_fields: Optional[Tuple[str, str]],
385-
):
386-
"""
387-
Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor.
388-
389-
:param stream: The stream to delegate to for partition generation.
390-
:param message_repository: The message repository to use to emit non-record messages.
391-
:param sync_mode: The synchronization mode.
392-
:param cursor: A Cursor object that maintains the state and the cursor field.
393-
"""
394-
self._stream = stream
395-
self.message_repository = message_repository
396-
self._sync_mode = SyncMode.full_refresh
397-
self._cursor = cursor
398-
self._cursor_field = cursor_field
399-
self._state = self._cursor.state
400-
self._slice_boundary_fields = slice_boundary_fields
401-
self._connector_state_converter = connector_state_converter
402-
403-
def generate(self) -> Iterable[Partition]:
404-
"""
405-
Generate partitions based on the slices in the cursor's state.
406-
407-
This method iterates through the list of slices found in the cursor's state, and for each slice, it generates
408-
a `StreamPartition` object.
409-
410-
:return: An iterable of StreamPartition objects.
411-
"""
412-
413-
start_boundary = (
414-
self._slice_boundary_fields[self._START_BOUNDARY]
415-
if self._slice_boundary_fields
416-
else "start"
417-
)
418-
end_boundary = (
419-
self._slice_boundary_fields[self._END_BOUNDARY]
420-
if self._slice_boundary_fields
421-
else "end"
422-
)
423-
424-
for slice_start, slice_end in self._cursor.generate_slices():
425-
stream_slice = StreamSlice(
426-
partition={},
427-
cursor_slice={
428-
start_boundary: self._connector_state_converter.output_format(slice_start),
429-
end_boundary: self._connector_state_converter.output_format(slice_end),
430-
},
431-
)
432-
433-
yield StreamPartition(
434-
self._stream,
435-
copy.deepcopy(stream_slice),
436-
self.message_repository,
437-
self._sync_mode,
438-
self._cursor_field,
439-
self._state,
440-
)
441-
442-
443364
@deprecated(
444365
"Availability strategy has been soft deprecated. Do not use. Class is subject to removal",
445366
category=ExperimentalClassWarning,

airbyte_cdk/sources/streams/concurrent/cursor.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from airbyte_cdk.sources.streams import NO_CURSOR_STATE_KEY
1212
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
1313
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
14+
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
1415
from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import (
1516
AbstractStreamStateConverter,
1617
)
18+
from airbyte_cdk.sources.types import StreamSlice
1719

1820

1921
def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
@@ -61,7 +63,7 @@ def extract_value(self, record: Record) -> CursorValueType:
6163
return cursor_value # type: ignore # we assume that the value the path points at is a comparable
6264

6365

64-
class Cursor(ABC):
66+
class Cursor(StreamSlicer, ABC):
6567
@property
6668
@abstractmethod
6769
def state(self) -> MutableMapping[str, Any]: ...
@@ -88,12 +90,12 @@ def ensure_at_least_one_state_emitted(self) -> None:
8890
"""
8991
raise NotImplementedError()
9092

91-
def generate_slices(self) -> Iterable[Tuple[Any, Any]]:
93+
def stream_slices(self) -> Iterable[StreamSlice]:
9294
"""
9395
Default placeholder implementation of generate_slices.
9496
Subclasses can override this method to provide actual behavior.
9597
"""
96-
yield from ()
98+
yield StreamSlice(partition={}, cursor_slice={})
9799

98100

99101
class FinalStateCursor(Cursor):
@@ -184,8 +186,8 @@ def cursor_field(self) -> CursorField:
184186
return self._cursor_field
185187

186188
@property
187-
def slice_boundary_fields(self) -> Optional[Tuple[str, str]]:
188-
return self._slice_boundary_fields
189+
def _slice_boundary_fields_wrapper(self) -> Tuple[str, str]:
190+
return self._slice_boundary_fields if self._slice_boundary_fields else (self._connector_state_converter.START_KEY, self._connector_state_converter.END_KEY)
189191

190192
def _get_concurrent_state(
191193
self, state: MutableMapping[str, Any]
@@ -299,7 +301,7 @@ def ensure_at_least_one_state_emitted(self) -> None:
299301
"""
300302
self._emit_state_message()
301303

302-
def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]:
304+
def stream_slices(self) -> Iterable[StreamSlice]:
303305
"""
304306
Generating slices based on a few parameters:
305307
* lookback_window: Buffer to remove from END_KEY of the highest slice
@@ -368,7 +370,7 @@ def _calculate_lower_boundary_of_last_slice(
368370

369371
def _split_per_slice_range(
370372
self, lower: CursorValueType, upper: CursorValueType, upper_is_end: bool
371-
) -> Iterable[Tuple[CursorValueType, CursorValueType]]:
373+
) -> Iterable[StreamSlice]:
372374
if lower >= upper:
373375
return
374376

@@ -377,10 +379,14 @@ def _split_per_slice_range(
377379

378380
lower = max(lower, self._start) if self._start else lower
379381
if not self._slice_range or self._evaluate_upper_safely(lower, self._slice_range) >= upper:
380-
if self._cursor_granularity and not upper_is_end:
381-
yield lower, upper - self._cursor_granularity
382-
else:
383-
yield lower, upper
382+
start_value, end_value = (lower, upper - self._cursor_granularity) if self._cursor_granularity and not upper_is_end else (lower, upper)
383+
yield StreamSlice(
384+
partition={},
385+
cursor_slice={
386+
self._slice_boundary_fields_wrapper[self._START_BOUNDARY]: self._connector_state_converter.output_format(start_value),
387+
self._slice_boundary_fields_wrapper[self._END_BOUNDARY]: self._connector_state_converter.output_format(end_value)
388+
}
389+
)
384390
else:
385391
stop_processing = False
386392
current_lower_boundary = lower
@@ -389,12 +395,15 @@ def _split_per_slice_range(
389395
self._evaluate_upper_safely(current_lower_boundary, self._slice_range), upper
390396
)
391397
has_reached_upper_boundary = current_upper_boundary >= upper
392-
if self._cursor_granularity and (
393-
not upper_is_end or not has_reached_upper_boundary
394-
):
395-
yield current_lower_boundary, current_upper_boundary - self._cursor_granularity
396-
else:
397-
yield current_lower_boundary, current_upper_boundary
398+
399+
start_value, end_value = (current_lower_boundary, current_upper_boundary - self._cursor_granularity) if self._cursor_granularity and (not upper_is_end or not has_reached_upper_boundary) else (current_lower_boundary, current_upper_boundary)
400+
yield StreamSlice(
401+
partition={},
402+
cursor_slice={
403+
self._slice_boundary_fields_wrapper[self._START_BOUNDARY]: self._connector_state_converter.output_format(start_value),
404+
self._slice_boundary_fields_wrapper[self._END_BOUNDARY]: self._connector_state_converter.output_format(end_value)
405+
}
406+
)
398407
current_lower_boundary = current_upper_boundary
399408
if current_upper_boundary >= upper:
400409
stop_processing = True
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Iterable
5+
6+
from airbyte_cdk.sources.types import StreamSlice
7+
8+
9+
class StreamSlicer(ABC):
10+
"""
11+
Slices the stream into chunks that can be fetched independently. Slices enable state checkpointing and data retrieval parallelization.
12+
"""
13+
14+
@abstractmethod
15+
def stream_slices(self) -> Iterable[StreamSlice]:
16+
"""
17+
Defines stream slices
18+
19+
:return: An iterable of stream slices
20+
"""
21+
pass

0 commit comments

Comments
 (0)