Skip to content

Commit 2185bd9

Browse files
feat(low-code): pass refresh headers to oauth (#219)
1 parent 40a9f1e commit 2185bd9

File tree

8 files changed

+187
-4
lines changed

8 files changed

+187
-4
lines changed

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
3939
token_expiry_date_format str: format of the datetime; provide it if expires_in is returned in datetime instead of seconds
4040
token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration
4141
refresh_request_body (Optional[Mapping[str, Any]]): The request body to send in the refresh request
42+
refresh_request_headers (Optional[Mapping[str, Any]]): The request headers to send in the refresh request
4243
grant_type: The grant_type to request for access_token. If set to refresh_token, the refresh_token parameter has to be provided
4344
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
4445
"""
@@ -61,6 +62,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
6162
expires_in_name: Union[InterpolatedString, str] = "expires_in"
6263
refresh_token_name: Union[InterpolatedString, str] = "refresh_token"
6364
refresh_request_body: Optional[Mapping[str, Any]] = None
65+
refresh_request_headers: Optional[Mapping[str, Any]] = None
6466
grant_type_name: Union[InterpolatedString, str] = "grant_type"
6567
grant_type: Union[InterpolatedString, str] = "refresh_token"
6668
message_repository: MessageRepository = NoopMessageRepository()
@@ -101,6 +103,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
101103
self._refresh_request_body = InterpolatedMapping(
102104
self.refresh_request_body or {}, parameters=parameters
103105
)
106+
self._refresh_request_headers = InterpolatedMapping(
107+
self.refresh_request_headers or {}, parameters=parameters
108+
)
104109
self._token_expiry_date: pendulum.DateTime = (
105110
pendulum.parse(
106111
InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(
@@ -178,6 +183,9 @@ def get_grant_type(self) -> str:
178183
def get_refresh_request_body(self) -> Mapping[str, Any]:
179184
return self._refresh_request_body.eval(self.config)
180185

186+
def get_refresh_request_headers(self) -> Mapping[str, Any]:
187+
return self._refresh_request_headers.eval(self.config)
188+
181189
def get_token_expiry_date(self) -> pendulum.DateTime:
182190
return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks
183191

airbyte_cdk/sources/declarative/declarative_component_schema.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,14 @@ definitions:
11391139
- applicationId: "{{ config['application_id'] }}"
11401140
applicationSecret: "{{ config['application_secret'] }}"
11411141
token: "{{ config['token'] }}"
1142+
refresh_request_headers:
1143+
title: Refresh Request Headers
1144+
description: Headers of the request sent to get a new access token.
1145+
type: object
1146+
additionalProperties: true
1147+
examples:
1148+
- Authorization: "<AUTH_TOKEN>"
1149+
Content-Type: "application/x-www-form-urlencoded"
11421150
scopes:
11431151
title: Scopes
11441152
description: List of scopes that should be granted to the access token.

airbyte_cdk/sources/declarative/models/declarative_component_schema.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,17 @@ class OAuthAuthenticator(BaseModel):
571571
],
572572
title="Refresh Request Body",
573573
)
574+
refresh_request_headers: Optional[Dict[str, Any]] = Field(
575+
None,
576+
description="Headers of the request sent to get a new access token.",
577+
examples=[
578+
{
579+
"Authorization": "<AUTH_TOKEN>",
580+
"Content-Type": "application/x-www-form-urlencoded",
581+
}
582+
],
583+
title="Refresh Request Headers",
584+
)
574585
scopes: Optional[List[str]] = Field(
575586
None,
576587
description="List of scopes that should be granted to the access token.",

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,6 +1919,9 @@ def create_oauth_authenticator(
19191919
refresh_request_body=InterpolatedMapping(
19201920
model.refresh_request_body or {}, parameters=model.parameters or {}
19211921
).eval(config),
1922+
refresh_request_headers=InterpolatedMapping(
1923+
model.refresh_request_headers or {}, parameters=model.parameters or {}
1924+
).eval(config),
19221925
scopes=model.scopes,
19231926
token_expiry_date_format=model.token_expiry_date_format,
19241927
message_repository=self._message_repository,
@@ -1938,6 +1941,7 @@ def create_oauth_authenticator(
19381941
grant_type_name=model.grant_type_name or "grant_type",
19391942
grant_type=model.grant_type or "refresh_token",
19401943
refresh_request_body=model.refresh_request_body,
1944+
refresh_request_headers=model.refresh_request_headers,
19411945
refresh_token_name=model.refresh_token_name or "refresh_token",
19421946
refresh_token=model.refresh_token,
19431947
scopes=model.scopes,

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:
9898

9999
return payload
100100

101+
def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
102+
"""
103+
Returns the request headers to set on the refresh request
104+
105+
"""
106+
headers = self.get_refresh_request_headers()
107+
return headers if headers else None
108+
101109
def _wrap_refresh_token_exception(
102110
self, exception: requests.exceptions.RequestException
103111
) -> bool:
@@ -128,6 +136,7 @@ def _get_refresh_access_token_response(self) -> Any:
128136
method="POST",
129137
url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
130138
data=self.build_refresh_request_body(),
139+
headers=self.build_refresh_request_headers(),
131140
)
132141
if response.ok:
133142
response_json = response.json()
@@ -254,6 +263,10 @@ def get_expires_in_name(self) -> str:
254263
def get_refresh_request_body(self) -> Mapping[str, Any]:
255264
"""Returns the request body to set on the refresh request"""
256265

266+
@abstractmethod
267+
def get_refresh_request_headers(self) -> Mapping[str, Any]:
268+
"""Returns the request headers to set on the refresh request"""
269+
257270
@abstractmethod
258271
def get_grant_type(self) -> str:
259272
"""Returns grant_type specified for requesting access_token"""

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
access_token_name: str = "access_token",
4040
expires_in_name: str = "expires_in",
4141
refresh_request_body: Mapping[str, Any] | None = None,
42+
refresh_request_headers: Mapping[str, Any] | None = None,
4243
grant_type_name: str = "grant_type",
4344
grant_type: str = "refresh_token",
4445
token_expiry_is_time_of_expiration: bool = False,
@@ -57,6 +58,7 @@ def __init__(
5758
self._access_token_name = access_token_name
5859
self._expires_in_name = expires_in_name
5960
self._refresh_request_body = refresh_request_body
61+
self._refresh_request_headers = refresh_request_headers
6062
self._grant_type_name = grant_type_name
6163
self._grant_type = grant_type
6264

@@ -101,6 +103,9 @@ def get_expires_in_name(self) -> str:
101103
def get_refresh_request_body(self) -> Mapping[str, Any]:
102104
return self._refresh_request_body # type: ignore [return-value]
103105

106+
def get_refresh_request_headers(self) -> Mapping[str, Any]:
107+
return self._refresh_request_headers # type: ignore [return-value]
108+
104109
def get_grant_type_name(self) -> str:
105110
return self._grant_type_name
106111

@@ -149,6 +154,7 @@ def __init__(
149154
expires_in_name: str = "expires_in",
150155
refresh_token_name: str = "refresh_token",
151156
refresh_request_body: Mapping[str, Any] | None = None,
157+
refresh_request_headers: Mapping[str, Any] | None = None,
152158
grant_type_name: str = "grant_type",
153159
grant_type: str = "refresh_token",
154160
client_id_name: str = "client_id",
@@ -174,6 +180,7 @@ def __init__(
174180
expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
175181
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
176182
refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
183+
refresh_request_headers (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request headers. Defaults to None.
177184
grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
178185
client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object.
179186
client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object.
@@ -220,6 +227,7 @@ def __init__(
220227
access_token_name=access_token_name,
221228
expires_in_name=expires_in_name,
222229
refresh_request_body=refresh_request_body,
230+
refresh_request_headers=refresh_request_headers,
223231
grant_type_name=self._grant_type_name,
224232
grant_type=grant_type,
225233
token_expiry_date_format=token_expiry_date_format,

unit_tests/sources/declarative/auth/test_oauth.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,42 @@ def test_refresh_request_body(self):
6969
}
7070
assert body == expected
7171

72+
def test_refresh_request_headers(self):
73+
"""
74+
Request headers should match given configuration.
75+
"""
76+
oauth = DeclarativeOauth2Authenticator(
77+
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
78+
client_id="{{ config['client_id'] }}",
79+
client_secret="{{ config['client_secret'] }}",
80+
refresh_token="{{ parameters['refresh_token'] }}",
81+
config=config,
82+
token_expiry_date="{{ config['token_expiry_date'] }}",
83+
refresh_request_headers={
84+
"Authorization": "Basic {{ [config['client_id'], config['client_secret']] | join(':') | base64encode }}",
85+
"Content-Type": "application/x-www-form-urlencoded",
86+
},
87+
parameters=parameters,
88+
)
89+
headers = oauth.build_refresh_request_headers()
90+
expected = {
91+
"Authorization": "Basic c29tZV9jbGllbnRfaWQ6c29tZV9jbGllbnRfc2VjcmV0",
92+
"Content-Type": "application/x-www-form-urlencoded",
93+
}
94+
assert headers == expected
95+
96+
oauth = DeclarativeOauth2Authenticator(
97+
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
98+
client_id="{{ config['client_id'] }}",
99+
client_secret="{{ config['client_secret'] }}",
100+
refresh_token="{{ parameters['refresh_token'] }}",
101+
config=config,
102+
token_expiry_date="{{ config['token_expiry_date'] }}",
103+
parameters=parameters,
104+
)
105+
headers = oauth.build_refresh_request_headers()
106+
assert headers is None
107+
72108
def test_refresh_with_encode_config_params(self):
73109
oauth = DeclarativeOauth2Authenticator(
74110
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
@@ -191,6 +227,36 @@ def test_refresh_access_token(self, mocker):
191227
filtered = filter_secrets("access_token")
192228
assert filtered == "****"
193229

230+
def test_refresh_access_token_when_headers_provided(self, mocker):
231+
expected_headers = {
232+
"Authorization": "Bearer some_access_token",
233+
"Content-Type": "application/x-www-form-urlencoded",
234+
}
235+
oauth = DeclarativeOauth2Authenticator(
236+
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
237+
client_id="{{ config['client_id'] }}",
238+
client_secret="{{ config['client_secret'] }}",
239+
refresh_token="{{ config['refresh_token'] }}",
240+
config=config,
241+
scopes=["scope1", "scope2"],
242+
token_expiry_date="{{ config['token_expiry_date'] }}",
243+
refresh_request_headers=expected_headers,
244+
parameters={},
245+
)
246+
247+
resp.status_code = 200
248+
mocker.patch.object(
249+
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
250+
)
251+
mocked_request = mocker.patch.object(
252+
requests, "request", side_effect=mock_request, autospec=True
253+
)
254+
token = oauth.refresh_access_token()
255+
256+
assert ("access_token", 1000) == token
257+
258+
assert mocked_request.call_args.kwargs["headers"] == expected_headers
259+
194260
def test_refresh_access_token_missing_access_token(self, mocker):
195261
oauth = DeclarativeOauth2Authenticator(
196262
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
@@ -371,7 +437,9 @@ def test_error_handling(self, mocker):
371437
assert e.value.errno == 400
372438

373439

374-
def mock_request(method, url, data):
440+
def mock_request(method, url, data, headers):
375441
if url == "refresh_end":
376442
return resp
377-
raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}")
443+
raise Exception(
444+
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
445+
)

unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,38 @@ def test_refresh_request_body(self):
165165
}
166166
assert body == expected
167167

168+
def test_refresh_request_headers(self):
169+
"""
170+
Request headers should match given configuration.
171+
"""
172+
oauth = Oauth2Authenticator(
173+
token_refresh_endpoint="refresh_end",
174+
client_id="some_client_id",
175+
client_secret="some_client_secret",
176+
refresh_token="some_refresh_token",
177+
token_expiry_date=pendulum.now().add(days=3),
178+
refresh_request_headers={
179+
"Authorization": "Bearer some_refresh_token",
180+
"Content-Type": "application/x-www-form-urlencoded",
181+
},
182+
)
183+
headers = oauth.build_refresh_request_headers()
184+
expected = {
185+
"Authorization": "Bearer some_refresh_token",
186+
"Content-Type": "application/x-www-form-urlencoded",
187+
}
188+
assert headers == expected
189+
190+
oauth = Oauth2Authenticator(
191+
token_refresh_endpoint="refresh_end",
192+
client_id="some_client_id",
193+
client_secret="some_client_secret",
194+
refresh_token="some_refresh_token",
195+
token_expiry_date=pendulum.now().add(days=3),
196+
)
197+
headers = oauth.build_refresh_request_headers()
198+
assert headers is None
199+
168200
def test_refresh_request_body_with_keys_override(self):
169201
"""
170202
Request body should match given configuration.
@@ -245,6 +277,35 @@ def test_refresh_access_token(self, mocker):
245277
assert isinstance(expires_in, str)
246278
assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in)
247279

