diff --git a/poetry.lock b/poetry.lock index e6e679a4b2..64bbc0f91e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3879,6 +3879,20 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tenacity" +version = "8.2.3" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, + {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, +] + +[package.extras] +doc = ["reno", "sphinx", "tornado (>=4.5)"] + [[package]] name = "thrift" version = "0.16.0" diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index 765f04b128..6a75328cae 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -30,6 +30,7 @@ from pydantic import Field, ValidationError from requests import HTTPError, Session +from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt from pyiceberg import __version__ from pyiceberg.catalog import ( @@ -118,6 +119,19 @@ class Endpoints: NAMESPACE_SEPARATOR = b"\x1F".decode(UTF8) +def _retry_hook(retry_state: RetryCallState) -> None: + rest_catalog: RestCatalog = retry_state.args[0] + rest_catalog._refresh_token() # pylint: disable=protected-access + + +_RETRY_ARGS = { + "retry": retry_if_exception_type(AuthorizationExpiredError), + "stop": stop_after_attempt(2), + "before": _retry_hook, + "reraise": True, +} + + class TableResponse(IcebergBaseModel): metadata_location: str = Field(alias="metadata-location") metadata: TableMetadata @@ -225,13 +239,7 @@ def _create_session(self) -> Session: elif ssl_client_cert := ssl_client.get(CERT): session.cert = ssl_client_cert - # If we have credentials, but not a token, we want to fetch a token - if TOKEN not in self.properties and CREDENTIAL in self.properties: - self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL]) - - # Set Auth token for subsequent calls in the session - if token := self.properties.get(TOKEN): - session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" + self._refresh_token(session, self.properties.get(TOKEN)) # Set HTTP headers session.headers["Content-type"] = "application/json" @@ -438,6 +446,18 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response: catalog=self, ) + def _refresh_token(self, session: Optional[Session] = None, new_token: Optional[str] = None) -> None: + session = session or self._session + if new_token is not None: + self.properties[TOKEN] = new_token + elif CREDENTIAL in self.properties: + self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL]) + + # Set Auth token for subsequent calls in the session + if token := self.properties.get(TOKEN): + session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" + + @retry(**_RETRY_ARGS) def create_table( self, identifier: Union[str, Identifier], @@ -472,6 +492,7 @@ def create_table( table_response = TableResponse(**response.json()) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) + @retry(**_RETRY_ARGS) def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: """Register a new table using existing metadata. @@ -503,6 +524,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: table_response = TableResponse(**response.json()) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) + @retry(**_RETRY_ARGS) def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple) @@ -513,6 +535,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) return [(*table.namespace, table.name) for table in ListTablesResponse(**response.json()).identifiers] + @retry(**_RETRY_ARGS) def load_table(self, identifier: Union[str, Identifier]) -> Table: identifier_tuple = self.identifier_to_tuple_without_catalog(identifier) response = self._session.get( @@ -526,6 +549,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: table_response = TableResponse(**response.json()) return self._response_to_table(identifier_tuple, table_response) + @retry(**_RETRY_ARGS) def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None: identifier_tuple = self.identifier_to_tuple_without_catalog(identifier) response = self._session.delete( @@ -538,9 +562,11 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = except HTTPError as exc: self._handle_non_200_response(exc, {404: NoSuchTableError}) + @retry(**_RETRY_ARGS) def purge_table(self, identifier: Union[str, Identifier]) -> None: self.drop_table(identifier=identifier, purge_requested=True) + @retry(**_RETRY_ARGS) def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: from_identifier_tuple = self.identifier_to_tuple_without_catalog(from_identifier) payload = { @@ -555,6 +581,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U return self.load_table(to_identifier) + @retry(**_RETRY_ARGS) def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse: """Update the table. @@ -585,6 +612,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons ) return CommitTableResponse(**response.json()) + @retry(**_RETRY_ARGS) def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: namespace_tuple = self._check_valid_namespace_identifier(namespace) payload = {"namespace": namespace_tuple, "properties": properties} @@ -594,6 +622,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper except HTTPError as exc: self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceAlreadyExistsError}) + @retry(**_RETRY_ARGS) def drop_namespace(self, namespace: Union[str, Identifier]) -> None: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) @@ -603,6 +632,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: except HTTPError as exc: self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + @retry(**_RETRY_ARGS) def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: namespace_tuple = self.identifier_to_tuple(namespace) response = self._session.get( @@ -620,6 +650,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi namespaces = ListNamespaceResponse(**response.json()) return [namespace_tuple + child_namespace for child_namespace in namespaces.namespaces] + @retry(**_RETRY_ARGS) def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) @@ -631,6 +662,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper return NamespaceResponse(**response.json()).properties + @retry(**_RETRY_ARGS) def update_namespace_properties( self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: diff --git a/pyproject.toml b/pyproject.toml index dcc91fbbe3..2c79ed1a22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ sortedcontainers = "2.4.0" fsspec = ">=2023.1.0,<2024.1.0" pyparsing = ">=3.1.0,<4.0.0" zstandard = ">=0.13.0,<1.0.0" +tenacity = ">=8.2.3,<9.0.0" pyarrow = { version = ">=9.0.0,<16.0.0", optional = true } pandas = { version = ">=1.0.0,<3.0.0", optional = true } duckdb = { version = ">=0.5.0,<1.0.0", optional = true } @@ -295,6 +296,10 @@ ignore_missing_imports = true module = "setuptools.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "tenacity.*" +ignore_missing_imports = true + [tool.coverage.run] source = ['pyiceberg/'] diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index 248cc14d88..7ae0d19558 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -26,6 +26,7 @@ from pyiceberg.catalog import PropertiesUpdateSummary, Table, load_catalog from pyiceberg.catalog.rest import AUTH_URL, RestCatalog from pyiceberg.exceptions import ( + AuthorizationExpiredError, NamespaceAlreadyExistsError, NoSuchNamespaceError, NoSuchTableError, @@ -266,6 +267,48 @@ def test_list_namespace_with_parent_200(rest_mock: Mocker) -> None: ] +def test_list_namespaces_419(rest_mock: Mocker) -> None: + new_token = "new_jwt_token" + new_header = dict(TEST_HEADERS) + new_header["Authorization"] = f"Bearer {new_token}" + + rest_mock.post( + f"{TEST_URI}v1/namespaces", + json={ + "error": { + "message": "Authorization expired.", + "type": "AuthorizationExpiredError", + "code": 419, + } + }, + status_code=419, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/oauth/tokens", + json={ + "access_token": new_token, + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }, + status_code=200, + ) + rest_mock.get( + f"{TEST_URI}v1/namespaces", + json={"namespaces": [["default"], ["examples"], ["fokko"], ["system"]]}, + status_code=200, + request_headers=new_header, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, credential=TEST_CREDENTIALS) + assert catalog.list_namespaces() == [ + ("default",), + ("examples",), + ("fokko",), + ("system",), + ] + + def test_create_namespace_200(rest_mock: Mocker) -> None: namespace = "leden" rest_mock.post( @@ -517,6 +560,35 @@ def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) -> Non assert "Table already exists" in str(e.value) +def test_create_table_419(rest_mock: Mocker, table_schema_simple: Schema) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables", + json={ + "error": { + "message": "Authorization expired.", + "type": "AuthorizationExpiredError", + "code": 419, + } + }, + status_code=419, + request_headers=TEST_HEADERS, + ) + + with pytest.raises(AuthorizationExpiredError) as e: + RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).create_table( + identifier=("fokko", "fokko2"), + schema=table_schema_simple, + location=None, + partition_spec=PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id") + ), + sort_order=SortOrder(SortField(source_id=2, transform=IdentityTransform())), + properties={"owner": "fokko"}, + ) + assert "Authorization expired" in str(e.value) + assert rest_mock.call_count == 3 + + def test_register_table_200( rest_mock: Mocker, table_schema_simple: Schema, example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any] ) -> None: