Skip to content

Commit b6eaf3e

Browse files
authored
Pass gemm_descs for grouped gemm via __constant__ buff (pytorch#232)
* moved gemm_descs_args into const buff * use CK_CONSTANT_ADDRESS_SPACE instead of global constant * clean * moved hipMemAlloc outside of deviceOp * add SetWorkSpacePointer * fix ignore
1 parent 7b1e2c3 commit b6eaf3e

File tree

4 files changed

+113
-113
lines changed

4 files changed

+113
-113
lines changed

example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
7878
exit(0);
7979
}
8080

81-
int group_count = 4;
81+
int group_count = rand() % 16 + 1;
8282

8383
// GEMM shape
8484
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
@@ -189,12 +189,17 @@ int main(int argc, char* argv[])
189189
auto b_element_op = BElementOp{};
190190
auto c_element_op = CElementOp{};
191191

192-
// do GEMM
193192
auto gemm = DeviceGemmInstance{};
194193
auto invoker = gemm.MakeInvoker();
194+
195+
// do GEMM
195196
auto argument =
196197
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
197198

199+
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
200+
201+
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
202+
198203
if(!gemm.IsSupportedArgument(argument))
199204
{
200205
throw std::runtime_error(

include/ck/tensor_operation/gpu/device/device_base.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ struct BaseOperator
4242

4343
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
4444

45+
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {}
46+
4547
virtual ~BaseOperator() {}
4648
};
4749

include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp

Lines changed: 98 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -24,57 +24,33 @@ template <typename GridwiseGemm,
2424
typename AElementwiseOperation,
2525
typename BElementwiseOperation,
2626
typename CElementwiseOperation,
27-
bool HasMainKBlockLoop,
28-
index_t MaxGroupCount>
27+
bool HasMainKBlockLoop>
2928
__global__ void
3029
#if CK_USE_LAUNCH_BOUNDS
3130
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
3231
#endif
33-
kernel_grouped_gemm_xdlops_v2r3(
34-
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
35-
const index_t group_count,
36-
const AElementwiseOperation a_element_op,
37-
const BElementwiseOperation b_element_op,
38-
const CElementwiseOperation c_element_op)
32+
kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
33+
const index_t group_count,
34+
const AElementwiseOperation a_element_op,
35+
const BElementwiseOperation b_element_op,
36+
const CElementwiseOperation c_element_op)
3937
{
4038
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
4139
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
4240

4341
const index_t block_id = get_block_1d_id();
4442

45-
#if 1
46-
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
47-
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
48-
i < group_count)
49-
{
50-
auto group_id = i;
51-
52-
GridwiseGemm::template Run<HasMainKBlockLoop>(
53-
gemm_descs[group_id].a_ptr,
54-
gemm_descs[group_id].b_ptr,
55-
gemm_descs[group_id].c_ptr,
56-
p_shared,
57-
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
58-
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
59-
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
60-
a_element_op,
61-
b_element_op,
62-
c_element_op,
63-
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
64-
}
65-
});
66-
#else
67-
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
43+
const auto gemm_desc_ptr =
44+
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
6845

6946
index_t group_id = 0;
70-
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
71-
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
72-
i < group_count)
73-
? i
74-
: group_id;
75-
});
76-
77-
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
47+
for(index_t i = 0; i < group_count; i++)
48+
{
49+
group_id =
50+
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_)
51+
? i
52+
: group_id;
53+
}
7854

7955
GridwiseGemm::template Run<HasMainKBlockLoop>(
8056
gemm_desc_ptr[group_id].a_ptr,
@@ -87,11 +63,9 @@ __global__ void
8763
a_element_op,
8864
b_element_op,
8965
c_element_op,
90-
gemm_desc_ptr[group_id].block_2_ctile_map_,
91-
block_id_grp);
92-
#endif
66+
gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_);
9367
#else
94-
ignore = gemm_descs;
68+
ignore = gemm_descs_const;
9569
ignore = group_count;
9670
ignore = a_element_op;
9771
ignore = b_element_op;
@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl
388362
{
389363
grid_size_ = 0;
390364

365+
gemm_descs_args_workspace_ = nullptr;
366+
391367
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
392368

393369
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl
461437

462438
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
463439

440+
void* gemm_descs_args_workspace_;
441+
464442
index_t grid_size_;
465443
};
466444

@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl
471449