280+
def test_refresh_access_token_when_headers_provided(self, mocker):
281+
expected_headers = {
282+
"Authorization": "Bearer some_access_token",
283+
"Content-Type": "application/x-www-form-urlencoded",
284+
}
285+
oauth = Oauth2Authenticator(
286+
token_refresh_endpoint="refresh_end",
287+
client_id="some_client_id",
288+
client_secret="some_client_secret",
289+
refresh_token="some_refresh_token",
290+
scopes=["scope1", "scope2"],
291+
token_expiry_date=pendulum.now().add(days=3),
292+
refresh_request_headers=expected_headers,
293+
)
294+
295+
resp.status_code = 200
296+
mocker.patch.object(
297+
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
298+
)
299+
mocked_request = mocker.patch.object(
300+
requests, "request", side_effect=mock_request, autospec=True
301+
)
302+
token, expires_in = oauth.refresh_access_token()
303+
304+
assert isinstance(expires_in, int)
305+
assert ("access_token", 1000) == (token, expires_in)
306+
307+
assert mocked_request.call_args.kwargs["headers"] == expected_headers
308+
248309
@pytest.mark.parametrize(
249310
"expires_in_response, token_expiry_date_format, expected_token_expiry_date",
250311
[
@@ -557,7 +618,9 @@ def test_refresh_access_token(self, mocker, connector_config):
557618
)
558619

559620

560-
def mock_request(method, url, data):
621+
def mock_request(method, url, data, headers):
561622
if url == "refresh_end":
562623
return resp
563-
raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}")
624+
raise Exception(
625+
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
626+
)

0 commit comments

Comments
 (0)