Skip to content

Commit 82d8892

Browse files
authored
Retry with new Access Token HTTP 419 (#340)
* Refresh Auth token on expiry * Check call count * Add test to cover retry logic * Update poetry.lock with tenacity * Fix tests for Python <= 3.9
1 parent f2aee48 commit 82d8892

File tree

4 files changed

+130
-7
lines changed

4 files changed

+130
-7
lines changed

poetry.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyiceberg/catalog/rest.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from pydantic import Field, ValidationError
3232
from requests import HTTPError, Session
33+
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
3334

3435
from pyiceberg import __version__
3536
from pyiceberg.catalog import (
@@ -118,6 +119,19 @@ class Endpoints:
118119
NAMESPACE_SEPARATOR = b"\x1F".decode(UTF8)
119120

120121

122+
def _retry_hook(retry_state: RetryCallState) -> None:
123+
rest_catalog: RestCatalog = retry_state.args[0]
124+
rest_catalog._refresh_token() # pylint: disable=protected-access
125+
126+
127+
_RETRY_ARGS = {
128+
"retry": retry_if_exception_type(AuthorizationExpiredError),
129+
"stop": stop_after_attempt(2),
130+
"before": _retry_hook,
131+
"reraise": True,
132+
}
133+
134+
121135
class TableResponse(IcebergBaseModel):
122136
metadata_location: str = Field(alias="metadata-location")
123137
metadata: TableMetadata
@@ -225,13 +239,7 @@ def _create_session(self) -> Session:
225239
elif ssl_client_cert := ssl_client.get(CERT):
226240
session.cert = ssl_client_cert
227241

228-
# If we have credentials, but not a token, we want to fetch a token
229-
if TOKEN not in self.properties and CREDENTIAL in self.properties:
230-
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])
231-
232-
# Set Auth token for subsequent calls in the session
233-
if token := self.properties.get(TOKEN):
234-
session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
242+
self._refresh_token(session, self.properties.get(TOKEN))
235243

236244
# Set HTTP headers
237245
session.headers["Content-type"] = "application/json"
@@ -439,6 +447,18 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response:
439447
catalog=self,
440448
)
441449

450+
def _refresh_token(self, session: Optional[Session] = None, new_token: Optional[str] = None) -> None:
451+
session = session or self._session
452+
if new_token is not None:
453+
self.properties[TOKEN] = new_token
454+
elif CREDENTIAL in self.properties:
455+
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])
456+
457+
# Set Auth token for subsequent calls in the session
458+
if token := self.properties.get(TOKEN):
459+
session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
460+
461+
@retry(**_RETRY_ARGS)
442462
def create_table(
443463
self,
444464
identifier: Union[str, Identifier],
@@ -475,6 +495,7 @@ def create_table(
475495
table_response = TableResponse(**response.json())
476496
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)
477497

498+
@retry(**_RETRY_ARGS)
478499
def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table:
479500
"""Register a new table using existing metadata.
480501
@@ -506,6 +527,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
506527
table_response = TableResponse(**response.json())
507528
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)
508529

530+
@retry(**_RETRY_ARGS)
509531
def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
510532
namespace_tuple = self._check_valid_namespace_identifier(namespace)
511533
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
@@ -516,6 +538,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
516538
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
517539
return [(*table.namespace, table.name) for table in ListTablesResponse(**response.json()).identifiers]
518540

541+
@retry(**_RETRY_ARGS)
519542
def load_table(self, identifier: Union[str, Identifier]) -> Table:
520543
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
521544
response = self._session.get(
@@ -529,6 +552,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
529552
table_response = TableResponse(**response.json())
530553
return self._response_to_table(identifier_tuple, table_response)
531554

555+
@retry(**_RETRY_ARGS)
532556
def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None:
533557
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
534558
response = self._session.delete(
@@ -541,9 +565,11 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool =
541565
except HTTPError as exc:
542566
self._handle_non_200_response(exc, {404: NoSuchTableError})
543567

568+
@retry(**_RETRY_ARGS)
544569
def purge_table(self, identifier: Union[str, Identifier]) -> None:
545570
self.drop_table(identifier=identifier, purge_requested=True)
546571

572+
@retry(**_RETRY_ARGS)
547573
def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table:
548574
from_identifier_tuple = self.identifier_to_tuple_without_catalog(from_identifier)
549575
payload = {
@@ -558,6 +584,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
558584

559585
return self.load_table(to_identifier)
560586

587+
@retry(**_RETRY_ARGS)
561588
def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse:
562589
"""Update the table.
563590
@@ -588,6 +615,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
588615
)
589616
return CommitTableResponse(**response.json())
590617

618+
@retry(**_RETRY_ARGS)
591619
def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None:
592620
namespace_tuple = self._check_valid_namespace_identifier(namespace)
593621
payload = {"namespace": namespace_tuple, "properties": properties}
@@ -597,6 +625,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper
597625
except HTTPError as exc:
598626
self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceAlreadyExistsError})
599627

628+
@retry(**_RETRY_ARGS)
600629
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
601630
namespace_tuple = self._check_valid_namespace_identifier(namespace)
602631
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
@@ -606,6 +635,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
606635
except HTTPError as exc:
607636
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
608637

