diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index ec832727a2..2d411cb409 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -18,6 +18,7 @@ import logging import socket import time +from functools import cached_property from types import TracebackType from typing import ( TYPE_CHECKING, @@ -143,40 +144,47 @@ class _HiveClient: """Helper class to nicely open and close the transport.""" _transport: TTransport - _client: Client _ugi: Optional[List[str]] def __init__(self, uri: str, ugi: Optional[str] = None, kerberos_auth: Optional[bool] = HIVE_KERBEROS_AUTH_DEFAULT): self._uri = uri self._kerberos_auth = kerberos_auth self._ugi = ugi.split(":") if ugi else None + self._transport = self._init_thrift_transport() - self._init_thrift_client() - - def _init_thrift_client(self) -> None: + def _init_thrift_transport(self) -> TTransport: url_parts = urlparse(self._uri) - socket = TSocket.TSocket(url_parts.hostname, url_parts.port) - if not self._kerberos_auth: - self._transport = TTransport.TBufferedTransport(socket) + return TTransport.TBufferedTransport(socket) else: - self._transport = TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive") + return TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive") + @cached_property + def _client(self) -> Client: protocol = TBinaryProtocol.TBinaryProtocol(self._transport) - - self._client = Client(protocol) + client = Client(protocol) + if self._ugi: + client.set_ugi(*self._ugi) + return client def __enter__(self) -> Client: - self._transport.open() - if self._ugi: - self._client.set_ugi(*self._ugi) + """Make sure the transport is initialized and open.""" + if not self._transport.isOpen(): + try: + self._transport.open() + except TTransport.TTransportException: + # reinitialize _transport + self._transport = self._init_thrift_transport() + self._transport.open() return self._client def __exit__( self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] ) -> None: - self._transport.close() + """Close transport if it was opened.""" + if self._transport.isOpen(): + self._transport.close() def _construct_hive_storage_descriptor(