@@ -223,8 +223,13 @@ def apply_dataarray_vfunc(
223
223
args , join = join , copy = False , exclude = exclude_dims , raise_on_invalid = False
224
224
)
225
225
226
- if keep_attrs and hasattr (args [0 ], "name" ):
227
- name = args [0 ].name
226
+ for arg in args :
227
+ first_obj = arg
228
+ if isinstance (arg , DataArray ):
229
+ break
230
+
231
+ if keep_attrs and hasattr (first_obj , "name" ):
232
+ name = first_obj .name
228
233
else :
229
234
name = result_name (args )
230
235
result_coords = build_output_coords (args , signature , exclude_dims )
@@ -241,6 +246,12 @@ def apply_dataarray_vfunc(
241
246
(coords ,) = result_coords
242
247
out = DataArray (result_var , coords , name = name , fastpath = True )
243
248
249
+ if keep_attrs and hasattr (first_obj , "attrs" ):
250
+ if isinstance (out , tuple ):
251
+ out = tuple (da ._copy_attrs_from (first_obj ) for da in out )
252
+ else :
253
+ out ._copy_attrs_from (first_obj )
254
+
244
255
return out
245
256
246
257
@@ -361,7 +372,10 @@ def apply_dataset_vfunc(
361
372
"""
362
373
from .dataset import Dataset
363
374
364
- first_obj = args [0 ] # we'll copy attrs from this in case keep_attrs=True
375
+ for arg in args :
376
+ first_obj = args
377
+ if isinstance (first_obj , Dataset ):
378
+ break
365
379
366
380
if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE :
367
381
raise TypeError (
@@ -554,6 +568,11 @@ def apply_variable_ufunc(
554
568
"""
555
569
from .variable import Variable , as_compatible_data
556
570
571
+ for arg in args :
572
+ first_obj = arg
573
+ if isinstance (arg , Variable ):
574
+ break
575
+
557
576
dim_sizes = unified_dim_sizes (
558
577
(a for a in args if hasattr (a , "dims" )), exclude_dims = exclude_dims
559
578
)
@@ -639,8 +658,8 @@ def func(*arrays):
639
658
)
640
659
)
641
660
642
- if keep_attrs and isinstance (args [ 0 ] , Variable ):
643
- var .attrs .update (args [ 0 ] .attrs )
661
+ if keep_attrs and isinstance (first_obj , Variable ):
662
+ var .attrs .update (first_obj .attrs )
644
663
output .append (var )
645
664
646
665
if signature .num_outputs == 1 :
0 commit comments