Skip to content

[Snippets][CPU] Support runtime offsets in ARM load/store emitters #31112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -12,10 +12,12 @@

#include "emitters/plugin/aarch64/jit_emitter.hpp"
#include "emitters/plugin/aarch64/jit_load_store_emitters.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/type.hpp"
#include "openvino/core/type/element_type.hpp"
#include "snippets/lowered/expression.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/op/broadcastload.hpp"
#include "snippets/op/load.hpp"
#include "snippets/op/store.hpp"
@@ -29,15 +31,64 @@ using jit_generator = dnnl::impl::cpu::aarch64::jit_generator;
using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t;
using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;

jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
jit_memory_emitter::jit_memory_emitter(jit_generator* h,
cpu_isa_t isa,
const ExpressionPtr& expr,
emitter_in_out_map in_out_type)
: jit_emitter(h, isa) {
in_out_type_ = in_out_type;

const auto n = expr->get_node();
src_prc = n->get_input_element_type(0);
dst_prc = n->get_output_element_type(0);

const auto& memory_access = std::dynamic_pointer_cast<ov::snippets::modifier::MemoryAccess>(expr->get_node());
if (in_out_type_ == emitter_in_out_map::gpr_to_vec) {
OV_CPU_JIT_EMITTER_ASSERT(memory_access->is_memory_access_input_port(0), "must be input port - memory access");
count = memory_access->get_input_count();
compiled_byte_offset = memory_access->get_input_offset();
buffer_cluster_id = get_parent_buffer_cluster_id(expr);
} else if (in_out_type_ == emitter_in_out_map::vec_to_gpr) {
OV_CPU_JIT_EMITTER_ASSERT(memory_access->is_memory_access_output_port(0),
"must be output port - memory access");
count = memory_access->get_output_count();
compiled_byte_offset = memory_access->get_output_offset();
buffer_cluster_id = get_consumer_buffer_cluster_id(expr);
} else {
OV_CPU_JIT_EMITTER_THROW("unsupported in_out_type");
}

if (ov::snippets::utils::is_dynamic_value(compiled_byte_offset)) {
is_offset_runtime = true;
// Compiled byte offset is zero to manually `add` runtime offset before operation and `sub` after to reset
// pointer in the register
compiled_byte_offset = 0;
OV_CPU_JIT_EMITTER_ASSERT(buffer_cluster_id != SIZE_MAX, "Incorrect buffer offset in call_args");
}
}

size_t jit_memory_emitter::get_parent_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr) {
OV_CPU_JIT_EMITTER_ASSERT(expr->get_input_count() == 1, "MemoryAccess must have one parent");
const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr();
if (const auto buffer = ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(parent_expr)) {
return buffer->get_cluster_id();
}
return SIZE_MAX;
}

size_t jit_memory_emitter::get_consumer_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr) {
OV_CPU_JIT_EMITTER_ASSERT(expr->get_output_count() == 1, "MemoryAccess must have one output");
const auto& consumers = expr->get_output_port_connector(0)->get_consumers();
for (const auto& consumer : consumers) {
if (const auto buffer = ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(consumer.get_expr())) {
return buffer->get_cluster_id();
}
}
return SIZE_MAX;
}

jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
: jit_memory_emitter(h, isa, expr) {
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) {
bool is_supported_precision =
one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
src_prc == dst_prc;
@@ -46,32 +97,78 @@ jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa
const auto load = ov::as_type_ptr<snippets::op::Load>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(load != nullptr, "Expects Load expression");
count = load->get_count();
byte_offset = load->get_offset();
in_out_type_ = emitter_in_out_map::gpr_to_vec;
load_emitter = std::make_unique<jit_load_emitter>(h, isa, src_prc, dst_prc, count, byte_offset);
load_emitter = std::make_unique<jit_load_emitter>(h, isa, src_prc, dst_prc, count, compiled_byte_offset);
}

void jit_load_memory_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in, out);
size_t jit_memory_emitter::get_aux_gprs_count() const {
// for runtime arguments
return is_offset_runtime ? 1 : 0;
}

