You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
staticvoidconcat_f32_cuda(constfloat * x, constfloat * y, float * dst, constint ne0, int ne1, int ne2, intne02, cudaStream_t stream) {
81
+
staticvoidconcat_f32_cuda(constfloat * x, constfloat * y, float * dst, int ne00, intne01, int ne02, intne0, int ne1, int ne2, intdim, cudaStream_t stream) {
29
82
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
30
83
dim3gridDim(num_blocks, ne1, ne2);
31
-
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
84
+
if (dim == 0) {
85
+
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
86
+
return;
87
+
}
88
+
if (dim == 1) {
89
+
concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
90
+
return;
91
+
}
92
+
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
0 commit comments