@@ -16,6 +16,31 @@ namespace ck {
16
16
namespace tensor_operation {
17
17
namespace device {
18
18
19
+ /*
20
+ * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
21
+ *
22
+ * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
23
+ * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
24
+ * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
25
+ * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
26
+ * limitations.
27
+ *
28
+ * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
29
+ * returns the 2D index of the tile that it computes. \see
30
+ * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
31
+ *
32
+ * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
33
+ * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
34
+ * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
35
+ * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
36
+ * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
37
+ * pointer offset into \p ComputePtrOffsetOfStridedBatch.
38
+ *
39
+ * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
40
+ * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
41
+ * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
42
+ *
43
+ */
19
44
template <typename GridwiseGemm,
20
45
typename FloatAB,
21
46
typename FloatC,
@@ -25,7 +50,7 @@ template <typename GridwiseGemm,
25
50
typename AElementwiseOperation,
26
51
typename BElementwiseOperation,
27
52
typename CElementwiseOperation,
28
- typename ComputeBasePrtOfBatch ,
53
+ typename ComputePtrOffsetOfBatch ,
29
54
typename Block2CTileMap,
30
55
bool HasMainKBlockLoop>
31
56
__global__ void
@@ -43,7 +68,7 @@ __global__ void
43
68
const AElementwiseOperation a_element_op,
44
69
const BElementwiseOperation b_element_op,
45
70
const CElementwiseOperation c_element_op,
46
- const ComputeBasePrtOfBatch compute_base_ptr_of_batch_ ,
71
+ const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch ,
47
72
const Block2CTileMap block_2_ctile_map)
48
73
{
49
74
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
@@ -52,11 +77,11 @@ __global__ void
52
77
const index_t g_idx = __builtin_amdgcn_readfirstlane (get_block_1d_id () / num_blocks_per_batch);
53
78
54
79
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane (
55
- static_cast <long_index_t >(compute_base_ptr_of_batch_. GetABasePtr (g_idx)));
80
+ static_cast <long_index_t >(compute_ptr_offset_of_batch. GetAPtrOffset (g_idx)));
56
81
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane (
57
- static_cast <long_index_t >(compute_base_ptr_of_batch_. GetBBasePtr (g_idx)));
82
+ static_cast <long_index_t >(compute_ptr_offset_of_batch. GetBPtrOffset (g_idx)));
58
83
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane (
59
- static_cast <long_index_t >(compute_base_ptr_of_batch_. GetCBasePtr (g_idx)));
84
+ static_cast <long_index_t >(compute_ptr_offset_of_batch. GetCPtrOffset (g_idx)));
60
85
61
86
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte ()];
62
87
@@ -256,26 +281,26 @@ struct DeviceBatchedGemmXdl
256
281
return globalblockid_to_m0_n0_block_cluster_adaptor;
257
282
}
258
283
259
- struct ComputeBasePtrOfStridedBatch
284
+ struct ComputePtrOffsetOfStridedBatch
260
285
{
261
- ComputeBasePtrOfStridedBatch (index_t BatchStrideA,
262
- index_t BatchStrideB,
263
- index_t BatchStrideC)
286
+ ComputePtrOffsetOfStridedBatch (index_t BatchStrideA,
287
+ index_t BatchStrideB,
288
+ index_t BatchStrideC)
264
289
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
265
290
{
266
291
}
267
292
268
- __host__ __device__ constexpr long_index_t GetABasePtr (index_t g_idx) const
293
+ __host__ __device__ constexpr long_index_t GetAPtrOffset (index_t g_idx) const
269
294
{
270
295
return g_idx * static_cast <long_index_t >(BatchStrideA_);
271
296
}
272
297
273
- __host__ __device__ constexpr long_index_t GetBBasePtr (index_t g_idx) const
298
+ __host__ __device__ constexpr long_index_t GetBPtrOffset (index_t g_idx) const
274
299
{
275
300
return g_idx * static_cast <long_index_t >(BatchStrideB_);
276
301
}
277
302
278
- __host__ __device__ constexpr long_index_t GetCBasePtr (index_t g_idx) const
303
+ __host__ __device__ constexpr long_index_t GetCPtrOffset (index_t g_idx) const
279
304
{
280
305
return g_idx * static_cast <long_index_t >(BatchStrideC_);
281
306
}
@@ -359,9 +384,9 @@ struct DeviceBatchedGemmXdl
359
384
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1 (K, N, StrideB)},
360
385
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N (M, N, StrideC)},
361
386
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
362
- compute_base_ptr_of_batch_ {a_grid_desc_k0_m_k1_.GetElementSpaceSize (),
363
- b_grid_desc_k0_n_k1_.GetElementSpaceSize (),
364
- c_grid_desc_m_n_.GetElementSpaceSize ()},
387
+ compute_ptr_offset_of_batch_ {a_grid_desc_k0_m_k1_.GetElementSpaceSize (),
388
+ b_grid_desc_k0_n_k1_.GetElementSpaceSize (),
389
+ c_grid_desc_m_n_.GetElementSpaceSize ()},
365
390
block_2_ctile_map_{},
366
391
M01_{M01},
367
392
N01_{N01},
@@ -388,7 +413,7 @@ struct DeviceBatchedGemmXdl
388
413
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
389
414
CGridDesc_M_N c_grid_desc_m_n_;
390
415
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
391
- ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_ ;
416
+ ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_ ;
392
417
Block2CTileMap block_2_ctile_map_;
393
418
index_t M01_;
394
419
index_t N01_;
@@ -448,7 +473,7 @@ struct DeviceBatchedGemmXdl
448
473
AElementwiseOperation,
449
474
BElementwiseOperation,
450
475
CElementwiseOperation,
451
- ComputeBasePtrOfStridedBatch ,
476
+ ComputePtrOffsetOfStridedBatch ,
452
477
remove_reference_t <Block2CTileMap>,
453
478
true >;
454
479
@@ -467,7 +492,7 @@ struct DeviceBatchedGemmXdl
467
492
arg.a_element_op_ ,
468
493
arg.b_element_op_ ,
469
494
arg.c_element_op_ ,
470
- arg.compute_base_ptr_of_batch_ ,
495
+ arg.compute_ptr_offset_of_batch_ ,
471
496
arg.block_2_ctile_map_ );
472
497
}
473
498
else
@@ -482,7 +507,7 @@ struct DeviceBatchedGemmXdl
482
507
AElementwiseOperation,
483
508
BElementwiseOperation,
484
509
CElementwiseOperation,
485
- ComputeBasePtrOfStridedBatch ,
510
+ ComputePtrOffsetOfStridedBatch ,
486
511
remove_reference_t <Block2CTileMap>,
487
512
false >;
488
513
@@ -501,7 +526,7 @@ struct DeviceBatchedGemmXdl
501
526
arg.a_element_op_ ,
502
527
arg.b_element_op_ ,
503
528
arg.c_element_op_ ,
504
- arg.compute_base_ptr_of_batch_ ,
529
+ arg.compute_ptr_offset_of_batch_ ,
505
530
arg.block_2_ctile_map_ );
506
531
}
507
532
0 commit comments