Skip to content

Commit f05b1ae

Browse files
authored
[bug] fix reading with to_arrow_batch_reader and limit (#1042)
* fix project_batches with limit * add test * lint + readability
1 parent 2e73a41 commit f05b1ae

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,9 @@ def project_batches(
14551455
total_row_count = 0
14561456

14571457
for task in tasks:
1458+
# stop early if limit is satisfied
1459+
if limit is not None and total_row_count >= limit:
1460+
break
14581461
batches = _task_to_record_batches(
14591462
fs,
14601463
task,
@@ -1468,9 +1471,10 @@ def project_batches(
14681471
)
14691472
for batch in batches:
14701473
if limit is not None:
1471-
if total_row_count + len(batch) >= limit:
1472-
yield batch.slice(0, limit - total_row_count)
1474+
if total_row_count >= limit:
14731475
break
1476+
elif total_row_count + len(batch) >= limit:
1477+
batch = batch.slice(0, limit - total_row_count)
14741478
yield batch
14751479
total_row_count += len(batch)
14761480

tests/integration/test_reads.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,54 @@ def test_pyarrow_limit(catalog: Catalog) -> None:
244244
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow()
245245
assert len(full_result) == 10
246246

247+
# test `to_arrow_batch_reader`
248+
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all()
249+
assert len(limited_result) == 1
250+
251+
empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all()
252+
assert len(empty_result) == 0
253+
254+
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all()
255+
assert len(full_result) == 10
256+
257+
258+
@pytest.mark.integration
259+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
260+
def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None:
261+
table_name = "default.test_pyarrow_limit_with_multiple_files"
262+
try:
263+
catalog.drop_table(table_name)
264+
except NoSuchTableError:
265+
pass
266+
reference_table = catalog.load_table("default.test_limit")
267+
data = reference_table.scan().to_arrow()
268+
table_test_limit = catalog.create_table(table_name, schema=reference_table.schema())
269+
270+
n_files = 2
271+
for _ in range(n_files):
272+
table_test_limit.append(data)
273+
assert len(table_test_limit.inspect.files()) == n_files
274+
275+
# test with multiple files
276+
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow()
277+
assert len(limited_result) == 1
278+
279+
empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow()
280+
assert len(empty_result) == 0
281+
282+
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow()
283+
assert len(full_result) == 10 * n_files
284+
285+
# test `to_arrow_batch_reader`
286+
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all()
287+
assert len(limited_result) == 1
288+
289+
empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all()
290+
assert len(empty_result) == 0
291+
292+
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all()
293+
assert len(full_result) == 10 * n_files
294+
247295

248296
@pytest.mark.integration
249297
@pytest.mark.filterwarnings("ignore")

0 commit comments

Comments
 (0)