Skip to content

Commit 2cbabbb

Browse files
author
Chao Liu
committed
use int instead of index_t in kernel wrapper
1 parent 0834bc7 commit 2cbabbb

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,39 @@ constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBloc
6262
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
6363

6464
extern "C" __global__ void
65-
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
66-
index_t C,
67-
index_t Hi,
68-
index_t Wi,
69-
index_t K,
70-
index_t Y,
71-
index_t X,
72-
index_t ConvStrideH,
73-
index_t ConvStrideW,
74-
index_t ConvDilationH,
75-
index_t ConvDilationW,
76-
index_t InLeftPadH,
77-
index_t InLeftPadW,
78-
index_t InRightPadH,
79-
index_t InRightPadW,
65+
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_,
66+
int C_,
67+
int Hi_,
68+
int Wi_,
69+
int K_,
70+
int Y_,
71+
int X_,
72+
int ConvStrideH_,
73+
int ConvStrideW_,
74+
int ConvDilationH_,
75+
int ConvDilationW_,
76+
int InLeftPadH_,
77+
int InLeftPadW_,
78+
int InRightPadH_,
79+
int InRightPadW_,
8080
void* p_desc_tuple)
8181
{
82+
index_t N = static_cast<index_t>(N_);
83+
index_t C = static_cast<index_t>(C_);
84+
index_t Hi = static_cast<index_t>(Hi_);
85+
index_t Wi = static_cast<index_t>(Wi_);
86+
index_t K = static_cast<index_t>(K_);
87+
index_t Y = static_cast<index_t>(Y_);
88+
index_t X = static_cast<index_t>(X_);
89+
index_t ConvStrideH = static_cast<index_t>(ConvStrideH_);
90+
index_t ConvStrideW = static_cast<index_t>(ConvStrideW_);
91+
index_t ConvDilationH = static_cast<index_t>(ConvDilationH_);
92+
index_t ConvDilationW = static_cast<index_t>(ConvDilationW_);
93+
index_t InLeftPadH = static_cast<index_t>(InLeftPadH_);
94+
index_t InLeftPadW = static_cast<index_t>(InLeftPadW_);
95+
index_t InRightPadH = static_cast<index_t>(InRightPadH_);
96+
index_t InRightPadW = static_cast<index_t>(InRightPadW_);
97+
8298
constexpr auto I0 = Number<0>{};
8399
constexpr auto I1 = Number<1>{};
84100
constexpr auto I2 = Number<2>{};

0 commit comments

Comments
 (0)