Skip to content

Commit 56adf7e

Browse files
author
Chao Liu
authored
GEMM with Multiple Source, GEMM+Bias+Add+FastGeLU example and ckProfiler (pytorch#241)
* ad gelu and fast_gelu * added GeLU and fast GeLU * clean up * add gemm+fastgelu example * add gemm+gelu instances * update profiler * clean up * clean up * adding gemm+bias+activation * clean * adding bias * clean * adding gemm multiple d * debugging * add gemm bias add fastgelu * rename, clean * refactoring; add readme * refactor * refactor * refactor * refactor * refactor * refactor * fix * fix * update example * update example * rename * update example * add ckProfiler * clean * clean * clean * clean * add comment * use type_convert * clean * clean element wise op
1 parent e4584d9 commit 56adf7e

File tree

41 files changed

+3358
-517
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3358
-517
lines changed

example/01_gemm/gemm_xdl_fp16.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
2727

2828
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
2929

30-
using ADataType = ck::half_t;
31-
using BDataType = ck::half_t;
32-
using CDataType = ck::half_t;
33-
using AccDataType = float;
30+
using ADataType = F16;
31+
using BDataType = F16;
32+
using AccDataType = F32;
33+
using CShuffleDataType = F32;
34+
using CDataType = F16;
3435

35-
using ALayout = ck::tensor_layout::gemm::RowMajor;
36-
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
37-
using CLayout = ck::tensor_layout::gemm::RowMajor;
36+
using ALayout = Row;
37+
using BLayout = Col;
38+
using CLayout = Row;
3839

39-
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
40-
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
41-
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
40+
using AElementOp = PassThrough;
41+
using BElementOp = PassThrough;
42+
using CElementOp = PassThrough;
4243

4344
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
4445

4546
// clang-format off
4647
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
47-
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
48-
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
49-
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
50-
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
51-
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
48+
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
49+
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
50+
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
51+
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
52+
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
5253
// clang-format on
5354

5455
using ReferenceGemmInstance = ck::tensor_operation::host::
@@ -69,7 +70,11 @@ int main(int argc, char* argv[])
6970
ck::index_t StrideB = 4096;
7071
ck::index_t StrideC = 4096;
7172

72-
if(argc == 4)
73+
if(argc == 1)
74+
{
75+
// use default case
76+
}
77+
else if(argc == 4)
7378
{
7479
do_verification = std::stoi(argv[1]);
7580
init_method = std::stoi(argv[2]);
@@ -93,7 +98,7 @@ int main(int argc, char* argv[])
9398
{
9499
printf("arg1: verification (0=no, 1=yes)\n");
95100
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
96-
printf("arg3: time kernel (0=n0, 1=yes)\n");
101+
printf("arg3: time kernel (0=no, 1=yes)\n");
97102
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
98103
exit(0);
99104
}

0 commit comments

Comments
 (0)