Skip to content

Commit f8da93a

Browse files
authored
Merge pull request #122 from jakirkham/add_visitor_patterns
Add visitor pattern methods
2 parents d4e01f4 + 143954a commit f8da93a

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

docs/api/hierarchy.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ Groups (``zarr.hierarchy``)
1515
.. automethod:: groups
1616
.. automethod:: array_keys
1717
.. automethod:: arrays
18+
.. automethod:: visit
19+
.. automethod:: visitkeys
20+
.. automethod:: visitvalues
21+
.. automethod:: visititems
1822
.. automethod:: create_group
1923
.. automethod:: require_group
2024
.. automethod:: create_groups

zarr/hierarchy.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import absolute_import, print_function, division
33
from collections import MutableMapping
4+
from itertools import islice
45

56

67
import numpy as np
@@ -55,6 +56,10 @@ class Group(MutableMapping):
5556
groups
5657
array_keys
5758
arrays
59+
visit
60+
visitkeys
61+
visitvalues
62+
visititems
5863
create_group
5964
require_group
6065
create_groups
@@ -414,6 +419,129 @@ def arrays(self):
414419
chunk_store=self._chunk_store,
415420
synchronizer=self._synchronizer)
416421

422+
def visitvalues(self, func):
423+
"""Run ``func`` on each object.
424+
425+
Note: If ``func`` returns ``None`` (or doesn't return),
426+
iteration continues. However, if ``func`` returns
427+
anything else, it ceases and returns that value.
428+
429+
Examples
430+
--------
431+
>>> import zarr
432+
>>> g1 = zarr.group()
433+
>>> g2 = g1.create_group('foo')
434+
>>> g3 = g1.create_group('bar')
435+
>>> g4 = g3.create_group('baz')
436+
>>> g5 = g3.create_group('quux')
437+
>>> def print_visitor(obj):
438+
... print(obj)
439+
>>> g1.visitvalues(print_visitor)
440+
Group(/bar, 2)
441+
groups: 2; baz, quux
442+
store: DictStore
443+
Group(/bar/baz, 0)
444+
store: DictStore
445+
Group(/bar/quux, 0)
446+
store: DictStore
447+
Group(/foo, 0)
448+
store: DictStore
449+
>>> g3.visitvalues(print_visitor)
450+
Group(/bar/baz, 0)
451+
store: DictStore
452+
Group(/bar/quux, 0)
453+
store: DictStore
454+
455+
"""
456+
457+
def _visit(obj):
458+
yield obj
459+
460+
keys = sorted(getattr(obj, "keys", lambda : [])())
461+
for each_key in keys:
462+
for each_obj in _visit(obj[each_key]):
463+
yield each_obj
464+
465+
for each_obj in islice(_visit(self), 1, None):
466+
value = func(each_obj)
467+
if value is not None:
468+
return value
469+
470+
def visit(self, func):
471+
"""Run ``func`` on each object's path.
472+
473+
Note: If ``func`` returns ``None`` (or doesn't return),
474+
iteration continues. However, if ``func`` returns
475+
anything else, it ceases and returns that value.
476+
477+
Examples
478+
--------
479+
>>> import zarr
480+
>>> g1 = zarr.group()
481+
>>> g2 = g1.create_group('foo')
482+
>>> g3 = g1.create_group('bar')
483+
>>> g4 = g3.create_group('baz')
484+
>>> g5 = g3.create_group('quux')
485+
>>> def print_visitor(name):
486+
... print(name)
487+
>>> g1.visit(print_visitor)
488+
bar
489+
bar/baz
490+
bar/quux
491+
foo
492+
>>> g3.visit(print_visitor)
493+
baz
494+
quux
495+
496+
"""
497+
498+
base_len = len(self.name)
499+
return self.visitvalues(lambda o: func(o.name[base_len:].lstrip("/")))
500+
501+
def visitkeys(self, func):
502+
"""An alias for :py:meth:`~Group.visit`.
503+
"""
504+
505+
return self.visit(func)
506+
507+
def visititems(self, func):
508+
"""Run ``func`` on each object's path and the object itself.
509+
510+
Note: If ``func`` returns ``None`` (or doesn't return),
511+
iteration continues. However, if ``func`` returns
512+
anything else, it ceases and returns that value.
513+
514+
Examples
515+
--------
516+
>>> import zarr
517+
>>> g1 = zarr.group()
518+
>>> g2 = g1.create_group('foo')
519+
>>> g3 = g1.create_group('bar')
520+
>>> g4 = g3.create_group('baz')
521+
>>> g5 = g3.create_group('quux')
522+
>>> def print_visitor(name, obj):
523+
... print((name, obj))
524+
>>> g1.visititems(print_visitor)
525+
('bar', Group(/bar, 2)
526+
groups: 2; baz, quux
527+
store: DictStore)
528+
('bar/baz', Group(/bar/baz, 0)
529+
store: DictStore)
530+
('bar/quux', Group(/bar/quux, 0)
531+
store: DictStore)
532+
('foo', Group(/foo, 0)
533+
store: DictStore)
534+
>>> g3.visititems(print_visitor)
535+
('baz', Group(/bar/baz, 0)
536+
store: DictStore)
537+
('quux', Group(/bar/quux, 0)
538+
store: DictStore)
539+
540+
"""
541+
542+
base_len = len(self.name)
543+
return self.visitvalues(lambda o: func(o.name[base_len:].lstrip("/"), o))
544+
417545
def _write_op(self, f, *args, **kwargs):
418546

