Skip to content

Commit db07eff

Browse files
fix(cdk): determine state from manager if not received a state in per partition router (#544)
1 parent 895756d commit db07eff

File tree

3 files changed

+261
-5
lines changed

3 files changed

+261
-5
lines changed

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,6 +1484,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
14841484
stream_state_migrations=stream_state_migrations,
14851485
)
14861486
)
1487+
14871488
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)
14881489
# Per-partition state doesn't make sense for GroupingPartitionRouter, so force the global state
14891490
use_global_cursor = isinstance(
@@ -1993,14 +1994,19 @@ def _build_incremental_cursor(
19931994
) -> Optional[StreamSlicer]:
19941995
if model.incremental_sync and stream_slicer:
19951996
if model.retriever.type == "AsyncRetriever":
1997+
stream_name = model.name or ""
1998+
stream_namespace = None
1999+
stream_state = self._connector_state_manager.get_stream_state(
2000+
stream_name, stream_namespace
2001+
)
19962002
return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing
19972003
state_manager=self._connector_state_manager,
19982004
model_type=DatetimeBasedCursorModel,
19992005
component_definition=model.incremental_sync.__dict__,
2000-
stream_name=model.name or "",
2001-
stream_namespace=None,
2006+
stream_name=stream_name,
2007+
stream_namespace=stream_namespace,
20022008
config=config or {},
2003-
stream_state={},
2009+
stream_state=stream_state,
20042010
partition_router=stream_slicer,
20052011
)
20062012

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
decoder:
2+
type: JsonDecoder
3+
extractor:
4+
type: DpathExtractor
5+
selector:
6+
type: RecordSelector
7+
record_filter:
8+
type: RecordFilter
9+
condition: "{{ record['id'] > stream_state['id'] }}"
10+
requester:
11+
type: HttpRequester
12+
name: "{{ parameters['name'] }}"
13+
url_base: "https://api.sendgrid.com/v3/"
14+
http_method: "GET"
15+
authenticator:
16+
type: SessionTokenAuthenticator
17+
decoder:
18+
type: JsonDecoder
19+
expiration_duration: P10D
20+
login_requester:
21+
path: /session
22+
type: HttpRequester
23+
url_base: 'https://api.sendgrid.com'
24+
http_method: POST
25+
request_body_json:
26+
password: '{{ config.apikey }}'
27+
username: '{{ parameters.name }}'
28+
session_token_path:
29+
- id
30+
request_authentication:
31+
type: ApiKey
32+
inject_into:
33+
type: RequestOption
34+
field_name: X-Metabase-Session
35+
inject_into: header
36+
request_parameters:
37+
unit: "day"
38+
list_stream:
39+
type: DeclarativeStream
40+
name: lists
41+
schema_loader:
42+
type: JsonFileSchemaLoader
43+
file_path: "./source_sendgrid/schemas/{{ parameters.name }}.json"
44+
incremental_sync:
45+
type: DatetimeBasedCursor
46+
$parameters:
47+
datetime_format: "%Y-%m-%dT%H:%M:%S%z"
48+
start_datetime:
49+
type: MinMaxDatetime
50+
datetime: "{{ config['reports_start_date'] }}"
51+
datetime_format: "%Y-%m-%d"
52+
end_datetime:
53+
type: MinMaxDatetime
54+
datetime: "{{ format_datetime(now_utc(), '%Y-%m-%d') }}"
55+
datetime_format: "%Y-%m-%d"
56+
cursor_field: TimePeriod
57+
cursor_datetime_formats:
58+
- "%Y-%m-%dT%H:%M:%S%z"
59+
retriever:
60+
type: AsyncRetriever
61+
name: "{{ parameters['name'] }}"
62+
decoder:
63+
$ref: "#/decoder"
64+
partition_router:
65+
type: ListPartitionRouter
66+
values: "{{config['repos']}}"
67+
cursor_field: a_key
68+
request_option:
69+
inject_into: header
70+
field_name: a_key
71+
status_mapping:
72+
failed:
73+
- Error
74+
running:
75+
- Pending
76+
completed:
77+
- Success
78+
timeout: [ ]
79+
status_extractor:
80+
type: DpathExtractor
81+
field_path:
82+
- ReportRequestStatus
83+
- Status
84+
download_target_extractor:
85+
type: DpathExtractor
86+
field_path:
87+
- ReportRequestStatus
88+
- ReportDownloadUrl
89+
creation_requester:
90+
type: HttpRequester
91+
url_base: https://reporting.api.bingads.microsoft.com/
92+
path: Reporting/v13/GenerateReport/Submit
93+
http_method: POST
94+
request_headers:
95+
Content-Type: application/json
96+
DeveloperToken: "{{ config['developer_token'] }}"
97+
CustomerId: "'{{ stream_partition['customer_id'] }}'"
98+
CustomerAccountId: "'{{ stream_partition['account_id'] }}'"
99+
request_body_json:
100+
ReportRequest:
101+
ExcludeColumnHeaders: false
102+
polling_requester:
103+
type: HttpRequester
104+
url_base: https://fakerporting.api.bingads.microsoft.com/
105+
path: Reporting/v13/GenerateReport/Poll
106+
http_method: POST
107+
request_headers:
108+
Content-Type: application/json
109+
DeveloperToken: "{{ config['developer_token'] }}"
110+
CustomerId: "'{{ stream_partition['customer_id'] }}'"
111+
CustomerAccountId: "'{{ stream_partition['account_id'] }}'"
112+
request_body_json:
113+
ReportRequestId: "'{{ creation_response['ReportRequestId'] }}'"
114+
download_requester:
115+
type: HttpRequester
116+
url_base: "{{download_target}}"
117+
http_method: GET
118+
paginator:
119+
type: DefaultPaginator
120+
page_size_option:
121+
inject_into: request_parameter
122+
field_name: page_size
123+
page_token_option:
124+
inject_into: path
125+
type: RequestPath
126+
pagination_strategy:
127+
type: "CursorPagination"
128+
cursor_value: "{{ response._metadata.next }}"
129+
page_size: 10
130+
requester:
131+
$ref: "#/requester"
132+
path: "{{ next_page_token['next_page_url'] }}"
133+
record_selector:
134+
$ref: "#/selector"
135+
$parameters:
136+
name: "lists"
137+
primary_key: "id"
138+
extractor:
139+
$ref: "#/extractor"
140+
field_path: ["{{ parameters['name'] }}"]

