Skip to content

Commit bd548c7

Browse files
committed
first attempt to make expiration date optional
1 parent 533b70a commit bd548c7

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,17 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
130130
headers = self.get_refresh_request_headers()
131131
return headers if headers else None
132132

133-
def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
133+
def refresh_access_token(self) -> Tuple[str, Union[str, int, None]]:
134134
"""
135-
Returns the refresh token and its expiration datetime
135+
Returns the refreshed access token and its expiration datetime
136136
137137
:return: a tuple of (access_token, token_lifespan)
138138
"""
139139
response_json = self._make_handled_request()
140140
self._ensure_access_token_in_response(response_json)
141141

142142
return (
143-
self._extract_access_token(response_json),
143+
str(self._extract_access_token(response_json)),
144144
self._extract_token_expiry_date(response_json),
145145
)
146146

@@ -184,7 +184,7 @@ def _wrap_refresh_token_exception(
184184
),
185185
max_time=300,
186186
)
187-
def _make_handled_request(self) -> Any:
187+
def _make_handled_request(self) -> Mapping[str, Any]:
188188
"""
189189
Makes a handled HTTP request to refresh an OAuth token.
190190
@@ -292,41 +292,44 @@ def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
292292
response_data (Mapping[str, Any]): The response data from which to extract the access token.
293293
294294
Returns:
295-
str: The extracted access token.
295+
str: The extracted access token or None if not found.
296296
"""
297-
return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
297+
access_token = self._find_and_get_value_from_response(response_data, self.get_access_token_name())
298+
return str(access_token) if access_token is not None else None
298299

299-
def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
300+
def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> str | None:
300301
"""
301302
Extracts the refresh token from the given response data.
302303
303304
Args:
304305
response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
305306
306307
Returns:
307-
str: The extracted refresh token.
308+
str: The extracted refresh token or None if not found.
308309
"""
309-
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
310+
refresh_token = self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
311+
return str(refresh_token) if refresh_token is not None else None
310312

311-
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
313+
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> str | None:
312314
"""
313315
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
314316
315317
Args:
316318
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
317319
318320
Returns:
319-
str: The extracted token_expiry_date.
321+
str: The extracted token_expiry_date or None if not found.
320322
"""
321-
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
323+
token_expiry_date = self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
324+
return str(token_expiry_date) if token_expiry_date is not None else None
322325

323326
def _find_and_get_value_from_response(
324327
self,
325328
response_data: Mapping[str, Any],
326329
key_name: str,
327330
max_depth: int = 5,
328331
current_depth: int = 0,
329-
) -> Any:
332+
) -> Any | None:
330333
"""
331334
Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
332335

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,17 @@ def get_access_token(self) -> str:
346346
new_access_token, access_token_expires_in, new_refresh_token = (
347347
self.refresh_access_token()
348348
)
349-
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
350-
access_token_expires_in, self._token_expiry_date_format
351-
)
349+
if access_token_expires_in is not None:
350+
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
351+
access_token_expires_in, self._token_expiry_date_format
352+
)
353+
self.set_token_expiry_date(new_token_expiry_date)
352354
self.access_token = new_access_token
353355
self.set_refresh_token(new_refresh_token)
354-
self.set_token_expiry_date(new_token_expiry_date)
355356
self._emit_control_message()
356357
return self.access_token
357358

358-
def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
359+
def refresh_access_token(self) -> Tuple[str, str | None, str]: # type: ignore[override]
359360
"""
360361
Refreshes the access token by making a handled request and extracting the necessary token information.
361362

0 commit comments

Comments
 (0)