18
18
19
19
if TYPE_CHECKING :
20
20
from databricks .sql .client import Cursor
21
- from databricks .sql .backend .sea .result_set import SeaResultSet
21
+
22
+ from databricks .sql .backend .sea .result_set import SeaResultSet
22
23
23
24
from databricks .sql .backend .databricks_client import DatabricksClient
24
25
from databricks .sql .backend .types import (
@@ -332,7 +333,7 @@ def _extract_description_from_manifest(
332
333
return columns
333
334
334
335
def _results_message_to_execute_response (
335
- self , response : GetStatementResponse
336
+ self , response : Union [ ExecuteStatementResponse , GetStatementResponse ]
336
337
) -> ExecuteResponse :
337
338
"""
338
339
Convert a SEA response to an ExecuteResponse and extract result data.
@@ -366,6 +367,27 @@ def _results_message_to_execute_response(
366
367
367
368
return execute_response
368
369
370
+ def _response_to_result_set (
371
+ self ,
372
+ response : Union [ExecuteStatementResponse , GetStatementResponse ],
373
+ cursor : Cursor ,
374
+ ) -> SeaResultSet :
375
+ """
376
+ Convert a SEA response to a SeaResultSet.
377
+ """
378
+
379
+ execute_response = self ._results_message_to_execute_response (response )
380
+
381
+ return SeaResultSet (
382
+ connection = cursor .connection ,
383
+ execute_response = execute_response ,
384
+ sea_client = self ,
385
+ result_data = response .result ,
386
+ manifest = response .manifest ,
387
+ buffer_size_bytes = cursor .buffer_size_bytes ,
388
+ arraysize = cursor .arraysize ,
389
+ )
390
+
369
391
def _check_command_not_in_failed_or_closed_state (
370
392
self , state : CommandState , command_id : CommandId
371
393
) -> None :
@@ -386,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(
386
408
387
409
def _wait_until_command_done (
388
410
self , response : ExecuteStatementResponse
389
- ) -> CommandState :
411
+ ) -> Union [ ExecuteStatementResponse , GetStatementResponse ] :
390
412
"""
391
413
Wait until a command is done.
392
414
"""
393
415
394
- state = response .status .state
395
- command_id = CommandId .from_sea_statement_id (response .statement_id )
416
+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
417
+
418
+ state = final_response .status .state
419
+ command_id = CommandId .from_sea_statement_id (final_response .statement_id )
396
420
397
421
while state in [CommandState .PENDING , CommandState .RUNNING ]:
398
422
time .sleep (self .POLL_INTERVAL_SECONDS )
399
- state = self .get_query_state (command_id )
423
+ final_response = self ._poll_query (command_id )
424
+ state = final_response .status .state
400
425
401
426
self ._check_command_not_in_failed_or_closed_state (state , command_id )
402
427
403
- return state
428
+ return final_response
404
429
405
430
def execute_command (
406
431
self ,
@@ -506,8 +531,11 @@ def execute_command(
506
531
if async_op :
507
532
return None
508
533
509
- self ._wait_until_command_done (response )
510
- return self .get_execution_result (command_id , cursor )
534
+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
535
+ if response .status .state != CommandState .SUCCEEDED :
536
+ final_response = self ._wait_until_command_done (response )
537
+
538
+ return self ._response_to_result_set (final_response , cursor )
511
539
512
540
def cancel_command (self , command_id : CommandId ) -> None :
513
541
"""
@@ -559,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None:
559
587
data = request .to_dict (),
560
588
)
561
589
562
- def get_query_state (self , command_id : CommandId ) -> CommandState :
590
+ def _poll_query (self , command_id : CommandId ) -> GetStatementResponse :
563
591
"""
564
- Get the state of a running query.
565
-
566
- Args:
567
- command_id: Command identifier
568
-
569
- Returns:
570
- CommandState: The current state of the command
571
-
572
- Raises:
573
- ValueError: If the command ID is invalid
592
+ Poll for the current command info.
574
593
"""
575
594
576
595
if command_id .backend_type != BackendType .SEA :
@@ -586,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
586
605
path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
587
606
data = request .to_dict (),
588
607
)
589
-
590
- # Parse the response
591
608
response = GetStatementResponse .from_dict (response_data )
609
+
610
+ return response
611
+
612
+ def get_query_state (self , command_id : CommandId ) -> CommandState :
613
+ """
614
+ Get the state of a running query.
615
+
616
+ Args:
617
+ command_id: Command identifier
618
+
619
+ Returns:
620
+ CommandState: The current state of the command
621
+
622
+ Raises:
623
+ ProgrammingError: If the command ID is invalid
624
+ """
625
+
626
+ response = self ._poll_query (command_id )
592
627
return response .status .state
593
628
594
629
def get_execution_result (
@@ -610,38 +645,8 @@ def get_execution_result(
610
645
ValueError: If the command ID is invalid
611
646
"""
612
647
613
- if command_id .backend_type != BackendType .SEA :
614
- raise ValueError ("Not a valid SEA command ID" )
615
-
616
- sea_statement_id = command_id .to_sea_statement_id ()
617
- if sea_statement_id is None :
618
- raise ValueError ("Not a valid SEA command ID" )
619
-
620
- # Create the request model
621
- request = GetStatementRequest (statement_id = sea_statement_id )
622
-
623
- # Get the statement result
624
- response_data = self ._http_client ._make_request (
625
- method = "GET" ,
626
- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
627
- data = request .to_dict (),
628
- )
629
- response = GetStatementResponse .from_dict (response_data )
630
-
631
- # Create and return a SeaResultSet
632
- from databricks .sql .backend .sea .result_set import SeaResultSet
633
-
634
- execute_response = self ._results_message_to_execute_response (response )
635
-
636
- return SeaResultSet (
637
- connection = cursor .connection ,
638
- execute_response = execute_response ,
639
- sea_client = self ,
640
- result_data = response .result ,
641
- manifest = response .manifest ,
642
- buffer_size_bytes = cursor .buffer_size_bytes ,
643
- arraysize = cursor .arraysize ,
644
- )
648
+ response = self ._poll_query (command_id )
649
+ return self ._response_to_result_set (response , cursor )
645
650
646
651
def get_chunk_links (
647
652
self , statement_id : str , chunk_index : int
0 commit comments