Skip to content

Commit cf7df8d

Browse files
committed
query tags with sea flow
Signed-off-by: Sreekanth Vadigi <[email protected]>
1 parent b4926d6 commit cf7df8d

File tree

4 files changed

+21
-15
lines changed

4 files changed

+21
-15
lines changed

src/databricks/sql/backend/sea/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"STATEMENT_TIMEOUT": "0",
1616
"TIMEZONE": "UTC",
1717
"USE_CACHED_RESULT": "true",
18+
"QUERY_TAGS": "",
1819
}
1920

2021

tests/e2e/test_driver.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -848,10 +848,21 @@ def test_socket_timeout_user_defined(self):
848848
query = "select * from range(1000000000)"
849849
cursor.execute(query)
850850

851-
def test_ssp_passthrough(self):
851+
@pytest.mark.parametrize(
852+
"extra_params",
853+
[
854+
{
855+
"use_sea": False,
856+
},
857+
{
858+
"use_sea": True,
859+
},
860+
],
861+
)
862+
def test_ssp_passthrough(self, extra_params):
852863
for enable_ansi in (True, False):
853864
with self.cursor(
854-
{"session_configuration": {"ansi_mode": enable_ansi, "QUERY_TAGS": "team:marketing,dashboard:abc123,driver:python"}}
865+
{"session_configuration": {"ansi_mode": enable_ansi, "QUERY_TAGS": "team:marketing,dashboard:abc123,driver:python"}, **extra_params}
855866
) as cursor:
856867
cursor.execute("SET ansi_mode")
857868
assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)]

tests/unit/test_sea_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i
185185
session_config = {
186186
"ANSI_MODE": "FALSE", # Supported parameter
187187
"STATEMENT_TIMEOUT": "3600", # Supported parameter
188+
"QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter
188189
"unsupported_param": "value", # Unsupported parameter
189190
}
190191
catalog = "test_catalog"
@@ -196,6 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i
196197
"session_confs": {
197198
"ansi_mode": "FALSE",
198199
"statement_timeout": "3600",
200+
"query_tags": "team:marketing,dashboard:abc123",
199201
},
200202
"catalog": catalog,
201203
"schema": schema,
@@ -641,6 +643,7 @@ def test_filter_session_configuration(self):
641643
"TIMEZONE": "UTC",
642644
"enable_photon": False,
643645
"MAX_FILE_PARTITION_BYTES": 128.5,
646+
"QUERY_TAGS": "team:engineering,project:data-pipeline",
644647
"unsupported_param": "value",
645648
"ANOTHER_UNSUPPORTED": 42,
646649
}
@@ -663,6 +666,7 @@ def test_filter_session_configuration(self):
663666
"timezone": "UTC", # string -> "UTC", key lowercased
664667
"enable_photon": "False", # boolean False -> "False", key lowercased
665668
"max_file_partition_bytes": "128.5", # float -> "128.5", key lowercased
669+
"query_tags": "team:engineering,project:data-pipeline",
666670
}
667671

668672
assert result == expected_result
@@ -683,12 +687,14 @@ def test_filter_session_configuration(self):
683687
"ansi_mode": "false", # lowercase key
684688
"STATEMENT_TIMEOUT": 7200, # uppercase key
685689
"TiMeZoNe": "America/New_York", # mixed case key
690+
"QueRy_TaGs": "team:marketing,test:case-insensitive",
686691
}
687692
result = _filter_session_configuration(case_insensitive_config)
688693
expected_case_result = {
689694
"ansi_mode": "false",
690695
"statement_timeout": "7200",
691696
"timezone": "America/New_York",
697+
"query_tags": "team:marketing,test:case-insensitive",
692698
}
693699
assert result == expected_case_result
694700

tests/unit/test_session.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,26 +154,14 @@ def test_socket_timeout_passthrough(self, mock_client_class):
154154

155155
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
156156
def test_configuration_passthrough(self, mock_client_class):
157-
mock_session_config = Mock()
157+
mock_session_config = {"ANSI_MODE": "FALSE", "QUERY_TAGS": "team:engineering,project:data-pipeline"}
158158
databricks.sql.connect(
159159
session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS
160160
)
161161

162162
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
163163
assert call_kwargs["session_configuration"] == mock_session_config
164164

165-
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
166-
def test_query_tags_configuration_passthrough(self, mock_client_class):
167-
"""Test that Query Tags are properly passed through to the backend."""
168-
query_tags_config = {"QUERY_TAGS": "team:marketing,dashboard:abc123"}
169-
databricks.sql.connect(
170-
session_configuration=query_tags_config, **self.DUMMY_CONNECTION_ARGS
171-
)
172-
173-
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
174-
assert call_kwargs["session_configuration"] == query_tags_config
175-
assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:marketing,dashboard:abc123"
176-
177165
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
178166
def test_initial_namespace_passthrough(self, mock_client_class):
179167
mock_cat = Mock()

0 commit comments

Comments
 (0)