Skip to content

[PECOBLR-201] add variant support #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def convert_col(t_column_desc):
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])

@staticmethod
def _col_to_description(col):
def _col_to_description(col, field):
type_entry = col.typeDesc.types[0]

if type_entry.primitiveEntry:
Expand All @@ -692,12 +692,36 @@ def _col_to_description(col):
else:
precision, scale = None, None

# Extract variant type from field if available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is correct? I tried and was getting metadata as null when the column type is variant. Also for variant the pyarrow schema just shows string in my testing, shouldn't the server return variant type ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes,
debug output:
[SHIVAM] field pyarrow.Field<CAST(1 AS VARIANT): string>
[SHIVAM] field metadata {b'Spark:DataType:SqlName': b'VARIANT', b'Spark:DataType:JsonType': b'"variant"'}

Copy link
Contributor

@jprakash-db jprakash-db Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivam2680 I am getting this as the arrow_schema, where metadata is null. Is this some transient behaviour ? or am I missing something
Screenshot 2025-06-17 at 1 43 18 PM

if field is not None:
try:
# Check for variant type in metadata
if field.metadata and b"Spark:DataType:SqlName" in field.metadata:
sql_type = field.metadata.get(b"Spark:DataType:SqlName")
if sql_type == b"VARIANT":
cleaned_type = "variant"
except Exception as e:
logger.debug(f"Could not extract variant type from field: {e}")

return col.columnName, cleaned_type, None, None, precision, scale, None

@staticmethod
def _hive_schema_to_description(t_table_schema):
def _hive_schema_to_description(t_table_schema, schema_bytes=None):
# Create a field lookup dictionary for efficient column access
field_dict = {}
if pyarrow and schema_bytes:
try:
arrow_schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
# Build a dictionary mapping column names to fields
for field in arrow_schema:
field_dict[field.name] = field
except Exception as e:
logger.debug(f"Could not parse arrow schema: {e}")

# Process each column with its corresponding Arrow field (if available)
return [
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
ThriftBackend._col_to_description(col, field_dict.get(col.columnName))
for col in t_table_schema.columns
]

def _results_message_to_execute_response(self, resp, operation_state):
Expand Down Expand Up @@ -726,9 +750,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
or (not direct_results.resultSet)
or direct_results.resultSet.hasMoreRows
)
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)

if pyarrow:
schema_bytes = (
Expand All @@ -740,6 +761,10 @@ def _results_message_to_execute_response(self, resp, operation_state):
else:
schema_bytes = None

description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema, schema_bytes
)

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
if direct_results and direct_results.resultSet:
Expand Down Expand Up @@ -793,9 +818,6 @@ def get_execution_result(self, op_handle, cursor):
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
has_more_rows = resp.hasMoreRows
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)

if pyarrow:
schema_bytes = (
Expand All @@ -807,6 +829,10 @@ def get_execution_result(self, op_handle, cursor):
else:
schema_bytes = None

description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema, schema_bytes
)

