Skip to content

gh-113202: Add a strict option to itertools.batched() #113203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions Doc/library/itertools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,14 @@ loops that truncate the stream.
Added the optional *initial* parameter.


.. function:: batched(iterable, n)
.. function:: batched(iterable, n, *, strict=False)

Batch data from the *iterable* into tuples of length *n*. The last
batch may be shorter than *n*.

If *strict* is true, will raise a :exc:`ValueError` if the final
batch is shorter than *n*.

Loops over the input iterable and accumulates data into tuples up to
size *n*. The input is consumed lazily, just enough to fill a batch.
The result is yielded as soon as the batch is full or when the input
Expand All @@ -190,16 +193,21 @@ loops that truncate the stream.

Roughly equivalent to::

def batched(iterable, n):
def batched(iterable, n, *, strict=False):
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while batch := tuple(islice(it, n)):
if strict and len(batch) != n:
raise ValueError('batched(): incomplete batch')
yield batch

.. versionadded:: 3.12

.. versionchanged:: 3.13
Added the *strict* option.


.. function:: chain(*iterables)

Expand Down Expand Up @@ -1039,7 +1047,7 @@ The following recipes have a more mathematical flavor:
def reshape(matrix, cols):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) --> (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), cols)
return batched(chain.from_iterable(matrix), cols, strict=True)

def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
Expand Down Expand Up @@ -1270,6 +1278,10 @@ The following recipes have a more mathematical flavor:
[(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)]
>>> list(reshape(M, 4))
[(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]
>>> list(reshape(M, 5))
Traceback (most recent call last):
...
ValueError: batched(): incomplete batch
>>> list(reshape(M, 6))
[(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)]
>>> list(reshape(M, 12))
Expand Down
4 changes: 4 additions & 0 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,11 @@ def test_batched(self):
[('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)])
self.assertEqual(list(batched('ABCDEFG', 1)),
[('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)])
self.assertEqual(list(batched('ABCDEF', 2, strict=True)),
[('A', 'B'), ('C', 'D'), ('E', 'F')])

with self.assertRaises(ValueError): # Incomplete batch when strict
list(batched('ABCDEFG', 3, strict=True))
with self.assertRaises(TypeError): # Too few arguments
list(batched('ABCDEFG'))
with self.assertRaises(TypeError):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a ``strict`` option to ``batched()`` in the ``itertools`` module.
32 changes: 23 additions & 9 deletions Modules/clinic/itertoolsmodule.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 17 additions & 12 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,21 @@ class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type"

/* batched object ************************************************************/

/* Note: The built-in zip() function includes a "strict" argument
that was needed because that function would silently truncate data,
and there was no easy way for a user to detect the data loss.
The same reasoning does not apply to batched() which never drops data.
Instead, batched() produces a shorter tuple which can be handled
as the user sees fit. If requested, it would be reasonable to add
"fillvalue" support which had demonstrated value in zip_longest().
For now, the API is kept simple and clean.
*/

typedef struct {
PyObject_HEAD
PyObject *it;
Py_ssize_t batch_size;
bool strict;
} batchedobject;

/*[clinic input]
@classmethod
itertools.batched.__new__ as batched_new
iterable: object
n: Py_ssize_t
*
strict: bool = False

Batch data into tuples of length n. The last batch may be shorter than n.

Loops over the input iterable and accumulates data into tuples
Expand All @@ -140,11 +134,15 @@ or when the input iterable is exhausted.
('D', 'E', 'F')
('G',)

If "strict" is True, raises a ValueError if the final batch is shorter
than n.

[clinic start generated code]*/

static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
/*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
int strict)
/*[clinic end generated code: output=c6de11b061529d3e input=7814b47e222f5467]*/
{
PyObject *it;
batchedobject *bo;
Expand All @@ -170,6 +168,7 @@ batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
}
bo->batch_size = n;
bo->it = it;
bo->strict = (bool) strict;
return (PyObject *)bo;
}

Expand Down Expand Up @@ -233,6 +232,12 @@ batched_next(batchedobject *bo)
Py_DECREF(result);
return NULL;
}
if (bo->strict) {
Py_CLEAR(bo->it);
Py_DECREF(result);
PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch");
return NULL;
}
_PyTuple_Resize(&result, i);
return result;
}
Expand Down