diff --git a/airbyte_cdk/connector_builder/connector_builder_handler.py b/airbyte_cdk/connector_builder/connector_builder_handler.py index 27929dfa2..ebd196976 100644 --- a/airbyte_cdk/connector_builder/connector_builder_handler.py +++ b/airbyte_cdk/connector_builder/connector_builder_handler.py @@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict, List, Mapping +from typing import Any, Dict, List, Mapping, Optional from airbyte_cdk.connector_builder.test_reader import TestReader from airbyte_cdk.models import ( @@ -15,6 +15,14 @@ Type, ) from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( + DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, + DEFAULT_MAXIMUM_NUMBER_OF_SLICES, + DEFAULT_MAXIMUM_RECORDS, + DEFAULT_MAXIMUM_STREAMS, + ConcurrentDeclarativeSource, + TestLimits, +) from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( @@ -24,25 +32,12 @@ from airbyte_cdk.utils.datetime_helpers import ab_datetime_now from airbyte_cdk.utils.traced_exception import AirbyteTracedException -DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5 -DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5 -DEFAULT_MAXIMUM_RECORDS = 100 -DEFAULT_MAXIMUM_STREAMS = 100 - MAX_PAGES_PER_SLICE_KEY = "max_pages_per_slice" MAX_SLICES_KEY = "max_slices" MAX_RECORDS_KEY = "max_records" MAX_STREAMS_KEY = "max_streams" -@dataclass -class TestLimits: - max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS) - max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE) - max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES) - max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS) - - def get_limits(config: Mapping[str, Any]) -> TestLimits: command_config = config.get("__test_read_config", {}) max_pages_per_slice = ( @@ -54,19 +49,24 @@ def get_limits(config: Mapping[str, Any]) -> TestLimits: return TestLimits(max_records, max_pages_per_slice, max_slices, max_streams) -def create_source(config: Mapping[str, Any], limits: TestLimits) -> ManifestDeclarativeSource: +def _ensure_concurrency_level(manifest: Dict[str, Any]) -> None: + # We need to do that to ensure that the state in the StreamReadSlices only contains the changes for one slice + # Note that this is below the _LOWEST_SAFE_CONCURRENCY_LEVEL but it is fine in this case because we are limiting the number of slices + # being generated which means that the memory usage is limited anyway + if "concurrency_level" not in manifest: + manifest["concurrency_level"] = {"type": "ConcurrencyLevel"} + manifest["concurrency_level"]["default_concurrency"] = 1 + +def create_source(config: Mapping[str, Any], limits: TestLimits, catalog: Optional[ConfiguredAirbyteCatalog] = None, state: Any = None) -> ManifestDeclarativeSource: manifest = config["__injected_declarative_manifest"] - return ManifestDeclarativeSource( + _ensure_concurrency_level(manifest) + return ConcurrentDeclarativeSource( config=config, - emit_connector_builder_messages=True, + catalog=catalog, + state=state, source_config=manifest, - component_factory=ModelToComponentFactory( - emit_connector_builder_messages=True, - limit_pages_fetched_per_slice=limits.max_pages_per_slice, - limit_slices_fetched=limits.max_slices, - disable_retries=True, - disable_cache=True, - ), + emit_connector_builder_messages=True, + limits=limits, ) diff --git a/airbyte_cdk/connector_builder/main.py b/airbyte_cdk/connector_builder/main.py index ad2d6650f..ae8e9ba55 100644 --- a/airbyte_cdk/connector_builder/main.py +++ b/airbyte_cdk/connector_builder/main.py @@ -91,7 +91,7 @@ def handle_connector_builder_request( def handle_request(args: List[str]) -> str: command, config, catalog, state = get_config_and_catalog_from_args(args) limits = get_limits(config) - source = create_source(config, limits) + source = create_source(config, limits, catalog, state) return orjson.dumps( AirbyteMessageSerializer.dump( handle_connector_builder_request(source, command, config, catalog, state, limits) diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index f57db7e14..fe993c246 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -34,7 +34,6 @@ def __init__( partition_enqueuer: PartitionEnqueuer, thread_pool_manager: ThreadPoolManager, logger: logging.Logger, - slice_logger: SliceLogger, message_repository: MessageRepository, partition_reader: PartitionReader, ): @@ -44,7 +43,6 @@ def __init__( :param partition_enqueuer: PartitionEnqueuer instance :param thread_pool_manager: ThreadPoolManager instance :param logger: Logger instance - :param slice_logger: SliceLogger instance :param message_repository: MessageRepository instance :param partition_reader: PartitionReader instance """ @@ -59,7 +57,6 @@ def __init__( self._stream_instances_to_start_partition_generation = stream_instances_to_read_from self._streams_currently_generating_partitions: List[str] = [] self._logger = logger - self._slice_logger = slice_logger self._message_repository = message_repository self._partition_reader = partition_reader self._streams_done: Set[str] = set() @@ -95,11 +92,7 @@ def on_partition(self, partition: Partition) -> None: """ stream_name = partition.stream_name() self._streams_to_running_partitions[stream_name].add(partition) - if self._slice_logger.should_log_slice_message(self._logger): - self._message_repository.emit_message( - self._slice_logger.create_slice_log_message(partition.to_slice()) - ) - self._thread_pool_manager.submit(self._partition_reader.process_partition, partition) + self._thread_pool_manager.submit(self._partition_reader.process_partition, partition, self._stream_name_to_instance[partition.stream_name()].cursor) def on_partition_complete_sentinel( self, sentinel: PartitionCompleteSentinel @@ -112,26 +105,19 @@ def on_partition_complete_sentinel( """ partition = sentinel.partition - try: - if sentinel.is_successful: - stream = self._stream_name_to_instance[partition.stream_name()] - stream.cursor.close_partition(partition) - except Exception as exception: - self._flag_exception(partition.stream_name(), exception) - yield AirbyteTracedException.from_exception( - exception, stream_descriptor=StreamDescriptor(name=partition.stream_name()) - ).as_sanitized_airbyte_message() - finally: - partitions_running = self._streams_to_running_partitions[partition.stream_name()] - if partition in partitions_running: - partitions_running.remove(partition) - # If all partitions were generated and this was the last one, the stream is done - if ( - partition.stream_name() not in self._streams_currently_generating_partitions - and len(partitions_running) == 0 - ): - yield from self._on_stream_is_done(partition.stream_name()) - yield from self._message_repository.consume_queue() + if sentinel.is_successful: + stream = self._stream_name_to_instance[partition.stream_name()] + + partitions_running = self._streams_to_running_partitions[partition.stream_name()] + if partition in partitions_running: + partitions_running.remove(partition) + # If all partitions were generated and this was the last one, the stream is done + if ( + partition.stream_name() not in self._streams_currently_generating_partitions + and len(partitions_running) == 0 + ): + yield from self._on_stream_is_done(partition.stream_name()) + yield from self._message_repository.consume_queue() def on_record(self, record: Record) -> Iterable[AirbyteMessage]: """ @@ -160,7 +146,6 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]: stream.as_airbyte_stream(), AirbyteStreamStatus.RUNNING ) self._record_counter[stream.name] += 1 - stream.cursor.observe(record) yield message yield from self._message_repository.consume_queue() diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index bc7d97cdd..d119858a5 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -4,7 +4,7 @@ import concurrent import logging from queue import Queue -from typing import Iterable, Iterator, List +from typing import Iterable, Iterator, List, Optional from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor @@ -16,7 +16,7 @@ from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer -from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader +from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionLogger, PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( PartitionCompleteSentinel, @@ -44,6 +44,7 @@ def create( slice_logger: SliceLogger, message_repository: MessageRepository, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, + queue: Optional[Queue[QueueItem]] = None ) -> "ConcurrentSource": is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1 too_many_generator = ( @@ -65,6 +66,7 @@ def create( message_repository, initial_number_of_partitions_to_generate, timeout_seconds, + queue, ) def __init__( @@ -75,6 +77,7 @@ def __init__( message_repository: MessageRepository = InMemoryMessageRepository(), initial_number_partitions_to_generate: int = 1, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, + queue: Optional[Queue[QueueItem]] = None, ) -> None: """ :param threadpool: The threadpool to submit tasks to @@ -90,6 +93,7 @@ def __init__( self._message_repository = message_repository self._initial_number_partitions_to_generate = initial_number_partitions_to_generate self._timeout_seconds = timeout_seconds + self._queue = queue if queue else Queue(maxsize=10_000) def read( self, @@ -101,15 +105,13 @@ def read( # threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating # partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more # information and might even need to be configurable depending on the source - queue: Queue[QueueItem] = Queue(maxsize=10_000) concurrent_stream_processor = ConcurrentReadProcessor( streams, - PartitionEnqueuer(queue, self._threadpool), + PartitionEnqueuer(self._queue, self._threadpool), self._threadpool, self._logger, - self._slice_logger, self._message_repository, - PartitionReader(queue), + PartitionReader(self._queue, PartitionLogger(self._slice_logger, self._logger, self._message_repository)), ) # Enqueue initial partition generation tasks @@ -117,7 +119,7 @@ def read( # Read from the queue until all partitions were generated and read yield from self._consume_from_queue( - queue, + self._queue, concurrent_stream_processor, ) self._threadpool.check_for_errors_and_shutdown() @@ -161,5 +163,7 @@ def _handle_item( yield from concurrent_stream_processor.on_partition_complete_sentinel(queue_item) elif isinstance(queue_item, Record): yield from concurrent_stream_processor.on_record(queue_item) + elif isinstance(queue_item, AirbyteMessage): + yield queue_item else: raise ValueError(f"Unknown queue item type: {type(queue_item)}") diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index e212b0f2a..5dc20b447 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -3,7 +3,22 @@ # import logging -from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple +from dataclasses import dataclass, field +from queue import Queue +from typing import ( + Any, + Callable, + Generic, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, +) + +from airbyte_protocol_dataclasses.models import Level from airbyte_cdk.models import ( AirbyteCatalog, @@ -43,7 +58,9 @@ DeclarativePartitionFactory, StreamSlicerPartitionGenerator, ) +from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import TestReadSlicerDecorator from airbyte_cdk.sources.declarative.types import ConnectionDefinition +from airbyte_cdk.sources.message import InMemoryMessageRepository, LogMessage, MessageRepository from airbyte_cdk.sources.source import TState from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -54,6 +71,45 @@ from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, FinalStateCursor from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.helpers import get_primary_key_from_stream +from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem + +DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5 +DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5 +DEFAULT_MAXIMUM_RECORDS = 100 +DEFAULT_MAXIMUM_STREAMS = 100 + + +@dataclass +class TestLimits: + max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS) + max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE) + max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES) + max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS) + + +class ConcurrentMessageRepository(MessageRepository): + + def __init__(self, queue: Queue[QueueItem], message_repository: MessageRepository): + self._queue = queue + self._decorated = message_repository + + def emit_message(self, message: AirbyteMessage) -> None: + self._decorated.emit_message(message) + for message in self._decorated.consume_queue(): + self._queue.put(message) + + def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: + self._decorated.log_message(level, message_provider) + for message in self._decorated.consume_queue(): + self._queue.put(message) + + def consume_queue(self) -> Iterable[AirbyteMessage]: + """ + The consumption of messages from the ConcurrentMessageRepository is done through the queue passed in parameters. + + TODO to confirm but it seems like yielding from self._queue.get() could cause locking issues as anyone can consume from the queue today, not just the main thread + """ + yield from [] class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]): @@ -69,7 +125,7 @@ def __init__( source_config: ConnectionDefinition, debug: bool = False, emit_connector_builder_messages: bool = False, - component_factory: Optional[ModelToComponentFactory] = None, + limits: Optional[TestLimits] = None, **kwargs: Any, ) -> None: # todo: We could remove state from initialization. Now that streams are grouped during the read(), a source @@ -80,10 +136,16 @@ def __init__( # cursors. We do this by no longer automatically instantiating RFR cursors when converting # the declarative models into runtime components. Concurrent sources will continue to checkpoint # incremental streams running in full refresh. - component_factory = component_factory or ModelToComponentFactory( + queue: Queue[QueueItem] = Queue(maxsize=10_000) + component_factory = ModelToComponentFactory( emit_connector_builder_messages=emit_connector_builder_messages, disable_resumable_full_refresh=True, connector_state_manager=self._connector_state_manager, + message_repository=ConcurrentMessageRepository(queue, InMemoryMessageRepository(Level.DEBUG if emit_connector_builder_messages else Level.INFO)), + limit_pages_fetched_per_slice = limits.max_pages_per_slice if limits else None, + limit_slices_fetched = limits.max_slices if limits else None, + disable_retries = True if limits else False, + disable_cache = True if limits else False, ) super().__init__( @@ -120,6 +182,7 @@ def __init__( logger=self.logger, slice_logger=self._slice_logger, message_repository=self.message_repository, + queue=queue, ) # TODO: Remove this. This property is necessary to safely migrate Stripe during the transition state. @@ -360,16 +423,19 @@ def _group_streams( and incremental_sync_component_definition.get("type", "") == DatetimeBasedCursorModel.__name__ and hasattr(declarative_stream.retriever, "stream_slicer") - and isinstance( - declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor + and ( + isinstance( + declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor + ) or + hasattr(declarative_stream.retriever.stream_slicer, "_decorated") and isinstance( + declarative_stream.retriever.stream_slicer._decorated, PerPartitionWithGlobalCursor) ) - ): + ): stream_state = self._connector_state_manager.get_stream_state( stream_name=declarative_stream.name, namespace=declarative_stream.namespace ) stream_state = self._migrate_state(declarative_stream, stream_state) - - partition_router = declarative_stream.retriever.stream_slicer._partition_router + partition_router = declarative_stream.retriever.stream_slicer._decorated._partition_router if hasattr(declarative_stream.retriever.stream_slicer, "_decorated") else declarative_stream.retriever.stream_slicer._partition_router perpartition_cursor = ( self._constructor.create_concurrent_cursor_from_perpartition_cursor( @@ -386,6 +452,9 @@ def _group_streams( retriever = self._get_retriever(declarative_stream, stream_state) + from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import ( + TestReadSlicerDecorator, + ) partition_generator = StreamSlicerPartitionGenerator( DeclarativePartitionFactory( declarative_stream.name, @@ -393,7 +462,7 @@ def _group_streams( retriever, self.message_repository, ), - perpartition_cursor, + TestReadSlicerDecorator(perpartition_cursor, declarative_stream.retriever.stream_slicer._maximum_number_of_slices) if hasattr(declarative_stream.retriever.stream_slicer, "_decorated") else perpartition_cursor, ) concurrent_streams.append( diff --git a/airbyte_cdk/sources/declarative/interpolation/macros.py b/airbyte_cdk/sources/declarative/interpolation/macros.py index dc77744dc..e754977fb 100644 --- a/airbyte_cdk/sources/declarative/interpolation/macros.py +++ b/airbyte_cdk/sources/declarative/interpolation/macros.py @@ -177,6 +177,7 @@ def format_datetime( dt_datetime = ( datetime.datetime.strptime(dt, input_format) if input_format else str_to_datetime(dt) ) + dt_datetime = dt_datetime.replace(tzinfo=pytz.utc) return DatetimeParser().format(dt=dt_datetime, format=format) diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 4f4638190..0c2ca1f9e 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -106,7 +106,6 @@ ) from airbyte_cdk.sources.declarative.models import ( CustomStateMigration, - GzipDecoder, ) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( AddedFieldDefinition as AddedFieldDefinitionModel, @@ -389,10 +388,6 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( ZipfileDecoder as ZipfileDecoderModel, ) -from airbyte_cdk.sources.declarative.parsers.custom_code_compiler import ( - COMPONENTS_MODULE_NAME, - SDM_COMPONENTS_MODULE_NAME, -) from airbyte_cdk.sources.declarative.partition_routers import ( CartesianProductStreamSlicer, GroupingPartitionRouter, @@ -464,6 +459,7 @@ ) from airbyte_cdk.sources.declarative.spec import Spec from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer +from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import TestReadSlicerDecorator from airbyte_cdk.sources.declarative.transformations import ( AddFields, RecordTransformation, @@ -518,7 +514,7 @@ IncrementingCountStreamStateConverter, ) from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction -from airbyte_cdk.sources.types import Config +from airbyte_cdk.sources.types import Config, StreamSlice, StreamState from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer ComponentDefinition = Mapping[str, Any] @@ -2845,6 +2841,8 @@ def create_simple_retriever( ) if self._limit_slices_fetched or self._emit_connector_builder_messages: + slice_limit = self._limit_slices_fetched or 5 + stream_slicer = TestReadSlicerDecorator(stream_slicer, slice_limit) # FIXME Once log formatter is removed, we can just pass this to the SimpleRetriever return SimpleRetrieverTestReadDecorator( name=name, paginator=paginator, @@ -2855,7 +2853,6 @@ def create_simple_retriever( request_option_provider=request_options_provider, cursor=cursor, config=config, - maximum_number_of_slices=self._limit_slices_fetched or 5, ignore_stream_slicer_parameters_on_paginated_requests=ignore_stream_slicer_parameters_on_paginated_requests, parameters=model.parameters or {}, ) diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index 65aa5d406..1d86199ae 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -584,10 +584,6 @@ def __post_init__(self, options: Mapping[str, Any]) -> None: f"The maximum number of slices on a test read needs to be strictly positive. Got {self.maximum_number_of_slices}" ) - # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore - return islice(super().stream_slices(), self.maximum_number_of_slices) - def _fetch_next_page( self, stream_state: Mapping[str, Any], @@ -623,6 +619,7 @@ def _fetch_next_page( stream_slice=stream_slice, next_page_token=next_page_token, ), + # FIXME remove this implementation and have the log_formatter depend on the fact that the logger is debug or not log_formatter=lambda response: format_http_message( response, f"Stream '{self.name}' request", diff --git a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py index a9b625e7d..f860d60fb 100644 --- a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py @@ -37,7 +37,7 @@ def get_json_schema(self) -> Mapping[str, Any]: try: return self.default_loader.get_json_schema() - except OSError: + except (OSError, ValueError): # A slight hack since we don't directly have the stream name. However, when building the default filepath we assume the # runtime options stores stream name 'name' so we'll do the same here stream_name = self._parameters.get("name", "") diff --git a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py index db15496ff..f27743578 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py @@ -3,6 +3,8 @@ # from abc import ABC +from itertools import islice +from typing import Any, Iterable, Mapping, Optional, Union from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( RequestOptionsProvider, @@ -10,6 +12,7 @@ from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import ( StreamSlicer as ConcurrentStreamSlicer, ) +from airbyte_cdk.sources.types import StreamSlice, StreamState class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC): @@ -23,3 +26,48 @@ class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC): """ pass + + +class TestReadSlicerDecorator(StreamSlicer): + """ + A stream slicer wrapper for test reads which limits the number of slices produced. + """ + + def __init__(self, stream_slicer: StreamSlicer, maximum_number_of_slices: int) -> None: + self._decorated = stream_slicer + self._maximum_number_of_slices = maximum_number_of_slices + + def stream_slices(self) -> Iterable[StreamSlice]: + return islice(self._decorated.stream_slices(), self._maximum_number_of_slices) + + def get_request_params(self, *, stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, + next_page_token: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]: + return self._decorated.get_request_params( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + + def get_request_headers(self, *, stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, + next_page_token: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]: + return self._decorated.get_request_headers( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + + def get_request_body_data(self, *, stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, + next_page_token: Optional[Mapping[str, Any]] = None) -> Union[Mapping[str, Any], str]: + return self._decorated.get_request_body_data( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + + def get_request_body_json(self, *, stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, + next_page_token: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]: + return self._decorated.get_request_body_json( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) \ No newline at end of file diff --git a/airbyte_cdk/sources/streams/concurrent/partition_reader.py b/airbyte_cdk/sources/streams/concurrent/partition_reader.py index 3d23fd9cf..20677537c 100644 --- a/airbyte_cdk/sources/streams/concurrent/partition_reader.py +++ b/airbyte_cdk/sources/streams/concurrent/partition_reader.py @@ -1,16 +1,33 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +import logging from queue import Queue +from typing import Optional from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.concurrent.cursor import Cursor from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( PartitionCompleteSentinel, QueueItem, ) +from airbyte_cdk.sources.utils.slice_logger import SliceLogger +class PartitionLogger: + def __init__(self, slice_logger: SliceLogger, logger: logging.Logger, message_repository: MessageRepository): + self._slice_logger = slice_logger + self._logger = logger + self._message_repository = message_repository + + def log(self, partition: Partition) -> None: + if self._slice_logger.should_log_slice_message(self._logger): + self._message_repository.emit_message( + self._slice_logger.create_slice_log_message(partition.to_slice()) + ) + class PartitionReader: """ Generates records from a partition and puts them in a queue. @@ -18,13 +35,14 @@ class PartitionReader: _IS_SUCCESSFUL = True - def __init__(self, queue: Queue[QueueItem]) -> None: + def __init__(self, queue: Queue[QueueItem], partition_logger: Optional[PartitionLogger]) -> None: """ :param queue: The queue to put the records in. """ self._queue = queue + self._partition_logger = partition_logger - def process_partition(self, partition: Partition) -> None: + def process_partition(self, partition: Partition, cursor: Cursor) -> None: """ Process a partition and put the records in the output queue. When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated. @@ -36,9 +54,16 @@ def process_partition(self, partition: Partition) -> None: :param partition: The partition to read data from :return: None """ + if self._partition_logger: + self._partition_logger.log(partition) + try: for record in partition.read(): self._queue.put(record) + cursor.observe(record) + + # this assumes the cursor will put a state message on the queue. It also needs to be before the completion sentinel else the concurrent_read_processor might end the sync before consuming the state + cursor.close_partition(partition) self._queue.put(PartitionCompleteSentinel(partition, self._IS_SUCCESSFUL)) except Exception as e: self._queue.put(StreamThreadException(e, partition.stream_name())) diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte_cdk/sources/streams/concurrent/partitions/types.py index 77644c6b9..415c388dd 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/types.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/types.py @@ -4,6 +4,7 @@ from typing import Any, Union +from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( PartitionGenerationCompletedSentinel, ) @@ -34,5 +35,5 @@ def __eq__(self, other: Any) -> bool: Typedef representing the items that can be added to the ThreadBasedConcurrentStream """ QueueItem = Union[ - Record, Partition, PartitionCompleteSentinel, PartitionGenerationCompletedSentinel, Exception + Record, Partition, PartitionCompleteSentinel, PartitionGenerationCompletedSentinel, Exception, AirbyteMessage, ] diff --git a/airbyte_cdk/test/entrypoint_wrapper.py b/airbyte_cdk/test/entrypoint_wrapper.py index f8e85bfb0..9eccbd902 100644 --- a/airbyte_cdk/test/entrypoint_wrapper.py +++ b/airbyte_cdk/test/entrypoint_wrapper.py @@ -157,6 +157,8 @@ def _run_command( stream_handler.setFormatter(AirbyteLogFormatter()) parent_logger = logging.getLogger("") parent_logger.addHandler(stream_handler) + if "--debug" not in args: + args.append("--debug") parsed_args = AirbyteEntrypoint.parse_args(args) @@ -195,7 +197,7 @@ def discover( config_file = make_file(tmp_directory_path / "config.json", config) return _run_command( - source, ["discover", "--config", config_file, "--debug"], expecting_exception + source, ["discover", "--config", config_file], expecting_exception ) diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index 5c537811b..0950e5efd 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -61,6 +61,7 @@ from airbyte_cdk.sources.declarative.retrievers import SimpleRetrieverTestReadDecorator from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse +from airbyte_cdk.test.state_builder import StateBuilder from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets, update_secrets from unit_tests.connector_builder.utils import create_configured_catalog @@ -74,6 +75,7 @@ } _page_size = 2 +_NO_STATE = None _A_STATE = [ AirbyteStateMessage( type="STREAM", @@ -140,6 +142,19 @@ "type": "DeclarativeStream", "$parameters": _stream_options, "retriever": "#/definitions/retriever", + "incremental_sync": { + "type": "DatetimeBasedCursor", + "cursor_field": "updated_at", + "datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ", + "start_datetime": { + "datetime": "2025-01-01T00:00:00.000Z", + "datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ" + }, + "start_time_option": { + "field_name": "lastModifiedDateTime", + "inject_into": "request_parameter" + } + }, }, ], "check": {"type": "CheckStream", "stream_names": ["lists"]}, @@ -568,6 +583,33 @@ def test_resolve_manifest(valid_resolve_manifest_config_file): "streams": [ { "type": "DeclarativeStream", + "incremental_sync": { + "type": "DatetimeBasedCursor", + "cursor_field": "updated_at", + "datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ", + "start_datetime": { + "type": "MinMaxDatetime", + "datetime": "2025-01-01T00:00:00.000Z", + "datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ", + "$parameters": _stream_options, + "name": _stream_name, + "url_base": _stream_url_base, + "primary_key": _stream_primary_key, + }, + "start_time_option": { + "type": "RequestOption", + "field_name": "lastModifiedDateTime", + "inject_into": "request_parameter", + "$parameters": _stream_options, + "name": _stream_name, + "primary_key": _stream_primary_key, + "url_base": _stream_url_base, + }, + "name": _stream_name, + "primary_key": _stream_primary_key, + "url_base": _stream_url_base, + "$parameters": _stream_options, + }, "retriever": { "type": "SimpleRetriever", "paginator": { @@ -867,19 +909,20 @@ def test_handle_429_response(): config = TEST_READ_CONFIG limits = TestLimits() - source = create_source(config, limits) + catalog = ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG) + source = create_source(config, limits, catalog) with patch("requests.Session.send", return_value=response) as mock_send: response = handle_connector_builder_request( source, "test_read", config, - ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), + catalog, _A_PER_PARTITION_STATE, limits, ) - mock_send.assert_called_once() + assert mock_send.call_count == limits.max_slices @pytest.mark.parametrize( @@ -1047,8 +1090,8 @@ def _create_429_page_response(response_body): requests.Session, "send", side_effect=( - _create_page_response({"result": [{"id": 0}, {"id": 1}], "_metadata": {"next": "next"}}), - _create_page_response({"result": [{"id": 2}], "_metadata": {"next": "next"}}), + _create_page_response({"result": [{"id": 0, "updated_at": "2025-01-01T00:00:00.000Z"}, {"id": 1, "updated_at": "2025-01-02T00:00:00.000Z"}], "_metadata": {"next": "next"}}), + _create_page_response({"result": [{"id": 2, "updated_at": "2025-01-03T00:00:00.000Z"}], "_metadata": {"next": "next"}}), ) * 10, ) @@ -1081,7 +1124,7 @@ def test_read_source(mock_http_stream): config = {"__injected_declarative_manifest": MANIFEST} - source = create_source(config, limits) + source = create_source(config, limits, catalog, _NO_STATE) output_data = read_stream(source, config, catalog, _A_PER_PARTITION_STATE, limits).record.data slices = output_data["slices"] @@ -1094,6 +1137,7 @@ def test_read_source(mock_http_stream): first_page, second_page = pages[0], pages[1] assert len(first_page["records"]) == _page_size assert len(second_page["records"]) == 1 + assert s["state"] streams = source.streams(config) for s in streams: @@ -1104,8 +1148,8 @@ def test_read_source(mock_http_stream): requests.Session, "send", side_effect=( - _create_page_response({"result": [{"id": 0}, {"id": 1}], "_metadata": {"next": "next"}}), - _create_page_response({"result": [{"id": 2}], "_metadata": {"next": "next"}}), + _create_page_response({"result": [{"id": 0, "updated_at": "2025-01-01T00:00:00.000Z"}, {"id": 1, "updated_at": "2025-01-02T00:00:00.000Z"}], "_metadata": {"next": "next"}}), + _create_page_response({"result": [{"id": 2, "updated_at": "2025-01-03T00:00:00.000Z"}], "_metadata": {"next": "next"}}), ), ) def test_read_source_single_page_single_slice(mock_http_stream): @@ -1244,13 +1288,13 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error pytest.param( "CLOUD", "https://10.0.27.27/tokens/bearer", - "AirbyteTracedException", + "StreamThreadException", id="test_cloud_read_with_private_endpoint", ), pytest.param( "CLOUD", "http://unsecured.protocol/tokens/bearer", - "InvalidSchema", + "Invalid Protocol Scheme", id="test_cloud_read_with_unsecured_endpoint", ), pytest.param( diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index 9d462f330..318ae8640 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -2902,7 +2902,7 @@ def test_use_request_options_provider_for_datetime_based_cursor(): parameters={}, ) - connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) + connector_builder_factory = ModelToComponentFactory() retriever = connector_builder_factory.create_component( model_type=SimpleRetrieverModel, component_definition=simple_retriever_model, @@ -2990,7 +2990,7 @@ def test_do_not_separate_request_options_provider_for_non_datetime_based_cursor( partition_router=list_partition_router, ) - connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) + connector_builder_factory = ModelToComponentFactory() retriever = connector_builder_factory.create_component( model_type=SimpleRetrieverModel, component_definition=simple_retriever_model, @@ -3032,7 +3032,7 @@ def test_use_default_request_options_provider(): }, } - connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) + connector_builder_factory = ModelToComponentFactory() retriever = connector_builder_factory.create_component( model_type=SimpleRetrieverModel, component_definition=simple_retriever_model, diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index 0b5778b7b..adf65fb2c 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -692,27 +692,6 @@ def test_path(test_name, requester_path, paginator_path, expected_path): assert actual_path == expected_path -def test_limit_stream_slices(): - maximum_number_of_slices = 4 - stream_slicer = MagicMock() - stream_slicer.stream_slices.return_value = _generate_slices(maximum_number_of_slices * 2) - retriever = SimpleRetrieverTestReadDecorator( - name="stream_name", - primary_key=primary_key, - requester=MagicMock(), - paginator=MagicMock(), - record_selector=MagicMock(), - stream_slicer=stream_slicer, - maximum_number_of_slices=maximum_number_of_slices, - parameters={}, - config={}, - ) - - truncated_slices = list(retriever.stream_slices()) - - assert truncated_slices == _generate_slices(maximum_number_of_slices) - - @pytest.mark.parametrize( "test_name, first_greater_than_second", [ diff --git a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index e288cdc1b..6f2af46e0 100644 --- a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -95,7 +95,6 @@ def test_stream_is_not_done_initially(self): self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -109,7 +108,6 @@ def test_handle_partition_done_no_other_streams_to_generate_partitions_for(self) self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -135,7 +133,6 @@ def test_handle_last_stream_partition_done(self): self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -171,7 +168,6 @@ def test_handle_partition(self): self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -179,82 +175,12 @@ def test_handle_partition(self): handler.on_partition(self._a_closed_partition) self._thread_pool_manager.submit.assert_called_with( - self._partition_reader.process_partition, self._a_closed_partition + self._partition_reader.process_partition, self._a_closed_partition, self._another_stream.cursor ) assert ( self._a_closed_partition in handler._streams_to_running_partitions[_ANOTHER_STREAM_NAME] ) - def test_handle_partition_emits_log_message_if_it_should_be_logged(self): - stream_instances_to_read_from = [self._stream] - self._slice_logger = Mock(spec=SliceLogger) - self._slice_logger.should_log_slice_message.return_value = True - self._slice_logger.create_slice_log_message.return_value = self._log_message - - handler = ConcurrentReadProcessor( - stream_instances_to_read_from, - self._partition_enqueuer, - self._thread_pool_manager, - self._logger, - self._slice_logger, - self._message_repository, - self._partition_reader, - ) - - handler.on_partition(self._an_open_partition) - - self._thread_pool_manager.submit.assert_called_with( - self._partition_reader.process_partition, self._an_open_partition - ) - self._message_repository.emit_message.assert_called_with(self._log_message) - - assert self._an_open_partition in handler._streams_to_running_partitions[_STREAM_NAME] - - @freezegun.freeze_time("2020-01-01T00:00:00") - def test_handle_on_partition_complete_sentinel_with_messages_from_repository(self): - stream_instances_to_read_from = [self._stream] - partition = Mock(spec=Partition) - log_message = Mock(spec=LogMessage) - partition.to_slice.return_value = log_message - partition.stream_name.return_value = _STREAM_NAME - - handler = ConcurrentReadProcessor( - stream_instances_to_read_from, - self._partition_enqueuer, - self._thread_pool_manager, - self._logger, - self._slice_logger, - self._message_repository, - self._partition_reader, - ) - handler.start_next_partition_generator() - handler.on_partition(partition) - - sentinel = PartitionCompleteSentinel(partition) - - self._message_repository.consume_queue.return_value = [ - AirbyteMessage( - type=MessageType.LOG, - log=AirbyteLogMessage( - level=LogLevel.INFO, message="message emitted from the repository" - ), - ) - ] - - messages = list(handler.on_partition_complete_sentinel(sentinel)) - - expected_messages = [ - AirbyteMessage( - type=MessageType.LOG, - log=AirbyteLogMessage( - level=LogLevel.INFO, message="message emitted from the repository" - ), - ) - ] - assert messages == expected_messages - - self._stream.cursor.close_partition.assert_called_once() - @freezegun.freeze_time("2020-01-01T00:00:00") def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stream_is_done( self, @@ -270,7 +196,6 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -302,55 +227,6 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre ) ] assert messages == expected_messages - self._another_stream.cursor.close_partition.assert_called_once() - - @freezegun.freeze_time("2020-01-01T00:00:00") - def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_message_and_stream_is_incomplete( - self, - ) -> None: - self._a_closed_partition.stream_name.return_value = self._stream.name - self._stream.cursor.close_partition.side_effect = ValueError - - handler = ConcurrentReadProcessor( - [self._stream], - self._partition_enqueuer, - self._thread_pool_manager, - self._logger, - self._slice_logger, - self._message_repository, - self._partition_reader, - ) - handler.start_next_partition_generator() - handler.on_partition(self._a_closed_partition) - list( - handler.on_partition_generation_completed( - PartitionGenerationCompletedSentinel(self._stream) - ) - ) - messages = list( - handler.on_partition_complete_sentinel( - PartitionCompleteSentinel(self._a_closed_partition) - ) - ) - - expected_status_message = AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor( - name=self._stream.name, - ), - status=AirbyteStreamStatus.INCOMPLETE, - ), - emitted_at=1577836800000.0, - ), - ) - assert list(map(lambda message: message.trace.type, messages)) == [ - TraceType.ERROR, - TraceType.STREAM_STATUS, - ] - assert messages[1] == expected_status_message @freezegun.freeze_time("2020-01-01T00:00:00") def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_stream_is_not_done( @@ -367,7 +243,6 @@ def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_s self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -379,7 +254,6 @@ def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_s expected_messages = [] assert messages == expected_messages - self._stream.cursor.close_partition.assert_called_once() @freezegun.freeze_time("2020-01-01T00:00:00") def test_on_record_no_status_message_no_repository_messge(self): @@ -395,7 +269,6 @@ def test_on_record_no_status_message_no_repository_messge(self): self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -441,7 +314,6 @@ def test_on_record_with_repository_messge(self): self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -490,7 +362,6 @@ def test_on_record_emits_status_message_on_first_record_no_repository_message(se self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -541,7 +412,6 @@ def test_on_record_emits_status_message_on_first_record_with_repository_message( self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -594,7 +464,6 @@ def test_on_exception_return_trace_message_and_on_stream_complete_return_stream_ self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -657,7 +526,6 @@ def test_given_underlying_exception_is_traced_exception_on_exception_return_trac self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -718,7 +586,6 @@ def test_given_partition_completion_is_not_success_then_do_not_close_partition(s self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -747,7 +614,6 @@ def test_is_done_is_false_if_there_are_any_instances_to_read_from(self): self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -762,7 +628,6 @@ def test_is_done_is_false_if_there_are_streams_still_generating_partitions(self) self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -779,7 +644,6 @@ def test_is_done_is_false_if_all_partitions_are_not_closed(self): self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -802,7 +666,6 @@ def test_is_done_is_true_if_all_partitions_are_closed_and_no_streams_are_generat self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) @@ -817,7 +680,6 @@ def test_start_next_partition_generator(self): self._partition_enqueuer, self._thread_pool_manager, self._logger, - self._slice_logger, self._message_repository, self._partition_reader, ) diff --git a/unit_tests/sources/streams/concurrent/test_partition_reader.py b/unit_tests/sources/streams/concurrent/test_partition_reader.py index 1910e034d..f632b07b1 100644 --- a/unit_tests/sources/streams/concurrent/test_partition_reader.py +++ b/unit_tests/sources/streams/concurrent/test_partition_reader.py @@ -9,7 +9,8 @@ import pytest from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException -from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader +from airbyte_cdk.sources.streams.concurrent.cursor import Cursor +from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionLogger, PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( PartitionCompleteSentinel, @@ -26,10 +27,11 @@ class PartitionReaderTest(unittest.TestCase): def setUp(self) -> None: self._queue: Queue[QueueItem] = Queue() - self._partition_reader = PartitionReader(self._queue) + self._partition_reader = PartitionReader(self._queue, Mock(spec=PartitionLogger)) # FIXME ensure partition logger is called properly + self._cursor = Mock(spec=Cursor) def test_given_no_records_when_process_partition_then_only_emit_sentinel(self): - self._partition_reader.process_partition(self._a_partition([])) + self._partition_reader.process_partition(self._a_partition([]), self._cursor) while queue_item := self._queue.get(): if not isinstance(queue_item, PartitionCompleteSentinel): @@ -40,7 +42,7 @@ def test_given_read_partition_successful_when_process_partition_then_queue_recor self, ): partition = self._a_partition(_RECORDS) - self._partition_reader.process_partition(partition) + self._partition_reader.process_partition(partition, self._cursor) queue_content = self._consume_queue() @@ -52,7 +54,7 @@ def test_given_exception_when_process_partition_then_queue_records_and_exception partition = Mock() exception = ValueError() partition.read.side_effect = self._read_with_exception(_RECORDS, exception) - self._partition_reader.process_partition(partition) + self._partition_reader.process_partition(partition, self._cursor) queue_content = self._consume_queue() diff --git a/unit_tests/sources/streams/test_stream_read.py b/unit_tests/sources/streams/test_stream_read.py index ebe258ef2..c47f691cc 100644 --- a/unit_tests/sources/streams/test_stream_read.py +++ b/unit_tests/sources/streams/test_stream_read.py @@ -606,7 +606,6 @@ def test_concurrent_incremental_read_two_slices(): Mock(spec=PartitionEnqueuer), Mock(spec=ThreadPoolManager), logger, - slice_logger, message_repository, Mock(spec=PartitionReader), )