Skip to content

Commit f1d4f6f

Browse files
committed
add support for fields with selection, resolves #112
1 parent 4cdcced commit f1d4f6f

File tree

3 files changed

+337
-200
lines changed

3 files changed

+337
-200
lines changed

zarr/core.py

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,15 @@
1616
from zarr.compat import reduce
1717
from zarr.codecs import AsType, get_codec
1818
from zarr.indexing import OIndex, OrthogonalIndexer, BasicIndexer, VIndex, CoordinateIndexer, \
19-
MaskIndexer
19+
MaskIndexer, check_fields, pop_fields, ensure_tuple
20+
21+
22+
def is_scalar(value, dtype):
23+
if np.isscalar(value):
24+
return True
25+
if isinstance(value, tuple) and dtype.names and len(value) == len(dtype.names):
26+
return True
27+
return False
2028

2129

2230
class Array(object):
@@ -460,19 +468,10 @@ def __getitem__(self, selection):
460468
461469
"""
462470

463-
if len(self._shape) == 0:
464-
return self._get_basic_selection_zd(selection)
465-
466-
elif len(self._shape) == 1:
467-
# safe to do "fancy" indexing, no ambiguity
468-
return self.get_orthogonal_selection(selection)
469-
470-
else:
471-
# "fancy" indexing can be ambiguous/hard to understand for multidimensional arrays,
472-
# force people to go through explicit methods
473-
return self.get_basic_selection(selection)
471+
fields, selection = pop_fields(selection)
472+
return self.get_basic_selection(selection, fields=fields)
474473

475-
def get_basic_selection(self, selection, out=None):
474+
def get_basic_selection(self, selection, out=None, fields=None):
476475
"""TODO"""
477476

478477
# refresh metadata
@@ -481,15 +480,16 @@ def get_basic_selection(self, selection, out=None):
481480

482481
# handle zero-dimensional arrays
483482
if self._shape == ():
484-
return self._get_basic_selection_zd(selection, out=out)
483+
return self._get_basic_selection_zd(selection=selection, out=out, fields=fields)
485484
else:
486-
return self._get_basic_selection_nd(selection, out=out)
485+
return self._get_basic_selection_nd(selection=selection, out=out, fields=fields)
487486

488-
def _get_basic_selection_zd(self, selection, out=None):
487+
def _get_basic_selection_zd(self, selection, out=None, fields=None):
489488
# special case basic selection for zero-dimensional array
490489

491490
# check selection is valid
492-
if selection not in ((), Ellipsis):
491+
selection = ensure_tuple(selection)
492+
if selection not in ((), (Ellipsis,)):
493493
raise IndexError('too many indices for array')
494494

495495
try:
@@ -514,17 +514,21 @@ def _get_basic_selection_zd(self, selection, out=None):
514514
else:
515515
out[selection] = chunk[selection]
516516

517+
# handle fields
518+
if fields:
519+
out = out[fields]
520+
517521
return out
518522

519-
def _get_basic_selection_nd(self, selection, out=None):
523+
def _get_basic_selection_nd(self, selection, out=None, fields=None):
520524
# implementation of basic selection for array with at least one dimension
521525

522526
# setup indexer
523527
indexer = BasicIndexer(selection, self)
524528

525-
return self._get_selection(indexer, out=out)
529+
return self._get_selection(indexer=indexer, out=out, fields=fields)
526530

527-
def get_orthogonal_selection(self, selection, out=None):
531+
def get_orthogonal_selection(self, selection, out=None, fields=None):
528532
"""TODO"""
529533

530534
# refresh metadata
@@ -534,9 +538,9 @@ def get_orthogonal_selection(self, selection, out=None):
534538
# setup indexer
535539
indexer = OrthogonalIndexer(selection, self)
536540

537-
return self._get_selection(indexer, out=out)
541+
return self._get_selection(indexer=indexer, out=out, fields=fields)
538542

539-
def get_coordinate_selection(self, selection, out=None):
543+
def get_coordinate_selection(self, selection, out=None, fields=None):
540544
"""TODO"""
541545

542546
# refresh metadata
@@ -546,9 +550,9 @@ def get_coordinate_selection(self, selection, out=None):
546550
# setup indexer
547551
indexer = CoordinateIndexer(selection, self)
548552

549-
return self._get_selection(indexer, out=out)
553+
return self._get_selection(indexer=indexer, out=out, fields=fields)
550554

551-
def get_mask_selection(self, selection, out=None):
555+
def get_mask_selection(self, selection, out=None, fields=None):
552556
"""TODO"""
553557

554558
# refresh metadata
@@ -558,9 +562,9 @@ def get_mask_selection(self, selection, out=None):
558562
# setup indexer
559563
indexer = MaskIndexer(selection, self)
560564

561-
return self._get_selection(indexer, out=out)
565+
return self._get_selection(indexer=indexer, out=out, fields=fields)
562566

563-
def _get_selection(self, indexer, out=None):
567+
def _get_selection(self, indexer, out=None, fields=None):
564568

565569
# We iterate over all chunks which overlap the selection and thus contain data that needs
566570
# to be extracted. Each chunk is processed in turn, extracting the necessary data and
@@ -569,25 +573,28 @@ def _get_selection(self, indexer, out=None):
569573
# N.B., it is an important optimisation that we only visit chunks which overlap the
570574
# selection. This minimises the nuimber of iterations in the main for loop.
571575

576+
# check fields are sensible
577+
out_dtype = check_fields(fields, self._dtype)
578+
572579
# determine output shape
573-
sel_shape = indexer.shape
580+
out_shape = indexer.shape
574581

575582
# setup output array
576583
if out is None:
577-
out = np.empty(sel_shape, dtype=self._dtype, order=self._order)
584+
out = np.empty(out_shape, dtype=out_dtype, order=self._order)
578585
else:
579586
# validate 'out' parameter
580587
if not hasattr(out, 'shape'):
581588
raise TypeError('out must be an array-like object')
582-
if out.shape != sel_shape:
589+
if out.shape != out_shape:
583590
raise ValueError('out has wrong shape for selection')
584591

585592
# iterate over chunks
586593
for chunk_coords, chunk_selection, out_selection in indexer:
587594

588595
# load chunk selection into output array
589596
self._chunk_getitem(chunk_coords, chunk_selection, out, out_selection,
590-
drop_axes=indexer.drop_axes)
597+
drop_axes=indexer.drop_axes, fields=fields)
591598

592599
if out.shape:
593600
return out
@@ -653,19 +660,10 @@ def __setitem__(self, selection, value):
653660
654661
"""
655662

