Skip to content

Commit 2b07f93

Browse files
brianjlaimaxi297
andauthored
fix(connector-builder): Re-revert "fix: revert remerge concurrent cdk builder change because of flaky test" (#712)
Co-authored-by: maxime.c <[email protected]>
1 parent cd48741 commit 2b07f93

27 files changed

+599
-372
lines changed

airbyte_cdk/connector_builder/connector_builder_handler.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#
44

55

6-
from dataclasses import asdict, dataclass, field
7-
from typing import Any, ClassVar, Dict, List, Mapping
6+
from dataclasses import asdict
7+
from typing import Any, Dict, List, Mapping, Optional
88

99
from airbyte_cdk.connector_builder.test_reader import TestReader
1010
from airbyte_cdk.models import (
@@ -15,45 +15,32 @@
1515
Type,
1616
)
1717
from airbyte_cdk.models import Type as MessageType
18+
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
19+
ConcurrentDeclarativeSource,
20+
TestLimits,
21+
)
1822
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
1923
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
20-
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
21-
ModelToComponentFactory,
22-
)
2324
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
2425
from airbyte_cdk.utils.datetime_helpers import ab_datetime_now
2526
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
2627

27-
DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
28-
DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5
29-
DEFAULT_MAXIMUM_RECORDS = 100
30-
DEFAULT_MAXIMUM_STREAMS = 100
31-
3228
MAX_PAGES_PER_SLICE_KEY = "max_pages_per_slice"
3329
MAX_SLICES_KEY = "max_slices"
3430
MAX_RECORDS_KEY = "max_records"
3531
MAX_STREAMS_KEY = "max_streams"
3632

3733