std::vector<size_t> jit_memory_emitter::get_available_aux_gprs() const {
OV_CPU_JIT_EMITTER_ASSERT(IMPLICATION(is_offset_runtime, !aux_gpr_idxs.empty()),
"If offset is dynamic, memory emitter need to have one aux gpr at least!");
auto available_aux_gprs = aux_gpr_idxs;
if (is_offset_runtime) {
available_aux_gprs.pop_back();
}
return available_aux_gprs;
}

void jit_memory_emitter::emit_code_impl(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs,
const std::vector<size_t>& pool_gpr_idxs) const {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);

auto reg_runtime_params = dnnl::impl::cpu::aarch64::abi_param1;
XReg aux_gpr = is_offset_runtime ? XReg(static_cast<int>(aux_gpr_idxs.back())) : XReg(0);

XReg data_reg(0);
if (in_out_type_ == emitter_in_out_map::gpr_to_vec) {
data_reg = XReg(in_idxs[0]);
} else if (in_out_type_ == emitter_in_out_map::vec_to_gpr) {
data_reg = XReg(out_idxs[0]);
} else {
OV_CPU_JIT_EMITTER_THROW("Doesn't support isa ", host_isa_);
OV_CPU_JIT_EMITTER_THROW("unsupported in_out_type");
}

if (is_offset_runtime) {
// load the runtime offset from args.buffer_offsets[buffer_cluster_id]
h->ldr(aux_gpr,
ptr(reg_runtime_params,
static_cast<int32_t>(GET_OFF(buffer_offsets) + buffer_cluster_id * sizeof(size_t))));
// bump the pointer
// TODO: Consider ISA limitations on offset size - large offsets may require multiple operations
// for both h->add and h->sub instructions to handle cases where offset exceeds immediate
// value limits
h->add(data_reg, data_reg, aux_gpr);
}

emit_impl(in_idxs, out_idxs);

if (is_offset_runtime) {
// subtract back so we leave the pointer unchanged for the caller
// TODO: Consider ISA limitations on offset size - large offsets may require multiple operations
// for both h->add and h->sub instructions to handle cases where offset exceeds immediate
// value limits
h->sub(data_reg, data_reg, aux_gpr);
}

emitter_postamble();
}

template <cpu_isa_t isa>
void jit_load_memory_emitter::emit_isa(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
void jit_load_memory_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(load_emitter != nullptr, "Load CPU emitter isn't initialized!");

load_emitter->emit_code(in, out, aux_vec_idxs, aux_gpr_idxs);
load_emitter->emit_code(in, out, aux_vec_idxs, get_available_aux_gprs());
}

void jit_load_memory_emitter::emit_data() const {
load_emitter->emit_data();
}

jit_load_broadcast_emitter::jit_load_broadcast_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
: jit_memory_emitter(h, isa, expr) {
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) {
OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc,
"Only support equal input and output types but gets ",
src_prc.get_type_name(),
@@ -81,8 +178,6 @@ jit_load_broadcast_emitter::jit_load_broadcast_emitter(jit_generator* h, cpu_isa

const auto broadcast_load = ov::as_type_ptr<snippets::op::BroadcastLoad>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(broadcast_load != nullptr, "Expects BroadcastLoad expression");
byte_offset = broadcast_load->get_offset();
in_out_type_ = emitter_in_out_map::gpr_to_vec;
}

void jit_load_broadcast_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
@@ -99,11 +194,11 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector<size_t>& in, const s
auto src = XReg(in[0]);
auto dst = TReg(out[0]);

h->uni_ld1rw(dst.s, src, byte_offset);
h->uni_ld1rw(dst.s, src, compiled_byte_offset);
}

jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
: jit_memory_emitter(h, isa, expr) {
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::vec_to_gpr) {
bool is_supported_precision =
one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
src_prc == dst_prc;
@@ -112,24 +207,12 @@ jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t i
const auto store = ov::as_type_ptr<snippets::op::Store>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(store != nullptr, "Expects Store expression");
count = store->get_count();
byte_offset = store->get_offset();
in_out_type_ = emitter_in_out_map::vec_to_gpr;
store_emitter = std::make_unique<jit_store_emitter>(h, isa, src_prc, dst_prc, count, byte_offset);
store_emitter = std::make_unique<jit_store_emitter>(h, isa, src_prc, dst_prc, count, compiled_byte_offset);
}

