@@ -24,57 +24,33 @@ template <typename GridwiseGemm,
24
24
typename AElementwiseOperation,
25
25
typename BElementwiseOperation,
26
26
typename CElementwiseOperation,
27
- bool HasMainKBlockLoop,
28
- index_t MaxGroupCount>
27
+ bool HasMainKBlockLoop>
29
28
__global__ void
30
29
#if CK_USE_LAUNCH_BOUNDS
31
30
__launch_bounds__ (CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
32
31
#endif
33
- kernel_grouped_gemm_xdlops_v2r3 (
34
- const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
35
- const index_t group_count,
36
- const AElementwiseOperation a_element_op,
37
- const BElementwiseOperation b_element_op,
38
- const CElementwiseOperation c_element_op)
32
+ kernel_grouped_gemm_xdlops_v2r3 (const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
33
+ const index_t group_count,
34
+ const AElementwiseOperation a_element_op,
35
+ const BElementwiseOperation b_element_op,
36
+ const CElementwiseOperation c_element_op)
39
37
{
40
38
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
41
39
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte ()];
42
40
43
41
const index_t block_id = get_block_1d_id ();
44
42
45
- #if 1
46
- static_for<0 , MaxGroupCount, 1 >{}([&](auto i) {
47
- if (block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
48
- i < group_count)
49
- {
50
- auto group_id = i;
51
-
52
- GridwiseGemm::template Run<HasMainKBlockLoop>(
53
- gemm_descs[group_id].a_ptr ,
54
- gemm_descs[group_id].b_ptr ,
55
- gemm_descs[group_id].c_ptr ,
56
- p_shared,
57
- gemm_descs[group_id].a_grid_desc_k0_m_k1_ ,
58
- gemm_descs[group_id].b_grid_desc_k0_n_k1_ ,
59
- gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ ,
60
- a_element_op,
61
- b_element_op,
62
- c_element_op,
63
- gemm_descs[group_id].grouped_gemm_block_2_ctile_map_ );
64
- }
65
- });
66
- #else
67
- const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
43
+ const auto gemm_desc_ptr =
44
+ reinterpret_cast <const GemmDesc*>(cast_pointer_to_generic_address_space (gemm_descs_const));
68
45
69
46
index_t group_id = 0 ;
70
- static_for<0, MaxGroupCount, 1>{}([&](auto i) {
71
- group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
72
- i < group_count)
73
- ? i
74
- : group_id;
75
- });
76
-
77
- const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
47
+ for (index_t i = 0 ; i < group_count; i++)
48
+ {
49
+ group_id =
50
+ (block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_ )
51
+ ? i
52
+ : group_id;
53
+ }
78
54
79
55
GridwiseGemm::template Run<HasMainKBlockLoop>(
80
56
gemm_desc_ptr[group_id].a_ptr ,
@@ -87,11 +63,9 @@ __global__ void
87
63
a_element_op,
88
64
b_element_op,
89
65
c_element_op,
90
- gemm_desc_ptr[group_id].block_2_ctile_map_,
91
- block_id_grp);
92
- #endif
66
+ gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_ );
93
67
#else
94
- ignore = gemm_descs ;
68
+ ignore = gemm_descs_const ;
95
69
ignore = group_count;
96
70
ignore = a_element_op;
97
71
ignore = b_element_op;
@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl
388
362
{
389
363
grid_size_ = 0 ;
390
364
365
+ gemm_descs_args_workspace_ = nullptr ;
366
+
391
367
group_count_ = ck::type_convert<ck::index_t >(gemm_shapes.size ());
392
368
393
369
if (!(group_count_ == ck::type_convert<ck::index_t >(p_a.size ()) &&
@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl
461
437
462
438
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
463
439
440
+ void * gemm_descs_args_workspace_;
441
+
464
442
index_t grid_size_;
465
443
};
466
444
@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl
471
449
472
450
float Run (const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
473
451
{
474
- StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
475
-
476
452
bool has_main_k_block_loop = true ;
477
453
478
- static_for<0 , MaxGroupCount, 1 >{}([&](auto i) {
479
- if (i < arg.gemm_desc_kernel_arg_ .size ())
454
+ for (std::size_t i = 0 ; i < arg.gemm_desc_kernel_arg_ .size (); i++)
455
+ {
456
+ std::cout << " group: " << i << " arg.a_grid_desc_k0_m_k1_{"
457
+ << arg.gemm_desc_kernel_arg_ [i].a_grid_desc_k0_m_k1_ .GetLength (I0) << " , "
458
+ << arg.gemm_desc_kernel_arg_ [i].a_grid_desc_k0_m_k1_ .GetLength (I1) << " , "
459
+ << arg.gemm_desc_kernel_arg_ [i].a_grid_desc_k0_m_k1_ .GetLength (I2) << " }" ;
460
+
461
+ std::cout << " , arg.b_grid_desc_k0_n_k1_{"
462
+ << arg.gemm_desc_kernel_arg_ [i].b_grid_desc_k0_n_k1_ .GetLength (I0) << " , "
463
+ << arg.gemm_desc_kernel_arg_ [i].b_grid_desc_k0_n_k1_ .GetLength (I1) << " , "
464
+ << arg.gemm_desc_kernel_arg_ [i].b_grid_desc_k0_n_k1_ .GetLength (I2) << " }" ;
465
+
466
+ std::cout << " , arg.c_grid_desc_m_n_{ "
467
+ << arg.gemm_desc_kernel_arg_ [i].c_grid_desc_m_n_ .GetLength (I0) << " , "
468
+ << arg.gemm_desc_kernel_arg_ [i].c_grid_desc_m_n_ .GetLength (I1) << " }"
469
+ << std::endl;
470
+
471
+ if (!GridwiseGemm::CheckValidity (
472
+ arg.gemm_desc_kernel_arg_ [i].a_grid_desc_k0_m_k1_ ,
473
+ arg.gemm_desc_kernel_arg_ [i].b_grid_desc_k0_n_k1_ ,
474
+ arg.gemm_desc_kernel_arg_ [i].c_grid_desc_m_n_ ,
475
+ arg.gemm_desc_kernel_arg_ [i].grouped_gemm_block_2_ctile_map_ ))
480
476
{
481
- gemm_desc_kernel_args (i) = arg.gemm_desc_kernel_arg_ [i];
482
-
483
- std::cout << " group: " << i << " arg.a_grid_desc_k0_m_k1_{"
484
- << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_ .GetLength (I0) << " , "
485
- << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_ .GetLength (I1) << " , "
486
- << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_ .GetLength (I2) << " }" ;
487
-
488
- std::cout << " , arg.b_grid_desc_k0_n_k1_{"
489
- << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_ .GetLength (I0) << " , "
490
- << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_ .GetLength (I1) << " , "
491
- << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_ .GetLength (I2) << " }" ;
492
-
493
- std::cout << " , arg.c_grid_desc_m_n_{ "
494
- << gemm_desc_kernel_args[i].c_grid_desc_m_n_ .GetLength (I0) << " , "
495
- << gemm_desc_kernel_args[i].c_grid_desc_m_n_ .GetLength (I1) << " }"
496
- << std::endl;
497
-
498
- if (!GridwiseGemm::CheckValidity (
499
- gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_ ,
500
- gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_ ,
501
- gemm_desc_kernel_args[i].c_grid_desc_m_n_ ,
502
- gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_ ))
503
- {
504
- throw std::runtime_error (
505
- " wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting" );
506
- }
507
-
508
- const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_ .GetLength (I0) *
509
- gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_ .GetLength (I2);
510
-
511
- if (GridwiseGemm::CalculateHasMainKBlockLoop (K) != has_main_k_block_loop)
512
- {
513
- throw std::runtime_error (" wrong! not all gemm has_main_k_block_loop" );
514
- }
477
+ throw std::runtime_error (
478
+ " wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting" );
515
479
}
516
- });
480
+
481
+ const auto K = arg.gemm_desc_kernel_arg_ [i].a_grid_desc_k0_m_k1_ .GetLength (I0) *
482
+ arg.gemm_desc_kernel_arg_ [i].a_grid_desc_k0_m_k1_ .GetLength (I2);
483
+
484
+ if (GridwiseGemm::CalculateHasMainKBlockLoop (K) != has_main_k_block_loop)
485
+ {
486
+ throw std::runtime_error (" wrong! not all gemm has_main_k_block_loop" );
487
+ }
488
+ }
489
+
490
+ hipGetErrorString (
491
+ hipMemcpy (arg.gemm_descs_args_workspace_ ,
492
+ arg.gemm_desc_kernel_arg_ .data (),
493
+ arg.gemm_desc_kernel_arg_ .size () * sizeof (GemmDescKernelArg),
494
+ hipMemcpyHostToDevice));
517
495
518
496
float ave_time = 0 ;
519
497
@@ -523,47 +501,47 @@ struct DeviceGroupedGemmXdl
523
501
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
524
502
ADataType, // TODO: distiguish A/B datatype
525
503
CDataType,
526
- remove_reference_t < GemmDescKernelArg> ,
504
+ GemmDescKernelArg,
527
505
AElementwiseOperation,
528
506
BElementwiseOperation,
529
507
CElementwiseOperation,
530
- true ,
531
- MaxGroupCount>;
532
-
533
- ave_time = launch_and_time_kernel ( stream_config,
534
- kernel,
535
- dim3 (arg.grid_size_ ),
536
- dim3 (BlockSize),
537
- 0 ,
538
- gemm_desc_kernel_args ,
539
- arg.gemm_desc_kernel_arg_ .size (),
540
- arg.a_element_op_ ,
541
- arg.b_element_op_ ,
542
- arg.c_element_op_ );
508
+ true >;
509
+
510
+ ave_time = launch_and_time_kernel (
511
+ stream_config,
512
+ kernel,
513
+ dim3 (arg.grid_size_ ),
514
+ dim3 (BlockSize),
515
+ 0 ,
516
+ cast_pointer_to_constant_address_space (arg. gemm_descs_args_workspace_ ) ,
517
+ arg.gemm_desc_kernel_arg_ .size (),
518
+ arg.a_element_op_ ,
519
+ arg.b_element_op_ ,
520
+ arg.c_element_op_ );
543
521
}
544
522
else
545
523
{
546
524
const auto kernel =
547
525
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
548
526
ADataType, // TODO: distiguish A/B datatype
549
527
CDataType,
550
- remove_reference_t < GemmDescKernelArg> ,
528
+ GemmDescKernelArg,
551
529
AElementwiseOperation,
552
530
BElementwiseOperation,
553
531
CElementwiseOperation,
554
- false ,
555
- MaxGroupCount>;
556
-
557
- ave_time = launch_and_time_kernel ( stream_config,
558
- kernel,
559
- dim3 (arg.grid_size_ ),
560
- dim3 (BlockSize),
561
- 0 ,
562
- gemm_desc_kernel_args ,
563
- arg.gemm_desc_kernel_arg_ .size (),
564
- arg.a_element_op_ ,
565
- arg.b_element_op_ ,
566
- arg.c_element_op_ );
532
+ false >;
533
+
534
+ ave_time = launch_and_time_kernel (
535
+ stream_config,
536
+ kernel,
537
+ dim3 (arg.grid_size_ ),
538
+ dim3 (BlockSize),
539
+ 0 ,
540
+ cast_pointer_to_constant_address_space (arg. gemm_descs_args_workspace_ ) ,
541
+ arg.gemm_desc_kernel_arg_ .size (),
542
+ arg.a_element_op_ ,
543
+ arg.b_element_op_ ,
544
+ arg.c_element_op_ );
567
545
}
568
546
569
547
return ave_time;
@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl
652
630
653
631
return str.str ();
654
632
}
633
+
634
+ size_t GetWorkSpaceSize (const BaseArgument* p_arg) const override
635
+ {
636
+ return dynamic_cast <const Argument*>(p_arg)->group_count_ * sizeof (GemmDescKernelArg);
637
+ }
638
+
639
+ void SetWorkSpacePointer (BaseArgument* p_arg, void * workspace_ptr) const override
640
+ {
641
+ dynamic_cast <Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
642
+ }
655
643
};
656
644
657
645
} // namespace device
0 commit comments