@@ -571,6 +571,57 @@ def test_concat_str_dtype(self, dtype, dim) -> None:
571
571
572
572
assert np .issubdtype (actual .x2 .dtype , dtype )
573
573
574
+ @pytest .mark .parametrize ("dim" , [True , False ])
575
+ @pytest .mark .parametrize ("coord" , [True , False ])
576
+ def test_concat_fill_missing_variables (self , dim , coord ):
577
+ # create var names list with one missing value
578
+ def get_var_names (var_cnt = 10 , list_cnt = 10 ):
579
+ orig = [f'd{ i :02d} ' for i in range (var_cnt )]
580
+ var_names = []
581
+ for i in range (0 , list_cnt ):
582
+ l1 = orig .copy ()
583
+ var_names .append (l1 )
584
+ return var_names
585
+
586
+ def create_ds (var_names , dim = False , coord = False , drop_idx = False ):
587
+ out_ds = []
588
+ ds = Dataset ()
589
+ ds = ds .assign_coords ({"x" : np .arange (2 )})
590
+ ds = ds .assign_coords ({"y" : np .arange (3 )})
591
+ ds = ds .assign_coords ({"z" : np .arange (4 )})
592
+ for i , dsl in enumerate (var_names ):
593
+ vlist = dsl .copy ()
594
+ if drop_idx :
595
+ vlist .pop (drop_idx [i ])
596
+ foo_data = np .arange (48 , dtype = float ).reshape (2 , 2 , 3 , 4 )
597
+ dsi = ds .copy ()
598
+ if coord :
599
+ dsi = ds .assign ({"time" : (["time" ], [i * 2 , i * 2 + 1 ])})
600
+ for k in vlist :
601
+ dsi = dsi .assign ({k : (["time" , "x" , "y" , "z" ], foo_data .copy ())})
602
+ if not dim :
603
+ dsi = dsi .isel (time = 0 )
604
+ out_ds .append (dsi )
605
+ return out_ds
606
+ var_names = get_var_names ()
607
+
608
+ import random
609
+ random .seed (42 )
610
+ drop_idx = [random .randrange (len (vlist )) for vlist in var_names ]
611
+ expected = concat (create_ds (var_names , dim = dim , coord = coord ), dim = "time" , data_vars = "all" )
612
+ for i , idx in enumerate (drop_idx ):
613
+ if dim :
614
+ expected [var_names [0 ][idx ]][i * 2 : i * 2 + 2 ] = np .nan
615
+ else :
616
+ expected [var_names [0 ][idx ]][i ] = np .nan
617
+
618
+ concat_ds = create_ds (var_names , dim = dim , coord = coord , drop_idx = drop_idx )
619
+ actual = concat (concat_ds , dim = "time" , data_vars = "all" )
620
+
621
+ for name in var_names [0 ]:
622
+ assert_equal (expected [name ], actual [name ])
623
+ assert_equal (expected , actual )
624
+
574
625
575
626
class TestConcatDataArray :
576
627
def test_concat (self ) -> None :
0 commit comments