11
11
Optional ,
12
12
Tuple ,
13
13
Union ,
14
+ cast ,
14
15
)
15
16
16
17
import numpy as np
@@ -212,7 +213,10 @@ def from_variables(
212
213
213
214
@classmethod
214
215
def from_pandas_index (
215
- cls , index : pd .Index , dim : Hashable
216
+ cls ,
217
+ index : pd .Index ,
218
+ dim : Hashable ,
219
+ var_meta : Optional [Dict [Any , Dict ]] = None ,
216
220
) -> Tuple ["PandasIndex" , IndexVars ]:
217
221
from .variable import IndexVariable
218
222
@@ -223,10 +227,21 @@ def from_pandas_index(
223
227
else :
224
228
name = index .name
225
229
226
- data = PandasIndexingAdapter (index )
227
- index_var = IndexVariable (dim , data , fastpath = True )
230
+ if var_meta is None :
231
+ var_meta = {name : {}}
232
+
233
+ data = PandasIndexingAdapter (index , dtype = var_meta [name ].get ("dtype" ))
234
+ index_var = IndexVariable (
235
+ dim ,
236
+ data ,
237
+ fastpath = True ,
238
+ attrs = var_meta [name ].get ("attrs" ),
239
+ encoding = var_meta [name ].get ("encoding" ),
240
+ )
228
241
229
- return cls (index , dim ), {name : index_var }
242
+ return cls (index , dim , coord_dtype = var_meta [name ].get ("dtype" )), {
243
+ name : index_var
244
+ }
230
245
231
246
def to_pandas_index (self ) -> pd .Index :
232
247
return self .index
@@ -297,13 +312,11 @@ def rename(self, name_dict, dims_dict):
297
312
return self , {}
298
313
299
314
new_name = name_dict .get (self .index .name , self .index .name )
300
- pd_idx = self .index .rename (new_name )
315
+ index = self .index .rename (new_name )
301
316
new_dim = dims_dict .get (self .dim , self .dim )
317
+ var_meta = {new_name : {"dtype" : self .coord_dtype }}
302
318
303
- index , index_vars = self .from_pandas_index (pd_idx , dim = new_dim )
304
- index .coord_dtype = self .coord_dtype
305
-
306
- return index , index_vars
319
+ return self .from_pandas_index (index , dim = new_dim , var_meta = var_meta )
307
320
308
321
def copy (self , deep = True ):
309
322
return self ._replace (self .index .copy (deep = deep ))
@@ -411,13 +424,12 @@ def from_variables_maybe_expand(
411
424
"""Create a new multi-index maybe by expanding an existing one with
412
425
new variables as index levels.
413
426
414
- the index might be created along a new dimension.
427
+ The index and its corresponding coordinates may be created along a new dimension.
415
428
"""
416
429
names : List [Hashable ] = []
417
430
codes : List [List [int ]] = []
418
431
levels : List [List [int ]] = []
419
432
var_meta : Dict [str , Dict ] = {}
420
- level_coords_dtype : Dict [Hashable , Any ] = {}
421
433
422
434
_check_dim_compat ({** current_variables , ** variables })
423
435
@@ -427,20 +439,21 @@ def add_level_var(name, var):
427
439
"attrs" : var .attrs ,
428
440
"encoding" : var .encoding ,
429
441
}
430
- level_coords_dtype [name ] = var .dtype
431
442
432
443
if len (current_variables ) > 1 :
433
- current_index : pd .MultiIndex = next (
434
- iter (current_variables .values ())
435
- )._data .array
444
+ # expand from an existing multi-index
445
+ data = cast (
446
+ PandasMultiIndexingAdapter , next (iter (current_variables .values ()))._data
447
+ )
448
+ current_index = data .array
436
449
names .extend (current_index .names )
437
450
codes .extend (current_index .codes )
438
451
levels .extend (current_index .levels )
439
452
for name in current_index .names :
440
453
add_level_var (name , current_variables [name ])
441
454
442
455
elif len (current_variables ) == 1 :
443
- # one 1D variable (no multi-index): convert it to an index level
456
+ # expand from one 1D variable (no multi-index): convert it to an index level
444
457
var = next (iter (current_variables .values ()))
445
458
new_var_name = f"{ dim } _level_0"
446
459
names .append (new_var_name )
@@ -457,27 +470,63 @@ def add_level_var(name, var):
457
470
add_level_var (name , var )
458
471
459
472
index = pd .MultiIndex (levels , codes , names = names )
460
- obj = cls (index , dim , level_coords_dtype = level_coords_dtype )
461
- index_vars = _create_variables_from_multiindex (index , dim , var_meta = var_meta )
462
473
463
- return obj , index_vars
474
+ return cls .from_pandas_index (index , dim , var_meta = var_meta )
475
+
476
+ def keep_levels (
477
+ self , level_variables : Mapping [Any , "Variable" ]
478
+ ) -> Tuple [Union ["PandasMultiIndex" , PandasIndex ], IndexVars ]:
479
+ """Keep only the provided levels and return a new multi-index with its
480
+ corresponding coordinates.
481
+
482
+ """
483
+ var_meta : Dict [str , Dict ] = {}
484
+
485
+ for name , var in level_variables .items ():
486
+ var_meta [name ] = {
487
+ "dtype" : var .dtype ,
488
+ "attrs" : var .attrs ,
489
+ "encoding" : var .encoding ,
490
+ }
491
+
492
+ index = self .index .droplevel (
493
+ [k for k in self .index .names if k not in level_variables ]
494
+ )
495
+
496
+ if isinstance (index , pd .MultiIndex ):
497
+ return self .from_pandas_index (index , self .dim , var_meta = var_meta )
498
+ else :
499
+ return PandasIndex .from_pandas_index (index , self .dim , var_meta = var_meta )
464
500
465
501
@classmethod
466
502
def from_pandas_index (
467
- cls , index : pd .MultiIndex , dim : Hashable
503
+ cls ,
504
+ index : pd .MultiIndex ,
505
+ dim : Hashable ,
506
+ var_meta : Optional [Dict [Any , Dict ]] = None ,
468
507
) -> Tuple ["PandasMultiIndex" , IndexVars ]:
469
- var_meta = {}
508
+
509
+ names = []
510
+ idx_dtypes = {}
470
511
for i , idx in enumerate (index .levels ):
471
512
name = idx .name or f"{ dim } _level_{ i } "
472
513
if name == dim :
473
514
raise ValueError (
474
515
f"conflicting multi-index level name { name !r} with dimension { dim !r} "
475
516
)
476
- var_meta [name ] = {"dtype" : idx .dtype }
517
+ names .append (name )
518
+ idx_dtypes [name ] = idx .dtype
519
+
520
+ if var_meta is None :
521
+ var_meta = {k : {} for k in names }
522
+ for name , dtype in idx_dtypes .items ():
523
+ var_meta [name ]["dtype" ] = var_meta [name ].get ("dtype" , dtype )
524
+
525
+ level_coords_dtype = {k : var_meta [k ]["dtype" ] for k in names }
477
526
478
- index = index .rename (var_meta . keys () )
527
+ index = index .rename (names )
479
528
index_vars = _create_variables_from_multiindex (index , dim , var_meta = var_meta )
480
- return cls (index , dim ), index_vars
529
+ return cls (index , dim , level_coords_dtype = level_coords_dtype ), index_vars
481
530
482
531
def query (self , labels , method = None , tolerance = None ) -> QueryResult :
483
532
if method is not None or tolerance is not None :
@@ -570,15 +619,19 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult:
570
619
raise KeyError (f"not all values found in index { coord_name !r} " )
571
620
572
621
if new_index is not None :
622
+ # variable(s) attrs and encoding metadata are propagated
623
+ # when replacing the indexes in the resulting xarray object
624
+ var_meta = {k : {"dtype" : v } for k , v in self .level_coords_dtype .items ()}
625
+
573
626
if isinstance (new_index , pd .MultiIndex ):
574
627
new_index , new_vars = PandasMultiIndex .from_pandas_index (
575
- new_index , self .dim
628
+ new_index , self .dim , var_meta = var_meta
576
629
)
577
630
dims_dict = {}
578
631
drop_coords = set (self .index .names ) - set (new_index .index .names )
579
632
else :
580
633
new_index , new_vars = PandasIndex .from_pandas_index (
581
- new_index , new_index .name
634
+ new_index , new_index .name , var_meta = var_meta
582
635
)
583
636
dims_dict = {self .dim : new_index .index .name }
584
637
drop_coords = set (self .index .names ) - {new_index .index .name } | {
@@ -602,15 +655,14 @@ def rename(self, name_dict, dims_dict):
602
655
603
656
# pandas 1.3.0: could simply do `self.index.rename(names_dict)`
604
657
new_names = [name_dict .get (k , k ) for k in self .index .names ]
605
- pd_idx = self .index .rename (new_names )
606
- new_dim = dims_dict .get (self .dim , self .dim )
658
+ index = self .index .rename (new_names )
607
659
608
- index , index_vars = self . from_pandas_index ( pd_idx , new_dim )
609
- index . level_coords_dtype = {
610
- k : v for k , v in zip (new_names , self .level_coords_dtype .values ())
660
+ new_dim = dims_dict . get ( self . dim , self . dim )
661
+ var_meta = {
662
+ k : { "dtype" : v } for k , v in zip (new_names , self .level_coords_dtype .values ())
611
663
}
612
664
613
- return index , index_vars
665
+ return self . from_pandas_index ( index , new_dim , var_meta = var_meta )
614
666
615
667
616
668
def remove_unused_levels_categories (index : pd .Index ) -> pd .Index :
0 commit comments