void jit_store_memory_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in, out);
} else {
OV_CPU_JIT_EMITTER_THROW("Doesn't support isa ", host_isa_);
}
}

template <cpu_isa_t isa>
void jit_store_memory_emitter::emit_isa(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(store_emitter != nullptr, "Store CPU emitter isn't initialized!");

store_emitter->emit_code(in, out, aux_vec_idxs, aux_gpr_idxs);
store_emitter->emit_code(in, out, aux_vec_idxs, get_available_aux_gprs());
}

void jit_store_memory_emitter::emit_data() const {
Original file line number Diff line number Diff line change
@@ -21,14 +21,29 @@ class jit_memory_emitter : public jit_emitter {
public:
jit_memory_emitter(dnnl::impl::cpu::aarch64::jit_generator* h,
dnnl::impl::cpu::aarch64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr);
const ov::snippets::lowered::ExpressionPtr& expr,
emitter_in_out_map in_out_type);

size_t get_aux_gprs_count() const override;

void emit_code_impl(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs = {},
const std::vector<size_t>& pool_gpr_idxs = {}) const override;

std::vector<size_t> get_available_aux_gprs() const;

protected:
static size_t get_parent_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr);
static size_t get_consumer_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr);

ov::element::Type src_prc;
ov::element::Type dst_prc;

size_t count = 0;
size_t byte_offset = 0;
size_t compiled_byte_offset = 0;
size_t buffer_cluster_id = 0;
bool is_offset_runtime = false;
};

