Skip to content

Commit b13f37b

Browse files
committed
cherry-pick-me: add safe entrypoint methods, add file handling for the EntrypointOutput class
1 parent 08a1c22 commit b13f37b

File tree

1 file changed

+151
-16
lines changed

1 file changed

+151
-16
lines changed

airbyte_cdk/test/entrypoint_wrapper.py

Lines changed: 151 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import re
2020
import tempfile
2121
import traceback
22+
from collections import deque
23+
from collections.abc import Generator, Mapping
2224
from io import StringIO
2325
from pathlib import Path
24-
from typing import Any, List, Mapping, Optional, Union
26+
from typing import Any, List, Literal, Optional, Union, final, overload
2527

2628
import orjson
2729
from pydantic import ValidationError as V2ValidationError
@@ -43,18 +45,47 @@
4345
TraceType,
4446
Type,
4547
)
48+
from airbyte_cdk.models.airbyte_protocol import AirbyteMessage, AirbyteStreamState
4649
from airbyte_cdk.sources import Source
4750
from airbyte_cdk.test.models.scenario import ExpectedOutcome
4851

4952

5053
class EntrypointOutput:
51-
def __init__(self, messages: List[str], uncaught_exception: Optional[BaseException] = None):
52-
try:
53-
self._messages = [self._parse_message(message) for message in messages]
54-
except V2ValidationError as exception:
55-
raise ValueError("All messages are expected to be AirbyteMessage") from exception
54+
"""A class to encapsulate the output of an Airbyte connector's execution.
55+
56+
This class can be initialized with a list of messages or a file containing messages.
57+
It provides methods to access different types of messages produced during the execution
58+
of an Airbyte connector, including both successful messages and error messages.
59+
60+
When working with records and state messages, it provides both a list and an iterator
61+
implementation. Lists are easier to work with, but generators are better suited to handle
62+
large volumes of messages without overflowing the available memory.
63+
"""
64+
65+
def __init__(
66+
self,
67+
messages: list[str] | None = None,
68+
uncaught_exception: Optional[BaseException] = None,
69+
*,
70+
message_file: Path | None = None,
71+
) -> None:
72+
if messages is None and message_file is None:
73+
raise ValueError("Either messages or message_file must be provided")
74+
if messages is not None and message_file is not None:
75+
raise ValueError("Only one of messages or message_file can be provided")
76+
77+
self._messages: list[AirbyteMessage] | None = []
78+
self._message_file: Path | None = message_file
79+
if messages:
80+
try:
81+
self._messages = [self._parse_message(message) for message in messages]
82+
except V2ValidationError as exception:
83+
raise ValueError("All messages are expected to be AirbyteMessage") from exception
5684

5785
if uncaught_exception:
86+
if self._messages is None:
87+
self._messages = []
88+
5889
self._messages.append(
5990
assemble_uncaught_exception(
6091
type(uncaught_exception), uncaught_exception
@@ -72,13 +103,40 @@ def _parse_message(message: str) -> AirbyteMessage:
72103
)
73104

74105
@property
75-
def records_and_state_messages(self) -> List[AirbyteMessage]:
76-
return self._get_message_by_types([Type.RECORD, Type.STATE])
106+
def records_and_state_messages(
107+
self,
108+
) -> list[AirbyteMessage]:
109+
return self._get_message_by_types(
110+
message_types=[Type.RECORD, Type.STATE],
111+
safe_iterator=False,
112+
)
113+
114+
def records_and_state_messages_iterator(
115+
self,
116+
) -> Generator[AirbyteMessage, None, None]:
117+
"""Returns a generator that yields record and state messages one by one.
118+
119+
Use this instead of `records_and_state_messages` when the volume of messages could be large
120+
enough to overload available memory.
121+
"""
122+
return self._get_message_by_types(
123+
message_types=[Type.RECORD, Type.STATE],
124+
safe_iterator=True,
125+
)
77126

78127
@property
79128
def records(self) -> List[AirbyteMessage]:
80129
return self._get_message_by_types([Type.RECORD])
81130

131+
@property
132+
def records_iterator(self) -> Generator[AirbyteMessage, None, None]:
133+
"""Returns a generator that yields record messages one by one.
134+
135+
Use this instead of `records` when the volume of records could be large
136+
enough to overload available memory.
137+
"""
138+
return self._get_message_by_types([Type.RECORD], safe_iterator=True)
139+
82140
@property
83141
def state_messages(self) -> List[AirbyteMessage]:
84142
return self._get_message_by_types([Type.STATE])
@@ -92,11 +150,21 @@ def connection_status_messages(self) -> List[AirbyteMessage]:
92150
return self._get_message_by_types([Type.CONNECTION_STATUS])
93151

94152
@property
95-
def most_recent_state(self) -> Any:
96-
state_messages = self._get_message_by_types([Type.STATE])
97-
if not state_messages:
98-
raise ValueError("Can't provide most recent state as there are no state messages")
99-
return state_messages[-1].state.stream # type: ignore[union-attr] # state has `stream`
153+
def most_recent_state(self) -> AirbyteStreamState | None:
154+
state_message_iterator = self._get_message_by_types(
155+
[Type.STATE],
156+
safe_iterator=True,
157+
)
158+
# Use a deque with maxlen=1 to efficiently get the last state message
159+
double_ended_queue = deque(state_message_iterator, maxlen=1)
160+
try:
161+
final_state_message: AirbyteMessage = double_ended_queue.pop()
162+
except IndexError:
163+
raise ValueError(
164+
"Can't provide most recent state as there are no state messages."
165+
) from None
166+
167+
return final_state_message.state.stream # type: ignore[union-attr] # state has `stream`
100168

101169
@property
102170
def logs(self) -> List[AirbyteMessage]:
@@ -131,13 +199,80 @@ def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]:
131199
)
132200
return list(status_messages)
133201