unit_tests/sources/declarative/parsers/test_model_to_component_factory.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
#
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
4+
from copy import deepcopy
45

56
# mypy: ignore-errors
67
from datetime import datetime, timedelta, timezone
7-
from typing import Any, Iterable, Mapping
8+
from pathlib import Path
9+
from typing import Any, Iterable, Mapping, Optional, Union
810

911
import freezegun
1012
import pytest
1113
import requests
14+
from freezegun.api import FakeDatetime
1215
from pydantic.v1 import ValidationError
1316

1417
from airbyte_cdk import AirbyteTracedException
@@ -42,6 +45,7 @@
4245
ClientSideIncrementalRecordFilterDecorator,
4346
)
4447
from airbyte_cdk.sources.declarative.incremental import (
48+
ConcurrentPerPartitionCursor,
4549
CursorFactory,
4650
DatetimeBasedCursor,
4751
PerPartitionCursor,
@@ -166,7 +170,7 @@
166170
MonthClampingStrategy,
167171
WeekClampingStrategy,
168172
)
169-
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor
173+
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField
170174
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import (
171175
CustomFormatConcurrentStreamStateConverter,
172176
)
@@ -190,6 +194,21 @@
190194
input_config = {"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}
191195

192196

197+
def get_factory_with_parameters(
198+
connector_state_manager: Optional[ConnectorStateManager] = None,
199+
) -> ModelToComponentFactory:
200+
return ModelToComponentFactory(
201+
connector_state_manager=connector_state_manager,
202+
)
203+
204+
205+
def read_yaml_file(resource_path: Union[str, Path]) -> str:
206+
yaml_path = Path(__file__).parent / resource_path
207+
with open(yaml_path, "r") as file:
208+
content = file.read()
209+
return content
210+
211+
193212
def test_create_check_stream():
194213
manifest = {"check": {"type": "CheckStream", "stream_names": ["list_stream"]}}
195214

@@ -925,6 +944,97 @@ def test_stream_with_incremental_and_retriever_with_partition_router():
925944
assert list_stream_slicer._cursor_field.string == "a_key"
926945

927946

947+
@freezegun.freeze_time("2025-05-14")
948+
def test_stream_with_incremental_and_async_retriever_with_partition_router():
949+
content = read_yaml_file(
950+
"resources/stream_with_incremental_and_aync_retriever_with_partition_router.yaml"
951+
)
952+
parsed_manifest = YamlDeclarativeSource._parse(content)
953+
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
954+
stream_manifest = transformer.propagate_types_and_parameters(
955+
"", resolved_manifest["list_stream"], {}
956+
)
957+
cursor_time_period_value = "2025-05-06T12:00:00+0000"
958+
cursor_field_key = "TimePeriod"
959+
parent_user_id = "102023653"
960+
per_partition_key = {
961+
"account_id": 999999999,
962+
"parent_slice": {"parent_slice": {}, "user_id": parent_user_id},
963+
}
964+
stream_state = {
965+
"use_global_cursor": False,
966+
"states": [
967+
{"partition": per_partition_key, "cursor": {cursor_field_key: cursor_time_period_value}}
968+
],
969+
"state": {cursor_field_key: "2025-05-12T12:00:00+0000"},
970+
"lookback_window": 46,
971+
}
972+
connector_state_manager = ConnectorStateManager(
973+
state=[
974+
AirbyteStateMessage(
975+
type=AirbyteStateType.STREAM,
976+
stream=AirbyteStreamState(
977+
stream_descriptor=StreamDescriptor(name="lists"),
978+
stream_state=AirbyteStateBlob(stream_state),
979+
),
980+
)
981+
]
982+
)
983+
984+
factory_with_parameters = get_factory_with_parameters(
985+
connector_state_manager=connector_state_manager
986+
)
987+
connector_config = deepcopy(input_config)
988+
connector_config["reports_start_date"] = "2025-01-01"
989+
stream = factory_with_parameters.create_component(
990+
model_type=DeclarativeStreamModel,
991+
component_definition=stream_manifest,
992+
config=connector_config,
993+
)
994+
995+
assert isinstance(stream, DeclarativeStream)
996+
assert isinstance(stream.retriever, AsyncRetriever)
997+
stream_slicer = stream.retriever.stream_slicer.stream_slicer
998+
assert isinstance(stream_slicer, ConcurrentPerPartitionCursor)
999+
assert stream_slicer.state == stream_state
1000+
import json
1001+
1002+
cursor_perpartition = stream_slicer._cursor_per_partition
1003+
expected_cursor_perpartition_key = json.dumps(per_partition_key, sort_keys=True).replace(
1004+
" ", ""
1005+
)
1006+
assert (
1007+
cursor_perpartition[expected_cursor_perpartition_key].cursor_field.cursor_field_key
1008+
== cursor_field_key
1009+
)
1010+
assert cursor_perpartition[expected_cursor_perpartition_key].start == datetime(
1011+
2025, 5, 6, 12, 0, tzinfo=timezone.utc
1012+
)
1013+
assert (
1014+
cursor_perpartition[expected_cursor_perpartition_key].state[cursor_field_key]
1015+
== cursor_time_period_value
1016+
)
1017+
1018+
concurrent_cursor = cursor_perpartition[expected_cursor_perpartition_key]
1019+
assert concurrent_cursor._concurrent_state == {
1020+
"legacy": {cursor_field_key: cursor_time_period_value},
1021+
"slices": [
1022+
{
1023+
"end": FakeDatetime(2025, 5, 6, 12, 0, tzinfo=timezone.utc),
1024+
"most_recent_cursor_value": FakeDatetime(2025, 5, 6, 12, 0, tzinfo=timezone.utc),
1025+
"start": FakeDatetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc),
1026+
}
1027+
],
1028+
"state_type": "date-range",
1029+
}
1030+
1031+
stream_slices = list(concurrent_cursor.stream_slices())
1032+
expected_stream_slices = [
1033+
{"start_time": cursor_time_period_value, "end_time": "2025-05-14T00:00:00+0000"}
1034+
]
1035+
assert stream_slices == expected_stream_slices
1036+
1037+
9281038
def test_resumable_full_refresh_stream():
9291039
content = """
9301040
decoder:

0 commit comments

Comments
 (0)