Skip to content

Q5: Slightly faster AVX2 implementation #1197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ static ggml_fp16_t table_exp_f16[1 << 16];
// precomputed f32 table for f16 (256 KB)
static float table_f32_f16[1 << 16];

#if defined(__ARM_NEON)
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
Expand All @@ -339,7 +340,7 @@ static float table_f32_f16[1 << 16];

// precomputed tables for expanding 8bits to 8 bytes (shl 4)
static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
static const uint64_t table_b2b_i[1 << 8] = { B8(F0, 00) };
#endif

// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
Expand Down Expand Up @@ -490,6 +491,19 @@ static inline int hsum_i32_4(const __m128i a) {
}

#if __AVX2__ || __AVX512F__
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m256i shuf_mask = _mm256_set_epi64x(
0x0303030303030303, 0x0202020202020202,
0x0101010101010101, 0x0000000000000000);
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
bytes = _mm256_or_si256(bytes, bit_mask);
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
}

// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
Expand Down Expand Up @@ -3367,9 +3381,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));

__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = _mm256_set_epi64x(
table_b2b_i[x[i].qh[3]], table_b2b_i[x[i].qh[2]],
table_b2b_i[x[i].qh[1]], table_b2b_i[x[i].qh[0]]);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
bx = _mm256_or_si256(bx, bxhi);

__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
Expand Down Expand Up @@ -3501,9 +3514,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);

__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = _mm256_set_epi64x(
table_b2b_u[x[i].qh[3]], table_b2b_u[x[i].qh[2]],
table_b2b_u[x[i].qh[1]], table_b2b_u[x[i].qh[0]]);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
bx = _mm256_or_si256(bx, bxhi);

const __m256 dy = _mm256_broadcast_ss(&y[i].d);
Expand Down