diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..edd171b05 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,66 +1,115 @@ +""" +Main script to run all SEA connector tests. + +This script runs all the individual test modules and displays +a summary of test results with visual indicators. +""" import os import sys import logging -from databricks.sql.client import Connection +import subprocess +from typing import List, Tuple logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. - - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. - - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ - - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) - - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent +TEST_MODULES = [ + "test_sea_session", + "test_sea_sync_query", + "test_sea_async_query", + "test_sea_metadata", +] + + +def run_test_module(module_name: str) -> bool: + """Run a test module and return success status.""" + module_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" + ) + + # Simply run the module as a script - each module handles its own test execution + result = subprocess.run( + [sys.executable, module_path], capture_output=True, text=True + ) + + # Log the output from the test module + if result.stdout: + for line in result.stdout.strip().split("\n"): + logger.info(line) + + if result.stderr: + for line in result.stderr.strip().split("\n"): + logger.error(line) + + return result.returncode == 0 + + +def run_tests() -> List[Tuple[str, bool]]: + """Run all tests and return results.""" + results = [] + + for module_name in TEST_MODULES: + try: + logger.info(f"\n{'=' * 50}") + logger.info(f"Running test: {module_name}") + logger.info(f"{'-' * 50}") + + success = run_test_module(module_name) + results.append((module_name, success)) + + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"Test {module_name}: {status}") + + except Exception as e: + logger.error(f"Error loading or running test {module_name}: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + results.append((module_name, False)) + + return results + + +def print_summary(results: List[Tuple[str, bool]]) -> None: + """Print a summary of test results.""" + logger.info(f"\n{'=' * 50}") + logger.info("TEST SUMMARY") + logger.info(f"{'-' * 50}") + + passed = sum(1 for _, success in results if success) + total = len(results) + + for module_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"{status} - {module_name}") + + logger.info(f"{'-' * 50}") + logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") + logger.info(f"{'=' * 50}") + + +if __name__ == "__main__": + # Check if required environment variables are set + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) + logger.error("Please set these variables before running the tests.") sys.exit(1) - - logger.info("SEA session test completed successfully") -if __name__ == "__main__": - test_sea_session() + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..f805834b4 --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,246 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 5000 + cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing asynchronous query with cloud fetch to generate {requested_row_count} rows" + ) + cursor.execute_async(query) + logger.info( + "Asynchronous query submitted successfully with cloud fetch enabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + logger.info( + f"{actual_row_count} rows retrieved against {requested_row_count} requested" + ) + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + + logger.info( + "PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 + cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 100)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing asynchronous query without cloud fetch to generate {requested_row_count} rows" + ) + cursor.execute_async(query) + logger.info( + "Asynchronous query submitted successfully with cloud fetch disabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + logger.info( + f"{actual_row_count} rows retrieved against {requested_row_count} requested" + ) + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + + logger.info( + "PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info( + f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info( + f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..a200d97d3 --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,98 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error( + "DATABRICKS_CATALOG environment variable is required for metadata tests." + ) + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info( + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." + ) + cursor.columns( + catalog_name=catalog, schema_name="default", table_name="customer" + ) + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..516c1bbb8 --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,71 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..bfb86b82b --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,180 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 10000 + cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows" + ) + cursor.execute(query) + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + logger.info( + f"{actual_row_count} rows retrieved against {requested_row_count} requested" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 + cursor = connection.cursor() + logger.info("Executing synchronous query without cloud fetch: SELECT 100 rows") + cursor.execute( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" + ) + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info( + f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info( + f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e824de1c2..02d335aa4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -40,11 +40,11 @@ ) from databricks.sql.utils import ( - ResultSetQueueFactory, + ThriftResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, - ResultSetQueueFactory, + ThriftResultSetQueueFactory, convert_arrow_based_set_to_arrow_table, convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, @@ -1232,7 +1232,7 @@ def fetch_results( ) ) - queue = ResultSetQueueFactory.build_queue( + queue = ThriftResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, arrow_schema_bytes=arrow_schema_bytes, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 38b8a3c2f..49394b12a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -6,6 +6,7 @@ import pandas from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest try: import pyarrow @@ -19,7 +20,12 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.utils import ( + ColumnTable, + ColumnQueue, + JsonQueue, + SeaResultSetQueueFactory, +) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -88,6 +94,44 @@ def __iter__(self): else: break + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + @property def rownumber(self): return self._next_row_index @@ -97,12 +141,6 @@ def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" return self._is_staging_operation - # Define abstract methods that concrete implementations must implement - @abstractmethod - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - pass - @abstractmethod def fetchone(self) -> Optional[Row]: """Fetch the next row of a query result set.""" @@ -189,10 +227,10 @@ def __init__( # Build the results queue if t_row_set is provided results_queue = None if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory + from databricks.sql.utils import ThriftResultSetQueueFactory # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( + results_queue = ThriftResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, t_row_set=t_row_set, arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", @@ -249,44 +287,6 @@ def _convert_columnar_table(self, table): return result - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result @@ -456,8 +456,8 @@ def __init__( sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, - result_data=None, - manifest=None, + result_data: Optional["ResultData"] = None, + manifest: Optional["ResultManifest"] = None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -472,6 +472,19 @@ def __init__( manifest: Manifest from SEA response (optional) """ + results_queue = None + if result_data: + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + manifest, + str(execute_response.command_id.to_sea_statement_id()), + description=execute_response.description, + max_download_threads=sea_client.max_download_threads, + sea_client=sea_client, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, @@ -483,43 +496,132 @@ def __init__( description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", ) - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError( - "_fill_results_buffer is not implemented for SEA backend" - ) + # Initialize queue for result data if not provided + self.results = results_queue or JsonQueue([]) + + def _convert_json_table(self, rows): + """ + Convert raw data rows to Row objects with named columns based on description. + Args: + rows: List of raw data rows + Returns: + List of Row objects with named columns + """ + if not self.description or not rows: + return rows + + column_names = [col[0] for col in self.description] + ResultRow = Row(*column_names) + return [ResultRow(*row) for row in rows] + + def fetchmany_json(self, size: int): + """ + Fetch the next set of rows as a columnar table. + + Args: + size: Number of rows to fetch + + Returns: + Columnar table containing the fetched rows + + Raises: + ValueError: If size is negative + """ + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + results = self.results.next_n_rows(size) + self._next_row_index += len(results) + + return results + + def fetchall_json(self): + """ + Fetch all remaining rows as a columnar table. + + Returns: + Columnar table containing all remaining rows + """ + results = self.results.remaining_rows() + self._next_row_index += len(results) + + return results + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows as an Arrow table. + + Args: + size: Number of rows to fetch + + Returns: + PyArrow Table containing the fetched rows + + Raises: + ImportError: If PyArrow is not installed + ValueError: If size is negative + """ + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + results = self.results.next_n_rows(size) + self._next_row_index += results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """ + Fetch all remaining rows as an Arrow table. + """ + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + return results def fetchone(self) -> Optional[Row]: """ Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. + + Returns: + A single Row object or None if no more rows are available """ + if isinstance(self.results, JsonQueue): + res = self._convert_json_table(self.fetchmany_json(1)) + else: + raise NotImplementedError("fetchone only supported for JSON data") - raise NotImplementedError("fetchone is not implemented for SEA backend") + return res[0] if res else None - def fetchmany(self, size: Optional[int] = None) -> List[Row]: + def fetchmany(self, size: int) -> List[Row]: """ Fetch the next set of rows of a query result, returning a list of rows. - An empty sequence is returned when no more rows are available. - """ + Args: + size: Number of rows to fetch (defaults to arraysize if None) - raise NotImplementedError("fetchmany is not implemented for SEA backend") + Returns: + List of Row objects - def fetchall(self) -> List[Row]: + Raises: + ValueError: If size is negative """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - - raise NotImplementedError("fetchall is not implemented for SEA backend") + if isinstance(self.results, JsonQueue): + return self._convert_json_table(self.fetchmany_json(size)) + else: + raise NotImplementedError("fetchmany only supported for JSON data") - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + def fetchall(self) -> List[Row]: + """ + Fetch all remaining rows of a query result, returning them as a list of rows. - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") + Returns: + List of Row objects containing all remaining rows + """ + if isinstance(self.results, JsonQueue): + return self._convert_json_table(self.fetchall_json()) + else: + raise NotImplementedError("fetchall only supported for JSON data") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 8997bda22..933032044 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -13,6 +13,9 @@ import lz4.frame +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest + try: import pyarrow except ImportError: @@ -48,7 +51,7 @@ def remaining_rows(self): pass -class ResultSetQueueFactory(ABC): +class ThriftResultSetQueueFactory(ABC): @staticmethod def build_queue( row_set_type: TSparkRowSetType, @@ -107,6 +110,72 @@ def build_queue( raise AssertionError("Row set type is not valid") +class SeaResultSetQueueFactory(ABC): + @staticmethod + def build_queue( + sea_result_data: ResultData, + manifest: Optional[ResultManifest], + statement_id: str, + description: Optional[List[Tuple[Any, ...]]] = None, + schema_bytes: Optional[bytes] = None, + max_download_threads: Optional[int] = None, + sea_client: Optional["SeaDatabricksClient"] = None, + lz4_compressed: bool = False, + ) -> ResultSetQueue: + """ + Factory method to build a result set queue for SEA backend. + + Args: + sea_result_data (ResultData): Result data from SEA response + manifest (ResultManifest): Manifest from SEA response + statement_id (str): Statement ID for the query + description (List[List[Any]]): Column descriptions + schema_bytes (bytes): Arrow schema bytes + max_download_threads (int): Maximum number of download threads + ssl_options (SSLOptions): SSL options for downloads + sea_client (SeaDatabricksClient): SEA client for fetching additional links + lz4_compressed (bool): Whether the data is LZ4 compressed + + Returns: + ResultSetQueue: The appropriate queue for the result data + """ + + if sea_result_data.data is not None: + # INLINE disposition with JSON_ARRAY format + return JsonQueue(sea_result_data.data) + elif sea_result_data.external_links is not None: + # EXTERNAL_LINKS disposition + raise NotImplementedError( + "EXTERNAL_LINKS disposition is not implemented for SEA backend" + ) + else: + # Empty result set + return JsonQueue([]) + + +class JsonQueue(ResultSetQueue): + """Queue implementation for JSON_ARRAY format data.""" + + def __init__(self, data_array): + """Initialize with JSON array data.""" + self.data_array = data_array + self.cur_row_index = 0 + self.n_valid_rows = len(data_array) + + def next_n_rows(self, num_rows): + """Get the next n rows from the data array.""" + length = min(num_rows, self.n_valid_rows - self.cur_row_index) + slice = self.data_array[self.cur_row_index : self.cur_row_index + length] + self.cur_row_index += length + return slice + + def remaining_rows(self): + """Get all remaining rows from the data array.""" + slice = self.data_array[self.cur_row_index :] + self.cur_row_index += len(slice) + return slice + + class ColumnTable: def __init__(self, column_table, column_names): self.column_table = column_table diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index c596dbc14..f0049e3aa 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -122,80 +122,3 @@ def test_close_when_connection_closed( mock_sea_client.close_command.assert_not_called() assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, - match="_fill_results_buffer is not implemented for SEA backend", - ): - result_set._fill_results_buffer() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57b5e9b58..4a4295e11 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -610,7 +610,8 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): self.assertIn("some information about the error", str(cm.exception)) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) def test_handle_execute_response_sets_compression_in_direct_results( self, build_queue @@ -998,7 +999,8 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( @@ -1043,7 +1045,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(is_direct_results, has_more_rows_result) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response(