@@ -27,28 +27,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
27
27
28
28
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
29
29
30
- using ADataType = ck::half_t ;
31
- using BDataType = ck::half_t ;
32
- using CDataType = ck::half_t ;
33
- using AccDataType = float ;
30
+ using ADataType = F16;
31
+ using BDataType = F16;
32
+ using AccDataType = F32;
33
+ using CShuffleDataType = F32;
34
+ using CDataType = F16;
34
35
35
- using ALayout = ck::tensor_layout::gemm::RowMajor ;
36
- using BLayout = ck::tensor_layout::gemm::ColumnMajor ;
37
- using CLayout = ck::tensor_layout::gemm::RowMajor ;
36
+ using ALayout = Row ;
37
+ using BLayout = Col ;
38
+ using CLayout = Row ;
38
39
39
- using AElementOp = ck::tensor_operation::element_wise:: PassThrough;
40
- using BElementOp = ck::tensor_operation::element_wise:: PassThrough;
41
- using CElementOp = ck::tensor_operation::element_wise:: PassThrough;
40
+ using AElementOp = PassThrough;
41
+ using BElementOp = PassThrough;
42
+ using CElementOp = PassThrough;
42
43
43
44
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
44
45
45
46
// clang-format off
46
47
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
47
- // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
48
- // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
49
- // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
50
- // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
51
- < Row, Col, Row, F16, F16, F16, F32, F32 , AElementOp, BElementOp, CElementOp, GemmDefault, 1 , 256 , 256 , 128 , 32 , 8 , 8 , 32 , 32 , 4 , 2 , S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 8 , 8 , 1 , S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 8 , 8 , 1 , 1 , 1 , S<1 , 32 , 1 , 8 >, 8 >;
48
+ // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
49
+ // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
50
+ // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
51
+ // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
52
+ < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType , AElementOp, BElementOp, CElementOp, GemmDefault, 1 , 256 , 256 , 128 , 32 , 8 , 8 , 32 , 32 , 4 , 2 , S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 8 , 8 , 1 , S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 8 , 8 , 1 , 1 , 1 , S<1 , 32 , 1 , 8 >, 8 >;
52
53
// clang-format on
53
54
54
55
using ReferenceGemmInstance = ck::tensor_operation::host::
@@ -69,7 +70,11 @@ int main(int argc, char* argv[])
69
70
ck::index_t StrideB = 4096 ;
70
71
ck::index_t StrideC = 4096 ;
71
72
72
- if (argc == 4 )
73
+ if (argc == 1 )
74
+ {
75
+ // use default case
76
+ }
77
+ else if (argc == 4 )
73
78
{
74
79
do_verification = std::stoi (argv[1 ]);
75
80
init_method = std::stoi (argv[2 ]);
@@ -93,7 +98,7 @@ int main(int argc, char* argv[])
93
98
{
94
99
printf (" arg1: verification (0=no, 1=yes)\n " );
95
100
printf (" arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n " );
96
- printf (" arg3: time kernel (0=n0 , 1=yes)\n " );
101
+ printf (" arg3: time kernel (0=no , 1=yes)\n " );
97
102
printf (" arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n " );
98
103
exit (0 );
99
104
}
0 commit comments