Skip to content

Commit 3adaa1a

Browse files
iurytrotterdylan
iury
authored andcommitted
make itertools.chain be a class and from_iterable a classmethod of chain
1 parent d9af156 commit 3adaa1a

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

lib/itertools.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,20 @@
1717
import _collections
1818
import sys
1919

20-
def chain(*iterables):
21-
for it in iterables:
22-
for element in it:
23-
yield element
20+
class chain(object):
21+
22+
def from_iterable(cls, iterable):
23+
return cls(*iterable)
24+
25+
from_iterable = classmethod(from_iterable)
26+
27+
def __init__(self, *iterables):
28+
self.iterables = iterables
29+
30+
def __iter__(self):
31+
for it in self.iterables:
32+
for element in it:
33+
yield element
2434

2535
def compress(data, selectors):
2636
return (d for d,s in izip(data, selectors) if s)
@@ -49,11 +59,6 @@ def dropwhile(predicate, iterable):
4959
for x in iterable:
5060
yield x
5161

52-
def from_iterable(iterables):
53-
for it in iterables:
54-
for element in it:
55-
yield element
56-
5762
def ifilter(predicate, iterable):
5863
if predicate is None:
5964
predicate = bool

lib/itertools_test.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ def TestDropwhile():
4646
got = tuple(itertools.dropwhile(*args))
4747
assert got == want, 'tuple(dropwhile%s) == %s, want %s' % (args, got, want)
4848

49+
def TestChain():
50+
r = range(10)
51+
cases = [
52+
([r], tuple(r)),
53+
([r, r], tuple(r) + tuple(r)),
54+
([], ())
55+
]
56+
for args, want in cases:
57+
got = tuple(itertools.chain(*args))
58+
assert got == want, 'tuple(chain%s) == %s, want %s' % (args, got, want)
59+
4960
def TestFromIterable():
5061
r = range(10)
5162
cases = [
@@ -54,7 +65,7 @@ def TestFromIterable():
5465
([], ())
5566
]
5667
for args, want in cases:
57-
got = tuple(itertools.from_iterable(args))
68+
got = tuple(itertools.chain.from_iterable(args))
5869
assert got == want, 'tuple(from_iterable%s) == %s, want %s' % (args, got, want)
5970

6071
def TestIFilter():

0 commit comments

Comments
 (0)