5
5
* LICENSE file in the root directory of this source tree.
6
6
*/
7
7
#include " fbgemm/Utils.h"
8
+ #include " TransposeUtils.h"
8
9
#include < cpuinfo.h>
9
10
#include < immintrin.h>
10
11
#include < cassert>
@@ -156,16 +157,7 @@ template void printMatrix<int32_t>(
156
157
size_t ld,
157
158
std::string name);
158
159
159
- /* *
160
- * @brief Reference implementation of matrix transposition: B = A^T.
161
- * @param M The height of the matrix.
162
- * @param N The width of the matrix.
163
- * @param src The memory buffer of the source matrix A.
164
- * @param ld_src The leading dimension of the source matrix A.
165
- * @param dst The memory buffer of the destination matrix B.
166
- * @param ld_dst The leading dimension of the destination matrix B.
167
- */
168
- inline void transpose_ref (
160
+ void transpose_ref (
169
161
int M,
170
162
int N,
171
163
const float * src,
@@ -179,161 +171,6 @@ inline void transpose_ref(
179
171
} // for each output row
180
172
}
181
173
182
- inline void
183
- transpose_kernel_4x4_sse (const float * src, int ld_src, float * dst, int ld_dst) {
184
- // load from src to registers
185
- // a : a0 a1 a2 a3
186
- // b : b0 b1 b2 b3
187
- // c : c0 c1 c2 c3
188
- // d : d0 d1 d2 d3
189
- __m128 a = _mm_loadu_ps (&src[0 * ld_src]);
190
- __m128 b = _mm_loadu_ps (&src[1 * ld_src]);
191
- __m128 c = _mm_loadu_ps (&src[2 * ld_src]);
192
- __m128 d = _mm_loadu_ps (&src[3 * ld_src]);
193
-
194
- // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE
195
- // a : a0 b0 c0 d0
196
- // b : a1 b1 c1 d1
197
- // c : a2 b2 c2 d2
198
- // d : a3 b3 c3 d3
199
- _MM_TRANSPOSE4_PS (a, b, c, d);
200
-
201
- // store from registers to dst
202
- _mm_storeu_ps (&dst[0 * ld_dst], a);
203
- _mm_storeu_ps (&dst[1 * ld_dst], b);
204
- _mm_storeu_ps (&dst[2 * ld_dst], c);
205
- _mm_storeu_ps (&dst[3 * ld_dst], d);
206
- }
207
- inline void transpose_4x4 (
208
- int M,
209
- int N,
210
- const float * src,
211
- int ld_src,
212
- float * dst,
213
- int ld_dst) {
214
- int ib = 0 , jb = 0 ;
215
- for (ib = 0 ; ib + 4 <= M; ib += 4 ) {
216
- for (jb = 0 ; jb + 4 <= N; jb += 4 ) {
217
- transpose_kernel_4x4_sse (
218
- &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
219
- }
220
- }
221
- transpose_ref (ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
222
- transpose_ref (M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
223
- }
224
-
225
- inline void transpose_kernel_8x8_avx2 (
226
- const float * src,
227
- int ld_src,
228
- float * dst,
229
- int ld_dst) {
230
- // load from src to registers
231
- // a : a0 a1 a2 a3 a4 a5 a6 a7
232
- // b : b0 b1 b2 b3 b4 b5 b6 b7
233
- // c : c0 c1 c2 c3 c4 c5 c6 c7
234
- // d : d0 d1 d2 d3 d4 d5 d6 d7
235
- // e : e0 e1 e2 e3 e4 e5 e6 e7
236
- // f : f0 f1 f2 f3 f4 f5 f6 f7
237
- // g : g0 g1 g2 g3 g4 g5 g6 g7
238
- // h : h0 h1 h2 h3 h4 h5 h6 h7
239
- __m256 a = _mm256_loadu_ps (&src[0 * ld_src]);
240
- __m256 b = _mm256_loadu_ps (&src[1 * ld_src]);
241
- __m256 c = _mm256_loadu_ps (&src[2 * ld_src]);
242
- __m256 d = _mm256_loadu_ps (&src[3 * ld_src]);
243
- __m256 e = _mm256_loadu_ps (&src[4 * ld_src]);
244
- __m256 f = _mm256_loadu_ps (&src[5 * ld_src]);
245
- __m256 g = _mm256_loadu_ps (&src[6 * ld_src]);
246
- __m256 h = _mm256_loadu_ps (&src[7 * ld_src]);
247
-
248
- __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
249
- __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
250
- // unpacking and interleaving 32-bit elements
251
- // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5
252
- // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7
253
- // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5
254
- // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7
255
- // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5
256
- // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7
257
- // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5
258
- // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7
259
- ab0145 = _mm256_unpacklo_ps (a, b);
260
- ab2367 = _mm256_unpackhi_ps (a, b);
261
- cd0145 = _mm256_unpacklo_ps (c, d);
262
- cd2367 = _mm256_unpackhi_ps (c, d);
263
- ef0145 = _mm256_unpacklo_ps (e, f);
264
- ef2367 = _mm256_unpackhi_ps (e, f);
265
- gh0145 = _mm256_unpacklo_ps (g, h);
266
- gh2367 = _mm256_unpackhi_ps (g, h);
267
-
268
- // shuffling the 32-bit elements
269
- // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4
270
- // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5
271
- // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4
272
- // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5
273
- // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6
274
- // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7
275
- // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6
276
- // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7
277
- abcd04 = _mm256_shuffle_ps (ab0145, cd0145, 0x44 );
278
- abcd15 = _mm256_shuffle_ps (ab0145, cd0145, 0xee );
279
- efgh04 = _mm256_shuffle_ps (ef0145, gh0145, 0x44 );
280
- efgh15 = _mm256_shuffle_ps (ef0145, gh0145, 0xee );
281
- abcd26 = _mm256_shuffle_ps (ab2367, cd2367, 0x44 );
282
- abcd37 = _mm256_shuffle_ps (ab2367, cd2367, 0xee );
283
- efgh26 = _mm256_shuffle_ps (ef2367, gh2367, 0x44 );
284
- efgh37 = _mm256_shuffle_ps (ef2367, gh2367, 0xee );
285
-
286
- // shuffling 128-bit elements
287
- // a : a0 b0 c0 d0 e0 f0 g0 h0
288
- // b : a1 b1 c1 d1 e1 f1 g1 h1
289
- // c : a2 b2 c2 d2 e2 f2 g2 h2
290
- // d : a3 b3 c3 d3 e3 f3 g3 h3
291
- // e : a4 b4 c4 d4 e4 f4 g4 h4
292
- // f : a5 b5 c5 d5 e5 f5 g5 h5
293
- // g : a6 b6 c6 d6 e6 f6 g6 h6
294
- // h : a7 b7 c7 d7 e7 f7 g7 h7
295
- a = _mm256_permute2f128_ps (efgh04, abcd04, 0x02 );
296
- b = _mm256_permute2f128_ps (efgh15, abcd15, 0x02 );
297
- c = _mm256_permute2f128_ps (efgh26, abcd26, 0x02 );
298
- d = _mm256_permute2f128_ps (efgh37, abcd37, 0x02 );
299
- e = _mm256_permute2f128_ps (efgh04, abcd04, 0x13 );
300
- f = _mm256_permute2f128_ps (efgh15, abcd15, 0x13 );
301
- g = _mm256_permute2f128_ps (efgh26, abcd26, 0x13 );
302
- h = _mm256_permute2f128_ps (efgh37, abcd37, 0x13 );
303
-
304
- // store from registers to dst
305
- _mm256_storeu_ps (&dst[0 * ld_dst], a);
306
- _mm256_storeu_ps (&dst[1 * ld_dst], b);
307
- _mm256_storeu_ps (&dst[2 * ld_dst], c);
308
- _mm256_storeu_ps (&dst[3 * ld_dst], d);
309
- _mm256_storeu_ps (&dst[4 * ld_dst], e);
310
- _mm256_storeu_ps (&dst[5 * ld_dst], f);
311
- _mm256_storeu_ps (&dst[6 * ld_dst], g);
312
- _mm256_storeu_ps (&dst[7 * ld_dst], h);
313
- }
314
-
315
- namespace internal {
316
-
317
- void transpose_8x8 (
318
- int M,
319
- int N,
320
- const float * src,
321
- int ld_src,
322
- float * dst,
323
- int ld_dst) {
324
- int ib = 0 , jb = 0 ;
325
- for (ib = 0 ; ib + 8 <= M; ib += 8 ) {
326
- for (jb = 0 ; jb + 8 <= N; jb += 8 ) {
327
- transpose_kernel_8x8_avx2 (
328
- &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
329
- }
330
- }
331
- transpose_4x4 (ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
332
- transpose_4x4 (M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
333
- }
334
-
335
- } // namespace internal
336
-
337
174
void transpose_simd (
338
175
int M,
339
176
int N,
0 commit comments