Skip to content

Commit 4c96f10

Browse files
dskhudiafacebook-github-bot
authored andcommitted
Move avx2 specific code in different source files (#28)
Summary: Pull Request resolved: #28 Pull Request resolved: pytorch/pytorch#14516 This is the first diff in a series of diffs that will separate out avx2 specific code in separate files. The goal is to compile as little as possible code with avx2 and avx512 compiler flags. Reviewed By: jianyuh Differential Revision: D13248376 fbshipit-source-id: 2347f3687c2cbd5c6d21d7365c6f9bd87ee96641
1 parent 0d5a159 commit 4c96f10

File tree

6 files changed

+255
-202
lines changed

6 files changed

+255
-202
lines changed

CMakeLists.txt

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ set(FBGEMM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
2222
set(FBGEMM_THIRDPARTY_DIR ${FBGEMM_BINARY_DIR}/third_party)
2323
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2424

25+
2526
#All the source files that either use avx2 instructions statically or JIT
2627
#avx2/avx512 instructions.
27-
set(FBGEMM_AVX2_SRCS src/ExecuteKernel.cc
28+
set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
2829
src/ExecuteKernelU8S8.cc
2930
src/Fbgemm.cc
3031
src/FbgemmFP16.cc
@@ -62,6 +63,9 @@ else()
6263
message(WARNING "OpenMP is not supported by the compiler")
6364
endif()
6465

66+
#All the source files that either use avx2 instructions statically
67+
set(FBGEMM_AVX2_SRCS src/Utils_avx2.cc)
68+
6569
#All the source files that use avx512 instructions statically
6670
set(FBGEMM_AVX512_SRCS src/Utils_avx512.cc)
6771

@@ -74,14 +78,17 @@ set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h
7478
include/fbgemm/FbgemmI8Spmdm.h)
7579

7680

81+
add_library(fbgemm_generic OBJECT ${FBGEMM_GENERIC_SRCS})
7782
add_library(fbgemm_avx2 OBJECT ${FBGEMM_AVX2_SRCS})
7883
add_library(fbgemm_avx512 OBJECT ${FBGEMM_AVX512_SRCS})
7984

80-
set_target_properties(fbgemm_avx2 fbgemm_avx512 PROPERTIES
85+
set_target_properties(fbgemm_generic fbgemm_avx2 fbgemm_avx512 PROPERTIES
8186
CXX_STANDARD 11
8287
CXX_EXTENSIONS NO
8388
CXX_VISIBILITY_PRESET hidden)
8489

90+
target_compile_options(fbgemm_generic PRIVATE
91+
"-m64" "-mavx2" "-mfma" "-masm=intel")
8592
target_compile_options(fbgemm_avx2 PRIVATE
8693
"-m64" "-mavx2" "-mfma" "-masm=intel")
8794
target_compile_options(fbgemm_avx512 PRIVATE
@@ -132,6 +139,12 @@ if(NOT TARGET cpuinfo)
132139
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
133140
endif()
134141

142+
target_include_directories(fbgemm_generic BEFORE
143+
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
144+
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
145+
PRIVATE "${ASMJIT_SRC_DIR}/src"
146+
PRIVATE "${CPUINFO_SOURCE_DIR}/include")
147+
135148
target_include_directories(fbgemm_avx2 BEFORE
136149
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
137150
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
@@ -145,17 +158,24 @@ target_include_directories(fbgemm_avx512 BEFORE
145158
PRIVATE "${CPUINFO_SOURCE_DIR}/include")
146159

147160
if(FBGEMM_LIBRARY_TYPE STREQUAL "default")
148-
add_library(fbgemm $<TARGET_OBJECTS:fbgemm_avx2>
161+
add_library(fbgemm
162+
$<TARGET_OBJECTS:fbgemm_generic>
163+
$<TARGET_OBJECTS:fbgemm_avx2>
149164
$<TARGET_OBJECTS:fbgemm_avx512>)
150165
elseif(FBGEMM_LIBRARY_TYPE STREQUAL "shared")
151-
add_library(fbgemm SHARED $<TARGET_OBJECTS:fbgemm_avx2>
166+
add_library(fbgemm SHARED
167+
$<TARGET_OBJECTS:fbgemm_generic>
168+
$<TARGET_OBJECTS:fbgemm_avx2>
152169
$<TARGET_OBJECTS:fbgemm_avx512>)
170+
set_property(TARGET fbgemm_generic PROPERTY POSITION_INDEPENDENT_CODE ON)
153171
set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON)
154172
set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON)
155173
set_target_properties(fbgemm PROPERTIES
156174
CXX_VISIBILITY_PRESET hidden)
157175
elseif(FBGEMM_LIBRARY_TYPE STREQUAL "static")
158-
add_library(fbgemm STATIC $<TARGET_OBJECTS:fbgemm_avx2>
176+
add_library(fbgemm STATIC
177+
$<TARGET_OBJECTS:fbgemm_generic>
178+
$<TARGET_OBJECTS:fbgemm_avx2>
159179
$<TARGET_OBJECTS:fbgemm_avx512>)
160180
target_compile_definitions(fbgemm_avx2 PRIVATE FBGEMM_STATIC)
161181
target_compile_definitions(fbgemm_avx512 PRIVATE FBGEMM_STATIC)

include/fbgemm/Utils.h

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -95,34 +95,4 @@ void transpose_simd(
9595
float* dst,
9696
int ld_dst);
9797

98-
namespace internal {
99-
100-
/**
101-
* @brief Transpose a matrix using Intel AVX2.
102-
*
103-
* This is called if the code is running on a CPU with Intel AVX2 support.
104-
*/
105-
void transpose_8x8(
106-
int M,
107-
int N,
108-
const float* src,
109-
int ld_src,
110-
float* dst,
111-
int ld_dst);
112-
113-
/**
114-
* @brief Transpose a matrix using Intel AVX512.
115-
*
116-
* This is called if the code is running on a CPU with Intel AVX512 support.
117-
*/
118-
void transpose_16x16(
119-
int M,
120-
int N,
121-
const float* src,
122-
int ld_src,
123-
float* dst,
124-
int ld_dst);
125-
126-
} // namespace internal
127-
12898
} // namespace fbgemm

src/TransposeUtils.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
#pragma once
8+
9+
namespace fbgemm {
10+
11+
/**
12+
* @brief Reference implementation of matrix transposition: B = A^T.
13+
* @param M The height of the matrix.
14+
* @param N The width of the matrix.
15+
* @param src The memory buffer of the source matrix A.
16+
* @param ld_src The leading dimension of the source matrix A.
17+
* @param dst The memory buffer of the destination matrix B.
18+
* @param ld_dst The leading dimension of the destination matrix B.
19+
*/
20+
void transpose_ref(
21+
int M,
22+
int N,
23+
const float* src,
24+
int ld_src,
25+
float* dst,
26+
int ld_dst);
27+
28+
namespace internal {
29+
30+
/**
31+
* @brief Transpose a matrix using Intel AVX2.
32+
*
33+
* This is called if the code is running on a CPU with Intel AVX2 support.
34+
*/
35+
void transpose_8x8(
36+
int M,
37+
int N,
38+
const float* src,
39+
int ld_src,
40+
float* dst,
41+
int ld_dst);
42+
43+
/**
44+
* @brief Transpose a matrix using Intel AVX512.
45+
*
46+
* This is called if the code is running on a CPU with Intel AVX512 support.
47+
*/
48+
void transpose_16x16(
49+
int M,
50+
int N,
51+
const float* src,
52+
int ld_src,
53+
float* dst,
54+
int ld_dst);
55+
56+
} // namespace internal
57+
58+
} // namespace fbgemm

src/Utils.cc

Lines changed: 2 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77
#include "fbgemm/Utils.h"
8+
#include "TransposeUtils.h"
89
#include <cpuinfo.h>
910
#include <immintrin.h>
1011
#include <cassert>
@@ -156,16 +157,7 @@ template void printMatrix<int32_t>(
156157
size_t ld,
157158
std::string name);
158159

159-
/**
160-
* @brief Reference implementation of matrix transposition: B = A^T.
161-
* @param M The height of the matrix.
162-
* @param N The width of the matrix.
163-
* @param src The memory buffer of the source matrix A.
164-
* @param ld_src The leading dimension of the source matrix A.
165-
* @param dst The memory buffer of the destination matrix B.
166-
* @param ld_dst The leading dimension of the destination matrix B.
167-
*/
168-
inline void transpose_ref(
160+
void transpose_ref(
169161
int M,
170162
int N,
171163
const float* src,
@@ -179,161 +171,6 @@ inline void transpose_ref(
179171
} // for each output row
180172
}
181173

182-
inline void
183-
transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) {
184-
// load from src to registers
185-
// a : a0 a1 a2 a3
186-
// b : b0 b1 b2 b3
187-
// c : c0 c1 c2 c3
188-
// d : d0 d1 d2 d3
189-
__m128 a = _mm_loadu_ps(&src[0 * ld_src]);
190-
__m128 b = _mm_loadu_ps(&src[1 * ld_src]);
191-
__m128 c = _mm_loadu_ps(&src[2 * ld_src]);
192-
__m128 d = _mm_loadu_ps(&src[3 * ld_src]);
193-
194-
// transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE
195-
// a : a0 b0 c0 d0
196-
// b : a1 b1 c1 d1
197-
// c : a2 b2 c2 d2
198-
// d : a3 b3 c3 d3
199-
_MM_TRANSPOSE4_PS(a, b, c, d);
200-
201-
// store from registers to dst
202-
_mm_storeu_ps(&dst[0 * ld_dst], a);
203-
_mm_storeu_ps(&dst[1 * ld_dst], b);
204-
_mm_storeu_ps(&dst[2 * ld_dst], c);
205-
_mm_storeu_ps(&dst[3 * ld_dst], d);
206-
}
207-
inline void transpose_4x4(
208-
int M,
209-
int N,
210-
const float* src,
211-
int ld_src,
212-
float* dst,
213-
int ld_dst) {
214-
int ib = 0, jb = 0;
215-
for (ib = 0; ib + 4 <= M; ib += 4) {
216-
for (jb = 0; jb + 4 <= N; jb += 4) {
217-
transpose_kernel_4x4_sse(
218-
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
219-
}
220-
}
221-
transpose_ref(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
222-
transpose_ref(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
223-
}
224-
225-
inline void transpose_kernel_8x8_avx2(
226-
const float* src,
227-
int ld_src,
228-
float* dst,
229-
int ld_dst) {
230-
// load from src to registers
231-
// a : a0 a1 a2 a3 a4 a5 a6 a7
232-
// b : b0 b1 b2 b3 b4 b5 b6 b7
233-
// c : c0 c1 c2 c3 c4 c5 c6 c7
234-
// d : d0 d1 d2 d3 d4 d5 d6 d7
235-
// e : e0 e1 e2 e3 e4 e5 e6 e7
236-
// f : f0 f1 f2 f3 f4 f5 f6 f7
237-
// g : g0 g1 g2 g3 g4 g5 g6 g7
238-
// h : h0 h1 h2 h3 h4 h5 h6 h7
239-
__m256 a = _mm256_loadu_ps(&src[0 * ld_src]);
240-
__m256 b = _mm256_loadu_ps(&src[1 * ld_src]);
241-
__m256 c = _mm256_loadu_ps(&src[2 * ld_src]);
242-
__m256 d = _mm256_loadu_ps(&src[3 * ld_src]);
243-
__m256 e = _mm256_loadu_ps(&src[4 * ld_src]);
244-
__m256 f = _mm256_loadu_ps(&src[5 * ld_src]);
245-
__m256 g = _mm256_loadu_ps(&src[6 * ld_src]);
246-
__m256 h = _mm256_loadu_ps(&src[7 * ld_src]);
247-
248-
__m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
249-
__m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
250-
// unpacking and interleaving 32-bit elements
251-
// ab0145 : a0 b0 a1 b1 a4 b4 a5 b5
252-
// ab2367 : a2 b2 a3 b3 a6 b6 a7 b7
253-
// cd0145 : c0 d0 c1 d1 c4 d4 c5 d5
254-
// cd2367 : c2 d2 c3 d3 c6 d6 c7 d7
255-
// ef0145 : e0 f0 e1 f1 e4 f4 e5 f5
256-
// ef2367 : e2 f2 e3 f3 e6 f6 e7 f7
257-
// gh0145 : g0 h0 g1 h1 g4 h4 g5 h5
258-
// gh2367 : g2 h2 g3 h3 g6 h6 g7 h7
259-
ab0145 = _mm256_unpacklo_ps(a, b);
260-
ab2367 = _mm256_unpackhi_ps(a, b);
261-
cd0145 = _mm256_unpacklo_ps(c, d);
262-
cd2367 = _mm256_unpackhi_ps(c, d);
263-
ef0145 = _mm256_unpacklo_ps(e, f);
264-
ef2367 = _mm256_unpackhi_ps(e, f);
265-
gh0145 = _mm256_unpacklo_ps(g, h);
266-
gh2367 = _mm256_unpackhi_ps(g, h);
267-
268-
// shuffling the 32-bit elements
269-
// abcd04 : a0 b0 c0 d0 a4 b4 c4 d4
270-
// abcd15 : a1 b1 c1 d1 a5 b5 c5 d5
271-
// efgh04 : e0 f0 g0 h0 e4 f4 g4 h4
272-
// efgh15 : e1 f1 g1 h1 e5 b5 c5 d5
273-
// abcd26 : a2 b2 c2 d2 a6 b6 c6 d6
274-
// abcd37 : a3 b3 c3 d3 a7 b7 c7 d7
275-
// efgh26 : e2 f2 g2 h2 e6 f6 g6 h6
276-
// efgh37 : e3 f3 g3 h3 e7 f7 g7 h7
277-
abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44);
278-
abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee);
279-
efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44);
280-
efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee);
281-
abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44);
282-
abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee);
283-
efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44);
284-
efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee);
285-
286-
// shuffling 128-bit elements
287-
// a : a0 b0 c0 d0 e0 f0 g0 h0
288-
// b : a1 b1 c1 d1 e1 f1 g1 h1
289-
// c : a2 b2 c2 d2 e2 f2 g2 h2
290-
// d : a3 b3 c3 d3 e3 f3 g3 h3
291-
// e : a4 b4 c4 d4 e4 f4 g4 h4
292-
// f : a5 b5 c5 d5 e5 f5 g5 h5
293-
// g : a6 b6 c6 d6 e6 f6 g6 h6
294-
// h : a7 b7 c7 d7 e7 f7 g7 h7
295-
a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02);
296-
b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02);
297-
c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02);
298-
d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02);
299-
e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13);
300-
f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13);
301-
g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13);
302-
h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13);
303-
304-
// store from registers to dst
305-
_mm256_storeu_ps(&dst[0 * ld_dst], a);
306-
_mm256_storeu_ps(&dst[1 * ld_dst], b);
307-
_mm256_storeu_ps(&dst[2 * ld_dst], c);
308-
_mm256_storeu_ps(&dst[3 * ld_dst], d);
309-
_mm256_storeu_ps(&dst[4 * ld_dst], e);
310-
_mm256_storeu_ps(&dst[5 * ld_dst], f);
311-
_mm256_storeu_ps(&dst[6 * ld_dst], g);
312-
_mm256_storeu_ps(&dst[7 * ld_dst], h);
313-
}
314-
315-
namespace internal {
316-
317-
void transpose_8x8(
318-
int M,
319-
int N,
320-
const float* src,
321-
int ld_src,
322-
float* dst,
323-
int ld_dst) {
324-
int ib = 0, jb = 0;
325-
for (ib = 0; ib + 8 <= M; ib += 8) {
326-
for (jb = 0; jb + 8 <= N; jb += 8) {
327-
transpose_kernel_8x8_avx2(
328-
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
329-
}
330-
}
331-
transpose_4x4(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
332-
transpose_4x4(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
333-
}
334-
335-
} // namespace internal
336-
337174
void transpose_simd(
338175
int M,
339176
int N,

0 commit comments

Comments
 (0)