Skip to content

Commit 8fbca9d

Browse files
SEA: Reduce network calls for synchronous commands (#633)
* remove additional call on success Signed-off-by: varun-edachali-dbx <[email protected]> * reduce additional network call after wait Signed-off-by: varun-edachali-dbx <[email protected]> * re-introduce GetStatementResponse Signed-off-by: varun-edachali-dbx <[email protected]> * remove need for lazy load of SeaResultSet Signed-off-by: varun-edachali-dbx <[email protected]> * re-organise GetStatementResponse import Signed-off-by: varun-edachali-dbx <[email protected]> --------- Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 640cc82 commit 8fbca9d

File tree

3 files changed

+63
-61
lines changed

3 files changed

+63
-61
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
if TYPE_CHECKING:
2020
from databricks.sql.client import Cursor
21-
from databricks.sql.backend.sea.result_set import SeaResultSet
21+
22+
from databricks.sql.backend.sea.result_set import SeaResultSet
2223

2324
from databricks.sql.backend.databricks_client import DatabricksClient
2425
from databricks.sql.backend.types import (
@@ -332,7 +333,7 @@ def _extract_description_from_manifest(
332333
return columns
333334

334335
def _results_message_to_execute_response(
335-
self, response: GetStatementResponse
336+
self, response: Union[ExecuteStatementResponse, GetStatementResponse]
336337
) -> ExecuteResponse:
337338
"""
338339
Convert a SEA response to an ExecuteResponse and extract result data.
@@ -366,6 +367,27 @@ def _results_message_to_execute_response(
366367

367368
return execute_response
368369

370+
def _response_to_result_set(
371+
self,
372+
response: Union[ExecuteStatementResponse, GetStatementResponse],
373+
cursor: Cursor,
374+
) -> SeaResultSet:
375+
"""
376+
Convert a SEA response to a SeaResultSet.
377+
"""
378+
379+
execute_response = self._results_message_to_execute_response(response)
380+
381+
return SeaResultSet(
382+
connection=cursor.connection,
383+
execute_response=execute_response,
384+
sea_client=self,
385+
result_data=response.result,
386+
manifest=response.manifest,
387+
buffer_size_bytes=cursor.buffer_size_bytes,
388+
arraysize=cursor.arraysize,
389+
)
390+
369391
def _check_command_not_in_failed_or_closed_state(
370392
self, state: CommandState, command_id: CommandId
371393
) -> None:
@@ -386,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(
386408

387409
def _wait_until_command_done(
388410
self, response: ExecuteStatementResponse
389-
) -> CommandState:
411+
) -> Union[ExecuteStatementResponse, GetStatementResponse]:
390412
"""
391413
Wait until a command is done.
392414
"""
393415

394-
state = response.status.state
395-
command_id = CommandId.from_sea_statement_id(response.statement_id)
416+
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
417+
418+
state = final_response.status.state
419+
command_id = CommandId.from_sea_statement_id(final_response.statement_id)
396420

397421
while state in [CommandState.PENDING, CommandState.RUNNING]:
398422
time.sleep(self.POLL_INTERVAL_SECONDS)
399-
state = self.get_query_state(command_id)
423+
final_response = self._poll_query(command_id)
424+
state = final_response.status.state
400425

401426
self._check_command_not_in_failed_or_closed_state(state, command_id)
402427

403-
return state
428+
return final_response
404429

405430
def execute_command(
406431
self,
@@ -506,8 +531,11 @@ def execute_command(
506531
if async_op:
507532
return None
508533

509-
self._wait_until_command_done(response)
510-
return self.get_execution_result(command_id, cursor)
534+
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
535+
if response.status.state != CommandState.SUCCEEDED:
536+
final_response = self._wait_until_command_done(response)
537+
538+
return self._response_to_result_set(final_response, cursor)
511539

512540
def cancel_command(self, command_id: CommandId) -> None:
513541
"""
@@ -559,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None:
559587
data=request.to_dict(),
560588
)
561589

562-
def get_query_state(self, command_id: CommandId) -> CommandState:
590+
def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
563591
"""
564-
Get the state of a running query.
565-
566-
Args:
567-
command_id: Command identifier
568-
569-
Returns:
570-
CommandState: The current state of the command
571-
572-
Raises:
573-
ValueError: If the command ID is invalid
592+
Poll for the current command info.
574593
"""
575594

576595
if command_id.backend_type != BackendType.SEA:
@@ -586,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
586605
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
587606
data=request.to_dict(),
588607
)
589-
590-
# Parse the response
591608
response = GetStatementResponse.from_dict(response_data)
609+
610+
return response
611+
612+
def get_query_state(self, command_id: CommandId) -> CommandState:
613+
"""
614+
Get the state of a running query.
615+
616+
Args:
617+
command_id: Command identifier
618+
619+
Returns:
620+
CommandState: The current state of the command
621+
622+
Raises:
623+
ProgrammingError: If the command ID is invalid
624+
"""
625+
626+
response = self._poll_query(command_id)
592627
return response.status.state
593628

594629
def get_execution_result(
@@ -610,38 +645,8 @@ def get_execution_result(
610645
ValueError: If the command ID is invalid
611646
"""
612647

613-
if command_id.backend_type != BackendType.SEA:
614-
raise ValueError("Not a valid SEA command ID")
615-
616-
sea_statement_id = command_id.to_sea_statement_id()
617-
if sea_statement_id is None:
618-
raise ValueError("Not a valid SEA command ID")
619-
620-
# Create the request model
621-
request = GetStatementRequest(statement_id=sea_statement_id)
622-
623-
# Get the statement result
624-
response_data = self._http_client._make_request(
625-
method="GET",
626-
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
627-
data=request.to_dict(),
628-
)
629-
response = GetStatementResponse.from_dict(response_data)
630-
631-
# Create and return a SeaResultSet
632-
from databricks.sql.backend.sea.result_set import SeaResultSet
633-
634-
execute_response = self._results_message_to_execute_response(response)
635-
636-
return SeaResultSet(
637-
connection=cursor.connection,
638-
execute_response=execute_response,
639-
sea_client=self,
640-
result_data=response.result,
641-
manifest=response.manifest,
642-
buffer_size_bytes=cursor.buffer_size_bytes,
643-
arraysize=cursor.arraysize,
644-
)
648+
response = self._poll_query(command_id)
649+
return self._response_to_result_set(response, cursor)
645650

646651
def get_chunk_links(
647652
self, statement_id: str, chunk_index: int

src/databricks/sql/backend/sea/result_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import logging
66

7-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
87
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
98
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter
109

@@ -15,6 +14,7 @@
1514

1615
if TYPE_CHECKING:
1716
from databricks.sql.client import Connection
17+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
1818
from databricks.sql.types import Row
1919
from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory
2020
from databricks.sql.backend.types import ExecuteResponse

tests/unit/test_sea_backend.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def test_command_execution_sync(
227227
mock_http_client._make_request.return_value = execute_response
228228

229229
with patch.object(
230-
sea_client, "get_execution_result", return_value="mock_result_set"
230+
sea_client, "_response_to_result_set", return_value="mock_result_set"
231231
) as mock_get_result:
232232
result = sea_client.execute_command(
233233
operation="SELECT 1",
@@ -242,9 +242,6 @@ def test_command_execution_sync(
242242
enforce_embedded_schema_correctness=False,
243243
)
244244
assert result == "mock_result_set"
245-
cmd_id_arg = mock_get_result.call_args[0][0]
246-
assert isinstance(cmd_id_arg, CommandId)
247-
assert cmd_id_arg.guid == "test-statement-123"
248245

249246
# Test with invalid session ID
250247
with pytest.raises(ValueError) as excinfo:
@@ -332,7 +329,7 @@ def test_command_execution_advanced(
332329
mock_http_client._make_request.side_effect = [initial_response, poll_response]
333330

334331
with patch.object(
335-
sea_client, "get_execution_result", return_value="mock_result_set"
332+
sea_client, "_response_to_result_set", return_value="mock_result_set"
336333
) as mock_get_result:
337334
with patch("time.sleep"):
338335
result = sea_client.execute_command(
@@ -360,7 +357,7 @@ def test_command_execution_advanced(
360357
dbsql_param = IntegerParameter(name="param1", value=1)
361358
param = dbsql_param.as_tspark_param(named=True)
362359

363-
with patch.object(sea_client, "get_execution_result"):
360+
with patch.object(sea_client, "_response_to_result_set"):
364361
sea_client.execute_command(
365362
operation="SELECT * FROM table WHERE col = :param1",
366363
session_id=sea_session_id,

0 commit comments

Comments
 (0)