Skip to content

Commit 16effa7

Browse files
author
Chao Liu
committed
refactor
1 parent a91b68d commit 16effa7

19 files changed

+99
-91
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ message(STATUS "Build with HIP ${hip_VERSION}")
4343
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
4444

4545
# CMAKE_CXX_FLAGS
46+
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
4647
if(BUILD_DEV)
4748
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
4849
endif()

composable_kernel/include/tensor_description/multi_index_transform.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ struct RightPad
377377
// at compile-time
378378
template <typename UpLengths,
379379
typename Coefficients,
380-
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
380+
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
381381
struct Embed
382382
{
383383
static constexpr index_t NDimUp = UpLengths::Size();

composable_kernel/include/tensor_description/multi_index_transform_helper.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ __host__ __device__ constexpr auto make_right_pad_transform(
4242

4343
template <typename UpLengths,
4444
typename Coefficients,
45-
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
45+
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
4646
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
4747
const Coefficients& coefficients)
4848
{

composable_kernel/include/tensor_description/tensor_adaptor.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
454454
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
455455
}
456456

457-
template <typename X,
458-
typename... Xs,
459-
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
457+
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
460458
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
461459
{
462460
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));

composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
3737

3838
template <typename... Lengths,
3939
typename... Strides,
40-
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
40+
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
4141
__host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
4242
const Tuple<Strides...>& strides)
4343
{

composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,24 @@ namespace ck {
2222
// 2. CThreadBuffer is StaticBuffer
2323
// Also assume:
2424
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
25-
template <index_t BlockSize,
26-
typename FloatA,
27-
typename FloatB,
28-
typename FloatC,
29-
typename AKMBlockDesc,
30-
typename BKNBlockDesc,
31-
index_t M1PerThreadM11,
32-
index_t N1PerThreadN11,
33-
index_t KPerThread,
34-
index_t M1N1ThreadClusterM100,
35-
index_t M1N1ThreadClusterN100,
36-
index_t M1N1ThreadClusterM101,
37-
index_t M1N1ThreadClusterN101,
38-
index_t AThreadCopyScalarPerVector_M11,
39-
index_t BThreadCopyScalarPerVector_N11,
40-
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
41-
BKNBlockDesc::IsKnownAtCompileTime(),
42-
bool>::type = false>
25+
template <
26+
index_t BlockSize,
27+
typename FloatA,
28+
typename FloatB,
29+
typename FloatC,
30+
typename AKMBlockDesc,
31+
typename BKNBlockDesc,
32+
index_t M1PerThreadM11,
33+
index_t N1PerThreadN11,
34+
index_t KPerThread,
35+
index_t M1N1ThreadClusterM100,
36+
index_t M1N1ThreadClusterN100,
37+
index_t M1N1ThreadClusterM101,
38+
index_t M1N1ThreadClusterN101,
39+
index_t AThreadCopyScalarPerVector_M11,
40+
index_t BThreadCopyScalarPerVector_N11,
41+
typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
42+
bool>::type = false>
4343
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
4444
{
4545
using AIndex = MultiIndex<3>;

composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ template <index_t BlockSize,
3838
// BM10BN10ThreadClusterBN101, ...>
3939
index_t AThreadCopyScalarPerVector_BM11,
4040
index_t BThreadCopyScalarPerVector_BN11,
41-
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
42-
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
43-
bool>::type = false>
41+
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
42+
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
43+
bool>::type = false>
4444
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
4545
{
4646
using AIndex = MultiIndex<3>;

composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ template <typename FloatA,
2121
typename TKLengths,
2222
typename TMLengths,
2323
typename TNLengths,
24-
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
25-
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
26-
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
27-
bool>::type = false>
24+
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
25+
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
26+
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
27+
bool>::type = false>
2828
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
2929
{
3030
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
@@ -123,10 +123,10 @@ template <typename FloatA,
123123
typename TKLengths,
124124
typename TMLengths,
125125
typename TNLengths,
126-
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
127-
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
128-
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
129-
bool>::type = false>
126+
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
127+
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
128+
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
129+
bool>::type = false>
130130
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
131131
{
132132
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()

composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ template <typename FloatA,
1919
typename CDesc,
2020
index_t H,
2121
index_t W,
22-
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
23-
CDesc::IsKnownAtCompileTime(),
24-
bool>::type = false>
22+
typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
23+
CDesc::IsKnownAtCompileTime(),
24+
bool>::type = false>
2525
struct ThreadwiseGemmDlops_km_kn_mn_v3
2626
{
2727
template <typename ABuffer,

composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace ck {
1515
template <typename Data,
1616
typename Desc,
1717
typename SliceLengths,
18-
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
18+
typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
1919
struct ThreadwiseTensorSliceSet_v1
2020
{
2121
static constexpr index_t nDim = SliceLengths::Size();

composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ template <typename SrcData,
5757
InMemoryDataOperationEnum_t DstInMemOp,
5858
index_t DstScalarStrideInVector,
5959
bool DstResetCoordinateAfterRun,
60-
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
60+
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
6161
struct ThreadwiseTensorSliceTransfer_v1r3
6262
{
6363
static constexpr index_t nDim = SliceLengths::Size();
@@ -373,7 +373,7 @@ template <typename SrcData,
373373
index_t SrcScalarPerVector,
374374
index_t SrcScalarStrideInVector,
375375
bool SrcResetCoordinateAfterRun,
376-
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
376+
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
377377
struct ThreadwiseTensorSliceTransfer_v2
378378
{
379379
static constexpr index_t nDim = SliceLengths::Size();
@@ -1261,18 +1261,17 @@ struct ThreadwiseTensorSliceTransfer_v3
12611261
// 3. DstOriginIdx is known at compile-time
12621262
// 4. use direct address calculation
12631263
// 3. vector access on src
1264-
template <
1265-
typename SrcData,
1266-
typename DstData,
1267-
typename SrcDesc,
1268-
typename DstDesc,
1269-
typename SliceLengths,
1270-
typename DimAccessOrder,
1271-
index_t SrcVectorDim,
1272-
index_t SrcScalarPerVector,
1273-
index_t SrcScalarStrideInVector,
1274-
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1275-
bool>::type = false>
1264+
template <typename SrcData,
1265+
typename DstData,
1266+
typename SrcDesc,
1267+
typename DstDesc,
1268+
typename SliceLengths,
1269+
typename DimAccessOrder,
1270+
index_t SrcVectorDim,
1271+
index_t SrcScalarPerVector,
1272+
index_t SrcScalarStrideInVector,
1273+
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1274+
bool>::type = false>
12761275
struct ThreadwiseTensorSliceTransfer_v4
12771276
{
12781277
static constexpr index_t nDim = SliceLengths::Size();

composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -621,17 +621,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
621621
// 3. DstOriginIdx is known at compile-time
622622
// 4. use direct address calculation
623623
// 3. vector access on src
624-
template <
625-
typename SrcData,
626-
typename DstData,
627-
typename SrcDesc,
628-
typename DstDesc,
629-
typename SliceLengths,
630-
typename DimAccessOrder,
631-
typename SrcVectorTensorLengths,
632-
typename SrcVectorTensorContiguousDimOrder,
633-
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
634-
bool>::type = false>
624+
template <typename SrcData,
625+
typename DstData,
626+
typename SrcDesc,
627+
typename DstDesc,
628+
typename SliceLengths,
629+
typename DimAccessOrder,
630+
typename SrcVectorTensorLengths,
631+
typename SrcVectorTensorContiguousDimOrder,
632+
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
633+
bool>::type = false>
635634
struct ThreadwiseTensorSliceTransfer_v4r1
636635
{
637636
static constexpr auto I0 = Number<0>{};

composable_kernel/include/utility/c_style_pointer_cast.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
#define CK_C_STYLE_POINTER_CAST_HPP
33

44
#include "type.hpp"
5+
#include "enable_if.hpp"
56

67
namespace ck {
78

89
template <typename PY,
910
typename PX,
10-
typename std::enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
11+
typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
1112
__host__ __device__ PY c_style_pointer_cast(PX p_x)
1213
{
1314
#pragma clang diagnostic push

composable_kernel/include/utility/common_header.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "functional2.hpp"
1515
#include "functional3.hpp"
1616
#include "functional4.hpp"
17+
#include "enable_if.hpp"
1718
#include "integral_constant.hpp"
1819
#include "math.hpp"
1920
#include "number.hpp"

composable_kernel/include/utility/dynamic_buffer.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "amd_buffer_addressing.hpp"
55
#include "c_style_pointer_cast.hpp"
6+
#include "enable_if.hpp"
67

78
namespace ck {
89

@@ -38,7 +39,7 @@ struct DynamicBuffer
3839
}
3940

4041
template <typename X,
41-
typename std::enable_if<
42+
typename enable_if<
4243
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
4344
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
4445
bool>::type = false>
@@ -93,7 +94,7 @@ struct DynamicBuffer
9394
}
9495

9596
template <typename X,
96-
typename std::enable_if<
97+
typename enable_if<
9798
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
9899
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
99100
bool>::type = false>
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef CK_ENABLE_IF_HPP
2+
#define CK_ENABLE_IF_HPP
3+
4+
namespace ck {
5+
6+
template <bool B, typename T = void>
7+
using enable_if = std::enable_if<B, T>;
8+
9+
template <bool B, typename T = void>
10+
using enable_if_t = typename std::enable_if<B, T>::type;
11+
12+
} // namespace ck
13+
#endif

composable_kernel/include/utility/math.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "integral_constant.hpp"
66
#include "number.hpp"
77
#include "type.hpp"
8+
#include "enable_if.hpp"
89

910
namespace ck {
1011
namespace math {
@@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
184185
return Number<r>{};
185186
}
186187

187-
template <typename X,
188-
typename... Ys,
189-
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
188+
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
190189
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
191190
{
192191
return gcd(x, gcd(ys...));
@@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
199198
return (x * y) / gcd(x, y);
200199
}
201200

202-
template <typename X,
203-
typename... Ys,
204-
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
201+
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
205202
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
206203
{
207204
return lcm(x, lcm(ys...));

composable_kernel/include/utility/tuple.hpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "integral_constant.hpp"
55
#include "sequence.hpp"
66
#include "type.hpp"
7+
#include "enable_if.hpp"
78

89
namespace ck {
910

@@ -20,10 +21,9 @@ struct TupleElement
2021
{
2122
__host__ __device__ constexpr TupleElement() = default;
2223

23-
template <
24-
typename T,
25-
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
26-
bool>::type = false>
24+
template <typename T,
25+
typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
26+
bool>::type = false>
2727
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
2828
{
2929
}
@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
5858
{
5959
__host__ __device__ constexpr TupleImpl() = default;
6060

61-
template <
62-
typename Y,
63-
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
64-
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
65-
bool>::type = false>
61+
template <typename Y,
62+
typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
63+
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
64+
bool>::type = false>
6665
__host__ __device__ constexpr TupleImpl(Y&& y)
6766
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
6867
{
6968
}
7069

71-
template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
70+
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
7271
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
7372
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
7473
{
@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
102101
__host__ __device__ constexpr Tuple() = default;
103102

104103
template <typename Y,
105-
typename std::enable_if<
106-
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
107-
bool>::type = false>
104+
typename enable_if<sizeof...(Xs) == 1 &&
105+
!is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
106+
bool>::type = false>
108107
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
109108
{
110109
}
111110

112111
template <typename... Ys,
113-
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2,
114-
bool>::type = false>
112+
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
113+
false>
115114
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
116115
{
117116
}

0 commit comments

Comments
 (0)