Skip to content

Commit 67020f1

Browse files
committed
formatting (black) - fix some closures
1 parent bf0a2f6 commit 67020f1

13 files changed

+101
-60
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def __init__(
232232
try:
233233
self._transport.open()
234234
except:
235-
self._transport.release_connection()
235+
self._transport.close()
236236
raise
237237

238238
self._request_lock = threading.RLock()
@@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId:
607607
self._session_id_hex = session_id.hex_guid
608608
return session_id
609609
except:
610-
self._transport.release_connection()
610+
self._transport.close()
611611
raise
612612

613613
def close_session(self, session_id: SessionId) -> None:

src/databricks/sql/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@ def read(self) -> Optional[OAuthToken]:
284284
if hasattr(self, "session")
285285
else None,
286286
)
287+
if self.http_client:
288+
self.http_client.close()
287289
raise e
288290

289291
self.use_inline_params = self._set_use_inline_params_with_warning(

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def close(self):
359359
"""Flush remaining events before closing"""
360360
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex)
361361
self._flush()
362+
self._http_client.close()
362363

363364

364365
class TelemetryClientFactory:
@@ -460,7 +461,6 @@ def initialize_telemetry_client(
460461
):
461462
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
462463
try:
463-
464464
with TelemetryClientFactory._lock:
465465
TelemetryClientFactory._initialize()
466466

tests/e2e/common/staging_ingestion_tests.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user):
8080

8181
# GET after REMOVE should fail
8282

83-
with pytest.raises(
84-
Error, match="too many 404 error responses"
85-
):
83+
with pytest.raises(Error, match="too many 404 error responses"):
8684
cursor = conn.cursor()
8785
query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'"
8886
cursor.execute(query)

tests/e2e/common/uc_volume_tests.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ def test_uc_volume_life_cycle(self, catalog, schema):
8080

8181
# GET after REMOVE should fail
8282

83-
with pytest.raises(
84-
Error, match="too many 404 error responses"
85-
):
83+
with pytest.raises(Error, match="too many 404 error responses"):
8684
cursor = conn.cursor()
8785
query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'"
8886
cursor.execute(query)

tests/e2e/test_concurrent_telemetry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,13 @@ def execute_query_worker(thread_id):
122122
response = future.result()
123123
# Check status using urllib3 method (response.status instead of response.raise_for_status())
124124
if response.status >= 400:
125-
raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}")
125+
raise Exception(
126+
f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}"
127+
)
126128
# Parse JSON using urllib3 method (response.data.decode() instead of response.json())
127-
response_data = json.loads(response.data.decode()) if response.data else {}
129+
response_data = (
130+
json.loads(response.data.decode()) if response.data else {}
131+
)
128132
captured_responses.append(response_data)
129133
except Exception as e:
130134
captured_exceptions.append(e)

tests/e2e/test_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
for name in test_loader.getTestCaseNames(DecimalTestsMixin):
6565
if name.startswith("test_"):
6666
fn = getattr(DecimalTestsMixin, name)
67-
decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")(
68-
fn
69-
)
67+
decorated = skipUnless(
68+
pysql_supports_arrow(), "Decimal tests need arrow support"
69+
)(fn)
7070
setattr(DecimalTestsMixin, name, decorated)
7171

7272

tests/unit/test_auth.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
145145
hostname = "moderakh-test.cloud.databricks.com"
146146
kwargs = {"access_token": "dpi123"}
147147
mock_http_client = MagicMock()
148-
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
148+
auth_provider = get_python_sql_connector_auth_provider(
149+
hostname, mock_http_client, **kwargs
150+
)
149151
self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider")
150152

151153
headers = {}
@@ -163,7 +165,9 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
163165
hostname = "moderakh-test.cloud.databricks.com"
164166
kwargs = {"credentials_provider": MyProvider()}
165167
mock_http_client = MagicMock()
166-
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
168+
auth_provider = get_python_sql_connector_auth_provider(
169+
hostname, mock_http_client, **kwargs
170+
)
167171
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")
168172

169173
headers = {}
@@ -179,7 +183,9 @@ def test_get_python_sql_connector_auth_provider_noop(self):
179183
"_use_cert_as_auth": use_cert_as_auth,
180184
}
181185
mock_http_client = MagicMock()
182-
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
186+
auth_provider = get_python_sql_connector_auth_provider(
187+
hostname, mock_http_client, **kwargs
188+
)
183189
self.assertTrue(type(auth_provider).__name__, "CredentialProvider")
184190

185191
def test_get_python_sql_connector_basic_auth(self):
@@ -189,7 +195,9 @@ def test_get_python_sql_connector_basic_auth(self):
189195
}
190196
mock_http_client = MagicMock()
191197
with self.assertRaises(ValueError) as e:
192-
get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs)
198+
get_python_sql_connector_auth_provider(
199+
"foo.cloud.databricks.com", mock_http_client, **kwargs
200+
)
193201
self.assertIn(
194202
"Username/password authentication is no longer supported", str(e.exception)
195203
)
@@ -198,7 +206,9 @@ def test_get_python_sql_connector_basic_auth(self):
198206
def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
199207
hostname = "foo.cloud.databricks.com"
200208
mock_http_client = MagicMock()
201-
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client)
209+
auth_provider = get_python_sql_connector_auth_provider(
210+
hostname, mock_http_client
211+
)
202212
self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider")
203213
self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
204214

