diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 00000000..b9778788 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,165 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + cast, + TYPE_CHECKING, +) + +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.types import ExecuteResponse + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet, SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + + # Get all remaining rows + all_rows = result_set.results.remaining_rows() + + # Filter rows + filtered_rows = [row for row in all_rows if filter_func(row)] + + # Import SeaResultSet here to avoid circular imports + from databricks.sql.result_set import SeaResultSet + + # Reuse the command_id from the original result set + command_id = result_set.command_id + + # Create an ExecuteResponse with the filtered data + execute_response = ExecuteResponse( + command_id=command_id, + status=result_set.status, + description=result_set.description, + has_been_closed_server_side=result_set.has_been_closed_server_side, + lz4_compressed=result_set.lz4_compressed, + arrow_schema_bytes=result_set._arrow_schema_bytes, + is_staging_operation=False, + ) + + # Create a new ResultData object with filtered data + from databricks.sql.backend.sea.models.base import ResultData + + result_data = ResultData(data=filtered_rows, external_links=None) + + # Create a new SeaResultSet with the filtered data + filtered_result_set = SeaResultSet( + connection=result_set.connection, + execute_response=execute_response, + sea_client=cast(SeaDatabricksClient, result_set.backend), + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + result_data=result_data, + ) + + return filtered_result_set + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + from databricks.sql.result_set import SeaResultSet + + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=True + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 76903ccd..9d301d3b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -40,6 +40,11 @@ GetStatementResponse, CreateSessionResponse, ) +from databricks.sql.backend.sea.models.responses import ( + _parse_status, + _parse_manifest, + _parse_result, +) logger = logging.getLogger(__name__) @@ -627,9 +632,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -639,9 +657,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN {catalog_name}" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -653,9 +692,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG {catalog_name}" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -667,6 +740,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG {catalog_name}" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py new file mode 100644 index 00000000..bf8d3070 --- /dev/null +++ b/tests/unit/test_filters.py @@ -0,0 +1,162 @@ +""" +Tests for the ResultSetFilter class. +""" + +import unittest +from unittest.mock import MagicMock, patch + +from databricks.sql.backend.filters import ResultSetFilter + + +class TestResultSetFilter(unittest.TestCase): + """Tests for the ResultSetFilter class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a mock SeaResultSet + self.mock_sea_result_set = MagicMock() + + # Set up the remaining_rows method on the results attribute + self.mock_sea_result_set.results = MagicMock() + self.mock_sea_result_set.results.remaining_rows.return_value = [ + ["catalog1", "schema1", "table1", "owner1", "2023-01-01", "TABLE", ""], + ["catalog1", "schema1", "table2", "owner1", "2023-01-01", "VIEW", ""], + [ + "catalog1", + "schema1", + "table3", + "owner1", + "2023-01-01", + "SYSTEM TABLE", + "", + ], + [ + "catalog1", + "schema1", + "table4", + "owner1", + "2023-01-01", + "EXTERNAL TABLE", + "", + ], + ] + + # Set up the connection and other required attributes + self.mock_sea_result_set.connection = MagicMock() + self.mock_sea_result_set.backend = MagicMock() + self.mock_sea_result_set.buffer_size_bytes = 1000 + self.mock_sea_result_set.arraysize = 100 + self.mock_sea_result_set.statement_id = "test-statement-id" + self.mock_sea_result_set.lz4_compressed = False + + # Create a mock CommandId + from databricks.sql.backend.types import CommandId, BackendType + + mock_command_id = CommandId(BackendType.SEA, "test-statement-id") + self.mock_sea_result_set.command_id = mock_command_id + + self.mock_sea_result_set.status = MagicMock() + self.mock_sea_result_set.description = [ + ("catalog_name", "string", None, None, None, None, True), + ("schema_name", "string", None, None, None, None, True), + ("table_name", "string", None, None, None, None, True), + ("owner", "string", None, None, None, None, True), + ("creation_time", "string", None, None, None, None, True), + ("table_type", "string", None, None, None, None, True), + ("remarks", "string", None, None, None, None, True), + ] + self.mock_sea_result_set.has_been_closed_server_side = False + self.mock_sea_result_set._arrow_schema_bytes = None + + def test_filter_by_column_values(self): + """Test filtering by column values with various options.""" + # Case 1: Case-sensitive filtering + allowed_values = ["table1", "table3"] + + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + # Call filter_by_column_values on the table_name column (index 2) + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, 2, allowed_values, case_sensitive=True + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + # Check the filtered data passed to the constructor + args, kwargs = mock_sea_result_set_class.call_args + result_data = kwargs.get("result_data") + self.assertIsNotNone(result_data) + self.assertEqual(len(result_data.data), 2) + self.assertIn(result_data.data[0][2], allowed_values) + self.assertIn(result_data.data[1][2], allowed_values) + + # Case 2: Case-insensitive filtering + mock_sea_result_set_class.reset_mock() + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + # Call filter_by_column_values with case-insensitive matching + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, + 2, + ["TABLE1", "TABLE3"], + case_sensitive=False, + ) + mock_sea_result_set_class.assert_called_once() + + # Case 3: Unsupported result set type + mock_unsupported_result_set = MagicMock() + with patch("databricks.sql.backend.filters.isinstance", return_value=False): + with patch("databricks.sql.backend.filters.logger") as mock_logger: + result = ResultSetFilter.filter_by_column_values( + mock_unsupported_result_set, 0, ["value"], True + ) + mock_logger.warning.assert_called_once() + self.assertEqual(result, mock_unsupported_result_set) + + def test_filter_tables_by_type(self): + """Test filtering tables by type with various options.""" + # Case 1: Specific table types + table_types = ["TABLE", "VIEW"] + + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types + ) + args, kwargs = mock_filter.call_args + self.assertEqual(args[0], self.mock_sea_result_set) + self.assertEqual(args[1], 5) # Table type column index + self.assertEqual(args[2], table_types) + self.assertEqual(kwargs.get("case_sensitive"), True) + + # Case 2: Default table types (None or empty list) + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + # Test with None + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + + # Test with empty list + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, []) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f30c92ed..d75359f2 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -632,54 +632,254 @@ def test_utility_methods(self, sea_client): sea_client._extract_description_from_manifest(no_columns_manifest) is None ) - def test_unimplemented_metadata_methods( - self, sea_client, sea_session_id, mock_cursor - ): - """Test that metadata methods raise NotImplementedError.""" - # Test get_catalogs - with pytest.raises(NotImplementedError): - sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas - with pytest.raises(NotImplementedError): - sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_schemas( - sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" + def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): + """Test the get_catalogs method.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call get_catalogs + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify execute_command was called with the correct parameters + mock_execute.assert_called_once_with( + operation="SHOW CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result is correct + assert result == mock_result_set + + def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): + """Test the get_schemas method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 2: With catalog and schema names + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables(self, sea_client, sea_session_id, mock_cursor): + """Test the get_tables method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Mock the filter_tables_by_type method + with patch( + "databricks.sql.backend.filters.ResultSetFilter.filter_tables_by_type", + return_value=mock_result_set, + ) as mock_filter: + # Case 1: With catalog name only + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, None) + + # Case 2: With all parameters + table_types = ["TABLE", "VIEW"] + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + table_types=table_types, + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog SCHEMA LIKE 'test_schema' LIKE 'test_table'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, table_types) + + # Case 3: With wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN ALL CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 4: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns(self, sea_client, sea_session_id, mock_cursor): + """Test the get_columns method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, ) - # Test get_tables - with pytest.raises(NotImplementedError): - sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - - # Test get_tables with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_tables( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - table_types=["TABLE", "VIEW"], + # Case 2: With all parameters + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", ) - # Test get_columns - with pytest.raises(NotImplementedError): - sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - - # Test get_columns with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_columns( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - column_name="column", + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_columns" in str(excinfo.value)