34
34
T_DSorDA = TypeVar ("T_DSorDA" , DataArray , Dataset )
35
35
36
36
37
- def get_index_vars (obj : Union [DataArray , Dataset ]) -> dict :
38
- return {dim : obj [dim ] for dim in obj .indexes }
39
-
40
-
41
37
def to_object_array (iterable ):
42
38
npargs = np .empty ((len (iterable ),), dtype = np .object )
43
39
for idx , item in enumerate (iterable ):
@@ -247,7 +243,7 @@ def _wrapper(
247
243
raise ValueError (f"Dimensions { missing_dimensions } missing on returned object." )
248
244
249
245
# check that index lengths and values are as expected
250
- for name , index in get_index_vars ( result ) .items ():
246
+ for name , index in result . indexes .items ():
251
247
if name in check_shapes :
252
248
if len (index ) != check_shapes [name ]:
253
249
raise ValueError (
@@ -412,11 +408,11 @@ def map_blocks(
412
408
413
409
# check that chunk sizes are compatible
414
410
input_chunks = dict (npargs [0 ].chunks )
415
- input_indexes = get_index_vars (npargs [0 ])
411
+ input_indexes = dict (npargs [0 ]. indexes )
416
412
for arg in npargs [1 :][is_xarray [1 :]]:
417
413
assert_chunks_compatible (npargs [0 ], arg )
418
414
input_chunks .update (arg .chunks )
419
- input_indexes .update (get_index_vars ( arg ) )
415
+ input_indexes .update (arg . indexes )
420
416
421
417
if template is None :
422
418
# infer template by providing zero-shaped arrays
@@ -425,15 +421,15 @@ def map_blocks(
425
421
preserved_indexes = template_indexes & set (input_indexes )
426
422
new_indexes = template_indexes - set (input_indexes )
427
423
indexes = {dim : input_indexes [dim ] for dim in preserved_indexes }
428
- indexes .update ({k : template [k ] for k in new_indexes })
424
+ indexes .update ({k : template . indexes [k ] for k in new_indexes })
429
425
output_chunks = {
430
426
dim : input_chunks [dim ] for dim in template .dims if dim in input_chunks
431
427
}
432
428
433
429
else :
434
430
# template xarray object has been provided with proper sizes and chunk shapes
435
431
indexes = input_indexes
436
- indexes .update (get_index_vars ( template ) )
432
+ indexes .update (template . indexes )
437
433
if isinstance (template , DataArray ):
438
434
output_chunks = dict (zip (template .dims , template .chunks )) # type: ignore
439
435
else :
0 commit comments