19
19
import re
20
20
import tempfile
21
21
import traceback
22
+ from collections import deque
23
+ from collections .abc import Generator , Mapping
22
24
from io import StringIO
23
25
from pathlib import Path
24
- from typing import Any , List , Mapping , Optional , Union
26
+ from typing import Any , List , Literal , Optional , Union , final , overload
25
27
26
28
import orjson
27
29
from pydantic import ValidationError as V2ValidationError
43
45
TraceType ,
44
46
Type ,
45
47
)
48
+ from airbyte_cdk .models .airbyte_protocol import AirbyteMessage , AirbyteStreamState
46
49
from airbyte_cdk .sources import Source
47
50
from airbyte_cdk .test .models .scenario import ExpectedOutcome
48
51
49
52
50
53
class EntrypointOutput :
51
- def __init__ (self , messages : List [str ], uncaught_exception : Optional [BaseException ] = None ):
52
- try :
53
- self ._messages = [self ._parse_message (message ) for message in messages ]
54
- except V2ValidationError as exception :
55
- raise ValueError ("All messages are expected to be AirbyteMessage" ) from exception
54
+ """A class to encapsulate the output of an Airbyte connector's execution.
55
+
56
+ This class can be initialized with a list of messages or a file containing messages.
57
+ It provides methods to access different types of messages produced during the execution
58
+ of an Airbyte connector, including both successful messages and error messages.
59
+
60
+ When working with records and state messages, it provides both a list and an iterator
61
+ implementation. Lists are easier to work with, but generators are better suited to handle
62
+ large volumes of messages without overflowing the available memory.
63
+ """
64
+
65
+ def __init__ (
66
+ self ,
67
+ messages : list [str ] | None = None ,
68
+ uncaught_exception : Optional [BaseException ] = None ,
69
+ * ,
70
+ message_file : Path | None = None ,
71
+ ) -> None :
72
+ if messages is None and message_file is None :
73
+ raise ValueError ("Either messages or message_file must be provided" )
74
+ if messages is not None and message_file is not None :
75
+ raise ValueError ("Only one of messages or message_file can be provided" )
76
+
77
+ self ._messages : list [AirbyteMessage ] | None = []
78
+ self ._message_file : Path | None = message_file
79
+ if messages :
80
+ try :
81
+ self ._messages = [self ._parse_message (message ) for message in messages ]
82
+ except V2ValidationError as exception :
83
+ raise ValueError ("All messages are expected to be AirbyteMessage" ) from exception
56
84
57
85
if uncaught_exception :
86
+ if self ._messages is None :
87
+ self ._messages = []
88
+
58
89
self ._messages .append (
59
90
assemble_uncaught_exception (
60
91
type (uncaught_exception ), uncaught_exception
@@ -72,13 +103,40 @@ def _parse_message(message: str) -> AirbyteMessage:
72
103
)
73
104
74
105
@property
75
- def records_and_state_messages (self ) -> List [AirbyteMessage ]:
76
- return self ._get_message_by_types ([Type .RECORD , Type .STATE ])
106
+ def records_and_state_messages (
107
+ self ,
108
+ ) -> list [AirbyteMessage ]:
109
+ return self ._get_message_by_types (
110
+ message_types = [Type .RECORD , Type .STATE ],
111
+ safe_iterator = False ,
112
+ )
113
+
114
+ def records_and_state_messages_iterator (
115
+ self ,
116
+ ) -> Generator [AirbyteMessage , None , None ]:
117
+ """Returns a generator that yields record and state messages one by one.
118
+
119
+ Use this instead of `records_and_state_messages` when the volume of messages could be large
120
+ enough to overload available memory.
121
+ """
122
+ return self ._get_message_by_types (
123
+ message_types = [Type .RECORD , Type .STATE ],
124
+ safe_iterator = True ,
125
+ )
77
126
78
127
@property
79
128
def records (self ) -> List [AirbyteMessage ]:
80
129
return self ._get_message_by_types ([Type .RECORD ])
81
130
131
+ @property
132
+ def records_iterator (self ) -> Generator [AirbyteMessage , None , None ]:
133
+ """Returns a generator that yields record messages one by one.
134
+
135
+ Use this instead of `records` when the volume of records could be large
136
+ enough to overload available memory.
137
+ """
138
+ return self ._get_message_by_types ([Type .RECORD ], safe_iterator = True )
139
+
82
140
@property
83
141
def state_messages (self ) -> List [AirbyteMessage ]:
84
142
return self ._get_message_by_types ([Type .STATE ])
@@ -92,11 +150,21 @@ def connection_status_messages(self) -> List[AirbyteMessage]:
92
150
return self ._get_message_by_types ([Type .CONNECTION_STATUS ])
93
151
94
152
@property
95
- def most_recent_state (self ) -> Any :
96
- state_messages = self ._get_message_by_types ([Type .STATE ])
97
- if not state_messages :
98
- raise ValueError ("Can't provide most recent state as there are no state messages" )
99
- return state_messages [- 1 ].state .stream # type: ignore[union-attr] # state has `stream`
153
+ def most_recent_state (self ) -> AirbyteStreamState | None :
154
+ state_message_iterator = self ._get_message_by_types (
155
+ [Type .STATE ],
156
+ safe_iterator = True ,
157
+ )
158
+ # Use a deque with maxlen=1 to efficiently get the last state message
159
+ double_ended_queue = deque (state_message_iterator , maxlen = 1 )
160
+ try :
161
+ final_state_message : AirbyteMessage = double_ended_queue .pop ()
162
+ except IndexError :
163
+ raise ValueError (
164
+ "Can't provide most recent state as there are no state messages."
165
+ ) from None
166
+
167
+ return final_state_message .state .stream # type: ignore[union-attr] # state has `stream`
100
168
101
169
@property
102
170
def logs (self ) -> List [AirbyteMessage ]:
@@ -131,13 +199,80 @@ def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]:
131
199
)
132
200
return list (status_messages )
133
201
134
- def _get_message_by_types (self , message_types : List [Type ]) -> List [AirbyteMessage ]:
135
- return [message for message in self ._messages if message .type in message_types ]
202
+ def _read_all_messages (self ) -> Generator [AirbyteMessage , None , None ]:
203
+ """Creates a generator which yields messages one by one.
204
+
205
+ This will iterate over all messages in the output file (if provided) or the messages
206
+ provided during initialization. File results are provided first, followed by any
207
+ messages that were passed in directly.
208
+ """
209
+ if self ._message_file :
210
+ try :
211
+ with open (self ._message_file , "r" , encoding = "utf-8" ) as file :
212
+ for line in file :
213
+ if not line .strip ():
214
+ # Skip empty lines
215
+ continue
216
+
217
+ yield self ._parse_message (line .strip ())
218
+ except FileNotFoundError :
219
+ raise ValueError (f"Message file { self ._message_file } not found" )
220
+
221
+ if self ._messages is not None :
222
+ yield from self ._messages
223
+
224
+ # Overloads to provide proper type hints for different usages of `_get_message_by_types`.
225
+
226
+ @overload
227
+ def _get_message_by_types (
228
+ self ,
229
+ message_types : list [Type ],
230
+ ) -> list [AirbyteMessage ]: ...
231
+
232
+ @overload
233
+ def _get_message_by_types (
234
+ self ,
235
+ message_types : list [Type ],
236
+ * ,
237
+ safe_iterator : Literal [False ],
238
+ ) -> list [AirbyteMessage ]: ...
239
+
240
+ @overload
241
+ def _get_message_by_types (
242
+ self ,
243
+ message_types : list [Type ],
244
+ * ,
245
+ safe_iterator : Literal [True ],
246
+ ) -> Generator [AirbyteMessage , None , None ]: ...
247
+
248
+ def _get_message_by_types (
249
+ self ,
250
+ message_types : list [Type ],
251
+ * ,
252
+ safe_iterator : bool = True ,
253
+ ) -> list [AirbyteMessage ] | Generator [AirbyteMessage , None , None ]:
254
+ """Get messages of specific types.
255
+
256
+ If `safe_iterator` is True, returns a generator that yields messages one by one.
257
+ If `safe_iterator` is False, returns a list of messages.
258
+
259
+ Use `safe_iterator=True` when the volume of messages could overload the available
260
+ memory.
261
+ """
262
+ message_generator = self ._read_all_messages ()
263
+
264
+ if safe_iterator :
265
+ return (message for message in message_generator if message .type in message_types )
266
+
267
+ return [message for message in message_generator if message .type in message_types ]
136
268
137
269
def _get_trace_message_by_trace_type (self , trace_type : TraceType ) -> List [AirbyteMessage ]:
138
270
return [
139
271
message
140
- for message in self ._get_message_by_types ([Type .TRACE ])
272
+ for message in self ._get_message_by_types (
273
+ [Type .TRACE ],
274
+ safe_iterator = True ,
275
+ )
141
276
if message .trace .type == trace_type # type: ignore[union-attr] # trace has `type`
142
277
]
143
278
0 commit comments