656-
if len(self._shape) == 0:
657-
self._set_basic_selection_zd(selection, value)
663+
fields, selection = pop_fields(selection)
664+
self.set_basic_selection(selection, value, fields=fields)
658665

659-
elif len(self._shape) == 1:
660-
# safe to do "fancy" indexing, no ambiguity
661-
self.set_orthogonal_selection(selection, value)
662-
663-
else:
664-
# "fancy" indexing can be ambiguous/hard to understand for multidimensional arrays,
665-
# force people to go through explicit methods
666-
self.set_basic_selection(selection, value)
667-
668-
def set_basic_selection(self, selection, value):
666+
def set_basic_selection(self, selection, value, fields=None):
669667
"""TODO"""
670668

671669
# guard conditions
@@ -678,11 +676,11 @@ def set_basic_selection(self, selection, value):
678676

679677
# handle zero-dimensional arrays
680678
if self._shape == ():
681-
return self._set_basic_selection_zd(selection, value)
679+
return self._set_basic_selection_zd(selection, value, fields=fields)
682680
else:
683-
return self._set_basic_selection_nd(selection, value)
681+
return self._set_basic_selection_nd(selection, value, fields=fields)
684682

685-
def set_orthogonal_selection(self, selection, value):
683+
def set_orthogonal_selection(self, selection, value, fields=None):
686684
"""TODO"""
687685

