diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e3dc38ad5..2faa9d4a8 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -675,7 +675,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col): + def _col_to_description(col, field): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -702,12 +702,36 @@ def _col_to_description(col): else: precision, scale = None, None + # Extract variant type from field if available + if field is not None: + try: + # Check for variant type in metadata + if field.metadata and b"Spark:DataType:SqlName" in field.metadata: + sql_type = field.metadata.get(b"Spark:DataType:SqlName") + if sql_type == b"VARIANT": + cleaned_type = "variant" + except Exception as e: + logger.debug(f"Could not extract variant type from field: {e}") + return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema): + def _hive_schema_to_description(t_table_schema, schema_bytes=None): + # Create a field lookup dictionary for efficient column access + field_dict = {} + if pyarrow and schema_bytes: + try: + arrow_schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes)) + # Build a dictionary mapping column names to fields + for field in arrow_schema: + field_dict[field.name] = field + except Exception as e: + logger.debug(f"Could not parse arrow schema: {e}") + + # Process each column with its corresponding Arrow field (if available) return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftBackend._col_to_description(col, field_dict.get(col.columnName)) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -736,9 +760,6 @@ def _results_message_to_execute_response(self, resp, operation_state): or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema - ) if pyarrow: schema_bytes = ( @@ -750,6 +771,10 @@ def _results_message_to_execute_response(self, resp, operation_state): else: schema_bytes = None + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema, schema_bytes + ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation if direct_results and direct_results.resultSet: @@ -803,9 +828,6 @@ def get_execution_result(self, op_handle, cursor): lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema - ) if pyarrow: schema_bytes = ( @@ -817,6 +839,10 @@ def get_execution_result(self, op_handle, cursor): else: schema_bytes = None + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema, schema_bytes + ) + queue = ResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, diff --git a/tests/e2e/test_variant_types.py b/tests/e2e/test_variant_types.py new file mode 100644 index 000000000..11236e6d2 --- /dev/null +++ b/tests/e2e/test_variant_types.py @@ -0,0 +1,80 @@ +import pytest +from datetime import datetime +import json +try: + import pyarrow +except ImportError: + pyarrow = None + +from tests.e2e.test_driver import PySQLPytestTestCase +from tests.e2e.common.predicates import pysql_supports_arrow + +class TestVariantTypes(PySQLPytestTestCase): + """Tests for the proper detection and handling of VARIANT type columns""" + + @pytest.fixture(scope="class") + def variant_table(self, connection_details): + """A pytest fixture that creates a test table and cleans up after tests""" + self.arguments = connection_details.copy() + table_name = "pysql_test_variant_types_table" + + with self.cursor() as cursor: + try: + # Create the table with variant columns + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table ( + id INTEGER, + variant_col VARIANT, + regular_string_col STRING + ) + """ + ) + + # Insert test records with different variant values + cursor.execute( + """ + INSERT INTO pysql_test_variant_types_table + VALUES + (1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'), + (2, PARSE_JSON('[1, 2, 3, 4]'), 'another string') + """ + ) + yield table_name + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support") + def test_variant_type_detection(self, variant_table): + """Test that VARIANT type columns are properly detected in schema""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0") + + # Verify column types in description + assert cursor.description[0][1] == 'int', "Integer column type not correctly identified" + assert cursor.description[1][1] == 'variant', "VARIANT column type not correctly identified" + assert cursor.description[2][1] == 'string', "String column type not correctly identified" + + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support") + def test_variant_data_retrieval(self, variant_table): + """Test that VARIANT data is properly retrieved and can be accessed as JSON""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id") + rows = cursor.fetchall() + + # First row should have a JSON object + json_obj = rows[0][1] + assert isinstance(json_obj, str), "VARIANT column should be returned as string" + + parsed = json.loads(json_obj) + assert parsed.get('name') == 'John' + assert parsed.get('age') == 30 + + # Second row should have a JSON array + json_array = rows[1][1] + assert isinstance(json_array, str), "VARIANT array should be returned as string" + + # Parsing to verify it's valid JSON array + parsed_array = json.loads(json_array) + assert isinstance(parsed_array, list) + assert parsed_array == [1, 2, 3, 4] \ No newline at end of file diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 458ea9a82..b3f96b5f2 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -2202,6 +2202,149 @@ def test_execute_command_sets_complex_type_fields_correctly( t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow ) + def test_col_to_description_with_variant_type(self): + # Test variant type detection from Arrow field metadata + col = ttypes.TColumnDesc( + columnName="variant_col", + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ) + + # Create a field with variant type in metadata + field = pyarrow.field( + "variant_col", + pyarrow.string(), + metadata={b'Spark:DataType:SqlName': b'VARIANT'} + ) + + result = ThriftBackend._col_to_description(col, field) + + # Verify the result has variant as the type + self.assertEqual(result[0], "variant_col") # Column name + self.assertEqual(result[1], "variant") # Type name (should be variant instead of string) + self.assertIsNone(result[2]) # No display size + self.assertIsNone(result[3]) # No internal size + self.assertIsNone(result[4]) # No precision + self.assertIsNone(result[5]) # No scale + self.assertIsNone(result[6]) # No null ok + + def test_col_to_description_without_variant_type(self): + # Test normal column without variant type + col = ttypes.TColumnDesc( + columnName="normal_col", + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ) + + # Create a normal field without variant metadata + field = pyarrow.field( + "normal_col", + pyarrow.string(), + metadata={} + ) + + result = ThriftBackend._col_to_description(col, field) + + # Verify the result has string as the type (unchanged) + self.assertEqual(result[0], "normal_col") # Column name + self.assertEqual(result[1], "string") # Type name (should be string) + self.assertIsNone(result[2]) # No display size + self.assertIsNone(result[3]) # No internal size + self.assertIsNone(result[4]) # No precision + self.assertIsNone(result[5]) # No scale + self.assertIsNone(result[6]) # No null ok + + def test_col_to_description_with_null_field(self): + # Test handling of null field + col = ttypes.TColumnDesc( + columnName="missing_field", + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ) + + # Pass None as the field + result = ThriftBackend._col_to_description(col, None) + + # Verify the result has string as the type (unchanged) + self.assertEqual(result[0], "missing_field") # Column name + self.assertEqual(result[1], "string") # Type name (should be string) + self.assertIsNone(result[2]) # No display size + self.assertIsNone(result[3]) # No internal size + self.assertIsNone(result[4]) # No precision + self.assertIsNone(result[5]) # No scale + self.assertIsNone(result[6]) # No null ok + + def test_hive_schema_to_description_with_arrow_schema(self): + # Create a table schema with regular and variant columns + columns = [ + ttypes.TColumnDesc( + columnName="regular_col", + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ), + ttypes.TColumnDesc( + columnName="variant_col", + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ), + ] + t_table_schema = ttypes.TTableSchema(columns=columns) + + # Create an Arrow schema with one variant column + fields = [ + pyarrow.field("regular_col", pyarrow.string()), + pyarrow.field( + "variant_col", + pyarrow.string(), + metadata={b'Spark:DataType:SqlName': b'VARIANT'} + ) + ] + arrow_schema = pyarrow.schema(fields) + schema_bytes = arrow_schema.serialize().to_pybytes() + + # Get the description + description = ThriftBackend._hive_schema_to_description(t_table_schema, schema_bytes) + + # Verify regular column type + self.assertEqual(description[0][0], "regular_col") + self.assertEqual(description[0][1], "string") + + # Verify variant column type + self.assertEqual(description[1][0], "variant_col") + self.assertEqual(description[1][1], "variant") + + def test_hive_schema_to_description_with_null_schema_bytes(self): + # Create a simple table schema + columns = [ + ttypes.TColumnDesc( + columnName="regular_col", + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ), + ] + t_table_schema = ttypes.TTableSchema(columns=columns) + + # Get the description with null schema_bytes + description = ThriftBackend._hive_schema_to_description(t_table_schema, None) + + # Verify column type remains unchanged + self.assertEqual(description[0][0], "regular_col") + self.assertEqual(description[0][1], "string") + + def test_col_to_description_with_malformed_metadata(self): + # Test handling of malformed metadata + col = ttypes.TColumnDesc( + columnName="weird_field", + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ) + + # Create a field with malformed metadata + field = pyarrow.field( + "weird_field", + pyarrow.string(), + metadata={b'Spark:DataType:SqlName': b'Some unexpected value'} + ) + + result = ThriftBackend._col_to_description(col, field) + + # Verify the type remains unchanged + self.assertEqual(result[0], "weird_field") # Column name + self.assertEqual(result[1], "string") # Type name (should remain string) + if __name__ == "__main__": unittest.main()