Skip to content

Commit 4514b88

Browse files
committed
make public: get_message_iterator() and get_message_by_types()
1 parent 1d9fc99 commit 4514b88

File tree

3 files changed

+26
-31
lines changed

3 files changed

+26
-31
lines changed

airbyte_cdk/test/entrypoint_wrapper.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
if messages is not None and message_file is not None:
7575
raise ValueError("Only one of messages or message_file can be provided")
7676

77-
self._messages: list[AirbyteMessage] | None = []
77+
self._messages: list[AirbyteMessage] | None = None
7878
self._message_file: Path | None = message_file
7979
if messages:
8080
try:
@@ -106,7 +106,7 @@ def _parse_message(message: str) -> AirbyteMessage:
106106
def records_and_state_messages(
107107
self,
108108
) -> list[AirbyteMessage]:
109-
return self._get_message_by_types(
109+
return self.get_message_by_types(
110110
message_types=[Type.RECORD, Type.STATE],
111111
safe_iterator=False,
112112
)
@@ -119,14 +119,14 @@ def records_and_state_messages_iterator(
119119
Use this instead of `records_and_state_messages` when the volume of messages could be large
120120
enough to overload available memory.
121121
"""
122-
return self._get_message_by_types(
122+
return self.get_message_by_types(
123123
message_types=[Type.RECORD, Type.STATE],
124124
safe_iterator=True,
125125
)
126126

127127
@property
128128
def records(self) -> List[AirbyteMessage]:
129-
return self._get_message_by_types([Type.RECORD])
129+
return self.get_message_by_types([Type.RECORD])
130130

131131
@property
132132
def records_iterator(self) -> Generator[AirbyteMessage, None, None]:
@@ -135,23 +135,23 @@ def records_iterator(self) -> Generator[AirbyteMessage, None, None]:
135135
Use this instead of `records` when the volume of records could be large
136136
enough to overload available memory.
137137
"""
138-
return self._get_message_by_types([Type.RECORD], safe_iterator=True)
138+
return self.get_message_by_types([Type.RECORD], safe_iterator=True)
139139

140140
@property
141141
def state_messages(self) -> List[AirbyteMessage]:
142-
return self._get_message_by_types([Type.STATE])
142+
return self.get_message_by_types([Type.STATE])
143143

144144
@property
145145
def spec_messages(self) -> List[AirbyteMessage]:
146-
return self._get_message_by_types([Type.SPEC])
146+
return self.get_message_by_types([Type.SPEC])
147147

148148
@property
149149
def connection_status_messages(self) -> List[AirbyteMessage]:
150-
return self._get_message_by_types([Type.CONNECTION_STATUS])
150+
return self.get_message_by_types([Type.CONNECTION_STATUS])
151151

152152
@property
153153
def most_recent_state(self) -> AirbyteStreamState | None:
154-
state_message_iterator = self._get_message_by_types(
154+
state_message_iterator = self.get_message_by_types(
155155
[Type.STATE],
156156
safe_iterator=True,
157157
)
@@ -168,11 +168,11 @@ def most_recent_state(self) -> AirbyteStreamState | None:
168168

169169
@property
170170
def logs(self) -> List[AirbyteMessage]:
171-
return self._get_message_by_types([Type.LOG])
171+
return self.get_message_by_types([Type.LOG])
172172

173173
@property
174174
def trace_messages(self) -> List[AirbyteMessage]:
175-
return self._get_message_by_types([Type.TRACE])
175+
return self.get_message_by_types([Type.TRACE])
176176

177177
@property
178178
def analytics_messages(self) -> List[AirbyteMessage]:
@@ -184,7 +184,7 @@ def errors(self) -> List[AirbyteMessage]:
184184

185185
@property
186186
def catalog(self) -> AirbyteMessage:
187-
catalog = self._get_message_by_types([Type.CATALOG])
187+
catalog = self.get_message_by_types([Type.CATALOG])
188188
if len(catalog) != 1:
189189
raise ValueError(f"Expected exactly one catalog but got {len(catalog)}")
190190
return catalog[0]
@@ -199,7 +199,7 @@ def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]:
199199
)
200200
return list(status_messages)
201201

202-
def _read_all_messages(self) -> Generator[AirbyteMessage, None, None]:
202+
def get_message_iterator(self) -> Generator[AirbyteMessage, None, None]:
203203
"""Creates a generator which yields messages one by one.
204204
205205
This will iterate over all messages in the output file (if provided) or the messages
@@ -221,35 +221,35 @@ def _read_all_messages(self) -> Generator[AirbyteMessage, None, None]:
221221
if self._messages is not None:
222222
yield from self._messages
223223

224-
# Overloads to provide proper type hints for different usages of `_get_message_by_types`.
224+
# Overloads to provide proper type hints for different usages of `get_message_by_types`.
225225

226226
@overload
227-
def _get_message_by_types(
227+
def get_message_by_types(
228228
self,
229229
message_types: list[Type],
230230
) -> list[AirbyteMessage]: ...
231231

232232
@overload
233-
def _get_message_by_types(
233+
def get_message_by_types(
234234
self,
235235
message_types: list[Type],
236236
*,
237237
safe_iterator: Literal[False],
238238
) -> list[AirbyteMessage]: ...
239239

240240
@overload
241-
def _get_message_by_types(
241+
def get_message_by_types(
242242
self,
243243
message_types: list[Type],
244244
*,
245245
safe_iterator: Literal[True],
246246
) -> Generator[AirbyteMessage, None, None]: ...
247247

248-
def _get_message_by_types(
248+
def get_message_by_types(
249249
self,
250250
message_types: list[Type],
251251
*,
252-
safe_iterator: bool = True,
252+
safe_iterator: bool = False,
253253
) -> list[AirbyteMessage] | Generator[AirbyteMessage, None, None]:
254254
"""Get messages of specific types.
255255
@@ -259,7 +259,7 @@ def _get_message_by_types(
259259
Use `safe_iterator=True` when the volume of messages could overload the available
260260
memory.
261261
"""
262-
message_generator = self._read_all_messages()
262+
message_generator = self.get_message_iterator()
263263

264264
if safe_iterator:
265265
return (message for message in message_generator if message.type in message_types)
@@ -269,7 +269,7 @@ def _get_message_by_types(
269269
def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[AirbyteMessage]:
270270
return [
271271
message
272-
for message in self._get_message_by_types(
272+
for message in self.get_message_by_types(
273273
[Type.TRACE],
274274
safe_iterator=True,
275275
)

airbyte_cdk/test/standard_tests/connector_base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ def test_check(
112112
test_scenario=scenario,
113113
connector_root=self.get_connector_root_dir(),
114114
)
115-
conn_status_messages: list[AirbyteMessage] = [
116-
msg for msg in result._messages if msg.type == Type.CONNECTION_STATUS
117-
] # noqa: SLF001 # Non-public API
118-
assert len(conn_status_messages) == 1, (
119-
f"Expected exactly one CONNECTION_STATUS message. Got: {result._messages}"
115+
assert len(result.connection_status_messages) == 1, (
116+
f"Expected exactly one CONNECTION_STATUS message. "
117+
"Got: {result.connection_status_messages!s}"
120118
)

airbyte_cdk/test/standard_tests/source_base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,10 @@ def test_check(
4848
test_scenario=scenario,
4949
connector_root=self.get_connector_root_dir(),
5050
)
51-
conn_status_messages: list[AirbyteMessage] = [
52-
msg for msg in result._messages if msg.type == Type.CONNECTION_STATUS
53-
] # noqa: SLF001 # Non-public API
54-
num_status_messages = len(conn_status_messages)
51+
num_status_messages = len(result.connection_status_messages)
5552
assert num_status_messages == 1, (
5653
f"Expected exactly one CONNECTION_STATUS message. Got {num_status_messages}: \n"
57-
+ "\n".join([str(m) for m in result._messages])
54+
+ "\n".join([str(m) for m in result.get_message_iterator()])
5855
)
5956

6057
def test_discover(

0 commit comments

Comments
 (0)