Skip to content

Commit af372ec

Browse files
authored
Merge branch 'main' into patch-1
2 parents bc845fc + d2016c6 commit af372ec

35 files changed

+1756
-134
lines changed

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
3939
token_expiry_date_format str: format of the datetime; provide it if expires_in is returned in datetime instead of seconds
4040
token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration
4141
refresh_request_body (Optional[Mapping[str, Any]]): The request body to send in the refresh request
42+
refresh_request_headers (Optional[Mapping[str, Any]]): The request headers to send in the refresh request
4243
grant_type: The grant_type to request for access_token. If set to refresh_token, the refresh_token parameter has to be provided
4344
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
4445
"""
@@ -56,8 +57,13 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
5657
token_expiry_is_time_of_expiration: bool = False
5758
access_token_name: Union[InterpolatedString, str] = "access_token"
5859
access_token_value: Optional[Union[InterpolatedString, str]] = None
60+
client_id_name: Union[InterpolatedString, str] = "client_id"
61+
client_secret_name: Union[InterpolatedString, str] = "client_secret"
5962
expires_in_name: Union[InterpolatedString, str] = "expires_in"
63+
refresh_token_name: Union[InterpolatedString, str] = "refresh_token"
6064
refresh_request_body: Optional[Mapping[str, Any]] = None
65+
refresh_request_headers: Optional[Mapping[str, Any]] = None
66+
grant_type_name: Union[InterpolatedString, str] = "grant_type"
6167
grant_type: Union[InterpolatedString, str] = "refresh_token"
6268
message_repository: MessageRepository = NoopMessageRepository()
6369

@@ -69,8 +75,15 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
6975
)
7076
else:
7177
self._token_refresh_endpoint = None
78+
self._client_id_name = InterpolatedString.create(self.client_id_name, parameters=parameters)
7279
self._client_id = InterpolatedString.create(self.client_id, parameters=parameters)
80+
self._client_secret_name = InterpolatedString.create(
81+
self.client_secret_name, parameters=parameters
82+
)
7383
self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters)
84+
self._refresh_token_name = InterpolatedString.create(
85+
self.refresh_token_name, parameters=parameters
86+
)
7487
if self.refresh_token is not None:
7588
self._refresh_token: Optional[InterpolatedString] = InterpolatedString.create(
7689
self.refresh_token, parameters=parameters
@@ -83,10 +96,16 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
8396
self.expires_in_name = InterpolatedString.create(
8497
self.expires_in_name, parameters=parameters
8598
)
99+
self.grant_type_name = InterpolatedString.create(
100+
self.grant_type_name, parameters=parameters
101+
)
86102
self.grant_type = InterpolatedString.create(self.grant_type, parameters=parameters)
87103
self._refresh_request_body = InterpolatedMapping(
88104
self.refresh_request_body or {}, parameters=parameters
89105
)
106+
self._refresh_request_headers = InterpolatedMapping(
107+
self.refresh_request_headers or {}, parameters=parameters
108+
)
90109
self._token_expiry_date: pendulum.DateTime = (
91110
pendulum.parse(
92111
InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(
@@ -122,18 +141,27 @@ def get_token_refresh_endpoint(self) -> Optional[str]:
122141
return refresh_token_endpoint
123142
return None
124143

144+
def get_client_id_name(self) -> str:
145+
return self._client_id_name.eval(self.config) # type: ignore # eval returns a string in this context
146+
125147
def get_client_id(self) -> str:
126148
client_id: str = self._client_id.eval(self.config)
127149
if not client_id:
128150
raise ValueError("OAuthAuthenticator was unable to evaluate client_id parameter")
129151
return client_id
130152

153+
def get_client_secret_name(self) -> str:
154+
return self._client_secret_name.eval(self.config) # type: ignore # eval returns a string in this context
155+
131156
def get_client_secret(self) -> str:
132157
client_secret: str = self._client_secret.eval(self.config)
133158
if not client_secret:
134159
raise ValueError("OAuthAuthenticator was unable to evaluate client_secret parameter")
135160
return client_secret
136161

162+
def get_refresh_token_name(self) -> str:
163+
return self._refresh_token_name.eval(self.config) # type: ignore # eval returns a string in this context
164+
137165
def get_refresh_token(self) -> Optional[str]:
138166
return None if self._refresh_token is None else str(self._refresh_token.eval(self.config))
139167

@@ -146,12 +174,18 @@ def get_access_token_name(self) -> str:
146174
def get_expires_in_name(self) -> str:
147175
return self.expires_in_name.eval(self.config) # type: ignore # eval returns a string in this context
148176

177+
def get_grant_type_name(self) -> str:
178+
return self.grant_type_name.eval(self.config) # type: ignore # eval returns a string in this context
179+
149180
def get_grant_type(self) -> str:
150181
return self.grant_type.eval(self.config) # type: ignore # eval returns a string in this context
151182

152183
def get_refresh_request_body(self) -> Mapping[str, Any]:
153184
return self._refresh_request_body.eval(self.config)
154185

186+
def get_refresh_request_headers(self) -> Mapping[str, Any]:
187+
return self._refresh_request_headers.eval(self.config)
188+
155189
def get_token_expiry_date(self) -> pendulum.DateTime:
156190
return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks
157191

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
#
2-
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
2+
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
33
#
44

5+
from typing import Mapping
6+
7+
from pydantic.v1 import BaseModel
8+
9+
from airbyte_cdk.sources.declarative.checks.check_dynamic_stream import CheckDynamicStream
510
from airbyte_cdk.sources.declarative.checks.check_stream import CheckStream
611
from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker
12+
from airbyte_cdk.sources.declarative.models import (
13+
CheckDynamicStream as CheckDynamicStreamModel,
14+
)
15+
from airbyte_cdk.sources.declarative.models import (
16+
CheckStream as CheckStreamModel,
17+
)
18+
19+
COMPONENTS_CHECKER_TYPE_MAPPING: Mapping[str, type[BaseModel]] = {
20+
"CheckStream": CheckStreamModel,
21+
"CheckDynamicStream": CheckDynamicStreamModel,
22+
}
723

8-
__all__ = ["CheckStream", "ConnectionChecker"]
24+
__all__ = ["CheckStream", "CheckDynamicStream", "ConnectionChecker"]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#
2+
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
import logging
6+
import traceback
7+
from dataclasses import InitVar, dataclass
8+
from typing import Any, List, Mapping, Tuple
9+
10+
from airbyte_cdk import AbstractSource
11+
from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker
12+
from airbyte_cdk.sources.streams.http.availability_strategy import HttpAvailabilityStrategy
13+
14+
15+
@dataclass
16+
class CheckDynamicStream(ConnectionChecker):
17+
"""
18+
Checks the connections by checking availability of one or many dynamic streams
19+
20+
Attributes:
21+
stream_count (int): numbers of streams to check
22+
"""
23+
24+
stream_count: int
25+
parameters: InitVar[Mapping[str, Any]]
26+
27+
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
28+
self._parameters = parameters
29+
30+
def check_connection(
31+
self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any]
32+
) -> Tuple[bool, Any]:
33+
streams = source.streams(config=config)
34+
if len(streams) == 0:
35+
return False, f"No streams to connect to from source {source}"
36+
37+
for stream_index in range(min(self.stream_count, len(streams))):
38+
stream = streams[stream_index]
39+
availability_strategy = HttpAvailabilityStrategy()
40+
try:
41+
stream_is_available, reason = availability_strategy.check_availability(
42+
stream, logger
43+
)
44+
if not stream_is_available:
45+
return False, reason
46+
except Exception as error:
47+
logger.error(
48+
f"Encountered an error trying to connect to stream {stream.name}. Error: \n {traceback.format_exc()}"
49+
)
50+
return False, f"Unable to connect to stream {stream.name} - {error}"
51+
return True, None

0 commit comments

Comments
 (0)