472450
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
473451
{
474-
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
475-
476452
bool has_main_k_block_loop = true;
477453

478-
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
479-
if(i < arg.gemm_desc_kernel_arg_.size())
454+
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
455+
{
456+
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
457+
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
458+
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
459+
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
460+
461+
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
462+
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
463+
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
464+
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
465+
466+
std::cout << ", arg.c_grid_desc_m_n_{ "
467+
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
468+
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
469+
<< std::endl;
470+
471+
if(!GridwiseGemm::CheckValidity(
472+
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
473+
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
474+
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_,
475+
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_))
480476
{
481-
gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i];
482-
483-
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
484-
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
485-
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
486-
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
487-
488-
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
489-
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
490-
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
491-
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
492-
493-
std::cout << ", arg.c_grid_desc_m_n_{ "
494-
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", "
495-
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
496-
<< std::endl;
497-
498-
if(!GridwiseGemm::CheckValidity(
499-
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
500-
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
501-
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
502-
gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_))
503-
{
504-
throw std::runtime_error(
505-
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
506-
}
507-
508-
const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
509-
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2);
510-
511-
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
512-
{
513-
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
514-
}
477+
throw std::runtime_error(
478+
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
515479
}
516-
});
480+
481+
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
482+
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2);
483+
484+
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
485+
{
486+
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
487+
}
488+
}
489+
490+
hipGetErrorString(
491+
hipMemcpy(arg.gemm_descs_args_workspace_,
492+
arg.gemm_desc_kernel_arg_.data(),
493+
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
494+
hipMemcpyHostToDevice));
517495

518496
float ave_time = 0;
519497

@@ -523,47 +501,47 @@ struct DeviceGroupedGemmXdl
523501
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
524502
ADataType, // TODO: distiguish A/B datatype
525503
CDataType,
526-
remove_reference_t<GemmDescKernelArg>,
504+
GemmDescKernelArg,
527505
AElementwiseOperation,
528506
BElementwiseOperation,
529507
CElementwiseOperation,
530-
true,
531-
MaxGroupCount>;
532-
533-
ave_time = launch_and_time_kernel(stream_config,
534-
kernel,
535-
dim3(arg.grid_size_),
536-
dim3(BlockSize),
537-
0,
538-
gemm_desc_kernel_args,
539-
arg.gemm_desc_kernel_arg_.size(),
540-
arg.a_element_op_,
541-
arg.b_element_op_,
542-
arg.c_element_op_);
508+
true>;
509+
510+
ave_time = launch_and_time_kernel(
511+
stream_config,
512+
kernel,
513+
dim3(arg.grid_size_),
514+
dim3(BlockSize),
515+
0,
516+
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
517+
arg.gemm_desc_kernel_arg_.size(),
518+
arg.a_element_op_,
519+
arg.b_element_op_,
520+
arg.c_element_op_);
543521
}
544522
else
545523
{
546524
const auto kernel =
547525
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
548526
ADataType, // TODO: distiguish A/B datatype
549527
CDataType,
550-
remove_reference_t<GemmDescKernelArg>,
528+
GemmDescKernelArg,
551529
AElementwiseOperation,
552530
BElementwiseOperation,
553531
CElementwiseOperation,
554-
false,
555-
MaxGroupCount>;
556-
557-
ave_time = launch_and_time_kernel(stream_config,
558-
kernel,
559-
dim3(arg.grid_size_),
560-
dim3(BlockSize),
561-
0,
562-
gemm_desc_kernel_args,
563-
arg.gemm_desc_kernel_arg_.size(),
564-
arg.a_element_op_,
565-
arg.b_element_op_,
566-
arg.c_element_op_);
532+
false>;
533+
534+
ave_time = launch_and_time_kernel(
535+
stream_config,
536+
kernel,
537+
dim3(arg.grid_size_),
538+
dim3(BlockSize),
539+
0,
540+
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
541+
arg.gemm_desc_kernel_arg_.size(),
542+
arg.a_element_op_,
543+
arg.b_element_op_,
544+
arg.c_element_op_);
567545
}
568546

569547
return ave_time;
@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl
652630

653631
return str.str();
654632
}
633+
634+
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
635+
{
636+
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
637+
}
638+
639+
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
640+
{
641+
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
642+
}
655643
};
656644

657645
} // namespace device

test/grouped_gemm/grouped_gemm_fp16.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
141141
auto c_element_op = PassThrough{};
142142

143143
// do GEMM
144-
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
144+
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
145+
145146
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
146147
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
147148

149+
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
150+
151+
groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
152+
148153
invoker_ptr->Run(argument_ptr.get());
149154

150155
for(std::size_t i = 0; i < gemm_shapes.size(); i++)

0 commit comments

Comments
 (0)