638+
@retry(**_RETRY_ARGS)
609639
def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
610640
namespace_tuple = self.identifier_to_tuple(namespace)
611641
response = self._session.get(
@@ -623,6 +653,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi
623653
namespaces = ListNamespaceResponse(**response.json())
624654
return [namespace_tuple + child_namespace for child_namespace in namespaces.namespaces]
625655

656+
@retry(**_RETRY_ARGS)
626657
def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
627658
namespace_tuple = self._check_valid_namespace_identifier(namespace)
628659
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
@@ -634,6 +665,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper
634665

635666
return NamespaceResponse(**response.json()).properties
636667

668+
@retry(**_RETRY_ARGS)
637669
def update_namespace_properties(
638670
self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT
639671
) -> PropertiesUpdateSummary:

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ sortedcontainers = "2.4.0"
5757
fsspec = ">=2023.1.0,<2024.1.0"
5858
pyparsing = ">=3.1.0,<4.0.0"
5959
zstandard = ">=0.13.0,<1.0.0"
60+
tenacity = ">=8.2.3,<9.0.0"
6061
pyarrow = { version = ">=9.0.0,<16.0.0", optional = true }
6162
pandas = { version = ">=1.0.0,<3.0.0", optional = true }
6263
duckdb = { version = ">=0.5.0,<1.0.0", optional = true }
@@ -301,6 +302,10 @@ ignore_missing_imports = true
301302
module = "setuptools.*"
302303
ignore_missing_imports = true
303304

305+
[[tool.mypy.overrides]]
306+
module = "tenacity.*"
307+
ignore_missing_imports = true
308+
304309
[tool.coverage.run]
305310
source = ['pyiceberg/']
306311

tests/catalog/test_rest.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pyiceberg.catalog import PropertiesUpdateSummary, Table, load_catalog
2727
from pyiceberg.catalog.rest import AUTH_URL, RestCatalog
2828
from pyiceberg.exceptions import (
29+
AuthorizationExpiredError,
2930
NamespaceAlreadyExistsError,
3031
NoSuchNamespaceError,
3132
NoSuchTableError,
@@ -266,6 +267,48 @@ def test_list_namespace_with_parent_200(rest_mock: Mocker) -> None:
266267
]
267268

268269

270+
def test_list_namespaces_419(rest_mock: Mocker) -> None:
271+
new_token = "new_jwt_token"
272+
new_header = dict(TEST_HEADERS)
273+
new_header["Authorization"] = f"Bearer {new_token}"
274+
275+
rest_mock.post(
276+
f"{TEST_URI}v1/namespaces",
277+
json={
278+
"error": {
279+
"message": "Authorization expired.",
280+
"type": "AuthorizationExpiredError",
281+
"code": 419,
282+
}
283+
},
284+
status_code=419,
285+
request_headers=TEST_HEADERS,
286+
)
287+
rest_mock.post(
288+
f"{TEST_URI}v1/oauth/tokens",
289+
json={
290+
"access_token": new_token,
291+
"token_type": "Bearer",
292+
"expires_in": 86400,
293+
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
294+
},
295+
status_code=200,
296+
)
297+
rest_mock.get(
298+
f"{TEST_URI}v1/namespaces",
299+
json={"namespaces": [["default"], ["examples"], ["fokko"], ["system"]]},
300+
status_code=200,
301+
request_headers=new_header,
302+
)
303+
catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, credential=TEST_CREDENTIALS)
304+
assert catalog.list_namespaces() == [
305+
("default",),
306+
("examples",),
307+
("fokko",),
308+
("system",),
309+
]
310+
311+
269312
def test_create_namespace_200(rest_mock: Mocker) -> None:
270313
namespace = "leden"
271314
rest_mock.post(
@@ -517,6 +560,35 @@ def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) -> Non
517560
assert "Table already exists" in str(e.value)
518561

519562

563+
def test_create_table_419(rest_mock: Mocker, table_schema_simple: Schema) -> None:
564+
rest_mock.post(
565+
f"{TEST_URI}v1/namespaces/fokko/tables",
566+
json={
567+
"error": {
568+
"message": "Authorization expired.",
569+
"type": "AuthorizationExpiredError",
570+
"code": 419,
571+
}
572+
},
573+
status_code=419,
574+
request_headers=TEST_HEADERS,
575+
)
576+
577+
with pytest.raises(AuthorizationExpiredError) as e:
578+
RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).create_table(
579+
identifier=("fokko", "fokko2"),
580+
schema=table_schema_simple,
581+
location=None,
582+
partition_spec=PartitionSpec(
583+
PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id")
584+
),
585+
sort_order=SortOrder(SortField(source_id=2, transform=IdentityTransform())),
586+
properties={"owner": "fokko"},
587+
)
588+
assert "Authorization expired" in str(e.value)
589+
assert rest_mock.call_count == 3
590+
591+
520592
def test_register_table_200(
521593
rest_mock: Mocker, table_schema_simple: Schema, example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any]
522594
) -> None:

0 commit comments

Comments
 (0)