38-
@dataclass
39-
class TestLimits:
40-
__test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name
41-
42-
max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS)
43-
max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE)
44-
max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES)
45-
max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS)
46-
47-
4834
def get_limits(config: Mapping[str, Any]) -> TestLimits:
4935
command_config = config.get("__test_read_config", {})
50-
max_pages_per_slice = (
51-
command_config.get(MAX_PAGES_PER_SLICE_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE
36+
return TestLimits(
37+
max_records=command_config.get(MAX_RECORDS_KEY, TestLimits.DEFAULT_MAX_RECORDS),
38+
max_pages_per_slice=command_config.get(
39+
MAX_PAGES_PER_SLICE_KEY, TestLimits.DEFAULT_MAX_PAGES_PER_SLICE
40+
),
41+
max_slices=command_config.get(MAX_SLICES_KEY, TestLimits.DEFAULT_MAX_SLICES),
42+
max_streams=command_config.get(MAX_STREAMS_KEY, TestLimits.DEFAULT_MAX_STREAMS),
5243
)
53-
max_slices = command_config.get(MAX_SLICES_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_SLICES
54-
max_records = command_config.get(MAX_RECORDS_KEY) or DEFAULT_MAXIMUM_RECORDS
55-
max_streams = command_config.get(MAX_STREAMS_KEY) or DEFAULT_MAXIMUM_STREAMS
56-
return TestLimits(max_records, max_pages_per_slice, max_slices, max_streams)
5744

5845

5946
def should_migrate_manifest(config: Mapping[str, Any]) -> bool:
@@ -75,21 +62,30 @@ def should_normalize_manifest(config: Mapping[str, Any]) -> bool:
7562
return config.get("__should_normalize", False)
7663

7764

78-
def create_source(config: Mapping[str, Any], limits: TestLimits) -> ManifestDeclarativeSource:
65+
def create_source(
66+
config: Mapping[str, Any],
67+
limits: TestLimits,
68+
catalog: Optional[ConfiguredAirbyteCatalog],
69+
state: Optional[List[AirbyteStateMessage]],
70+
) -> ConcurrentDeclarativeSource[Optional[List[AirbyteStateMessage]]]:
7971
manifest = config["__injected_declarative_manifest"]
80-
return ManifestDeclarativeSource(
72+
73+
# We enforce a concurrency level of 1 so that the stream is processed on a single thread
74+
# to retain ordering for the grouping of the builder message responses.
75+
if "concurrency_level" in manifest:
76+
manifest["concurrency_level"]["default_concurrency"] = 1
77+
else:
78+
manifest["concurrency_level"] = {"type": "ConcurrencyLevel", "default_concurrency": 1}
79+
80+
return ConcurrentDeclarativeSource(
81+
catalog=catalog,
8182
config=config,
82-
emit_connector_builder_messages=True,
83+
state=state,
8384
source_config=manifest,
85+
emit_connector_builder_messages=True,
8486
migrate_manifest=should_migrate_manifest(config),
8587
normalize_manifest=should_normalize_manifest(config),
86-
component_factory=ModelToComponentFactory(
87-
emit_connector_builder_messages=True,
88-
limit_pages_fetched_per_slice=limits.max_pages_per_slice,
89-
limit_slices_fetched=limits.max_slices,
90-
disable_retries=True,
91-
disable_cache=True,
92-
),
88+
limits=limits,
9389
)
9490

9591

airbyte_cdk/connector_builder/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ def handle_connector_builder_request(
9191
def handle_request(args: List[str]) -> str:
9292
command, config, catalog, state = get_config_and_catalog_from_args(args)
9393
limits = get_limits(config)
94-
source = create_source(config, limits)
95-
return orjson.dumps(
94+
source = create_source(config=config, limits=limits, catalog=catalog, state=state)
95+
return orjson.dumps( # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
9696
AirbyteMessageSerializer.dump(
9797
handle_connector_builder_request(source, command, config, catalog, state, limits)
9898
)
99-
).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
99+
).decode()
100100

101101

102102
if __name__ == "__main__":

airbyte_cdk/connector_builder/test_reader/helpers.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
from copy import deepcopy
77
from json import JSONDecodeError
8-
from typing import Any, Dict, List, Mapping, Optional
8+
from typing import Any, Dict, List, Mapping, Optional, Union
99

1010
from airbyte_cdk.connector_builder.models import (
1111
AuxiliaryRequest,
@@ -17,6 +17,8 @@
1717
from airbyte_cdk.models import (
1818
AirbyteLogMessage,
1919
AirbyteMessage,
20+
AirbyteStateBlob,
21+
AirbyteStateMessage,
2022
OrchestratorType,
2123
TraceType,
2224
)
@@ -466,7 +468,7 @@ def handle_current_slice(
466468
return StreamReadSlices(
467469
pages=current_slice_pages,
468470
slice_descriptor=current_slice_descriptor,
469-
state=[latest_state_message] if latest_state_message else [],
471+
state=[convert_state_blob_to_mapping(latest_state_message)] if latest_state_message else [],
470472
auxiliary_requests=auxiliary_requests if auxiliary_requests else [],
471473
)
472474

@@ -718,3 +720,23 @@ def get_auxiliary_request_type(stream: dict, http: dict) -> str: # type: ignore
718720
Determines the type of the auxiliary request based on the stream and HTTP properties.
719721
"""
720722
return "PARENT_STREAM" if stream.get("is_substream", False) else str(http.get("type", None))
723+
724+
725+
def convert_state_blob_to_mapping(
726+
state_message: Union[AirbyteStateMessage, Dict[str, Any]],
727+
) -> Dict[str, Any]:
728+
"""
729+
The AirbyteStreamState stores state as an AirbyteStateBlob which deceivingly is not
730+
a dictionary, but rather a list of kwargs fields. This in turn causes it to not be
731+
properly turned into a dictionary when translating this back into response output
732+
by the connector_builder_handler using asdict()
733+
"""
734+
735+
if isinstance(state_message, AirbyteStateMessage) and state_message.stream:
736+
state_value = state_message.stream.stream_state
737+
if isinstance(state_value, AirbyteStateBlob):
738+
state_value_mapping = {k: v for k, v in state_value.__dict__.items()}
739+
state_message.stream.stream_state = state_value_mapping # type: ignore # we intentionally set this as a Dict so that StreamReadSlices is translated properly in the resulting HTTP response
740+
return state_message # type: ignore # See above, but when this is an AirbyteStateMessage we must convert AirbyteStateBlob to a Dict
741+
else:
742+
return state_message # type: ignore # This is guaranteed to be a Dict since we check isinstance AirbyteStateMessage above

airbyte_cdk/connector_builder/test_reader/message_grouper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_message_groups(
9595
latest_state_message: Optional[Dict[str, Any]] = None
9696
slice_auxiliary_requests: List[AuxiliaryRequest] = []
9797

98-
while records_count < limit and (message := next(messages, None)):
98+
while message := next(messages, None):
9999
json_message = airbyte_message_to_json(message)
100100

101101
if is_page_http_request_for_different_stream(json_message, stream_name):

airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44
import logging
5+
import os
56
from typing import Dict, Iterable, List, Optional, Set
67

78
from airbyte_cdk.exception_handler import generate_failed_streams_error_message
@@ -95,11 +96,14 @@ def on_partition(self, partition: Partition) -> None:
9596
"""
9697
stream_name = partition.stream_name()
9798
self._streams_to_running_partitions[stream_name].add(partition)
99+
cursor = self._stream_name_to_instance[stream_name].cursor
98100
if self._slice_logger.should_log_slice_message(self._logger):
99101
self._message_repository.emit_message(
100102
self._slice_logger.create_slice_log_message(partition.to_slice())
101103
)
102-
self._thread_pool_manager.submit(self._partition_reader.process_partition, partition)
104+
self._thread_pool_manager.submit(
105+
self._partition_reader.process_partition, partition, cursor
106+
)
103107

104108
def on_partition_complete_sentinel(
105109
self, sentinel: PartitionCompleteSentinel
@@ -112,26 +116,16 @@ def on_partition_complete_sentinel(
112116
"""
113117
partition = sentinel.partition
114118

115-
try:
116-
if sentinel.is_successful:
117-
stream = self._stream_name_to_instance[partition.stream_name()]
118-
stream.cursor.close_partition(partition)
119-
except Exception as exception:
120-
self._flag_exception(partition.stream_name(), exception)
121-
yield AirbyteTracedException.from_exception(
122-
exception, stream_descriptor=StreamDescriptor(name=partition.stream_name())
123-
).as_sanitized_airbyte_message()
124-
finally:
125-
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
126-
if partition in partitions_running:
127-
partitions_running.remove(partition)
128-
# If all partitions were generated and this was the last one, the stream is done
129-
if (
130-
partition.stream_name() not in self._streams_currently_generating_partitions
131-
and len(partitions_running) == 0
132-
):
133-
yield from self._on_stream_is_done(partition.stream_name())
134-
yield from self._message_repository.consume_queue()
119+
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
120+
if partition in partitions_running:
121+
partitions_running.remove(partition)
122+
# If all partitions were generated and this was the last one, the stream is done
123+
if (
124+
partition.stream_name() not in self._streams_currently_generating_partitions
125+
and len(partitions_running) == 0
126+
):
127+
yield from self._on_stream_is_done(partition.stream_name())
128+
yield from self._message_repository.consume_queue()
135129

136130
def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
137131
"""
@@ -160,7 +154,6 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
160154
stream.as_airbyte_stream(), AirbyteStreamStatus.RUNNING
161155
)
162156
self._record_counter[stream.name] += 1
163-
stream.cursor.observe(record)
164157
yield message
165158
yield from self._message_repository.consume_queue()
166159

airbyte_cdk/sources/concurrent_source/concurrent_source.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
4+
45
import concurrent
56
import logging
67
from queue import Queue
7-
from typing import Iterable, Iterator, List
8+
from typing import Iterable, Iterator, List, Optional
89

910
from airbyte_cdk.models import AirbyteMessage
1011
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
@@ -16,7 +17,7 @@
1617
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
1718
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
1819
from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer
19-
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader
20+
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionLogger, PartitionReader
2021
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
2122
from airbyte_cdk.sources.streams.concurrent.partitions.types import (
2223
PartitionCompleteSentinel,
@@ -43,6 +44,7 @@ def create(
4344
logger: logging.Logger,
4445
slice_logger: SliceLogger,
4546
message_repository: MessageRepository,
47+
queue: Optional[Queue[QueueItem]] = None,
4648
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
4749
) -> "ConcurrentSource":
4850
is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1
@@ -59,19 +61,21 @@ def create(
5961
logger,
6062
)
6163
return ConcurrentSource(
62-
threadpool,
63-
logger,
64-
slice_logger,
65-
message_repository,
66-
initial_number_of_partitions_to_generate,
67-
timeout_seconds,
64+
threadpool=threadpool,
65+
logger=logger,
66+
slice_logger=slice_logger,
67+
queue=queue,
68+
message_repository=message_repository,
69+
initial_number_partitions_to_generate=initial_number_of_partitions_to_generate,
70+
timeout_seconds=timeout_seconds,
6871
)
6972

7073
def __init__(
7174
self,
7275
threadpool: ThreadPoolManager,
7376
logger: logging.Logger,
7477
slice_logger: SliceLogger = DebugSliceLogger(),
78+
queue: Optional[Queue[QueueItem]] = None,
7579
message_repository: MessageRepository = InMemoryMessageRepository(),
7680
initial_number_partitions_to_generate: int = 1,
7781
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
@@ -91,33 +95,36 @@ def __init__(
9195
self._initial_number_partitions_to_generate = initial_number_partitions_to_generate
9296
self._timeout_seconds = timeout_seconds
9397

98+
# We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less
99+
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
100+
# partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more
101+
# information and might even need to be configurable depending on the source
102+
self._queue = queue or Queue(maxsize=10_000)
103+
94104
def read(
95105
self,
96106
streams: List[AbstractStream],
97107
) -> Iterator[AirbyteMessage]:
98108
self._logger.info("Starting syncing")
99-
100-
# We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less
101-
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
102-
# partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more
103-
# information and might even need to be configurable depending on the source
104-
queue: Queue[QueueItem] = Queue(maxsize=10_000)
105109
concurrent_stream_processor = ConcurrentReadProcessor(
106110
streams,
107-
PartitionEnqueuer(queue, self._threadpool),
111+
PartitionEnqueuer(self._queue, self._threadpool),
108112
self._threadpool,
109113
self._logger,
110114
self._slice_logger,
111115
self._message_repository,
112-
PartitionReader(queue),
116+
PartitionReader(
117+
self._queue,
118+
PartitionLogger(self._slice_logger, self._logger, self._message_repository),
119+
),
113120
)
114121

115122
# Enqueue initial partition generation tasks
116123
yield from self._submit_initial_partition_generators(concurrent_stream_processor)
117124

118125
# Read from the queue until all partitions were generated and read
119126
yield from self._consume_from_queue(
120-
queue,
127+
self._queue,
121128
concurrent_stream_processor,
122129
)
123130
self._threadpool.check_for_errors_and_shutdown()
@@ -141,7 +148,10 @@ def _consume_from_queue(
141148
airbyte_message_or_record_or_exception,
142149
concurrent_stream_processor,
143150
)
144-
if concurrent_stream_processor.is_done() and queue.empty():
151+
# In the event that a partition raises an exception, anything remaining in
152+
# the queue will be missed because is_done() can raise an exception and exit
153+
# out of this loop before remaining items are consumed
154+
if queue.empty() and concurrent_stream_processor.is_done():
145155
# all partitions were generated and processed. we're done here
146156
break
147157

@@ -161,5 +171,7 @@ def _handle_item(
161171
yield from concurrent_stream_processor.on_partition_complete_sentinel(queue_item)
162172
elif isinstance(queue_item, Record):
163173
yield from concurrent_stream_processor.on_record(queue_item)
174+
elif isinstance(queue_item, AirbyteMessage):
175+
yield queue_item
164176
else:
165177
raise ValueError(f"Unknown queue item type: {type(queue_item)}")

0 commit comments

Comments
 (0)