class jit_load_memory_emitter : public jit_memory_emitter {
@@ -43,9 +58,6 @@ class jit_load_memory_emitter : public jit_memory_emitter {

private:
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in, const std::vector<size_t>& out) const;
void emit_data() const override;

private:
@@ -81,9 +93,6 @@ class jit_store_memory_emitter : public jit_memory_emitter {

private:
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in, const std::vector<size_t>& out) const;
void emit_data() const override;

private:
Original file line number Diff line number Diff line change
@@ -20,8 +20,8 @@ namespace ov::intel_cpu {
# define SNIPPETS_MAX_DATA_PTR_COUNT 11
#endif

#define GET_OFF(field) offsetof(jit_snippets_call_args, field)
#define GET_OFF_LOOP_ARGS(field) offsetof(jit_snippets_call_args::loop_args_t, field)
#define GET_OFF(field) offsetof(ov::intel_cpu::jit_snippets_call_args, field)
#define GET_OFF_LOOP_ARGS(field) offsetof(ov::intel_cpu::jit_snippets_call_args::loop_args_t, field)

struct amx_tile_config_t {
dnnl_dim_t M = 0;
Original file line number Diff line number Diff line change
@@ -12,37 +12,37 @@ namespace snippets {

namespace {

const std::vector<InputShape> inputShape = {
{{}, {{1, 16}}},
{{}, {{1, 32}}},
{{}, {{1, 1}}},
{{}, {{1, 9}}},
{{}, {{1, 17}}},
{{}, {{1, 19}}},
{{}, {{1, 49}}},
{{}, {{1, 50}}},
{{}, {{5, 16}}},
{{}, {{5, 32}}},
{{}, {{5, 1}}},
{{}, {{5, 9}}},
{{}, {{5, 17}}},
{{}, {{5, 19}}},
{{}, {{5, 49}}},
{{}, {{5, 50}}},
{{}, {{1, 3, 128, 128}}},
{{}, {{1, 3, 128, 129}}},
{{}, {{1, 3, 128, 130}}},
{{}, {{1, 3, 128, 1}}},
{{}, {{1, 3, 128, 9}}},
{{}, {{1, 3, 128, 16}}},
{{}, {{1, 3, 128, 17}}},
{{}, {{1, 3, 128, 20}}},
const std::vector<std::vector<InputShape>> inputShape = {
{{{}, {{1, 16}}}},
{{{}, {{1, 32}}}},
{{{}, {{1, 1}}}},
{{{}, {{1, 9}}}},
{{{}, {{1, 17}}}},
{{{}, {{1, 19}}}},
{{{}, {{1, 49}}}},
{{{}, {{1, 50}}}},
{{{}, {{5, 16}}}},
{{{}, {{5, 32}}}},
{{{}, {{5, 1}}}},
{{{}, {{5, 9}}}},
{{{}, {{5, 17}}}},
{{{}, {{5, 19}}}},
{{{}, {{5, 49}}}},
{{{}, {{5, 50}}}},
{{{}, {{1, 3, 128, 128}}}},
{{{}, {{1, 3, 128, 129}}}},
{{{}, {{1, 3, 128, 130}}}},
{{{}, {{1, 3, 128, 1}}}},
{{{}, {{1, 3, 128, 9}}}},
{{{}, {{1, 3, 128, 16}}}},
{{{}, {{1, 3, 128, 17}}}},
{{{}, {{1, 3, 128, 20}}}},
// DS
{{-1, -1}, {{1, 16}, {1, 32}, {1, 1}, {1, 9}, {1, 17}, {1, 19}, {1, 49}, {1, 50}, {5, 16}, {1, 16}, {1, 9}}},
{{-1, -1, -1, -1}, {{1, 3, 128, 128}, {1, 3, 128, 129}, {1, 3, 128, 130}, {1, 3, 128, 1}, {1, 3, 128, 16}, {1, 3, 128, 1}}},
{{-1, -1, -1, 128}, {{1, 3, 128, 128}, {1, 3, 128, 128}, {1, 3, 64, 128}, {1, 3, 32, 128}, {1, 3, 64, 128}, {1, 3, 32, 128}}},
{{-1, -1, -1, 130}, {{1, 3, 8, 130}, {1, 3, 18, 130}, {1, 3, 8, 130}, {1, 3, 32, 130}, {1, 3, 18, 130}, {1, 3, 32, 130}}},
{{-1, -1, 128, -1}, {{1, 3, 128, 128}, {1, 3, 128, 129}, {1, 3, 128, 130}, {1, 3, 128, 1}, {1, 3, 128, 16}, {1, 3, 128, 1}}},
{{{-1, -1}, {{1, 16}, {1, 32}, {1, 1}, {1, 9}, {1, 17}, {1, 19}, {1, 49}, {1, 50}, {5, 16}, {1, 16}, {1, 9}}}},
{{{-1, -1, -1, -1}, {{1, 3, 128, 128}, {1, 3, 128, 129}, {1, 3, 128, 130}, {1, 3, 128, 1}, {1, 3, 128, 16}, {1, 3, 128, 1}}}},
{{{-1, -1, -1, 128}, {{1, 3, 128, 128}, {1, 3, 128, 128}, {1, 3, 64, 128}, {1, 3, 32, 128}, {1, 3, 64, 128}, {1, 3, 32, 128}}}},
{{{-1, -1, -1, 130}, {{1, 3, 8, 130}, {1, 3, 18, 130}, {1, 3, 8, 130}, {1, 3, 32, 130}, {1, 3, 18, 130}, {1, 3, 32, 130}}}},
{{{-1, -1, 128, -1}, {{1, 3, 128, 128}, {1, 3, 128, 129}, {1, 3, 128, 130}, {1, 3, 128, 1}, {1, 3, 128, 16}, {1, 3, 128, 1}}}},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Softmax, Softmax,
@@ -54,7 +54,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Softmax, Softmax,
::testing::Values(ov::test::utils::DEVICE_CPU)),
Softmax::getTestCaseName);

const std::vector<std::pair<InputShape, InputShape>> inputShapesPair = {
const std::vector<std::vector<InputShape>> inputShapesPair = {
{{{}, {{1, 5, 16, 35}}}, {{}, {{1, 5, 16, 35}}}},
{{{}, {{1, 5, 16, 1}}}, {{}, {{1, 5, 16, 35}}}},
{{{}, {{1, 5, 16, 35}}}, {{}, {{1, 5, 1, 1}}}},
@@ -85,6 +85,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_AddSoftmax, AddSoftmax,
::testing::Values(ov::test::utils::DEVICE_CPU)),
AddSoftmax::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SoftmaxSum, SoftmaxSum,
::testing::Combine(
::testing::ValuesIn(inputShapesPair),
::testing::Values(-1),
::testing::Values(1),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU)),
SoftmaxSum::getTestCaseName);
} // namespace
} // namespace snippets
} // namespace test
Loading