@@ -234,11 +234,10 @@ def test_aggregate_item_by_item(df):
234
234
K = len (result .columns )
235
235
236
236
# GH5782
237
- # odd comparisons can result here, so cast to make easy
238
- exp = Series (np .array ([foo ] * K ), index = list ("BCD" ), dtype = np .float64 , name = "foo" )
237
+ exp = Series (np .array ([foo ] * K ), index = list ("BCD" ), name = "foo" )
239
238
tm .assert_series_equal (result .xs ("foo" ), exp )
240
239
241
- exp = Series (np .array ([bar ] * K ), index = list ("BCD" ), dtype = np . float64 , name = "bar" )
240
+ exp = Series (np .array ([bar ] * K ), index = list ("BCD" ), name = "bar" )
242
241
tm .assert_almost_equal (result .xs ("bar" ), exp )
243
242
244
243
def aggfun (ser ):
@@ -442,6 +441,57 @@ def test_bool_agg_dtype(op):
442
441
assert is_integer_dtype (result )
443
442
444
443
444
+ @pytest .mark .parametrize (
445
+ "keys, agg_index" ,
446
+ [
447
+ (["a" ], Index ([1 ], name = "a" )),
448
+ (["a" , "b" ], MultiIndex ([[1 ], [2 ]], [[0 ], [0 ]], names = ["a" , "b" ])),
449
+ ],
450
+ )
451
+ @pytest .mark .parametrize (
452
+ "input_dtype" , ["bool" , "int32" , "int64" , "float32" , "float64" ]
453
+ )
454
+ @pytest .mark .parametrize (
455
+ "result_dtype" , ["bool" , "int32" , "int64" , "float32" , "float64" ]
456
+ )
457
+ @pytest .mark .parametrize ("method" , ["apply" , "aggregate" , "transform" ])
458
+ def test_callable_result_dtype_frame (
459
+ keys , agg_index , input_dtype , result_dtype , method
460
+ ):
461
+ # GH 21240
462
+ df = DataFrame ({"a" : [1 ], "b" : [2 ], "c" : [True ]})
463
+ df ["c" ] = df ["c" ].astype (input_dtype )
464
+ op = getattr (df .groupby (keys )[["c" ]], method )
465
+ result = op (lambda x : x .astype (result_dtype ).iloc [0 ])
466
+ expected_index = pd .RangeIndex (0 , 1 ) if method == "transform" else agg_index
467
+ expected = DataFrame ({"c" : [df ["c" ].iloc [0 ]]}, index = expected_index ).astype (
468
+ result_dtype
469
+ )
470
+ if method == "apply" :
471
+ expected .columns .names = [0 ]
472
+ tm .assert_frame_equal (result , expected )
473
+
474
+
475
+ @pytest .mark .parametrize (
476
+ "keys, agg_index" ,
477
+ [
478
+ (["a" ], Index ([1 ], name = "a" )),
479
+ (["a" , "b" ], MultiIndex ([[1 ], [2 ]], [[0 ], [0 ]], names = ["a" , "b" ])),
480
+ ],
481
+ )
482
+ @pytest .mark .parametrize ("input" , [True , 1 , 1.0 ])
483
+ @pytest .mark .parametrize ("dtype" , [bool , int , float ])
484
+ @pytest .mark .parametrize ("method" , ["apply" , "aggregate" , "transform" ])
485
+ def test_callable_result_dtype_series (keys , agg_index , input , dtype , method ):
486
+ # GH 21240
487
+ df = DataFrame ({"a" : [1 ], "b" : [2 ], "c" : [input ]})
488
+ op = getattr (df .groupby (keys )["c" ], method )
489
+ result = op (lambda x : x .astype (dtype ).iloc [0 ])
490
+ expected_index = pd .RangeIndex (0 , 1 ) if method == "transform" else agg_index
491
+ expected = Series ([df ["c" ].iloc [0 ]], index = expected_index , name = "c" ).astype (dtype )
492
+ tm .assert_series_equal (result , expected )
493
+
494
+
445
495
def test_order_aggregate_multiple_funcs ():
446
496
# GH 25692
447
497
df = DataFrame ({"A" : [1 , 1 , 2 , 2 ], "B" : [1 , 2 , 3 , 4 ]})
@@ -462,7 +512,9 @@ def test_uint64_type_handling(dtype, how):
462
512
expected = df .groupby ("y" ).agg ({"x" : how })
463
513
df .x = df .x .astype (dtype )
464
514
result = df .groupby ("y" ).agg ({"x" : how })
465
- result .x = result .x .astype (np .int64 )
515
+ if how not in ("mean" , "median" ):
516
+ # mean and median always result in floats
517
+ result .x = result .x .astype (np .int64 )
466
518
tm .assert_frame_equal (result , expected , check_exact = True )
467
519
468
520
@@ -849,7 +901,11 @@ def test_multiindex_custom_func(func):
849
901
data = [[1 , 4 , 2 ], [5 , 7 , 1 ]]
850
902
df = DataFrame (data , columns = MultiIndex .from_arrays ([[1 , 1 , 2 ], [3 , 4 , 3 ]]))
851
903
result = df .groupby (np .array ([0 , 1 ])).agg (func )
852
- expected_dict = {(1 , 3 ): {0 : 1 , 1 : 5 }, (1 , 4 ): {0 : 4 , 1 : 7 }, (2 , 3 ): {0 : 2 , 1 : 1 }}
904
+ expected_dict = {
905
+ (1 , 3 ): {0 : 1.0 , 1 : 5.0 },
906
+ (1 , 4 ): {0 : 4.0 , 1 : 7.0 },
907
+ (2 , 3 ): {0 : 2.0 , 1 : 1.0 },
908
+ }
853
909
expected = DataFrame (expected_dict )
854
910
tm .assert_frame_equal (result , expected )
855
911
0 commit comments