16
16
from zarr .compat import reduce
17
17
from zarr .codecs import AsType , get_codec
18
18
from zarr .indexing import OIndex , OrthogonalIndexer , BasicIndexer , VIndex , CoordinateIndexer , \
19
- MaskIndexer
19
+ MaskIndexer , check_fields , pop_fields , ensure_tuple
20
+
21
+
22
+ def is_scalar (value , dtype ):
23
+ if np .isscalar (value ):
24
+ return True
25
+ if isinstance (value , tuple ) and dtype .names and len (value ) == len (dtype .names ):
26
+ return True
27
+ return False
20
28
21
29
22
30
class Array (object ):
@@ -460,19 +468,10 @@ def __getitem__(self, selection):
460
468
461
469
"""
462
470
463
- if len (self ._shape ) == 0 :
464
- return self ._get_basic_selection_zd (selection )
465
-
466
- elif len (self ._shape ) == 1 :
467
- # safe to do "fancy" indexing, no ambiguity
468
- return self .get_orthogonal_selection (selection )
469
-
470
- else :
471
- # "fancy" indexing can be ambiguous/hard to understand for multidimensional arrays,
472
- # force people to go through explicit methods
473
- return self .get_basic_selection (selection )
471
+ fields , selection = pop_fields (selection )
472
+ return self .get_basic_selection (selection , fields = fields )
474
473
475
- def get_basic_selection (self , selection , out = None ):
474
+ def get_basic_selection (self , selection , out = None , fields = None ):
476
475
"""TODO"""
477
476
478
477
# refresh metadata
@@ -481,15 +480,16 @@ def get_basic_selection(self, selection, out=None):
481
480
482
481
# handle zero-dimensional arrays
483
482
if self ._shape == ():
484
- return self ._get_basic_selection_zd (selection , out = out )
483
+ return self ._get_basic_selection_zd (selection = selection , out = out , fields = fields )
485
484
else :
486
- return self ._get_basic_selection_nd (selection , out = out )
485
+ return self ._get_basic_selection_nd (selection = selection , out = out , fields = fields )
487
486
488
- def _get_basic_selection_zd (self , selection , out = None ):
487
+ def _get_basic_selection_zd (self , selection , out = None , fields = None ):
489
488
# special case basic selection for zero-dimensional array
490
489
491
490
# check selection is valid
492
- if selection not in ((), Ellipsis ):
491
+ selection = ensure_tuple (selection )
492
+ if selection not in ((), (Ellipsis ,)):
493
493
raise IndexError ('too many indices for array' )
494
494
495
495
try :
@@ -514,17 +514,21 @@ def _get_basic_selection_zd(self, selection, out=None):
514
514
else :
515
515
out [selection ] = chunk [selection ]
516
516
517
+ # handle fields
518
+ if fields :
519
+ out = out [fields ]
520
+
517
521
return out
518
522
519
- def _get_basic_selection_nd (self , selection , out = None ):
523
+ def _get_basic_selection_nd (self , selection , out = None , fields = None ):
520
524
# implementation of basic selection for array with at least one dimension
521
525
522
526
# setup indexer
523
527
indexer = BasicIndexer (selection , self )
524
528
525
- return self ._get_selection (indexer , out = out )
529
+ return self ._get_selection (indexer = indexer , out = out , fields = fields )
526
530
527
- def get_orthogonal_selection (self , selection , out = None ):
531
+ def get_orthogonal_selection (self , selection , out = None , fields = None ):
528
532
"""TODO"""
529
533
530
534
# refresh metadata
@@ -534,9 +538,9 @@ def get_orthogonal_selection(self, selection, out=None):
534
538
# setup indexer
535
539
indexer = OrthogonalIndexer (selection , self )
536
540
537
- return self ._get_selection (indexer , out = out )
541
+ return self ._get_selection (indexer = indexer , out = out , fields = fields )
538
542
539
- def get_coordinate_selection (self , selection , out = None ):
543
+ def get_coordinate_selection (self , selection , out = None , fields = None ):
540
544
"""TODO"""
541
545
542
546
# refresh metadata
@@ -546,9 +550,9 @@ def get_coordinate_selection(self, selection, out=None):
546
550
# setup indexer
547
551
indexer = CoordinateIndexer (selection , self )
548
552
549
- return self ._get_selection (indexer , out = out )
553
+ return self ._get_selection (indexer = indexer , out = out , fields = fields )
550
554
551
- def get_mask_selection (self , selection , out = None ):
555
+ def get_mask_selection (self , selection , out = None , fields = None ):
552
556
"""TODO"""
553
557
554
558
# refresh metadata
@@ -558,9 +562,9 @@ def get_mask_selection(self, selection, out=None):
558
562
# setup indexer
559
563
indexer = MaskIndexer (selection , self )
560
564
561
- return self ._get_selection (indexer , out = out )
565
+ return self ._get_selection (indexer = indexer , out = out , fields = fields )
562
566
563
- def _get_selection (self , indexer , out = None ):
567
+ def _get_selection (self , indexer , out = None , fields = None ):
564
568
565
569
# We iterate over all chunks which overlap the selection and thus contain data that needs
566
570
# to be extracted. Each chunk is processed in turn, extracting the necessary data and
@@ -569,25 +573,28 @@ def _get_selection(self, indexer, out=None):
569
573
# N.B., it is an important optimisation that we only visit chunks which overlap the
570
574
# selection. This minimises the nuimber of iterations in the main for loop.
571
575
576
+ # check fields are sensible
577
+ out_dtype = check_fields (fields , self ._dtype )
578
+
572
579
# determine output shape
573
- sel_shape = indexer .shape
580
+ out_shape = indexer .shape
574
581
575
582
# setup output array
576
583
if out is None :
577
- out = np .empty (sel_shape , dtype = self . _dtype , order = self ._order )
584
+ out = np .empty (out_shape , dtype = out_dtype , order = self ._order )
578
585
else :
579
586
# validate 'out' parameter
580
587
if not hasattr (out , 'shape' ):
581
588
raise TypeError ('out must be an array-like object' )
582
- if out .shape != sel_shape :
589
+ if out .shape != out_shape :
583
590
raise ValueError ('out has wrong shape for selection' )
584
591
585
592
# iterate over chunks
586
593
for chunk_coords , chunk_selection , out_selection in indexer :
587
594
588
595
# load chunk selection into output array
589
596
self ._chunk_getitem (chunk_coords , chunk_selection , out , out_selection ,
590
- drop_axes = indexer .drop_axes )
597
+ drop_axes = indexer .drop_axes , fields = fields )
591
598
592
599
if out .shape :
593
600
return out
@@ -653,19 +660,10 @@ def __setitem__(self, selection, value):
653
660
654
661
"""
655
662
656
- if len ( self . _shape ) == 0 :
657
- self ._set_basic_selection_zd (selection , value )
663
+ fields , selection = pop_fields ( selection )
664
+ self .set_basic_selection (selection , value , fields = fields )
658
665
659
- elif len (self ._shape ) == 1 :
660
- # safe to do "fancy" indexing, no ambiguity
661
- self .set_orthogonal_selection (selection , value )
662
-
663
- else :
664
- # "fancy" indexing can be ambiguous/hard to understand for multidimensional arrays,
665
- # force people to go through explicit methods
666
- self .set_basic_selection (selection , value )
667
-
668
- def set_basic_selection (self , selection , value ):
666
+ def set_basic_selection (self , selection , value , fields = None ):
669
667
"""TODO"""
670
668
671
669
# guard conditions
@@ -678,11 +676,11 @@ def set_basic_selection(self, selection, value):
678
676
679
677
# handle zero-dimensional arrays
680
678
if self ._shape == ():
681
- return self ._set_basic_selection_zd (selection , value )
679
+ return self ._set_basic_selection_zd (selection , value , fields = fields )
682
680
else :
683
- return self ._set_basic_selection_nd (selection , value )
681
+ return self ._set_basic_selection_nd (selection , value , fields = fields )
684
682
685
- def set_orthogonal_selection (self , selection , value ):
683
+ def set_orthogonal_selection (self , selection , value , fields = None ):
686
684
"""TODO"""
687
685
688
686
# guard conditions
@@ -696,9 +694,9 @@ def set_orthogonal_selection(self, selection, value):
696
694
# setup indexer
697
695
indexer = OrthogonalIndexer (selection , self )
698
696
699
- self ._set_selection (indexer , value )
697
+ self ._set_selection (indexer , value , fields = fields )
700
698
701
- def set_coordinate_selection (self , selection , value ):
699
+ def set_coordinate_selection (self , selection , value , fields = None ):
702
700
"""TODO"""
703
701
704
702
# guard conditions
@@ -712,9 +710,9 @@ def set_coordinate_selection(self, selection, value):
712
710
# setup indexer
713
711
indexer = CoordinateIndexer (selection , self )
714
712
715
- self ._set_selection (indexer , value )
713
+ self ._set_selection (indexer , value , fields = fields )
716
714
717
- def set_mask_selection (self , selection , value ):
715
+ def set_mask_selection (self , selection , value , fields = None ):
718
716
"""TODO"""
719
717
720
718
# guard conditions
@@ -728,13 +726,17 @@ def set_mask_selection(self, selection, value):
728
726
# setup indexer
729
727
indexer = MaskIndexer (selection , self )
730
728
731
- self ._set_selection (indexer , value )
729
+ self ._set_selection (indexer , value , fields = fields )
732
730
733
- def _set_basic_selection_zd (self , selection , value ):
731
+ def _set_basic_selection_zd (self , selection , value , fields = None ):
734
732
# special case __setitem__ for zero-dimensional array
735
733
734
+ if fields :
735
+ raise IndexError ('fields not supported for 0d array' )
736
+
736
737
# check item is valid
737
- if selection not in ((), Ellipsis ):
738
+ selection = ensure_tuple (selection )
739
+ if selection not in ((), (Ellipsis ,)):
738
740
raise IndexError ('too many indices for array' )
739
741
740
742
# setup data to store
@@ -751,15 +753,15 @@ def _set_basic_selection_zd(self, selection, value):
751
753
cdata = self ._encode_chunk (arr )
752
754
self .chunk_store [ckey ] = cdata
753
755
754
- def _set_basic_selection_nd (self , selection , value ):
756
+ def _set_basic_selection_nd (self , selection , value , fields = None ):
755
757
# implementation of __setitem__ for array with at least one dimension
756
758
757
759
# setup indexer
758
760
indexer = BasicIndexer (selection , self )
759
761
760
- self ._set_selection (indexer , value )
762
+ self ._set_selection (indexer , value , fields = fields )
761
763
762
- def _set_selection (self , indexer , value ):
764
+ def _set_selection (self , indexer , value , fields = None ):
763
765
764
766
# We iterate over all chunks which overlap the selection and thus contain data that needs
765
767
# to be replaced. Each chunk is processed in turn, extracting the necessary data from the
@@ -768,15 +770,20 @@ def _set_selection(self, indexer, value):
768
770
# N.B., it is an important optimisation that we only visit chunks which overlap the
769
771
# selection. This minimises the nuimber of iterations in the main for loop.
770
772
773
+ # check fields are sensible
774
+ check_fields (fields , self ._dtype )
775
+ if fields and isinstance (fields , list ):
776
+ raise ValueError ('multi-field assignment is not supported' )
777
+
771
778
# determine indices of chunks overlapping the selection
772
779
sel_shape = indexer .shape
773
780
774
781
# check value shape
775
- if np . isscalar (value ):
782
+ if is_scalar (value , self . _dtype ):
776
783
pass
777
784
else :
778
785
if not hasattr (value , 'shape' ):
779
- raise TypeError ( ' value must be an array-like object' )
786
+ value = np . asarray ( value )
780
787
if value .shape != sel_shape :
781
788
raise ValueError ('value has wrong shape for selection; expected {}, got {}'
782
789
.format (sel_shape , value .shape ))
@@ -785,7 +792,7 @@ def _set_selection(self, indexer, value):
785
792
for chunk_coords , chunk_selection , out_selection in indexer :
786
793
787
794
# extract data to store
788
- if np . isscalar (value ):
795
+ if is_scalar (value , self . _dtype ):
789
796
chunk_value = value
790
797
else :
791
798
chunk_value = value [out_selection ]
@@ -797,9 +804,10 @@ def _set_selection(self, indexer, value):
797
804
chunk_value = chunk_value [item ]
798
805
799
806
# put data
800
- self ._chunk_setitem (chunk_coords , chunk_selection , chunk_value )
807
+ self ._chunk_setitem (chunk_coords , chunk_selection , chunk_value , fields = fields )
801
808
802
- def _chunk_getitem (self , chunk_coords , chunk_selection , out , out_selection , drop_axes = None ):
809
+ def _chunk_getitem (self , chunk_coords , chunk_selection , out , out_selection , drop_axes = None ,
810
+ fields = None ):
803
811
"""Obtain part or whole of a chunk.
804
812
805
813
Parameters
@@ -814,6 +822,8 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
814
822
Location of region within output array to store results in.
815
823
drop_axes : tuple of ints
816
824
Axes to squeeze out of the chunk.
825
+ fields
826
+ TODO
817
827
818
828
"""
819
829
@@ -833,10 +843,11 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
833
843
834
844
else :
835
845
836
- if isinstance (out , np .ndarray ) and \
837
- isinstance (out_selection , slice ) and \
838
- is_total_slice (chunk_selection , self ._chunks ) and \
839
- not self ._filters :
846
+ if (isinstance (out , np .ndarray ) and
847
+ not fields and
848
+ isinstance (out_selection , slice ) and
849
+ is_total_slice (chunk_selection , self ._chunks ) and
850
+ not self ._filters ):
840
851
841
852
dest = out [out_selection ]
842
853
contiguous = ((self ._order == 'C' and dest .flags .c_contiguous ) or
@@ -859,13 +870,17 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, drop
859
870
# decode chunk
860
871
chunk = self ._decode_chunk (cdata )
861
872
862
- # set data in output array
873
+ # select data from chunk
874
+ if fields :
875
+ chunk = chunk [fields ]
863
876
tmp = chunk [chunk_selection ]
864
877
if drop_axes :
865
878
tmp = np .squeeze (tmp , axis = drop_axes )
879
+
880
+ # store selected data in output
866
881
out [out_selection ] = tmp
867
882
868
- def _chunk_setitem (self , chunk_coords , chunk_selection , value ):
883
+ def _chunk_setitem (self , chunk_coords , chunk_selection , value , fields = None ):
869
884
"""Replace part or whole of a chunk.
870
885
871
886
Parameters
@@ -881,25 +896,25 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value):
881
896
882
897
# synchronization
883
898
if self ._synchronizer is None :
884
- self ._chunk_setitem_nosync (chunk_coords , chunk_selection , value )
899
+ self ._chunk_setitem_nosync (chunk_coords , chunk_selection , value , fields = fields )
885
900
else :
886
901
# synchronize on the chunk
887
902
ckey = self ._chunk_key (chunk_coords )
888
903
with self ._synchronizer [ckey ]:
889
- self ._chunk_setitem_nosync (chunk_coords , chunk_selection , value )
904
+ self ._chunk_setitem_nosync (chunk_coords , chunk_selection , value , fields = fields )
890
905
891
- def _chunk_setitem_nosync (self , chunk_coords , chunk_selection , value ):
906
+ def _chunk_setitem_nosync (self , chunk_coords , chunk_selection , value , fields = None ):
892
907
893
908
# obtain key for chunk storage
894
909
ckey = self ._chunk_key (chunk_coords )
895
910
896
- if is_total_slice (chunk_selection , self ._chunks ):
911
+ if is_total_slice (chunk_selection , self ._chunks ) and not fields :
897
912
# totally replace chunk
898
913
899
914
# optimization: we are completely replacing the chunk, so no need
900
915
# to access the existing chunk data
901
916
902
- if np . isscalar (value ):
917
+ if is_scalar (value , self . _dtype ):
903
918
904
919
# setup array filled with value
905
920
chunk = np .empty (self ._chunks , dtype = self ._dtype , order = self ._order )
@@ -950,7 +965,12 @@ def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value):
950
965
chunk = chunk .copy (order = 'K' )
951
966
952
967
# modify
953
- chunk [chunk_selection ] = value
968
+ if fields :
969
+ # N.B., currently multi-field assignment is not supported in numpy, so this only
970
+ # works for a single field
971
+ chunk [fields ][chunk_selection ] = value
972
+ else :
973
+ chunk [chunk_selection ] = value
954
974
955
975
# encode chunk
956
976
cdata = self ._encode_chunk (chunk )
0 commit comments