Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ set(FBGEMM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(FBGEMM_THIRDPARTY_DIR ${FBGEMM_BINARY_DIR}/third_party)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)


#All the source files that either use avx2 instructions statically or JIT
#avx2/avx512 instructions.
set(FBGEMM_AVX2_SRCS src/ExecuteKernel.cc
set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/ExecuteKernelU8S8.cc
src/Fbgemm.cc
src/FbgemmFP16.cc
Expand Down Expand Up @@ -62,6 +63,9 @@ else()
message(WARNING "OpenMP is not supported by the compiler")
endif()

#All the source files that either use avx2 instructions statically
set(FBGEMM_AVX2_SRCS src/Utils_avx2.cc)

#All the source files that use avx512 instructions statically
set(FBGEMM_AVX512_SRCS src/Utils_avx512.cc)

Expand All @@ -74,14 +78,17 @@ set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h
include/fbgemm/FbgemmI8Spmdm.h)


add_library(fbgemm_generic OBJECT ${FBGEMM_GENERIC_SRCS})
add_library(fbgemm_avx2 OBJECT ${FBGEMM_AVX2_SRCS})
add_library(fbgemm_avx512 OBJECT ${FBGEMM_AVX512_SRCS})

set_target_properties(fbgemm_avx2 fbgemm_avx512 PROPERTIES
set_target_properties(fbgemm_generic fbgemm_avx2 fbgemm_avx512 PROPERTIES
CXX_STANDARD 11
CXX_EXTENSIONS NO
CXX_VISIBILITY_PRESET hidden)

target_compile_options(fbgemm_generic PRIVATE
"-m64" "-mavx2" "-mfma" "-masm=intel")
target_compile_options(fbgemm_avx2 PRIVATE
"-m64" "-mavx2" "-mfma" "-masm=intel")
target_compile_options(fbgemm_avx512 PRIVATE
Expand Down Expand Up @@ -132,6 +139,12 @@ if(NOT TARGET cpuinfo)
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()

target_include_directories(fbgemm_generic BEFORE
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
PRIVATE "${ASMJIT_SRC_DIR}/src"
PRIVATE "${CPUINFO_SOURCE_DIR}/include")

target_include_directories(fbgemm_avx2 BEFORE
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
Expand All @@ -145,17 +158,24 @@ target_include_directories(fbgemm_avx512 BEFORE
PRIVATE "${CPUINFO_SOURCE_DIR}/include")

if(FBGEMM_LIBRARY_TYPE STREQUAL "default")
add_library(fbgemm $<TARGET_OBJECTS:fbgemm_avx2>
add_library(fbgemm
$<TARGET_OBJECTS:fbgemm_generic>
$<TARGET_OBJECTS:fbgemm_avx2>
$<TARGET_OBJECTS:fbgemm_avx512>)
elseif(FBGEMM_LIBRARY_TYPE STREQUAL "shared")
add_library(fbgemm SHARED $<TARGET_OBJECTS:fbgemm_avx2>
add_library(fbgemm SHARED
$<TARGET_OBJECTS:fbgemm_generic>
$<TARGET_OBJECTS:fbgemm_avx2>
$<TARGET_OBJECTS:fbgemm_avx512>)
set_property(TARGET fbgemm_generic PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_target_properties(fbgemm PROPERTIES
CXX_VISIBILITY_PRESET hidden)
elseif(FBGEMM_LIBRARY_TYPE STREQUAL "static")
add_library(fbgemm STATIC $<TARGET_OBJECTS:fbgemm_avx2>
add_library(fbgemm STATIC
$<TARGET_OBJECTS:fbgemm_generic>
$<TARGET_OBJECTS:fbgemm_avx2>
$<TARGET_OBJECTS:fbgemm_avx512>)
target_compile_definitions(fbgemm_avx2 PRIVATE FBGEMM_STATIC)
target_compile_definitions(fbgemm_avx512 PRIVATE FBGEMM_STATIC)
Expand Down
30 changes: 0 additions & 30 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,34 +95,4 @@ void transpose_simd(
float* dst,
int ld_dst);

namespace internal {

/**
* @brief Transpose a matrix using Intel AVX2.
*
* This is called if the code is running on a CPU with Intel AVX2 support.
*/
void transpose_8x8(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);

/**
* @brief Transpose a matrix using Intel AVX512.
*
* This is called if the code is running on a CPU with Intel AVX512 support.
*/
void transpose_16x16(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);

} // namespace internal

} // namespace fbgemm
58 changes: 58 additions & 0 deletions src/TransposeUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once

namespace fbgemm {

/**
* @brief Reference implementation of matrix transposition: B = A^T.
* @param M The height of the matrix.
* @param N The width of the matrix.
* @param src The memory buffer of the source matrix A.
* @param ld_src The leading dimension of the source matrix A.
* @param dst The memory buffer of the destination matrix B.
* @param ld_dst The leading dimension of the destination matrix B.
*/
void transpose_ref(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);

namespace internal {

/**
* @brief Transpose a matrix using Intel AVX2.
*
* This is called if the code is running on a CPU with Intel AVX2 support.
*/
void transpose_8x8(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);

/**
* @brief Transpose a matrix using Intel AVX512.
*
* This is called if the code is running on a CPU with Intel AVX512 support.
*/
void transpose_16x16(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);

} // namespace internal

} // namespace fbgemm
167 changes: 2 additions & 165 deletions src/Utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* LICENSE file in the root directory of this source tree.
*/
#include "fbgemm/Utils.h"
#include "TransposeUtils.h"
#include <cpuinfo.h>
#include <immintrin.h>
#include <cassert>
Expand Down Expand Up @@ -156,16 +157,7 @@ template void printMatrix<int32_t>(
size_t ld,
std::string name);

/**
* @brief Reference implementation of matrix transposition: B = A^T.
* @param M The height of the matrix.
* @param N The width of the matrix.
* @param src The memory buffer of the source matrix A.
* @param ld_src The leading dimension of the source matrix A.
* @param dst The memory buffer of the destination matrix B.
* @param ld_dst The leading dimension of the destination matrix B.
*/
inline void transpose_ref(
void transpose_ref(
int M,
int N,
const float* src,
Expand All @@ -179,161 +171,6 @@ inline void transpose_ref(
} // for each output row
}

inline void
transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) {
// load from src to registers
// a : a0 a1 a2 a3
// b : b0 b1 b2 b3
// c : c0 c1 c2 c3
// d : d0 d1 d2 d3
__m128 a = _mm_loadu_ps(&src[0 * ld_src]);
__m128 b = _mm_loadu_ps(&src[1 * ld_src]);
__m128 c = _mm_loadu_ps(&src[2 * ld_src]);
__m128 d = _mm_loadu_ps(&src[3 * ld_src]);

// transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE
// a : a0 b0 c0 d0
// b : a1 b1 c1 d1
// c : a2 b2 c2 d2
// d : a3 b3 c3 d3
_MM_TRANSPOSE4_PS(a, b, c, d);

// store from registers to dst
_mm_storeu_ps(&dst[0 * ld_dst], a);
_mm_storeu_ps(&dst[1 * ld_dst], b);
_mm_storeu_ps(&dst[2 * ld_dst], c);
_mm_storeu_ps(&dst[3 * ld_dst], d);
}
inline void transpose_4x4(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst) {
int ib = 0, jb = 0;
for (ib = 0; ib + 4 <= M; ib += 4) {
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_4x4_sse(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
}
transpose_ref(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
transpose_ref(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
}

inline void transpose_kernel_8x8_avx2(
const float* src,
int ld_src,
float* dst,
int ld_dst) {
// load from src to registers
// a : a0 a1 a2 a3 a4 a5 a6 a7
// b : b0 b1 b2 b3 b4 b5 b6 b7
// c : c0 c1 c2 c3 c4 c5 c6 c7
// d : d0 d1 d2 d3 d4 d5 d6 d7
// e : e0 e1 e2 e3 e4 e5 e6 e7
// f : f0 f1 f2 f3 f4 f5 f6 f7
// g : g0 g1 g2 g3 g4 g5 g6 g7
// h : h0 h1 h2 h3 h4 h5 h6 h7
__m256 a = _mm256_loadu_ps(&src[0 * ld_src]);
__m256 b = _mm256_loadu_ps(&src[1 * ld_src]);
__m256 c = _mm256_loadu_ps(&src[2 * ld_src]);
__m256 d = _mm256_loadu_ps(&src[3 * ld_src]);
__m256 e = _mm256_loadu_ps(&src[4 * ld_src]);
__m256 f = _mm256_loadu_ps(&src[5 * ld_src]);
__m256 g = _mm256_loadu_ps(&src[6 * ld_src]);
__m256 h = _mm256_loadu_ps(&src[7 * ld_src]);

__m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
__m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
// unpacking and interleaving 32-bit elements
// ab0145 : a0 b0 a1 b1 a4 b4 a5 b5
// ab2367 : a2 b2 a3 b3 a6 b6 a7 b7
// cd0145 : c0 d0 c1 d1 c4 d4 c5 d5
// cd2367 : c2 d2 c3 d3 c6 d6 c7 d7
// ef0145 : e0 f0 e1 f1 e4 f4 e5 f5
// ef2367 : e2 f2 e3 f3 e6 f6 e7 f7
// gh0145 : g0 h0 g1 h1 g4 h4 g5 h5
// gh2367 : g2 h2 g3 h3 g6 h6 g7 h7
ab0145 = _mm256_unpacklo_ps(a, b);
ab2367 = _mm256_unpackhi_ps(a, b);
cd0145 = _mm256_unpacklo_ps(c, d);
cd2367 = _mm256_unpackhi_ps(c, d);
ef0145 = _mm256_unpacklo_ps(e, f);
ef2367 = _mm256_unpackhi_ps(e, f);
gh0145 = _mm256_unpacklo_ps(g, h);
gh2367 = _mm256_unpackhi_ps(g, h);

// shuffling the 32-bit elements
// abcd04 : a0 b0 c0 d0 a4 b4 c4 d4
// abcd15 : a1 b1 c1 d1 a5 b5 c5 d5
// efgh04 : e0 f0 g0 h0 e4 f4 g4 h4
// efgh15 : e1 f1 g1 h1 e5 b5 c5 d5
// abcd26 : a2 b2 c2 d2 a6 b6 c6 d6
// abcd37 : a3 b3 c3 d3 a7 b7 c7 d7
// efgh26 : e2 f2 g2 h2 e6 f6 g6 h6
// efgh37 : e3 f3 g3 h3 e7 f7 g7 h7
abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44);
abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee);
efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44);
efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee);
abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44);
abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee);
efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44);
efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee);

// shuffling 128-bit elements
// a : a0 b0 c0 d0 e0 f0 g0 h0
// b : a1 b1 c1 d1 e1 f1 g1 h1
// c : a2 b2 c2 d2 e2 f2 g2 h2
// d : a3 b3 c3 d3 e3 f3 g3 h3
// e : a4 b4 c4 d4 e4 f4 g4 h4
// f : a5 b5 c5 d5 e5 f5 g5 h5
// g : a6 b6 c6 d6 e6 f6 g6 h6
// h : a7 b7 c7 d7 e7 f7 g7 h7
a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02);
b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02);
c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02);
d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02);
e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13);
f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13);
g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13);
h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13);

// store from registers to dst
_mm256_storeu_ps(&dst[0 * ld_dst], a);
_mm256_storeu_ps(&dst[1 * ld_dst], b);
_mm256_storeu_ps(&dst[2 * ld_dst], c);
_mm256_storeu_ps(&dst[3 * ld_dst], d);
_mm256_storeu_ps(&dst[4 * ld_dst], e);
_mm256_storeu_ps(&dst[5 * ld_dst], f);
_mm256_storeu_ps(&dst[6 * ld_dst], g);
_mm256_storeu_ps(&dst[7 * ld_dst], h);
}

namespace internal {

void transpose_8x8(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst) {
int ib = 0, jb = 0;
for (ib = 0; ib + 8 <= M; ib += 8) {
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_8x8_avx2(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
}
transpose_4x4(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
transpose_4x4(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
}

} // namespace internal

void transpose_simd(
int M,
int N,
Expand Down
Loading