@@ -259,16 +269,16 @@ def test_no_token_refresh__when_token_is_not_expired(
259269

260270
def test_get_token_success(self, token_source, http_response):
261271
mock_http_client = MagicMock()
262-
272+
263273
with patch.object(token_source, "_http_client", mock_http_client):
264274
# Create a mock response with the expected format
265275
mock_response = MagicMock()
266276
mock_response.status = 200
267277
mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}'
268-
278+
269279
# Mock the request method to return the response directly
270280
mock_http_client.request.return_value = mock_response
271-
281+
272282
token = token_source.get_token()
273283

274284
# Assert
@@ -279,16 +289,16 @@ def test_get_token_success(self, token_source, http_response):
279289

280290
def test_get_token_failure(self, token_source, http_response):
281291
mock_http_client = MagicMock()
282-
292+
283293
with patch.object(token_source, "_http_client", mock_http_client):
284294
# Create a mock response with error
285295
mock_response = MagicMock()
286296
mock_response.status = 400
287297
mock_response.data.decode.return_value = "Bad Request"
288-
298+
289299
# Mock the request method to return the response directly
290300
mock_http_client.request.return_value = mock_response
291-
301+
292302
with pytest.raises(Exception) as e:
293303
token_source.get_token()
294304
assert "Failed to get token: 400" in str(e.value)

tests/unit/test_cloud_fetch_queue.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,24 @@
1313

1414
@pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed")
1515
class CloudFetchQueueSuite(unittest.TestCase):
16-
def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs):
16+
def create_queue(
17+
self, schema_bytes=None, result_links=None, description=None, **kwargs
18+
):
1719
"""Helper method to create ThriftCloudFetchQueue with sensible defaults"""
1820
# Set up defaults for commonly used parameters
1921
defaults = {
20-
'max_download_threads': 10,
21-
'ssl_options': SSLOptions(),
22-
'session_id_hex': Mock(),
23-
'statement_id': Mock(),
24-
'chunk_id': 0,
25-
'start_row_offset': 0,
26-
'lz4_compressed': True,
22+
"max_download_threads": 10,
23+
"ssl_options": SSLOptions(),
24+
"session_id_hex": Mock(),
25+
"statement_id": Mock(),
26+
"chunk_id": 0,
27+
"start_row_offset": 0,
28+
"lz4_compressed": True,
2729
}
28-
30+
2931
# Override defaults with any provided kwargs
3032
defaults.update(kwargs)
31-
33+
3234
mock_http_client = MagicMock()
3335
return utils.ThriftCloudFetchQueue(
3436
schema_bytes=schema_bytes or MagicMock(),
@@ -198,7 +200,12 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table):
198200
def test_next_n_rows_empty_table(self, mock_create_next_table):
199201
schema_bytes = self.get_schema_bytes()
200202
# Create description that matches the 4-column schema
201-
description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")]
203+
description = [
204+
("col0", "uint32"),
205+
("col1", "uint32"),
206+
("col2", "uint32"),
207+
("col3", "uint32"),
208+
]
202209
queue = self.create_queue(schema_bytes=schema_bytes, description=description)
203210
assert queue.table is None
204211

@@ -277,7 +284,12 @@ def test_remaining_rows_multiple_tables_fully_returned(
277284
def test_remaining_rows_empty_table(self, mock_create_next_table):
278285
schema_bytes = self.get_schema_bytes()
279286
# Create description that matches the 4-column schema
280-
description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")]
287+
description = [
288+
("col0", "uint32"),
289+
("col1", "uint32"),
290+
("col2", "uint32"),
291+
("col3", "uint32"),
292+
]
281293
queue = self.create_queue(schema_bytes=schema_bytes, description=description)
282294
assert queue.table is None
283295

tests/unit/test_downloader.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_run_uncompressed_successful(self, mock_time):
131131
self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes)
132132

133133
# Patch the log metrics method to avoid division by zero
134-
with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'):
134+
with patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"):
135135
d = downloader.ResultSetDownloadHandler(
136136
settings,
137137
result_link,
@@ -160,11 +160,16 @@ def test_run_compressed_successful(self, mock_time):
160160
result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789"
161161

162162
# Setup mock HTTP response using helper method
163-
self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes)
163+
self._setup_mock_http_response(
164+
mock_http_client, status=200, data=compressed_bytes
165+
)
164166

165167
# Mock the decompression method and log metrics to avoid issues
166-
with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \
167-
patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'):
168+
with patch.object(
169+
downloader.ResultSetDownloadHandler,
170+
"_decompress_data",
171+
return_value=file_bytes,
172+
), patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"):
168173
d = downloader.ResultSetDownloadHandler(
169174
settings,
170175
result_link,

0 commit comments

Comments
 (0)