688686
# guard conditions
@@ -696,9 +694,9 @@ def set_orthogonal_selection(self, selection, value):
696694
# setup indexer
697695
indexer = OrthogonalIndexer(selection, self)
698696

699-
self._set_selection(indexer, value)
697+
self._set_selection(indexer, value, fields=fields)
700698

701-
def set_coordinate_selection(self, selection, value):
699+
def set_coordinate_selection(self, selection, value, fields=None):
702700
"""TODO"""
703701

704702
# guard conditions
@@ -712,9 +710,9 @@ def set_coordinate_selection(self, selection, value):
712710
# setup indexer
713711
indexer = CoordinateIndexer(selection, self)
714712

715-
self._set_selection(indexer, value)
713+
self._set_selection(indexer, value, fields=fields)
716714

717-
def set_mask_selection(self, selection, value):
715+
def set_mask_selection(self, selection, value, fields=None):
718716
"""TODO"""
719717

720718
# guard conditions
@@ -728,13 +726,17 @@ def set_mask_selection(self, selection, value):
728726
# setup indexer
729727
indexer = MaskIndexer(selection, self)
730728

731-
self._set_selection(indexer, value)
729+
self._set_selection(indexer, value, fields=fields)
732730

733-
def _set_basic_selection_zd(self, selection, value):
731+
def _set_basic_selection_zd(self, selection, value, fields=None):
734732
# special case __setitem__ for zero-dimensional array
735733

734+
if fields:
735+
raise IndexError('fields not supported for 0d array')
736+
736737
# check item is valid
737-
if selection not in ((), Ellipsis):
738+
selection = ensure_tuple(selection)
739+
if selection not in ((), (Ellipsis,)):
738740
raise IndexError('too many indices for array')
739741

740742
# setup data to store
@@ -751,15 +753,15 @@ def _set_basic_selection_zd(self, selection, value):
751753
cdata = self._encode_chunk(arr)
752754
self.chunk_store[ckey] = cdata
753755

754-
def _set_basic_selection_nd(self, selection, value):
756+
def _set_basic_selection_nd(self, selection, value, fields=None):
755757
# implementation of __setitem__ for array with at least one dimension
756758

757759
# setup indexer
758760
indexer = BasicIndexer(selection, self)
759761

760-
self._set_selection(indexer, value)
762+
self._set_selection(indexer, value, fields=fields)
761763

762-
def _set_selection(self, indexer, value):
764+
def _set_selection(self, indexer, value, fields=None):
763765

764766
# We iterate over all chunks which overlap the selection and thus contain data that needs
765767
# to be replaced. Each chunk is processed in turn, extracting the necessary data from the
@@ -768,15 +770,20 @@ def _set_selection(self, indexer, value):
768770
# N.B., it is an important optimisation that we only visit chunks which overlap the
769771
# selection. This minimises the nuimber of iterations in the main for loop.
770772

773+
# check fields are sensible
774+
check_fields(fields, self._dtype)
775+
if fields and isinstance(fields, list):
776+
raise ValueError('multi-field assignment is not supported')
777+
771778
# determine indices of chunks overlapping the selection
772779
sel_shape = indexer.shape
773780

774781
# check value shape
775-
if np.isscalar(value):
782+
if is_scalar(value, self._dtype):
776783
pass
777784
else:
778785
if not hasattr(value, 'shape'):
779-
raise TypeError('value must be an array-like object')
786+
value = np.asarray(value)
780787
if value.shape != sel_shape:
781788
raise ValueError('value has wrong shape for selection; expected {}, got {}'
782789
.format(sel_shape, value.shape))
@@ -785,7 +792,7 @@ def _set_selection(self, indexer, value):
785792
for chunk_coords, chunk_selection, out_selection in indexer:
786793

