diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 1f4a1b81a..5967ec55c 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -114,7 +114,8 @@ def on_partition_complete_sentinel( try: if sentinel.is_successful: - partition.close() + stream = self._stream_name_to_instance[partition.stream_name()] + stream.cursor.close_partition(partition) except Exception as exception: self._flag_exception(partition.stream_name(), exception) yield AirbyteTracedException.from_exception( diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py index fda609ae8..0c5daf06e 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py @@ -226,7 +226,6 @@ def __init__( sync_mode: SyncMode, cursor_field: Optional[List[str]], state: Optional[MutableMapping[str, Any]], - cursor: "AbstractConcurrentFileBasedCursor", ): self._stream = stream self._slice = _slice @@ -234,8 +233,6 @@ def __init__( self._sync_mode = sync_mode self._cursor_field = cursor_field self._state = state - self._cursor = cursor - self._is_closed = False def read(self) -> Iterable[Record]: try: @@ -289,13 +286,6 @@ def to_slice(self) -> Optional[Mapping[str, Any]]: file = self._slice["files"][0] return {"files": [file]} - def close(self) -> None: - self._cursor.close_partition(self) - self._is_closed = True - - def is_closed(self) -> bool: - return self._is_closed - def __hash__(self) -> int: if self._slice: # Convert the slice to a string so that it can be hashed @@ -352,7 +342,6 @@ def generate(self) -> Iterable[FileBasedStreamPartition]: self._sync_mode, self._cursor_field, self._state, - self._cursor, ) ) self._cursor.set_pending_partitions(pending_partitions) diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index d4b539a52..1df713037 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -96,7 +96,6 @@ def create_from_stream( else SyncMode.incremental, [cursor_field] if cursor_field is not None else None, state, - cursor, ), name=stream.name, namespace=stream.namespace, @@ -259,7 +258,6 @@ def __init__( sync_mode: SyncMode, cursor_field: Optional[List[str]], state: Optional[MutableMapping[str, Any]], - cursor: Cursor, ): """ :param stream: The stream to delegate to @@ -272,8 +270,6 @@ def __init__( self._sync_mode = sync_mode self._cursor_field = cursor_field self._state = state - self._cursor = cursor - self._is_closed = False def read(self) -> Iterable[Record]: """ @@ -323,13 +319,6 @@ def __hash__(self) -> int: def stream_name(self) -> str: return self._stream.name - def close(self) -> None: - self._cursor.close_partition(self) - self._is_closed = True - - def is_closed(self) -> bool: - return self._is_closed - def __repr__(self) -> str: return f"StreamPartition({self._stream.name}, {self._slice})" @@ -349,7 +338,6 @@ def __init__( sync_mode: SyncMode, cursor_field: Optional[List[str]], state: Optional[MutableMapping[str, Any]], - cursor: Cursor, ): """ :param stream: The stream to delegate to @@ -360,7 +348,6 @@ def __init__( self._sync_mode = sync_mode self._cursor_field = cursor_field self._state = state - self._cursor = cursor def generate(self) -> Iterable[Partition]: for s in self._stream.stream_slices( @@ -373,7 +360,6 @@ def generate(self) -> Iterable[Partition]: self._sync_mode, self._cursor_field, self._state, - self._cursor, ) @@ -451,7 +437,6 @@ def generate(self) -> Iterable[Partition]: self._sync_mode, self._cursor_field, self._state, - self._cursor, ) diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/partition.py b/airbyte_cdk/sources/streams/concurrent/partitions/partition.py index 09f83d8f8..b51baf812 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/partition.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/partition.py @@ -40,21 +40,6 @@ def stream_name(self) -> str: """ pass - @abstractmethod - def close(self) -> None: - """ - Closes the partition. - """ - pass - - @abstractmethod - def is_closed(self) -> bool: - """ - Returns whether the partition is closed. - :return: - """ - pass - @abstractmethod def __hash__(self) -> int: """ diff --git a/unit_tests/sources/file_based/stream/concurrent/test_adapters.py b/unit_tests/sources/file_based/stream/concurrent/test_adapters.py index 3c271dfe4..aea439735 100644 --- a/unit_tests/sources/file_based/stream/concurrent/test_adapters.py +++ b/unit_tests/sources/file_based/stream/concurrent/test_adapters.py @@ -124,7 +124,12 @@ def test_file_based_stream_partition(transformer, expected_records): cursor_field = None state = None partition = FileBasedStreamPartition( - stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR + stream, + _slice, + message_repository, + sync_mode, + cursor_field, + state, ) a_log_message = AirbyteMessage( @@ -168,7 +173,6 @@ def test_file_based_stream_partition_raising_exception(exception_type, expected_ _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, - _ANY_CURSOR, ) stream.read_records.side_effect = Exception() @@ -204,7 +208,12 @@ def test_file_based_stream_partition_hash(_slice, expected_hash): stream = Mock() stream.name = "stream" partition = FileBasedStreamPartition( - stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR + stream, + _slice, + Mock(), + _ANY_SYNC_MODE, + _ANY_CURSOR_FIELD, + _ANY_STATE, ) _hash = partition.__hash__() diff --git a/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py b/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py index ce48f845f..4f4d0b5fa 100644 --- a/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py +++ b/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py @@ -214,7 +214,6 @@ def test_add_file( SyncMode.full_refresh, FileBasedConcurrentCursor.CURSOR_FIELD, initial_state, - cursor, ) for uri, timestamp in pending_files ] diff --git a/unit_tests/sources/streams/concurrent/test_adapters.py b/unit_tests/sources/streams/concurrent/test_adapters.py index cbebfe7ce..93e8fd212 100644 --- a/unit_tests/sources/streams/concurrent/test_adapters.py +++ b/unit_tests/sources/streams/concurrent/test_adapters.py @@ -76,7 +76,7 @@ def test_stream_partition_generator(sync_mode): stream.stream_slices.return_value = stream_slices partition_generator = StreamPartitionGenerator( - stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR + stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE ) partitions = list(partition_generator.generate()) @@ -115,9 +115,7 @@ def test_stream_partition(transformer, expected_records): sync_mode = SyncMode.full_refresh cursor_field = None state = None - partition = StreamPartition( - stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR - ) + partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state) a_log_message = AirbyteMessage( type=MessageType.LOG, @@ -162,7 +160,6 @@ def test_stream_partition_raising_exception(exception_type, expected_display_mes _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, - _ANY_CURSOR, ) stream.read_records.side_effect = Exception() @@ -188,7 +185,7 @@ def test_stream_partition_hash(_slice, expected_hash): stream = Mock() stream.name = "stream" partition = StreamPartition( - stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR + stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE ) _hash = partition.__hash__() diff --git a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index f6f6ecfba..cf94f8f9f 100644 --- a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -249,7 +249,7 @@ def test_handle_on_partition_complete_sentinel_with_messages_from_repository(sel ] assert messages == expected_messages - partition.close.assert_called_once() + self._stream.cursor.close_partition.assert_called_once() @freezegun.freeze_time("2020-01-01T00:00:00") def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stream_is_done( @@ -298,14 +298,14 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre ) ] assert messages == expected_messages - self._a_closed_partition.close.assert_called_once() + self._another_stream.cursor.close_partition.assert_called_once() @freezegun.freeze_time("2020-01-01T00:00:00") def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_message_and_stream_is_incomplete( self, ) -> None: self._a_closed_partition.stream_name.return_value = self._stream.name - self._a_closed_partition.close.side_effect = ValueError + self._stream.cursor.close_partition.side_effect = ValueError handler = ConcurrentReadProcessor( [self._stream], @@ -375,7 +375,7 @@ def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_s expected_messages = [] assert messages == expected_messages - partition.close.assert_called_once() + self._stream.cursor.close_partition.assert_called_once() @freezegun.freeze_time("2020-01-01T00:00:00") def test_on_record_no_status_message_no_repository_messge(self): @@ -733,7 +733,7 @@ def test_given_partition_completion_is_not_success_then_do_not_close_partition(s ) ) - assert self._an_open_partition.close.call_count == 0 + assert self._stream.cursor.close_partition.call_count == 0 def test_is_done_is_false_if_there_are_any_instances_to_read_from(self): stream_instances_to_read_from = [self._stream]