@@ -62,23 +62,39 @@ constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBloc
62
62
constexpr bool HasDoubleTailKBlockLoop = static_cast <bool >(CK_PARAM_HasDoubleTailKBlockLoop);
63
63
64
64
extern " C" __global__ void
65
- convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare (index_t N ,
66
- index_t C ,
67
- index_t Hi ,
68
- index_t Wi ,
69
- index_t K ,
70
- index_t Y ,
71
- index_t X ,
72
- index_t ConvStrideH ,
73
- index_t ConvStrideW ,
74
- index_t ConvDilationH ,
75
- index_t ConvDilationW ,
76
- index_t InLeftPadH ,
77
- index_t InLeftPadW ,
78
- index_t InRightPadH ,
79
- index_t InRightPadW ,
65
+ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare (int N_ ,
66
+ int C_ ,
67
+ int Hi_ ,
68
+ int Wi_ ,
69
+ int K_ ,
70
+ int Y_ ,
71
+ int X_ ,
72
+ int ConvStrideH_ ,
73
+ int ConvStrideW_ ,
74
+ int ConvDilationH_ ,
75
+ int ConvDilationW_ ,
76
+ int InLeftPadH_ ,
77
+ int InLeftPadW_ ,
78
+ int InRightPadH_ ,
79
+ int InRightPadW_ ,
80
80
void * p_desc_tuple)
81
81
{
82
+ index_t N = static_cast <index_t >(N_);
83
+ index_t C = static_cast <index_t >(C_);
84
+ index_t Hi = static_cast <index_t >(Hi_);
85
+ index_t Wi = static_cast <index_t >(Wi_);
86
+ index_t K = static_cast <index_t >(K_);
87
+ index_t Y = static_cast <index_t >(Y_);
88
+ index_t X = static_cast <index_t >(X_);
89
+ index_t ConvStrideH = static_cast <index_t >(ConvStrideH_);
90
+ index_t ConvStrideW = static_cast <index_t >(ConvStrideW_);
91
+ index_t ConvDilationH = static_cast <index_t >(ConvDilationH_);
92
+ index_t ConvDilationW = static_cast <index_t >(ConvDilationW_);
93
+ index_t InLeftPadH = static_cast <index_t >(InLeftPadH_);
94
+ index_t InLeftPadW = static_cast <index_t >(InLeftPadW_);
95
+ index_t InRightPadH = static_cast <index_t >(InRightPadH_);
96
+ index_t InRightPadW = static_cast <index_t >(InRightPadW_);
97
+
82
98
constexpr auto I0 = Number<0 >{};
83
99
constexpr auto I1 = Number<1 >{};
84
100
constexpr auto I2 = Number<2 >{};
0 commit comments