Skip to content

Commit a77ae43

Browse files
authoredApr 29, 2024
Merge pull request #25 from kbsriram/add-types-3
Update type annotations for itertools extras.
·
2.1.52.1.1
2 parents 750de7a + 3bd2dd9 commit a77ae43

File tree

3 files changed

+356
-25
lines changed

3 files changed

+356
-25
lines changed
 

‎adafruit_itertools/adafruit_itertools_extras.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,54 @@
4141

4242
import adafruit_itertools as it
4343

44+
try:
45+
from typing import (
46+
Any,
47+
Callable,
48+
Iterable,
49+
Iterator,
50+
List,
51+
Optional,
52+
Tuple,
53+
Type,
54+
TypeVar,
55+
Union,
56+
)
57+
from typing_extensions import TypeAlias
58+
59+
_T = TypeVar("_T")
60+
_N: TypeAlias = Union[int, float, complex]
61+
_Predicate: TypeAlias = Callable[[_T], bool]
62+
except ImportError:
63+
pass
64+
65+
4466
__version__ = "0.0.0+auto.0"
4567
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Itertools.git"
4668

4769

48-
def all_equal(iterable):
70+
def all_equal(iterable: Iterable[Any]) -> bool:
4971
"""Returns True if all the elements are equal to each other.
5072
5173
:param iterable: source of values
5274
5375
"""
5476
g = it.groupby(iterable)
55-
next(g) # should succeed, value isn't relevant
5677
try:
57-
next(g) # should fail: only 1 group
78+
next(g) # value isn't relevant
79+
except StopIteration:
80+
# Empty iterable, return True to match cpython behavior.
81+
return True
82+
try:
83+
next(g)
84+
# more than one group, so we have different elements.
5885
return False
5986
except StopIteration:
87+
# Only one group - all elements must be equal.
6088
return True
6189

6290

63-
def dotproduct(vec1, vec2):
91+
def dotproduct(vec1: Iterable[_N], vec2: Iterable[_N]) -> _N:
6492
"""Compute the dot product of two vectors.
6593
6694
:param vec1: the first vector
@@ -71,7 +99,11 @@ def dotproduct(vec1, vec2):
7199
return sum(map(lambda x, y: x * y, vec1, vec2))
72100

73101

74-
def first_true(iterable, default=False, pred=None):
102+
def first_true(
103+
iterable: Iterable[_T],
104+
default: Union[bool, _T] = False,
105+
pred: Optional[_Predicate[_T]] = None,
106+
) -> Union[bool, _T]:
75107
"""Returns the first true value in the iterable.
76108
77109
If no true value is found, returns *default*
@@ -94,7 +126,7 @@ def first_true(iterable, default=False, pred=None):
94126
return default
95127

96128

97-
def flatten(iterable_of_iterables):
129+
def flatten(iterable_of_iterables: Iterable[Iterable[_T]]) -> Iterator[_T]:
98130
"""Flatten one level of nesting.
99131
100132
:param iterable_of_iterables: a sequence of iterables to flatten
@@ -104,7 +136,9 @@ def flatten(iterable_of_iterables):
104136
return it.chain_from_iterable(iterable_of_iterables)
105137

106138

107-
def grouper(iterable, n, fillvalue=None):
139+
def grouper(
140+
iterable: Iterable[_T], n: int, fillvalue: Optional[_T] = None
141+
) -> Iterator[Tuple[_T, ...]]:
108142
"""Collect data into fixed-length chunks or blocks.
109143
110144
:param iterable: source of values
@@ -118,7 +152,7 @@ def grouper(iterable, n, fillvalue=None):
118152
return it.zip_longest(*args, fillvalue=fillvalue)
119153

120154

121-
def iter_except(func, exception):
155+
def iter_except(func: Callable[[], _T], exception: Type[BaseException]) -> Iterator[_T]:
122156
"""Call a function repeatedly, yielding the results, until exception is raised.
123157
124158
Converts a call-until-exception interface to an iterator interface.
@@ -143,7 +177,7 @@ def iter_except(func, exception):
143177
pass
144178

145179

146-
def ncycles(iterable, n):
180+
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]:
147181
"""Returns the sequence elements a number of times.
148182
149183
:param iterable: the source of values
@@ -153,7 +187,7 @@ def ncycles(iterable, n):
153187
return it.chain_from_iterable(it.repeat(tuple(iterable), n))
154188

155189

156-
def nth(iterable, n, default=None):
190+
def nth(iterable: Iterable[_T], n: int, default: Optional[_T] = None) -> Optional[_T]:
157191
"""Returns the nth item or a default value.
158192
159193
:param iterable: the source of values
@@ -166,7 +200,7 @@ def nth(iterable, n, default=None):
166200
return default
167201

168202

169-
def padnone(iterable):
203+
def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]:
170204
"""Returns the sequence elements and then returns None indefinitely.
171205
172206
Useful for emulating the behavior of the built-in map() function.
@@ -177,13 +211,17 @@ def padnone(iterable):
177211
return it.chain(iterable, it.repeat(None))
178212