419547
# guard condition

zarr/tests/test_hierarchy.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,122 @@ def test_getitem_contains_iterators(self):
473473
eq('baz', arrays[0][0])
474474
eq(g1['foo']['baz'], arrays[0][1])
475475

476+
# visitor collection tests
477+
items = []
478+
479+
def visitor2(obj):
480+
items.append(obj.path)
481+
482+
def visitor3(name, obj=None):
483+
items.append(name)
484+
485+
def visitor4(name, obj):
486+
items.append((name, obj))
487+
488+
del items[:]
489+
g1.visitvalues(visitor2)
490+
eq([
491+
"a",
492+
"a/b",
493+
"a/b/c",
494+
"foo",
495+
"foo/bar",
496+
"foo/baz",
497+
], items)
498+
499+
del items[:]
500+
g1["foo"].visitvalues(visitor2)
501+
eq([
502+
"foo/bar",
503+
"foo/baz",
504+
], items)
505+
506+
del items[:]
507+
g1.visit(visitor3)
508+
eq([
509+
"a",
510+
"a/b",
511+
"a/b/c",
512+
"foo",
513+
"foo/bar",
514+
"foo/baz",
515+
], items)
516+
517+
del items[:]
518+
g1["foo"].visit(visitor3)
519+
eq([
520+
"bar",
521+
"baz",
522+
], items)
523+
524+
del items[:]
525+
g1.visitkeys(visitor3)
526+
eq([
527+
"a",
528+
"a/b",
529+
"a/b/c",
530+
"foo",
531+
"foo/bar",
532+
"foo/baz",
533+
], items)
534+
535+
del items[:]
536+
g1["foo"].visitkeys(visitor3)
537+
eq([
538+
"bar",
539+
"baz",
540+
], items)
541+
542+
del items[:]
543+
g1.visititems(visitor3)
544+
eq([
545+
"a",
546+
"a/b",
547+
"a/b/c",
548+
"foo",
549+
"foo/bar",
550+
"foo/baz",
551+
], items)
552+
553+
del items[:]
554+
g1["foo"].visititems(visitor3)
555+
eq([
556+
"bar",
557+
"baz",
558+
], items)
559+
560+
del items[:]
561+
g1.visititems(visitor4)
562+
for n, o in items:
563+
eq(g1[n], o)
564+
565+
del items[:]
566+
g1["foo"].visititems(visitor4)
567+
for n, o in items:
568+
eq(g1["foo"][n], o)
569+
570+
# visitor filter tests
571+
def visitor0(val, *args):
572+
name = getattr(val, "path", val)
573+
574+
if name == "a/b/c/d":
575+
return True # pragma: no cover
576+
577+
def visitor1(val, *args):
578+
name = getattr(val, "path", val)
579+
580+
if name == "a/b/c":
581+
return True # pragma: no cover
582+
583+
eq(None, g1.visit(visitor0))
584+
eq(None, g1.visitkeys(visitor0))
585+
eq(None, g1.visitvalues(visitor0))
586+
eq(None, g1.visititems(visitor0))
587+
eq(True, g1.visit(visitor1))
588+
eq(True, g1.visitkeys(visitor1))
589+
eq(True, g1.visitvalues(visitor1))
590+
eq(True, g1.visititems(visitor1))
591+
476592
def test_empty_getitem_contains_iterators(self):
477593
# setup
478594
g = self.create_group()

0 commit comments

Comments
 (0)