queue = ResultSetQueueFactory.build_queue(
row_set_type=resp.resultSetMetadata.resultFormat,
t_row_set=resp.results,
Expand Down
80 changes: 80 additions & 0 deletions tests/e2e/test_variant_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
from datetime import datetime
import json
try:
import pyarrow
except ImportError:
pyarrow = None

from tests.e2e.test_driver import PySQLPytestTestCase
from tests.e2e.common.predicates import pysql_supports_arrow

class TestVariantTypes(PySQLPytestTestCase):
"""Tests for the proper detection and handling of VARIANT type columns"""

@pytest.fixture(scope="class")
def variant_table(self, connection_details):
"""A pytest fixture that creates a test table and cleans up after tests"""
self.arguments = connection_details.copy()
table_name = "pysql_test_variant_types_table"

with self.cursor() as cursor:
try:
# Create the table with variant columns
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table (
id INTEGER,
variant_col VARIANT,
regular_string_col STRING
)
"""
)

# Insert test records with different variant values
cursor.execute(
"""
INSERT INTO pysql_test_variant_types_table
VALUES
(1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'),
(2, PARSE_JSON('[1, 2, 3, 4]'), 'another string')
"""
)
yield table_name
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")

@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support")
def test_variant_type_detection(self, variant_table):
"""Test that VARIANT type columns are properly detected in schema"""
with self.cursor() as cursor:
cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0")

# Verify column types in description
assert cursor.description[0][1] == 'int', "Integer column type not correctly identified"
assert cursor.description[1][1] == 'variant', "VARIANT column type not correctly identified"
assert cursor.description[2][1] == 'string', "String column type not correctly identified"

@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support")
def test_variant_data_retrieval(self, variant_table):
"""Test that VARIANT data is properly retrieved and can be accessed as JSON"""
with self.cursor() as cursor:
cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id")
rows = cursor.fetchall()

# First row should have a JSON object
json_obj = rows[0][1]
assert isinstance(json_obj, str), "VARIANT column should be returned as string"

parsed = json.loads(json_obj)
assert parsed.get('name') == 'John'
assert parsed.get('age') == 30

# Second row should have a JSON array
json_array = rows[1][1]
assert isinstance(json_array, str), "VARIANT array should be returned as string"

# Parsing to verify it's valid JSON array
parsed_array = json.loads(json_array)
assert isinstance(parsed_array, list)
assert parsed_array == [1, 2, 3, 4]
143 changes: 143 additions & 0 deletions tests/unit/test_thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,149 @@ def test_execute_command_sets_complex_type_fields_correctly(
t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow
)

def test_col_to_description_with_variant_type(self):
# Test variant type detection from Arrow field metadata
col = ttypes.TColumnDesc(
columnName="variant_col",
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
)

# Create a field with variant type in metadata
field = pyarrow.field(
"variant_col",
pyarrow.string(),
metadata={b'Spark:DataType:SqlName': b'VARIANT'}
)

result = ThriftBackend._col_to_description(col, field)

# Verify the result has variant as the type
self.assertEqual(result[0], "variant_col") # Column name
self.assertEqual(result[1], "variant") # Type name (should be variant instead of string)
self.assertIsNone(result[2]) # No display size
self.assertIsNone(result[3]) # No internal size
self.assertIsNone(result[4]) # No precision
self.assertIsNone(result[5]) # No scale
self.assertIsNone(result[6]) # No null ok

def test_col_to_description_without_variant_type(self):
# Test normal column without variant type
col = ttypes.TColumnDesc(
columnName="normal_col",
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
)

# Create a normal field without variant metadata
field = pyarrow.field(
"normal_col",
pyarrow.string(),
metadata={}
)

result = ThriftBackend._col_to_description(col, field)

# Verify the result has string as the type (unchanged)
self.assertEqual(result[0], "normal_col") # Column name
self.assertEqual(result[1], "string") # Type name (should be string)
self.assertIsNone(result[2]) # No display size
self.assertIsNone(result[3]) # No internal size
self.assertIsNone(result[4]) # No precision
self.assertIsNone(result[5]) # No scale
self.assertIsNone(result[6]) # No null ok

def test_col_to_description_with_null_field(self):
# Test handling of null field
col = ttypes.TColumnDesc(
columnName="missing_field",
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
)

# Pass None as the field
result = ThriftBackend._col_to_description(col, None)

# Verify the result has string as the type (unchanged)
self.assertEqual(result[0], "missing_field") # Column name
self.assertEqual(result[1], "string") # Type name (should be string)
self.assertIsNone(result[2]) # No display size
self.assertIsNone(result[3]) # No internal size
self.assertIsNone(result[4]) # No precision
self.assertIsNone(result[5]) # No scale
self.assertIsNone(result[6]) # No null ok

def test_hive_schema_to_description_with_arrow_schema(self):
# Create a table schema with regular and variant columns
columns = [
ttypes.TColumnDesc(
columnName="regular_col",
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
),
ttypes.TColumnDesc(
columnName="variant_col",
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
),
]
t_table_schema = ttypes.TTableSchema(columns=columns)

# Create an Arrow schema with one variant column
fields = [
pyarrow.field("regular_col", pyarrow.string()),
pyarrow.field(
"variant_col",
pyarrow.string(),
metadata={b'Spark:DataType:SqlName': b'VARIANT'}
)
]
arrow_schema = pyarrow.schema(fields)
schema_bytes = arrow_schema.serialize().to_pybytes()

# Get the description
description = ThriftBackend._hive_schema_to_description(t_table_schema, schema_bytes)

# Verify regular column type
self.assertEqual(description[0][0], "regular_col")
self.assertEqual(description[0][1], "string")

# Verify variant column type
self.assertEqual(description[1][0], "variant_col")
self.assertEqual(description[1][1], "variant")

def test_hive_schema_to_description_with_null_schema_bytes(self):
# Create a simple table schema
columns = [
ttypes.TColumnDesc(
columnName="regular_col",
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
),
]
t_table_schema = ttypes.TTableSchema(columns=columns)

# Get the description with null schema_bytes
description = ThriftBackend._hive_schema_to_description(t_table_schema, None)

# Verify column type remains unchanged
self.assertEqual(description[0][0], "regular_col")
self.assertEqual(description[0][1], "string")

def test_col_to_description_with_malformed_metadata(self):
# Test handling of malformed metadata
col = ttypes.TColumnDesc(
columnName="weird_field",
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
)

# Create a field with malformed metadata
field = pyarrow.field(
"weird_field",
pyarrow.string(),
metadata={b'Spark:DataType:SqlName': b'Some unexpected value'}
)

result = ThriftBackend._col_to_description(col, field)

# Verify the type remains unchanged
self.assertEqual(result[0], "weird_field") # Column name
self.assertEqual(result[1], "string") # Type name (should remain string)


if __name__ == "__main__":
unittest.main()
Loading