179213

180-
def pairwise(iterable):
181-
"""Pair up valuesin the iterable.
214+
def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]:
215+
"""Return successive overlapping pairs from the iterable.
216+
217+
The number of tuples from the output will be one fewer than the
218+
number of values in the input. It will be empty if the input has
219+
fewer than two values.
182220
183221
:param iterable: source of values
184222
185223
"""
186-
# pairwise(range(11)) -> (1, 2), (3, 4), (5, 6), (7, 8), (9, 10)
224+
# pairwise(range(5)) -> (0, 1), (1, 2), (2, 3), (3, 4)
187225
a, b = it.tee(iterable)
188226
try:
189227
next(b)
@@ -192,7 +230,9 @@ def pairwise(iterable):
192230
return zip(a, b)
193231

194232

195-
def partition(pred, iterable):
233+
def partition(
234+
pred: _Predicate[_T], iterable: Iterable[_T]
235+
) -> Tuple[Iterator[_T], Iterator[_T]]:
196236
"""Use a predicate to partition entries into false entries and true entries.
197237
198238
:param pred: the predicate that divides the values
@@ -204,7 +244,7 @@ def partition(pred, iterable):
204244
return it.filterfalse(pred, t1), filter(pred, t2)
205245

206246

207-
def prepend(value, iterator):
247+
def prepend(value: _T, iterator: Iterable[_T]) -> Iterator[_T]:
208248
"""Prepend a single value in front of an iterator
209249
210250
:param value: the value to prepend
@@ -215,7 +255,7 @@ def prepend(value, iterator):
215255
return it.chain([value], iterator)
216256

217257

218-
def quantify(iterable, pred=bool):
258+
def quantify(iterable: Iterable[_T], pred: _Predicate[_T] = bool) -> int:
219259
"""Count how many times the predicate is true.
220260
221261
:param iterable: source of values
@@ -227,7 +267,9 @@ def quantify(iterable, pred=bool):
227267
return sum(map(pred, iterable))
228268

229269

230-
def repeatfunc(func, times=None, *args):
270+
def repeatfunc(
271+
func: Callable[..., _T], times: Optional[int] = None, *args: Any
272+
) -> Iterator[_T]:
231273
"""Repeat calls to func with specified arguments.
232274
233275
Example: repeatfunc(random.random)
@@ -242,7 +284,7 @@ def repeatfunc(func, times=None, *args):
242284
return it.starmap(func, it.repeat(args, times))
243285

244286

245-
def roundrobin(*iterables):
287+
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]:
246288
"""Return an iterable created by repeatedly picking value from each
247289
argument in order.
248290
@@ -263,18 +305,19 @@ def roundrobin(*iterables):
263305
nexts = it.cycle(it.islice(nexts, num_active))
264306

265307

266-
def tabulate(function, start=0):
267-
"""Apply a function to a sequence of consecutive integers.
308+
def tabulate(function: Callable[[int], int], start: int = 0) -> Iterator[int]:
309+
"""Apply a function to a sequence of consecutive numbers.
268310
269-
:param function: the function of one integer argument
311+
:param function: the function of one numeric argument.
270312
:param start: optional value to start at (default is 0)
271313
272314
"""
273315
# take(5, tabulate(lambda x: x * x))) -> 0 1 4 9 16
274-
return map(function, it.count(start))
316+
counter: Iterator[int] = it.count(start) # type: ignore[assignment]
317+
return map(function, counter)
275318

276319

277-
def tail(n, iterable):
320+
def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]:
278321
"""Return an iterator over the last n items
279322
280323
:param n: how many values to return
@@ -294,7 +337,7 @@ def tail(n, iterable):
294337
return iter(buf)
295338

296339

