Skip to content

Commit abc8978

Browse files
fix(extra_fields): Fix issue with substream partition router picking extra fields (#727)
1 parent eca065f commit abc8978

File tree

4 files changed

+165
-83
lines changed

4 files changed

+165
-83
lines changed

airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
149149
for stream_slice_tuple in product:
150150
partition = dict(ChainMap(*[s.partition for s in stream_slice_tuple])) # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
151151
cursor_slices = [s.cursor_slice for s in stream_slice_tuple if s.cursor_slice]
152+
extra_fields = dict(ChainMap(*[s.extra_fields for s in stream_slice_tuple])) # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
152153
if len(cursor_slices) > 1:
153154
raise ValueError(
154155
f"There should only be a single cursor slice. Found {cursor_slices}"
@@ -157,7 +158,9 @@ def stream_slices(self) -> Iterable[StreamSlice]:
157158
cursor_slice = cursor_slices[0]
158159
else:
159160
cursor_slice = {}
160-
yield StreamSlice(partition=partition, cursor_slice=cursor_slice)
161+
yield StreamSlice(
162+
partition=partition, cursor_slice=cursor_slice, extra_fields=extra_fields
163+
)
161164

162165
def set_initial_state(self, stream_state: StreamState) -> None:
163166
"""
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#
2+
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
from typing import Any, Iterable, List, Mapping, Optional, Union
6+
7+
from airbyte_cdk.models import SyncMode
8+
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
9+
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
10+
from airbyte_cdk.sources.streams.checkpoint import Cursor
11+
from airbyte_cdk.sources.types import Record, StreamSlice
12+
13+
14+
class MockStream(DeclarativeStream):
15+
def __init__(self, slices, records, name, cursor_field="", cursor=None):
16+
self.config = {}
17+
self._slices = slices
18+
self._records = records
19+
self._stream_cursor_field = (
20+
InterpolatedString.create(cursor_field, parameters={})
21+
if isinstance(cursor_field, str)
22+
else cursor_field
23+
)
24+
self._name = name
25+
self._state = {"states": []}
26+
self._cursor = cursor
27+
28+
@property
29+
def name(self) -> str:
30+
return self._name
31+
32+
@property
33+
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
34+
return "id"
35+
36+
@property
37+
def state(self) -> Mapping[str, Any]:
38+
return self._state
39+
40+
@state.setter
41+
def state(self, value: Mapping[str, Any]) -> None:
42+
self._state = value
43+
44+
@property
45+
def is_resumable(self) -> bool:
46+
return bool(self._cursor)
47+
48+
def get_cursor(self) -> Optional[Cursor]:
49+
return self._cursor
50+
51+
def stream_slices(
52+
self,
53+
*,
54+
sync_mode: SyncMode,
55+
cursor_field: List[str] = None,
56+
stream_state: Mapping[str, Any] = None,
57+
) -> Iterable[Optional[StreamSlice]]:
58+
for s in self._slices:
59+
if isinstance(s, StreamSlice):
60+
yield s
61+
else:
62+
yield StreamSlice(partition=s, cursor_slice={})
63+
64+
def read_records(
65+
self,
66+
sync_mode: SyncMode,
67+
cursor_field: List[str] = None,
68+
stream_slice: Mapping[str, Any] = None,
69+
stream_state: Mapping[str, Any] = None,
70+
) -> Iterable[Mapping[str, Any]]:
71+
# The parent stream's records should always be read as full refresh
72+
assert sync_mode == SyncMode.full_refresh
73+
74+
if not stream_slice:
75+
result = self._records
76+
else:
77+
result = [
78+
Record(data=r, associated_slice=stream_slice, stream_name=self.name)
79+
for r in self._records
80+
if r["slice"] == stream_slice["slice"]
81+
]
82+
83+
yield from result
84+
85+
# Update the state only after reading the full slice
86+
cursor_field = self._stream_cursor_field.eval(config=self.config)
87+
if stream_slice and cursor_field and result:
88+
self._state["states"].append(
89+
{cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]}
90+
)
91+
92+
def get_json_schema(self) -> Mapping[str, Any]:
93+
return {}

unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@
1111
CartesianProductStreamSlicer,
1212
ListPartitionRouter,
1313
)
14+
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import (
15+
ParentStreamConfig,
16+
SubstreamPartitionRouter,
17+
)
1418
from airbyte_cdk.sources.declarative.requesters.request_option import (
1519
RequestOption,
1620
RequestOptionType,
1721
)
1822
from airbyte_cdk.sources.types import StreamSlice
23+
from unit_tests.sources.declarative.partition_routers.helpers import MockStream
1924

2025

2126
@pytest.mark.parametrize(
@@ -171,6 +176,68 @@ def test_substream_slicer(test_name, stream_slicers, expected_slices):
171176
assert slices == expected_slices
172177

173178

179+
@pytest.mark.parametrize(
180+
"test_name, stream_slicers, expected_slices",
181+
[
182+
(
183+
"test_single_stream_slicer",
184+
[
185+
SubstreamPartitionRouter(
186+
parent_stream_configs=[
187+
ParentStreamConfig(
188+
stream=MockStream(
189+
[{}],
190+
[
191+
{"a": {"b": 0}, "extra_field_key": "extra_field_value_0"},
192+
{"a": {"b": 1}, "extra_field_key": "extra_field_value_1"},
193+
{"a": {"c": 2}, "extra_field_key": "extra_field_value_2"},
194+
{"a": {"b": 3}, "extra_field_key": "extra_field_value_3"},
195+
],
196+
"first_stream",
197+
),
198+
parent_key="a/b",
199+
partition_field="first_stream_id",
200+
parameters={},
201+
config={},
202+
extra_fields=[["extra_field_key"]],
203+
)
204+
],
205+
parameters={},
206+
config={},
207+
),
208+
],
209+
[
210+
StreamSlice(
211+
partition={"first_stream_id": 0, "parent_slice": {}},
212+
cursor_slice={},
213+
extra_fields={"extra_field_key": "extra_field_value_0"},
214+
),
215+
StreamSlice(
216+
partition={"first_stream_id": 1, "parent_slice": {}},
217+
cursor_slice={},
218+
extra_fields={"extra_field_key": "extra_field_value_1"},
219+
),
220+
StreamSlice(
221+
partition={"first_stream_id": 3, "parent_slice": {}},
222+
cursor_slice={},
223+
extra_fields={"extra_field_key": "extra_field_value_3"},
224+
),
225+
],
226+
)
227+
],
228+
)
229+
def test_substream_slicer_with_extra_fields(test_name, stream_slicers, expected_slices):
230+
slicer = CartesianProductStreamSlicer(stream_slicers=stream_slicers, parameters={})
231+
slices = [s for s in slicer.stream_slices()]
232+
partitions = [s.partition for s in slices]
233+
expected_partitions = [s.partition for s in expected_slices]
234+
assert partitions == expected_partitions
235+
236+
extra_fields = [s.extra_fields for s in slices]
237+
expected_extra_fields = [s.extra_fields for s in expected_slices]
238+
assert extra_fields == expected_extra_fields
239+
240+
174241
def test_stream_slices_raises_exception_if_multiple_cursor_slice_components():
175242
stream_slicers = [
176243
DatetimeBasedCursor(

unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from airbyte_cdk.sources.streams.checkpoint import Cursor
3535
from airbyte_cdk.sources.types import Record, StreamSlice
3636
from airbyte_cdk.utils import AirbyteTracedException
37+
from unit_tests.sources.declarative.partition_routers.helpers import MockStream
3738

3839
parent_records = [{"id": 1, "data": "data1"}, {"id": 2, "data": "data2"}]
3940
more_records = [
@@ -63,88 +64,6 @@
6364
)
6465

6566

66-
class MockStream(DeclarativeStream):
67-
def __init__(self, slices, records, name, cursor_field="", cursor=None):
68-
self.config = {}
69-
self._slices = slices
70-
self._records = records
71-
self._stream_cursor_field = (
72-
InterpolatedString.create(cursor_field, parameters={})
73-
if isinstance(cursor_field, str)
74-
else cursor_field
75-
)
76-
self._name = name
77-
self._state = {"states": []}
78-
self._cursor = cursor
79-
80-
@property
81-
def name(self) -> str:
82-
return self._name
83-
84-
@property
85-
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
86-
return "id"
87-
88-
@property
89-
def state(self) -> Mapping[str, Any]:
90-
return self._state
91-
92-
@state.setter
93-
def state(self, value: Mapping[str, Any]) -> None:
94-
self._state = value
95-
96-
@property
97-
def is_resumable(self) -> bool:
98-
return bool(self._cursor)
99-
100-
def get_cursor(self) -> Optional[Cursor]:
101-
return self._cursor
102-
103-
def stream_slices(
104-
self,
105-
*,
106-
sync_mode: SyncMode,
107-
cursor_field: List[str] = None,
108-
stream_state: Mapping[str, Any] = None,
109-
) -> Iterable[Optional[StreamSlice]]:
110-
for s in self._slices:
111-
if isinstance(s, StreamSlice):
112-
yield s
113-
else:
114-
yield StreamSlice(partition=s, cursor_slice={})
115-
116-
def read_records(
117-
self,
118-
sync_mode: SyncMode,
119-
cursor_field: List[str] = None,
120-
stream_slice: Mapping[str, Any] = None,
121-
stream_state: Mapping[str, Any] = None,
122-
) -> Iterable[Mapping[str, Any]]:
123-
# The parent stream's records should always be read as full refresh
124-
assert sync_mode == SyncMode.full_refresh
125-
126-
if not stream_slice:
127-
result = self._records
128-
else:
129-
result = [
130-
Record(data=r, associated_slice=stream_slice, stream_name=self.name)
131-
for r in self._records
132-
if r["slice"] == stream_slice["slice"]
133-
]
134-
135-
yield from result
136-
137-
# Update the state only after reading the full slice
138-
cursor_field = self._stream_cursor_field.eval(config=self.config)
139-
if stream_slice and cursor_field and result:
140-
self._state["states"].append(
141-
{cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]}
142-
)
143-
144-
def get_json_schema(self) -> Mapping[str, Any]:
145-
return {}
146-
147-
14867
class MockIncrementalStream(MockStream):
14968
def __init__(self, slices, records, name, cursor_field="", cursor=None, date_ranges=None):
15069
super().__init__(slices, records, name, cursor_field, cursor)

0 commit comments

Comments
 (0)