Skip to content

Commit 3956085

Browse files
authored
add comments to batched_gemm (pytorch#186)
* add comments to batched_gemm * formatting * fix a typo in batched_gemm_documentation * fix naming
1 parent 7c0b149 commit 3956085

File tree

2 files changed

+48
-20
lines changed

2 files changed

+48
-20
lines changed

include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,31 @@ namespace ck {
1616
namespace tensor_operation {
1717
namespace device {
1818

19+
/*
20+
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
21+
*
22+
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
23+
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
24+
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
25+
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
26+
* limitations.
27+
*
28+
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
29+
* returns the 2D index of the tile that it computes. \see
30+
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
31+
*
32+
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
33+
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
34+
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
35+
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
36+
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
37+
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
38+
*
39+
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
40+
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
41+
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
42+
*
43+
*/
1944
template <typename GridwiseGemm,
2045
typename FloatAB,
2146
typename FloatC,
@@ -25,7 +50,7 @@ template <typename GridwiseGemm,
2550
typename AElementwiseOperation,
2651
typename BElementwiseOperation,
2752
typename CElementwiseOperation,
28-
typename ComputeBasePrtOfBatch,
53+
typename ComputePtrOffsetOfBatch,
2954
typename Block2CTileMap,
3055
bool HasMainKBlockLoop>
3156
__global__ void
@@ -43,7 +68,7 @@ __global__ void
4368
const AElementwiseOperation a_element_op,
4469
const BElementwiseOperation b_element_op,
4570
const CElementwiseOperation c_element_op,
46-
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
71+
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
4772
const Block2CTileMap block_2_ctile_map)
4873
{
4974
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
@@ -52,11 +77,11 @@ __global__ void
5277
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
5378

5479
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
55-
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
80+
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
5681
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
57-
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
82+
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
5883
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
59-
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
84+
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
6085

6186
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
6287

@@ -256,26 +281,26 @@ struct DeviceBatchedGemmXdl
256281
return globalblockid_to_m0_n0_block_cluster_adaptor;
257282
}
258283

259-
struct ComputeBasePtrOfStridedBatch
284+
struct ComputePtrOffsetOfStridedBatch
260285
{
261-
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
262-
index_t BatchStrideB,
263-
index_t BatchStrideC)
286+
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
287+
index_t BatchStrideB,
288+
index_t BatchStrideC)
264289
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
265290
{
266291
}
267292

268-
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
293+
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
269294
{
270295
return g_idx * static_cast<long_index_t>(BatchStrideA_);
271296
}
272297

273-
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
298+
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
274299
{
275300
return g_idx * static_cast<long_index_t>(BatchStrideB_);
276301
}
277302

278-
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
303+
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
279304
{
280305
return g_idx * static_cast<long_index_t>(BatchStrideC_);
281306
}
@@ -359,9 +384,9 @@ struct DeviceBatchedGemmXdl
359384
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
360385
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
361386
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
362-
compute_base_ptr_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
363-
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
364-
c_grid_desc_m_n_.GetElementSpaceSize()},
387+
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
388+
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
389+
c_grid_desc_m_n_.GetElementSpaceSize()},
365390
block_2_ctile_map_{},
366391
M01_{M01},
367392
N01_{N01},
@@ -388,7 +413,7 @@ struct DeviceBatchedGemmXdl
388413
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
389414
CGridDesc_M_N c_grid_desc_m_n_;
390415
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
391-
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
416+
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
392417
Block2CTileMap block_2_ctile_map_;
393418
index_t M01_;
394419
index_t N01_;
@@ -448,7 +473,7 @@ struct DeviceBatchedGemmXdl
448473
AElementwiseOperation,
449474
BElementwiseOperation,
450475
CElementwiseOperation,
451-
ComputeBasePtrOfStridedBatch,
476+
ComputePtrOffsetOfStridedBatch,
452477
remove_reference_t<Block2CTileMap>,
453478
true>;
454479

@@ -467,7 +492,7 @@ struct DeviceBatchedGemmXdl
467492
arg.a_element_op_,
468493
arg.b_element_op_,
469494
arg.c_element_op_,
470-
arg.compute_base_ptr_of_batch_,
495+
arg.compute_ptr_offset_of_batch_,
471496
arg.block_2_ctile_map_);
472497
}
473498
else
@@ -482,7 +507,7 @@ struct DeviceBatchedGemmXdl
482507
AElementwiseOperation,
483508
BElementwiseOperation,
484509
CElementwiseOperation,
485-
ComputeBasePtrOfStridedBatch,
510+
ComputePtrOffsetOfStridedBatch,
486511
remove_reference_t<Block2CTileMap>,
487512
false>;
488513

@@ -501,7 +526,7 @@ struct DeviceBatchedGemmXdl
501526
arg.a_element_op_,
502527
arg.b_element_op_,
503528
arg.c_element_op_,
504-
arg.compute_base_ptr_of_batch_,
529+
arg.compute_ptr_offset_of_batch_,
505530
arg.block_2_ctile_map_);
506531
}
507532

include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ namespace ck {
1818
namespace tensor_operation {
1919
namespace device {
2020

21+
/*
22+
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
23+
*/
2124
template <typename GridwiseGemm,
2225
typename FloatAB,
2326
typename FloatC,

0 commit comments

Comments
 (0)