Skip to content

Inline asm register operands allocated only into zmm0..zmm15, not using zmm16..zmm31. #68818

@bjacob

Description

@bjacob

See Compiler Explorer testcase: https://godbolt.org/z/nh44Tsdo1

The hot loop in this piece of code needs to use 17 zmm registers, all as register operands to inline asm. But it's only using zmm0..zmm15, resulting in inefficient code copying between zmm registers and spilling.

Testcase pasted here for completeness:

#include <immintrin.h>
#include <stdint.h>

static inline __m512 foo(__m512 acc, __m512 lhs, const float* rhs) {
    asm("vfmadd231ps %[rhs]%{1to16%}, %[lhs], %[acc]"
        : [acc] "+x"(acc)
        : [lhs] "x"(lhs), [rhs] "m"(*rhs)
        :);
    return acc;
}

void bar(void* out_tile, const void* lhs_panel, const void* rhs_panel, int K) {
    float* out_ptr = out_tile;
    const float* lhs_ptr = lhs_panel;
    const float* rhs_ptr = rhs_panel;
    __m512 acc[16];
    for (int i = 0; i < 16; ++i) {
        acc[i] = _mm512_loadu_ps(out_ptr + i * 16);
    }

    for (int32_t k = 0; k < K; ++k) {
        __m512 rhs = _mm512_loadu_ps(rhs_ptr);
        rhs_ptr += 16;
        acc[0] = foo(acc[0], rhs, lhs_ptr + 0);
        acc[1] = foo(acc[1], rhs, lhs_ptr + 1);
        acc[2] = foo(acc[2], rhs, lhs_ptr + 2);
        acc[3] = foo(acc[3], rhs, lhs_ptr + 3);
        acc[4] = foo(acc[4], rhs, lhs_ptr + 4);
        acc[5] = foo(acc[5], rhs, lhs_ptr + 5);
        acc[6] = foo(acc[6], rhs, lhs_ptr + 6);
        acc[7] = foo(acc[7], rhs, lhs_ptr + 7);
        acc[8] = foo(acc[8], rhs, lhs_ptr + 8);
        acc[9] = foo(acc[9], rhs, lhs_ptr + 9);
        acc[10] = foo(acc[10], rhs, lhs_ptr + 10);
        acc[11] = foo(acc[11], rhs, lhs_ptr + 11);
        acc[12] = foo(acc[12], rhs, lhs_ptr + 12);
        acc[13] = foo(acc[13], rhs, lhs_ptr + 13);
        acc[14] = foo(acc[14], rhs, lhs_ptr + 14);
        acc[15] = foo(acc[15], rhs, lhs_ptr + 15);
        lhs_ptr += 16;
    }

    for (int i = 0; i < 16; ++i) {
        _mm512_storeu_ps(out_ptr + i * 16, acc[i]);
    }
}

Compile with: -O2 -mavx512f

Result (hot loop excerpt):

.LBB0_1:                                # =>This Inner Loop Header: Depth=1
        vmovups zmmword ptr [rsp - 64], zmm0    # 64-byte Spill
        vmovups zmm0, zmmword ptr [rdx]
        vfmadd231ps     zmm3, zmm0, dword ptr [rsi]{1to16} # zmm3 = (zmm0 * mem) + zmm3
        vfmadd231ps     zmm15, zmm0, dword ptr [rsi + 4]{1to16} # zmm15 = (zmm0 * mem) + zmm15
        vfmadd231ps     zmm14, zmm0, dword ptr [rsi + 8]{1to16} # zmm14 = (zmm0 * mem) + zmm14
        vfmadd231ps     zmm13, zmm0, dword ptr [rsi + 12]{1to16} # zmm13 = (zmm0 * mem) + zmm13
        vfmadd231ps     zmm12, zmm0, dword ptr [rsi + 16]{1to16} # zmm12 = (zmm0 * mem) + zmm12
        vfmadd231ps     zmm11, zmm0, dword ptr [rsi + 20]{1to16} # zmm11 = (zmm0 * mem) + zmm11
        vfmadd231ps     zmm10, zmm0, dword ptr [rsi + 24]{1to16} # zmm10 = (zmm0 * mem) + zmm10
        vfmadd231ps     zmm9, zmm0, dword ptr [rsi + 28]{1to16} # zmm9 = (zmm0 * mem) + zmm9
        vfmadd231ps     zmm8, zmm0, dword ptr [rsi + 32]{1to16} # zmm8 = (zmm0 * mem) + zmm8
        vfmadd231ps     zmm7, zmm0, dword ptr [rsi + 36]{1to16} # zmm7 = (zmm0 * mem) + zmm7
        vfmadd231ps     zmm6, zmm0, dword ptr [rsi + 40]{1to16} # zmm6 = (zmm0 * mem) + zmm6
        vfmadd231ps     zmm5, zmm0, dword ptr [rsi + 44]{1to16} # zmm5 = (zmm0 * mem) + zmm5
        vfmadd231ps     zmm4, zmm0, dword ptr [rsi + 48]{1to16} # zmm4 = (zmm0 * mem) + zmm4
        vfmadd231ps     zmm2, zmm0, dword ptr [rsi + 52]{1to16} # zmm2 = (zmm0 * mem) + zmm2
        vmovaps zmm1, zmm15
        vmovaps zmm15, zmm14
        vmovaps zmm14, zmm13
        vmovaps zmm13, zmm12
        vmovaps zmm12, zmm11
        vmovaps zmm11, zmm10
        vmovaps zmm10, zmm9
        vmovaps zmm9, zmm8
        vmovaps zmm8, zmm7
        vmovaps zmm7, zmm6
        vmovaps zmm6, zmm5
        vmovaps zmm5, zmm4
        vmovaps zmm4, zmm2
        vmovups zmm2, zmmword ptr [rsp - 128]   # 64-byte Reload
        vfmadd231ps     zmm2, zmm0, dword ptr [rsi + 56]{1to16} # zmm2 = (zmm0 * mem) + zmm2
        vmovups zmmword ptr [rsp - 128], zmm2   # 64-byte Spill
        vmovaps zmm2, zmm4
        vmovaps zmm4, zmm5
        vmovaps zmm5, zmm6
        vmovaps zmm6, zmm7
        vmovaps zmm7, zmm8
        vmovaps zmm8, zmm9
        vmovaps zmm9, zmm10
        vmovaps zmm10, zmm11
        vmovaps zmm11, zmm12
        vmovaps zmm12, zmm13
        vmovaps zmm13, zmm14
        vmovaps zmm14, zmm15
        vmovaps zmm15, zmm1
        vmovups zmm1, zmmword ptr [rsp - 64]    # 64-byte Reload
        vfmadd231ps     zmm1, zmm0, dword ptr [rsi + 60]{1to16} # zmm1 = (zmm0 * mem) + zmm1
        vmovups zmmword ptr [rsp - 64], zmm1    # 64-byte Spill
        vmovups zmm0, zmmword ptr [rsp - 64]    # 64-byte Reload
        add     rdx, 64
        add     rsi, 64
        dec     ecx
        jne     .LBB0_1

The inefficient part is all these vmovaps zmm, zmm copying between registers. Also, that Spill explicitly called out in the comment.
This code should look like the vmovups zmm0, zmmword ptr [rdx], and then just the 16 vfmadd231ps instructions. But that would require allocating zmm asm operands beyond zmm15.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions