Skip to content

Commit 1583c40

Browse files
authored
gh-113202: Add a strict option to itertools.batched() (gh-113203)
1 parent fe479fb commit 1583c40

File tree

5 files changed

+60
-24
lines changed

5 files changed

+60
-24
lines changed

Doc/library/itertools.rst

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,14 @@ loops that truncate the stream.
164164
Added the optional *initial* parameter.
165165

166166

167-
.. function:: batched(iterable, n)
167+
.. function:: batched(iterable, n, *, strict=False)
168168

169169
Batch data from the *iterable* into tuples of length *n*. The last
170170
batch may be shorter than *n*.
171171

172+
If *strict* is true, will raise a :exc:`ValueError` if the final
173+
batch is shorter than *n*.
174+
172175
Loops over the input iterable and accumulates data into tuples up to
173176
size *n*. The input is consumed lazily, just enough to fill a batch.
174177
The result is yielded as soon as the batch is full or when the input
@@ -190,16 +193,21 @@ loops that truncate the stream.
190193

191194
Roughly equivalent to::
192195

193-
def batched(iterable, n):
196+
def batched(iterable, n, *, strict=False):
194197
# batched('ABCDEFG', 3) --> ABC DEF G
195198
if n < 1:
196199
raise ValueError('n must be at least one')
197200
it = iter(iterable)
198201
while batch := tuple(islice(it, n)):
202+
if strict and len(batch) != n:
203+
raise ValueError('batched(): incomplete batch')
199204
yield batch
200205

201206
.. versionadded:: 3.12
202207

208+
.. versionchanged:: 3.13
209+
Added the *strict* option.
210+
203211

204212
.. function:: chain(*iterables)
205213

@@ -1039,7 +1047,7 @@ The following recipes have a more mathematical flavor:
10391047
def reshape(matrix, cols):
10401048
"Reshape a 2-D matrix to have a given number of columns."
10411049
# reshape([(0, 1), (2, 3), (4, 5)], 3) --> (0, 1, 2), (3, 4, 5)
1042-
return batched(chain.from_iterable(matrix), cols)
1050+
return batched(chain.from_iterable(matrix), cols, strict=True)
10431051

10441052
def transpose(matrix):
10451053
"Swap the rows and columns of a 2-D matrix."
@@ -1270,6 +1278,10 @@ The following recipes have a more mathematical flavor:
12701278
[(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)]
12711279
>>> list(reshape(M, 4))
12721280
[(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]
1281+
>>> list(reshape(M, 5))
1282+
Traceback (most recent call last):
1283+
...
1284+
ValueError: batched(): incomplete batch
12731285
>>> list(reshape(M, 6))
12741286
[(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)]
12751287
>>> list(reshape(M, 12))

Lib/test/test_itertools.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ def test_batched(self):
187187
[('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)])
188188
self.assertEqual(list(batched('ABCDEFG', 1)),
189189
[('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)])
190+
self.assertEqual(list(batched('ABCDEF', 2, strict=True)),
191+
[('A', 'B'), ('C', 'D'), ('E', 'F')])
190192

193+
with self.assertRaises(ValueError): # Incomplete batch when strict
194+
list(batched('ABCDEFG', 3, strict=True))
191195
with self.assertRaises(TypeError): # Too few arguments
192196
list(batched('ABCDEFG'))
193197
with self.assertRaises(TypeError):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a ``strict`` option to ``batched()`` in the ``itertools`` module.

Modules/clinic/itertoolsmodule.c.h

Lines changed: 23 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Modules/itertoolsmodule.c

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,27 +105,21 @@ class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type"
105105

106106
/* batched object ************************************************************/
107107

108-
/* Note: The built-in zip() function includes a "strict" argument
109-
that was needed because that function would silently truncate data,
110-
and there was no easy way for a user to detect the data loss.
111-
The same reasoning does not apply to batched() which never drops data.
112-
Instead, batched() produces a shorter tuple which can be handled
113-
as the user sees fit. If requested, it would be reasonable to add
114-
"fillvalue" support which had demonstrated value in zip_longest().
115-
For now, the API is kept simple and clean.
116-
*/
117-
118108
typedef struct {
119109
PyObject_HEAD
120110
PyObject *it;
121111
Py_ssize_t batch_size;
112+
bool strict;
122113
} batchedobject;
123114

124115
/*[clinic input]
125116
@classmethod
126117
itertools.batched.__new__ as batched_new
127118
iterable: object
128119
n: Py_ssize_t
120+
*
121+
strict: bool = False
122+
129123
Batch data into tuples of length n. The last batch may be shorter than n.
130124
131125
Loops over the input iterable and accumulates data into tuples
@@ -140,11 +134,15 @@ or when the input iterable is exhausted.
140134
('D', 'E', 'F')
141135
('G',)
142136
137+
If "strict" is True, raises a ValueError if the final batch is shorter
138+
than n.
139+
143140
[clinic start generated code]*/
144141

145142
static PyObject *
146-
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
147-
/*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/
143+
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
144+
int strict)
145+
/*[clinic end generated code: output=c6de11b061529d3e input=7814b47e222f5467]*/
148146
{
149147
PyObject *it;
150148
batchedobject *bo;
@@ -170,6 +168,7 @@ batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
170168
}
171169
bo->batch_size = n;
172170
bo->it = it;
171+
bo->strict = (bool) strict;
173172
return (PyObject *)bo;
174173
}
175174

@@ -233,6 +232,12 @@ batched_next(batchedobject *bo)
233232
Py_DECREF(result);
234233
return NULL;
235234
}
235+
if (bo->strict) {
236+
Py_CLEAR(bo->it);
237+
Py_DECREF(result);
238+
PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch");
239+
return NULL;
240+
}
236241
_PyTuple_Resize(&result, i);
237242
return result;
238243
}

0 commit comments

Comments
 (0)