297-
def take(n, iterable):
340+
def take(n: int, iterable: Iterable[_T]) -> List[_T]:
298341
"""Return first n items of the iterable as a list
299342
300343
:param n: how many values to take

‎optional_requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
# SPDX-FileCopyrightText: 2022 Alec Delaney, for Adafruit Industries
22
#
33
# SPDX-License-Identifier: Unlicense
4+
5+
# For comparison when running tests
6+
more-itertools

‎tests/test_itertools_extras.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# SPDX-FileCopyrightText: KB Sriram
2+
# SPDX-License-Identifier: MIT
3+
4+
from typing import (
5+
Callable,
6+
Iterator,
7+
Optional,
8+
Sequence,
9+
TypeVar,
10+
)
11+
from typing_extensions import TypeAlias
12+
13+
import more_itertools as itextras
14+
import pytest
15+
from adafruit_itertools import adafruit_itertools_extras as aextras
16+
17+
_K = TypeVar("_K")
18+
_T = TypeVar("_T")
19+
_S = TypeVar("_S")
20+
_Predicate: TypeAlias = Callable[[_T], bool]
21+
22+
23+
def _take(n: int, iterator: Iterator[_T]) -> Sequence[_T]:
24+
"""Extract the first n elements from a long/infinite iterator."""
25+
return [v for _, v in zip(range(n), iterator)]
26+
27+
28+
@pytest.mark.parametrize(
29+
"data",
30+
[
31+
"aaaa",
32+
"abcd",
33+
"a",
34+
"",
35+
(1, 2),
36+
(3, 3),
37+
("", False),
38+
(42, True),
39+
],
40+
)
41+
def test_all_equal(data: Sequence[_T]) -> None:
42+
assert itextras.all_equal(data) == aextras.all_equal(data)
43+
44+
45+
@pytest.mark.parametrize(
46+
("vec1", "vec2"),
47+
[
48+
([1, 2], [3, 4]),
49+
([], []),
50+
([1], [2, 3]),
51+
([4, 5], [6]),
52+
],
53+
)
54+
def test_dotproduct(vec1: Sequence[int], vec2: Sequence[int]) -> None:
55+
assert itextras.dotproduct(vec1, vec2) == aextras.dotproduct(vec1, vec2)
56+
57+
58+
@pytest.mark.parametrize(
59+
("seq", "dflt", "pred"),
60+
[
61+
([0, 2], 0, None),
62+
([], 10, None),
63+
([False], True, None),
64+
([1, 2], -1, lambda _: False),
65+
([0, 1], -1, lambda _: True),
66+
([], -1, lambda _: True),
67+
],
68+
)
69+
def test_first_true(
70+
seq: Sequence[_T], dflt: _T, pred: Optional[_Predicate[_T]]
71+
) -> None:
72+
assert itextras.first_true(seq, dflt, pred) == aextras.first_true(seq, dflt, pred)
73+
74+
75+
@pytest.mark.parametrize(
76+
("seq1", "seq2"),
77+
[
78+
("abc", "def"),
79+
("", "def"),
80+
("abc", ""),
81+
("", ""),
82+
],
83+
)
84+
def test_flatten(seq1: str, seq2: str) -> None:
85+
assert list(itextras.flatten(seq1 + seq2)) == list(aextras.flatten(seq1 + seq2))
86+
for repeat in range(3):
87+
assert list(itextras.flatten([seq1] * repeat)) == list(
88+
aextras.flatten([seq1] * repeat)
89+
)
90+
assert list(itextras.flatten([seq2] * repeat)) == list(
91+
aextras.flatten([seq2] * repeat)
92+
)
93+
94+
95+
@pytest.mark.parametrize(
96+
("seq", "count", "fill"),
97+
[
98+
("abc", 3, None),
99+
("abcd", 3, None),
100+
("abc", 3, "x"),
101+
("abcd", 3, "x"),
102+
("abc", 0, None),
103+
("", 3, "xy"),
104+
],
105+
)
106+
def test_grouper(seq: Sequence[str], count: int, fill: Optional[str]) -> None:
107+
assert list(itextras.grouper(seq, count, fillvalue=fill)) == list(
108+
aextras.grouper(seq, count, fillvalue=fill)
109+
)
110+
111+
112+
@pytest.mark.parametrize(
113+
("data"),
114+
[
115+
(1, 2, 3),
116+
(),
117+
],
118+
)
119+
def test_iter_except(data: Sequence[int]) -> None:
120+
assert list(itextras.iter_except(list(data).pop, IndexError)) == list(
121+
aextras.iter_except(list(data).pop, IndexError)
122+
)
123+
124+
125+
@pytest.mark.parametrize(
126+
("seq", "count"),
127+
[
128+
("abc", 4),
129+
("abc", 0),
130+
("", 4),
131+
],
132+
)
133+
def test_ncycles(seq: str, count: int) -> None:
134+
assert list(itextras.ncycles(seq, count)) == list(aextras.ncycles(seq, count))
135+
136+
137+
@pytest.mark.parametrize(
138+
("seq", "n", "dflt"),
139+
[
140+
("abc", 1, None),
141+
("abc", 10, None),
142+
("abc", 10, "x"),
143+
("", 0, None),
144+
],
145+
)
146+
def test_nth(seq: str, n: int, dflt: Optional[str]) -> None:
147+
assert itextras.nth(seq, n, dflt) == aextras.nth(seq, n, dflt)
148+
149+
150+
@pytest.mark.parametrize(
151+
("seq"),
152+
[
153+
"abc",
154+
"",
155+
],
156+
)
157+
def test_padnone(seq: str) -> None:
158+
assert _take(10, itextras.padnone(seq)) == _take(10, aextras.padnone(seq))
159+
160+
161+
@pytest.mark.parametrize(
162+
("seq"),
163+
[
164+
(),
165+
(1,),
166+
(1, 2),
167+
(1, 2, 3),
168+
(1, 2, 3, 4),
169+
],
170+
)
171+
def test_pairwise(seq: Sequence[int]) -> None:
172+
assert list(itextras.pairwise(seq)) == list(aextras.pairwise(seq))
173+
174+
175+
@pytest.mark.parametrize(
176+
("pred", "seq"),
177+
[
178+
(lambda x: x % 2, (0, 1, 2, 3)),
179+
(lambda x: x % 2, (0, 2)),
180+
(lambda x: x % 2, ()),
181+
],
182+
)
183+
def test_partition(pred: _Predicate[int], seq: Sequence[int]) -> None:
184+
# assert list(itextras.partition(pred, seq)) == list(aextras.partition(pred, seq))
185+
true1, false1 = itextras.partition(pred, seq)
186+
true2, false2 = aextras.partition(pred, seq)
187+
assert list(true1) == list(true2)
188+
assert list(false1) == list(false2)
189+
190+
191+
@pytest.mark.parametrize(
192+
("value", "seq"),
193+
[
194+
(1, (2, 3)),
195+
(1, ()),
196+
],
197+
)
198+
def test_prepend(value: int, seq: Sequence[int]) -> None:
199+
assert list(itextras.prepend(value, seq)) == list(aextras.prepend(value, seq))
200+
201+
202+
@pytest.mark.parametrize(
203+
("seq", "pred"),
204+
[
205+
((0, 1), lambda x: x % 2 == 0),
206+
((1, 1), lambda x: x % 2 == 0),
207+
((), lambda x: x % 2 == 0),
208+
],
209+
)
210+
def test_quantify(seq: Sequence[int], pred: _Predicate[int]) -> None:
211+
assert itextras.quantify(seq) == aextras.quantify(seq)
212+
assert itextras.quantify(seq, pred) == aextras.quantify(seq, pred)
213+
214+
215+
@pytest.mark.parametrize(
216+
("func", "times", "args"),
217+
[
218+
(lambda: 1, 5, []),
219+
(lambda: 1, 0, []),
220+
(lambda x: x + 1, 10, [3]),
221+
(lambda x, y: x + y, 10, [3, 4]),
222+
],
223+
)
224+
def test_repeatfunc(func: Callable, times: int, args: Sequence[int]) -> None:
225+
assert _take(5, itextras.repeatfunc(func, None, *args)) == _take(
226+
5, aextras.repeatfunc(func, None, *args)
227+
)
228+
assert list(itextras.repeatfunc(func, times, *args)) == list(
229+
aextras.repeatfunc(func, times, *args)
230+
)
231+
232+
233+
@pytest.mark.parametrize(
234+
("seq1", "seq2"),
235+
[
236+
("abc", "def"),
237+
("a", "bc"),
238+
("ab", "c"),
239+
("", "abc"),
240+
("", ""),
241+
],
242+
)
243+
def test_roundrobin(seq1: str, seq2: str) -> None:
244+
assert list(itextras.roundrobin(seq1)) == list(aextras.roundrobin(seq1))
245+
assert list(itextras.roundrobin(seq1, seq2)) == list(aextras.roundrobin(seq1, seq2))
246+
247+
248+
@pytest.mark.parametrize(
249+
("func", "start"),
250+
[
251+
(lambda x: 2 * x, 17),
252+
(lambda x: -x, -3),
253+
],
254+
)
255+
def test_tabulate(func: Callable[[int], int], start: int) -> None:
256+
assert _take(5, itextras.tabulate(func)) == _take(5, aextras.tabulate(func))
257+
assert _take(5, itextras.tabulate(func, start)) == _take(
258+
5, aextras.tabulate(func, start)
259+
)
260+
261+
262+
@pytest.mark.parametrize(
263+
("n", "seq"),
264+
[
265+
(3, "abcdefg"),
266+
(0, "abcdefg"),
267+
(10, "abcdefg"),
268+
(5, ""),
269+
],
270+
)
271+
def test_tail(n: int, seq: str) -> None:
272+
assert list(itextras.tail(n, seq)) == list(aextras.tail(n, seq))
273+
274+
275+
@pytest.mark.parametrize(
276+
("n", "seq"),
277+
[
278+
(3, "abcdefg"),
279+
(0, "abcdefg"),
280+
(10, "abcdefg"),
281+
(5, ""),
282+
],
283+
)
284+
def test_take(n: int, seq: str) -> None:
285+
assert list(itextras.take(n, seq)) == list(aextras.take(n, seq))

0 commit comments

Comments
 (0)
Please sign in to comment.