@@ -815,6 +815,29 @@ def test_astype_extension_dtypes_duplicate_col(self, dtype):
815
815
expected = concat ([a1 .astype (dtype ), a2 .astype (dtype )], axis = 1 )
816
816
tm .assert_frame_equal (result , expected )
817
817
818
+ def test_df_where_change_dtype (self ):
819
+ # GH 16979
820
+ df = DataFrame (np .arange (2 * 3 ).reshape (2 , 3 ), columns = list ("ABC" ))
821
+ mask = np .array ([[True , False , True ], [False , True , True ]])
822
+
823
+ result = df .where (mask )
824
+ expected = DataFrame ([[0 , np .nan , 2 ], [np .nan , 4 , 5 ]], columns = list ("ABC" ))
825
+
826
+ tm .assert_frame_equal (result , expected )
827
+
828
+ # change type to category
829
+ df .A = df .A .astype ("category" )
830
+ df .B = df .B .astype ("category" )
831
+ df .C = df .C .astype ("category" )
832
+
833
+ result = df .where (mask )
834
+ A = pd .Categorical ([0 , np .nan ], categories = [0 , 3 ])
835
+ B = pd .Categorical ([np .nan , 4 ], categories = [1 , 4 ])
836
+ C = pd .Categorical ([2 , 5 ], categories = [2 , 5 ])
837
+ expected = DataFrame ({"A" : A , "B" : B , "C" : C })
838
+
839
+ tm .assert_frame_equal (result , expected )
840
+
818
841
@pytest .mark .parametrize ("kwargs" , [dict (), dict (other = None )])
819
842
def test_df_where_with_category (self , kwargs ):
820
843
# GH 16979
@@ -827,6 +850,7 @@ def test_df_where_with_category(self, kwargs):
827
850
828
851
result = df .A .where (mask [:, 0 ], ** kwargs )
829
852
expected = Series (pd .Categorical ([0 , np .nan ], categories = [0 , 3 ]), name = "A" )
853
+
830
854
tm .assert_series_equal (result , expected )
831
855
832
856
@pytest .mark .parametrize (
0 commit comments