787794
# extract data to store
788-
if np.isscalar(value):
795+
if is_scalar(value, self._dtype):
789796
chunk_value = value
790797
else:
791798
chunk_value = value[out_selection]
@@ -797,9 +804,10 @@ def _set_selection(self, indexer, value):
797804
chunk_value = chunk_value[item]
798805

799806
# put data
800-
self._chunk_setitem(chunk_coords, chunk_selection, chunk_value)
807+
self._chunk_setitem(chunk_coords, chunk_selection, chunk_value, fields=fields)
801808

802-
def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop_axes=None):
809+
def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop_axes=None,
810+
fields=None):
803811
"""Obtain part or whole of a chunk.
804812
805813
Parameters
@@ -814,6 +822,8 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
814822
Location of region within output array to store results in.
815823
drop_axes : tuple of ints
816824
Axes to squeeze out of the chunk.
825+
fields
826+
TODO
817827
818828
"""
819829

@@ -833,10 +843,11 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
833843

834844
else:
835845

836-
if isinstance(out, np.ndarray) and \
837-
isinstance(out_selection, slice) and \
838-
is_total_slice(chunk_selection, self._chunks) and \
839-
not self._filters:
846+
if (isinstance(out, np.ndarray) and
847+
not fields and
848+
isinstance(out_selection, slice) and
849+
is_total_slice(chunk_selection, self._chunks) and
850+
not self._filters):
840851

841852
dest = out[out_selection]
842853
contiguous = ((self._order == 'C' and dest.flags.c_contiguous) or
@@ -859,13 +870,17 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
859870
# decode chunk
860871
chunk = self._decode_chunk(cdata)
861872

862-
# set data in output array
873+
# select data from chunk
874+
if fields:
875+
chunk = chunk[fields]
863876
tmp = chunk[chunk_selection]
864877
if drop_axes:
865878
tmp = np.squeeze(tmp, axis=drop_axes)
879+
880+
# store selected data in output
866881
out[out_selection] = tmp
867882

868-
def _chunk_setitem(self, chunk_coords, chunk_selection, value):
883+
def _chunk_setitem(self, chunk_coords, chunk_selection, value, fields=None):
869884
"""Replace part or whole of a chunk.
870885
871886
Parameters
@@ -881,25 +896,25 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value):
881896

882897
# synchronization
883898
if self._synchronizer is None:
884-
self._chunk_setitem_nosync(chunk_coords, chunk_selection, value)
899+
self._chunk_setitem_nosync(chunk_coords, chunk_selection, value, fields=fields)
885900
else:
886901
# synchronize on the chunk
887902
ckey = self._chunk_key(chunk_coords)
888903
with self._synchronizer[ckey]:
889-
self._chunk_setitem_nosync(chunk_coords, chunk_selection, value)
904+
self._chunk_setitem_nosync(chunk_coords, chunk_selection, value, fields=fields)
890905

891-
def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value):
906+
def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value, fields=None):
892907

893908
# obtain key for chunk storage
894909
ckey = self._chunk_key(chunk_coords)
895910

896-
if is_total_slice(chunk_selection, self._chunks):
911+
if is_total_slice(chunk_selection, self._chunks) and not fields:
897912
# totally replace chunk
898913

899914
# optimization: we are completely replacing the chunk, so no need
900915
# to access the existing chunk data
901916

902-
if np.isscalar(value):
917+
if is_scalar(value, self._dtype):
903918

904919
# setup array filled with value
905920
chunk = np.empty(self._chunks, dtype=self._dtype, order=self._order)
@@ -950,7 +965,12 @@ def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value):
950965
chunk = chunk.copy(order='K')
951966

952967
# modify
953-
chunk[chunk_selection] = value
968+
if fields:
969+
# N.B., currently multi-field assignment is not supported in numpy, so this only
970+
# works for a single field
971+
chunk[fields][chunk_selection] = value
972+
else:
973+
chunk[chunk_selection] = value
954974

955975
# encode chunk
956976
cdata = self._encode_chunk(chunk)

0 commit comments

Comments
 (0)