From 8f567638a1765cc30f59866d9867a79c60e69d56 Mon Sep 17 00:00:00 2001 From: darynaishchenko Date: Wed, 15 Jan 2025 13:58:16 +0200 Subject: [PATCH 1/4] added refresh_request_headers to oauth --- airbyte_cdk/sources/declarative/auth/oauth.py | 8 ++++++++ .../declarative/declarative_component_schema.yaml | 8 ++++++++ .../models/declarative_component_schema.py | 11 +++++++++++ .../parsers/model_to_component_factory.py | 4 ++++ .../http/requests_native_auth/abstract_oauth.py | 13 +++++++++++++ .../streams/http/requests_native_auth/oauth.py | 8 ++++++++ 6 files changed, 52 insertions(+) diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index f3ba528ac..6527b0a93 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -39,6 +39,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut token_expiry_date_format str: format of the datetime; provide it if expires_in is returned in datetime instead of seconds token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration refresh_request_body (Optional[Mapping[str, Any]]): The request body to send in the refresh request + refresh_request_headers (Optional[Mapping[str, Any]]): The request headers to send in the refresh request grant_type: The grant_type to request for access_token. If set to refresh_token, the refresh_token parameter has to be provided message_repository (MessageRepository): the message repository used to emit logs on HTTP requests """ @@ -58,6 +59,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut access_token_value: Optional[Union[InterpolatedString, str]] = None expires_in_name: Union[InterpolatedString, str] = "expires_in" refresh_request_body: Optional[Mapping[str, Any]] = None + refresh_request_headers: Optional[Mapping[str, Any]] = None grant_type: Union[InterpolatedString, str] = "refresh_token" message_repository: MessageRepository = NoopMessageRepository() @@ -87,6 +89,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._refresh_request_body = InterpolatedMapping( self.refresh_request_body or {}, parameters=parameters ) + self._refresh_request_headers = InterpolatedMapping( + self.refresh_request_headers or {}, parameters=parameters + ) self._token_expiry_date: pendulum.DateTime = ( pendulum.parse( InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval( @@ -152,6 +157,9 @@ def get_grant_type(self) -> str: def get_refresh_request_body(self) -> Mapping[str, Any]: return self._refresh_request_body.eval(self.config) + def get_refresh_request_headers(self) -> Mapping[str, Any]: + return self._refresh_request_headers.eval(self.config) + def get_token_expiry_date(self) -> pendulum.DateTime: return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 7a3619a45..7e27ad868 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -1111,6 +1111,14 @@ definitions: - applicationId: "{{ config['application_id'] }}" applicationSecret: "{{ config['application_secret'] }}" token: "{{ config['token'] }}" + refresh_request_headers: + title: Refresh Request Headers + description: Headers of the request sent to get a new access token. + type: object + additionalProperties: true + examples: + - Authorization: "" + Content-Type: "application/x-www-form-urlencoded" scopes: title: Scopes description: List of scopes that should be granted to the access token. diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index df6925eaa..de0c28052 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -547,6 +547,17 @@ class OAuthAuthenticator(BaseModel): ], title="Refresh Request Body", ) + refresh_request_headers: Optional[Dict[str, Any]] = Field( + None, + description="Headers of the request sent to get a new access token.", + examples=[ + { + "Authorization": "", + "Content-Type": "application/x-www-form-urlencoded", + } + ], + title="Refresh Request Headers", + ) scopes: Optional[List[str]] = Field( None, description="List of scopes that should be granted to the access token.", diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 8a31fab2e..bcb7b9f83 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -1900,6 +1900,9 @@ def create_oauth_authenticator( refresh_request_body=InterpolatedMapping( model.refresh_request_body or {}, parameters=model.parameters or {} ).eval(config), + refresh_request_headers=InterpolatedMapping( + model.refresh_request_headers or {}, parameters=model.parameters or {} + ).eval(config), scopes=model.scopes, token_expiry_date_format=model.token_expiry_date_format, message_repository=self._message_repository, @@ -1916,6 +1919,7 @@ def create_oauth_authenticator( expires_in_name=model.expires_in_name or "expires_in", grant_type=model.grant_type or "refresh_token", refresh_request_body=model.refresh_request_body, + refresh_request_headers=model.refresh_request_headers, refresh_token=model.refresh_token, scopes=model.scopes, token_expiry_date=model.token_expiry_date, diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 1f3c1c85e..2edb71e74 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -98,6 +98,14 @@ def build_refresh_request_body(self) -> Mapping[str, Any]: return payload + def build_refresh_request_headers(self) -> Mapping[str, Any] | None: + """ + Returns the request headers to set on the refresh request + + """ + headers = self.get_refresh_request_headers() + return headers if headers else None + def _wrap_refresh_token_exception( self, exception: requests.exceptions.RequestException ) -> bool: @@ -128,6 +136,7 @@ def _get_refresh_access_token_response(self) -> Any: method="POST", url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. data=self.build_refresh_request_body(), + headers=self.build_refresh_request_headers(), ) if response.ok: response_json = response.json() @@ -242,6 +251,10 @@ def get_expires_in_name(self) -> str: def get_refresh_request_body(self) -> Mapping[str, Any]: """Returns the request body to set on the refresh request""" + @abstractmethod + def get_refresh_request_headers(self) -> Mapping[str, Any]: + """Returns the request body to set on the refresh request""" + @abstractmethod def get_grant_type(self) -> str: """Returns grant_type specified for requesting access_token""" diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 8e5c71458..9490d1757 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -36,6 +36,7 @@ def __init__( access_token_name: str = "access_token", expires_in_name: str = "expires_in", refresh_request_body: Mapping[str, Any] | None = None, + refresh_request_headers: Mapping[str, Any] | None = None, grant_type: str = "refresh_token", token_expiry_is_time_of_expiration: bool = False, refresh_token_error_status_codes: Tuple[int, ...] = (), @@ -50,6 +51,7 @@ def __init__( self._access_token_name = access_token_name self._expires_in_name = expires_in_name self._refresh_request_body = refresh_request_body + self._refresh_request_headers = refresh_request_headers self._grant_type = grant_type self._token_expiry_date = token_expiry_date or pendulum.now().subtract(days=1) # type: ignore [no-untyped-call] @@ -84,6 +86,9 @@ def get_expires_in_name(self) -> str: def get_refresh_request_body(self) -> Mapping[str, Any]: return self._refresh_request_body # type: ignore [return-value] + def get_refresh_request_headers(self) -> Mapping[str, Any]: + return self._refresh_request_headers # type: ignore [return-value] + def get_grant_type(self) -> str: return self._grant_type @@ -129,6 +134,7 @@ def __init__( expires_in_name: str = "expires_in", refresh_token_name: str = "refresh_token", refresh_request_body: Mapping[str, Any] | None = None, + refresh_request_headers: Mapping[str, Any] | None = None, grant_type: str = "refresh_token", client_id: Optional[str] = None, client_secret: Optional[str] = None, @@ -151,6 +157,7 @@ def __init__( expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in". refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token". refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None. + refresh_request_headers (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request headers. Defaults to None. grant_type (str, optional): OAuth grant type. Defaults to "refresh_token". client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object. client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object. @@ -191,6 +198,7 @@ def __init__( access_token_name=access_token_name, expires_in_name=expires_in_name, refresh_request_body=refresh_request_body, + refresh_request_headers=refresh_request_headers, grant_type=grant_type, token_expiry_date_format=token_expiry_date_format, token_expiry_is_time_of_expiration=token_expiry_is_time_of_expiration, From ab3ee27d35083349230d8decf15913d5b2b732fc Mon Sep 17 00:00:00 2001 From: darynaishchenko Date: Wed, 15 Jan 2025 14:59:12 +0200 Subject: [PATCH 2/4] updated unit tests --- .../sources/declarative/auth/test_oauth.py | 69 ++++++++++++++++++- .../test_requests_native_auth.py | 67 +++++++++++++++++- 2 files changed, 132 insertions(+), 4 deletions(-) diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index 4130a9dc8..b285d260f 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -69,6 +69,39 @@ def test_refresh_request_body(self): } assert body == expected + def test_refresh_request_headers(self): + """ + Request headers should match given configuration. + """ + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ parameters['refresh_token'] }}", + config=config, + token_expiry_date="{{ config['token_expiry_date'] }}", + refresh_request_headers={ + "Authorization": "", + "Content-Type": "application/x-www-form-urlencoded", + }, + parameters=parameters, + ) + headers = oauth.build_refresh_request_headers() + expected = {"Authorization": "", "Content-Type": "application/x-www-form-urlencoded"} + assert headers == expected + + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ parameters['refresh_token'] }}", + config=config, + token_expiry_date="{{ config['token_expiry_date'] }}", + parameters=parameters, + ) + headers = oauth.build_refresh_request_headers() + assert headers is None + def test_refresh_with_encode_config_params(self): oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -191,6 +224,36 @@ def test_refresh_access_token(self, mocker): filtered = filter_secrets("access_token") assert filtered == "****" + def test_refresh_access_token_when_headers_provided(self, mocker): + expected_headers = { + "Authorization": "Bearer some_access_token", + "Content-Type": "application/x-www-form-urlencoded", + } + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ config['refresh_token'] }}", + config=config, + scopes=["scope1", "scope2"], + token_expiry_date="{{ config['token_expiry_date'] }}", + refresh_request_headers=expected_headers, + parameters={}, + ) + + resp.status_code = 200 + mocker.patch.object( + resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} + ) + mocked_request = mocker.patch.object( + requests, "request", side_effect=mock_request, autospec=True + ) + token = oauth.refresh_access_token() + + assert ("access_token", 1000) == token + + assert mocked_request.call_args.kwargs["headers"] == expected_headers + def test_refresh_access_token_missing_access_token(self, mocker): oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -371,7 +434,9 @@ def test_error_handling(self, mocker): assert e.value.errno == 400 -def mock_request(method, url, data): +def mock_request(method, url, data, headers): if url == "refresh_end": return resp - raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}") + raise Exception( + f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}" + ) diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index 093c136e5..a0b541599 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -165,6 +165,38 @@ def test_refresh_request_body(self): } assert body == expected + def test_refresh_request_headers(self): + """ + Request headers should match given configuration. + """ + oauth = Oauth2Authenticator( + token_refresh_endpoint="refresh_end", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + token_expiry_date=pendulum.now().add(days=3), + refresh_request_headers={ + "Authorization": "Bearer some_refresh_token", + "Content-Type": "application/x-www-form-urlencoded", + }, + ) + headers = oauth.build_refresh_request_headers() + expected = { + "Authorization": "Bearer some_refresh_token", + "Content-Type": "application/x-www-form-urlencoded", + } + assert headers == expected + + oauth = Oauth2Authenticator( + token_refresh_endpoint="refresh_end", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + token_expiry_date=pendulum.now().add(days=3), + ) + headers = oauth.build_refresh_request_headers() + assert headers is None + def test_refresh_access_token(self, mocker): oauth = Oauth2Authenticator( token_refresh_endpoint="refresh_end", @@ -210,6 +242,35 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, str) assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in) + def test_refresh_access_token_when_headers_provided(self, mocker): + expected_headers = { + "Authorization": "Bearer some_access_token", + "Content-Type": "application/x-www-form-urlencoded", + } + oauth = Oauth2Authenticator( + token_refresh_endpoint="refresh_end", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + scopes=["scope1", "scope2"], + token_expiry_date=pendulum.now().add(days=3), + refresh_request_headers=expected_headers, + ) + + resp.status_code = 200 + mocker.patch.object( + resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} + ) + mocked_request = mocker.patch.object( + requests, "request", side_effect=mock_request, autospec=True + ) + token, expires_in = oauth.refresh_access_token() + + assert isinstance(expires_in, int) + assert ("access_token", 1000) == (token, expires_in) + + assert mocked_request.call_args.kwargs["headers"] == expected_headers + @pytest.mark.parametrize( "expires_in_response, token_expiry_date_format, expected_token_expiry_date", [ @@ -522,7 +583,9 @@ def test_refresh_access_token(self, mocker, connector_config): ) -def mock_request(method, url, data): +def mock_request(method, url, data, headers): if url == "refresh_end": return resp - raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}") + raise Exception( + f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}" + ) From 60ecfbf7d570711a84b1ef08184f7cc63b740da8 Mon Sep 17 00:00:00 2001 From: darynaishchenko Date: Wed, 15 Jan 2025 15:06:02 +0200 Subject: [PATCH 3/4] updated comment --- .../sources/streams/http/requests_native_auth/abstract_oauth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 2edb71e74..d21a36722 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -253,7 +253,7 @@ def get_refresh_request_body(self) -> Mapping[str, Any]: @abstractmethod def get_refresh_request_headers(self) -> Mapping[str, Any]: - """Returns the request body to set on the refresh request""" + """Returns the request headers to set on the refresh request""" @abstractmethod def get_grant_type(self) -> str: From 464ce0dba5e359abf703db08fefd5cf92a158046 Mon Sep 17 00:00:00 2001 From: darynaishchenko Date: Thu, 16 Jan 2025 13:23:16 +0200 Subject: [PATCH 4/4] updated unit test --- unit_tests/sources/declarative/auth/test_oauth.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index b285d260f..dc384bb10 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -81,13 +81,16 @@ def test_refresh_request_headers(self): config=config, token_expiry_date="{{ config['token_expiry_date'] }}", refresh_request_headers={ - "Authorization": "", + "Authorization": "Basic {{ [config['client_id'], config['client_secret']] | join(':') | base64encode }}", "Content-Type": "application/x-www-form-urlencoded", }, parameters=parameters, ) headers = oauth.build_refresh_request_headers() - expected = {"Authorization": "", "Content-Type": "application/x-www-form-urlencoded"} + expected = { + "Authorization": "Basic c29tZV9jbGllbnRfaWQ6c29tZV9jbGllbnRfc2VjcmV0", + "Content-Type": "application/x-www-form-urlencoded", + } assert headers == expected oauth = DeclarativeOauth2Authenticator(