Skip to content

Commit 4f11ff0

Browse files
Introduce row_limit param (#607)
* introduce row_limit Signed-off-by: varun-edachali-dbx <[email protected]> * move use_sea init to Session constructor Signed-off-by: varun-edachali-dbx <[email protected]> * more explicit typing Signed-off-by: varun-edachali-dbx <[email protected]> * add row_limit to Thrift backend Signed-off-by: varun-edachali-dbx <[email protected]> * formatting (black) Signed-off-by: varun-edachali-dbx <[email protected]> * add e2e test for thrift resultRowLimit Signed-off-by: varun-edachali-dbx <[email protected]> * explicitly convert extra cursor params to dict Signed-off-by: varun-edachali-dbx <[email protected]> * remove excess tests Signed-off-by: varun-edachali-dbx <[email protected]> * add docstring for row_limit Signed-off-by: varun-edachali-dbx <[email protected]> --------- Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent abf9aab commit 4f11ff0

File tree

5 files changed

+85
-12
lines changed

5 files changed

+85
-12
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def execute_command(
8585
parameters: List,
8686
async_op: bool,
8787
enforce_embedded_schema_correctness: bool,
88+
row_limit: Optional[int] = None,
8889
) -> Union["ResultSet", None]:
8990
"""
9091
Executes a SQL command or query within the specified session.
@@ -103,6 +104,7 @@ def execute_command(
103104
parameters: List of parameters to bind to the query
104105
async_op: Whether to execute the command asynchronously
105106
enforce_embedded_schema_correctness: Whether to enforce schema correctness
107+
row_limit: Maximum number of rows in the operation result.
106108
107109
Returns:
108110
If async_op is False, returns a ResultSet object containing the

src/databricks/sql/backend/sea/backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def execute_command(
405405
parameters: List[Dict[str, Any]],
406406
async_op: bool,
407407
enforce_embedded_schema_correctness: bool,
408+
row_limit: Optional[int] = None,
408409
) -> Union[SeaResultSet, None]:
409410
"""
410411
Execute a SQL command using the SEA backend.
@@ -462,7 +463,7 @@ def execute_command(
462463
format=format,
463464
wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value,
464465
on_wait_timeout="CONTINUE",
465-
row_limit=max_rows,
466+
row_limit=row_limit,
466467
parameters=sea_parameters if sea_parameters else None,
467468
result_compression=result_compression,
468469
)

src/databricks/sql/backend/thrift_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
import time
66
import threading
7-
from typing import List, Union, Any, TYPE_CHECKING
7+
from typing import List, Optional, Union, Any, TYPE_CHECKING
88

99
if TYPE_CHECKING:
1010
from databricks.sql.client import Cursor
@@ -929,6 +929,7 @@ def execute_command(
929929
parameters=[],
930930
async_op=False,
931931
enforce_embedded_schema_correctness=False,
932+
row_limit: Optional[int] = None,
932933
) -> Union["ResultSet", None]:
933934
thrift_handle = session_id.to_thrift_handle()
934935
if not thrift_handle:
@@ -969,6 +970,7 @@ def execute_command(
969970
useArrowNativeTypes=spark_arrow_types,
970971
parameters=parameters,
971972
enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness,
973+
resultRowLimit=row_limit,
972974
)
973975
resp = self.make_request(self._client.ExecuteStatement, req)
974976

src/databricks/sql/client.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,14 @@ def cursor(
335335
self,
336336
arraysize: int = DEFAULT_ARRAY_SIZE,
337337
buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
338+
row_limit: Optional[int] = None,
338339
) -> "Cursor":
339340
"""
341+
Args:
342+
arraysize: The maximum number of rows in direct results.
343+
buffer_size_bytes: The maximum number of bytes in direct results.
344+
row_limit: The maximum number of rows in the result.
345+
340346
Return a new Cursor object using the connection.
341347
342348
Will throw an Error if the connection has been closed.
@@ -349,6 +355,7 @@ def cursor(
349355
self.session.backend,
350356
arraysize=arraysize,
351357
result_buffer_size_bytes=buffer_size_bytes,
358+
row_limit=row_limit,
352359
)
353360
self._cursors.append(cursor)
354361
return cursor
@@ -382,6 +389,7 @@ def __init__(
382389
backend: DatabricksClient,
383390
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
384391
arraysize: int = DEFAULT_ARRAY_SIZE,
392+
row_limit: Optional[int] = None,
385393
) -> None:
386394
"""
387395
These objects represent a database cursor, which is used to manage the context of a fetch
@@ -391,16 +399,18 @@ def __init__(
391399
visible by other cursors or connections.
392400
"""
393401

394-
self.connection = connection
395-
self.rowcount = -1 # Return -1 as this is not supported
396-
self.buffer_size_bytes = result_buffer_size_bytes
402+
self.connection: Connection = connection
403+
404+
self.rowcount: int = -1 # Return -1 as this is not supported
405+
self.buffer_size_bytes: int = result_buffer_size_bytes
397406
self.active_result_set: Union[ResultSet, None] = None
398-
self.arraysize = arraysize
407+
self.arraysize: int = arraysize
408+
self.row_limit: Optional[int] = row_limit
399409
# Note that Cursor closed => active result set closed, but not vice versa
400-
self.open = True
401-
self.executing_command_id = None
402-
self.backend = backend
403-
self.active_command_id = None
410+
self.open: bool = True
411+
self.executing_command_id: Optional[CommandId] = None
412+
self.backend: DatabricksClient = backend
413+
self.active_command_id: Optional[CommandId] = None
404414
self.escaper = ParamEscaper()
405415
self.lastrowid = None
406416

@@ -779,6 +789,7 @@ def execute(
779789
parameters=prepared_params,
780790
async_op=False,
781791
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
792+
row_limit=self.row_limit,
782793
)
783794

784795
if self.active_result_set and self.active_result_set.is_staging_operation:
@@ -835,6 +846,7 @@ def execute_async(
835846
parameters=prepared_params,
836847
async_op=True,
837848
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
849+
row_limit=self.row_limit,
838850
)
839851

840852
return self

tests/e2e/test_driver.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ def connection(self, extra_params=()):
113113
conn.close()
114114

115115
@contextmanager
116-
def cursor(self, extra_params=()):
116+
def cursor(self, extra_params=(), extra_cursor_params=()):
117117
with self.connection(extra_params) as conn:
118118
cursor = conn.cursor(
119-
arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes
119+
arraysize=self.arraysize,
120+
buffer_size_bytes=self.buffer_size_bytes,
121+
**dict(extra_cursor_params),
120122
)
121123
try:
122124
yield cursor
@@ -943,6 +945,60 @@ def test_catalogs_returns_arrow_table(self):
943945
results = cursor.fetchall_arrow()
944946
assert isinstance(results, pyarrow.Table)
945947

948+
def test_row_limit_with_larger_result(self):
949+
"""Test that row_limit properly constrains results when query would return more rows"""
950+
row_limit = 1000
951+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
952+
# Execute a query that returns more than row_limit rows
953+
cursor.execute("SELECT * FROM range(2000)")
954+
rows = cursor.fetchall()
955+
956+
# Check if the number of rows is limited to row_limit
957+
assert len(rows) == row_limit, f"Expected {row_limit} rows, got {len(rows)}"
958+
959+
def test_row_limit_with_smaller_result(self):
960+
"""Test that row_limit doesn't affect results when query returns fewer rows than limit"""
961+
row_limit = 100
962+
expected_rows = 50
963+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
964+
# Execute a query that returns fewer than row_limit rows
965+
cursor.execute(f"SELECT * FROM range({expected_rows})")
966+
rows = cursor.fetchall()
967+
968+
# Check if all rows are returned (not limited by row_limit)
969+
assert (
970+
len(rows) == expected_rows
971+
), f"Expected {expected_rows} rows, got {len(rows)}"
972+
973+
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
974+
def test_row_limit_with_arrow_larger_result(self):
975+
"""Test that row_limit properly constrains arrow results when query would return more rows"""
976+
row_limit = 800
977+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
978+
# Execute a query that returns more than row_limit rows
979+
cursor.execute("SELECT * FROM range(1500)")
980+
arrow_table = cursor.fetchall_arrow()
981+
982+
# Check if the number of rows in the arrow table is limited to row_limit
983+
assert (
984+
arrow_table.num_rows == row_limit
985+
), f"Expected {row_limit} rows, got {arrow_table.num_rows}"
986+
987+
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
988+
def test_row_limit_with_arrow_smaller_result(self):
989+
"""Test that row_limit doesn't affect arrow results when query returns fewer rows than limit"""
990+
row_limit = 200
991+
expected_rows = 100
992+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
993+
# Execute a query that returns fewer than row_limit rows
994+
cursor.execute(f"SELECT * FROM range({expected_rows})")
995+
arrow_table = cursor.fetchall_arrow()
996+
997+
# Check if all rows are returned (not limited by row_limit)
998+
assert (
999+
arrow_table.num_rows == expected_rows
1000+
), f"Expected {expected_rows} rows, got {arrow_table.num_rows}"
1001+
9461002

9471003
# use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep
9481004
# the 429/503 subsuites separate since they execute under different circumstances.

0 commit comments

Comments
 (0)