134-
def _get_message_by_types(self, message_types: List[Type]) -> List[AirbyteMessage]:
135-
return [message for message in self._messages if message.type in message_types]
202+
def _read_all_messages(self) -> Generator[AirbyteMessage, None, None]:
203+
"""Creates a generator which yields messages one by one.
204+
205+
This will iterate over all messages in the output file (if provided) or the messages
206+
provided during initialization. File results are provided first, followed by any
207+
messages that were passed in directly.
208+
"""
209+
if self._message_file:
210+
try:
211+
with open(self._message_file, "r", encoding="utf-8") as file:
212+
for line in file:
213+
if not line.strip():
214+
# Skip empty lines
215+
continue
216+
217+
yield self._parse_message(line.strip())
218+
except FileNotFoundError:
219+
raise ValueError(f"Message file {self._message_file} not found")
220+
221+
if self._messages is not None:
222+
yield from self._messages
223+
224+
# Overloads to provide proper type hints for different usages of `_get_message_by_types`.
225+
226+
@overload
227+
def _get_message_by_types(
228+
self,
229+
message_types: list[Type],
230+
) -> list[AirbyteMessage]: ...
231+
232+
@overload
233+
def _get_message_by_types(
234+
self,
235+
message_types: list[Type],
236+
*,
237+
safe_iterator: Literal[False],
238+
) -> list[AirbyteMessage]: ...
239+
240+
@overload
241+
def _get_message_by_types(
242+
self,
243+
message_types: list[Type],
244+
*,
245+
safe_iterator: Literal[True],
246+
) -> Generator[AirbyteMessage, None, None]: ...
247+
248+
def _get_message_by_types(
249+
self,
250+
message_types: list[Type],
251+
*,
252+
safe_iterator: bool = True,
253+
) -> list[AirbyteMessage] | Generator[AirbyteMessage, None, None]:
254+
"""Get messages of specific types.
255+
256+
If `safe_iterator` is True, returns a generator that yields messages one by one.
257+
If `safe_iterator` is False, returns a list of messages.
258+
259+
Use `safe_iterator=True` when the volume of messages could overload the available
260+
memory.
261+
"""
262+
message_generator = self._read_all_messages()
263+
264+
if safe_iterator:
265+
return (message for message in message_generator if message.type in message_types)
266+
267+
return [message for message in message_generator if message.type in message_types]
136268

137269
def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[AirbyteMessage]:
138270
return [
139271
message
140-
for message in self._get_message_by_types([Type.TRACE])
272+
for message in self._get_message_by_types(
273+
[Type.TRACE],
274+
safe_iterator=True,
275+
)
141276
if message.trace.type == trace_type # type: ignore[union-attr] # trace has `type`
142277
]
143278

0 commit comments

Comments
 (0)