Skip to content

Commit 2a903f9

Browse files
pytorchbotmcr229
andauthored
[XNNPACK][Weights Cache] Enable in XNNPACK (#9297)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9155 by @mcr229 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/mcr229/11/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/11/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/mcr229/10/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/11/orig @diff-train-skip-merge --------- Co-authored-by: Max Ren <[email protected]>
1 parent d08d938 commit 2a903f9

File tree

8 files changed

+131
-31
lines changed

8 files changed

+131
-31
lines changed

backends/xnnpack/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ option(EXECUTORCH_XNNPACK_SHARED_WORKSPACE
3737
# Keeping this OFF by default due to regressions in decode and model load with
3838
# kleidi kernels
3939
option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI "Enable Arm Kleidi kernels" OFF)
40+
41+
# Turning this on cache weights between partitions and methods. If weights
42+
# are shared across methods/partitions then this can reduce load time and
43+
# memory usage
44+
45+
# Keeping this off maintains existing behavior. Turning this on serializes
46+
# execution and initialization of delegates, to be revisited
47+
option(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE
48+
"Enable weights cache to cache and manage all packed weights" OFF)
49+
50+
if(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE)
51+
add_definitions(-DENABLE_XNNPACK_WEIGHTS_CACHE)
52+
endif()
4053
if(EXECUTORCH_XNNPACK_SHARED_WORKSPACE)
4154
add_definitions(-DENABLE_XNNPACK_SHARED_WORKSPACE)
4255
endif()

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
#include <executorch/backends/xnnpack/serialization/schema_generated.h>
1212
#include <executorch/extension/threadpool/threadpool.h>
1313
#include <executorch/runtime/executor/pte_data_map.h>
14+
#include <string>
1415
#include <unordered_map>
16+
#include <vector>
1517

1618
#pragma clang diagnostic ignored "-Wmissing-prototypes"
1719
#pragma clang diagnostic ignored "-Wglobal-constructors"
@@ -167,7 +169,8 @@ const uint8_t* getConstantDataPtr(
167169
GraphPtr flatbuffer_graph,
168170
const uint8_t* constant_data_ptr,
169171
const NamedDataMap* named_data_map,
170-
std::vector<FreeableBuffer>& loaded_buffers_from_map) {
172+
std::vector<FreeableBuffer>& freeable_buffers,
173+
XNNWeightsCache* weights_cache) {
171174
auto buffer_idx = tensor_value->constant_buffer_idx();
172175
if (buffer_idx) {
173176
if (!constant_data_ptr) {
@@ -187,6 +190,15 @@ const uint8_t* getConstantDataPtr(
187190
return constant_data_ptr + offset;
188191
} else {
189192
const std::string& data_name = constant_data_offset->named_key()->str();
193+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
194+
Result<const uint8_t*> data_ptr =
195+
weights_cache->load_unpacked_data(data_name);
196+
if (!data_ptr.ok()) {
197+
ET_LOG(Error, "Failed to load weights from cache");
198+
return nullptr;
199+
}
200+
return data_ptr.get();
201+
#else
190202
Result<FreeableBuffer> buffer =
191203
named_data_map->get_data(data_name.c_str());
192204
if (!buffer.ok()) {
@@ -198,8 +210,9 @@ const uint8_t* getConstantDataPtr(
198210
}
199211
const uint8_t* data_ptr =
200212
static_cast<const uint8_t*>(buffer.get().data());
201-
loaded_buffers_from_map.push_back(std::move(buffer.get()));
213+
freeable_buffers.push_back(std::move(buffer.get()));
202214
return data_ptr;
215+
#endif
203216
}
204217
}
205218
}
@@ -222,7 +235,8 @@ Error defineTensor(
222235
std::vector<uint32_t>& output_ids,
223236
CompileAllocator& allocator,
224237
const NamedDataMap* named_data_map,
225-
std::vector<FreeableBuffer>& loaded_buffers_from_map) {
238+
std::vector<FreeableBuffer>& freeable_buffers,
239+
XNNWeightsCache* weights_cache) {
226240
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
227241
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
228242

@@ -264,7 +278,8 @@ Error defineTensor(
264278
flatbuffer_graph,
265279
constant_data_ptr,
266280
named_data_map,
267-
loaded_buffers_from_map);
281+
freeable_buffers,
282+
weights_cache);
268283

269284
xnn_status status;
270285
// The type we might have to convert to
@@ -1999,9 +2014,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19992014
const void* buffer_pointer,
20002015
size_t num_bytes,
20012016
XNNExecutor* executor,
2002-
MemoryAllocator* runtime_allocator,
2003-
const NamedDataMap* named_data_map,
2004-
xnn_workspace_t workspace) {
2017+
XNNWeightsCache* weights_cache,
2018+
xnn_workspace_t workspace,
2019+
const NamedDataMap* named_data_map) {
20052020
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
20062021
const uint8_t* flatbuffer_data = nullptr;
20072022
const uint8_t* constant_data = nullptr;
@@ -2065,11 +2080,14 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20652080
// Invalid ids do not need to be remapped
20662081
remapped_ids.emplace(XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID);
20672082

2083+
// If weight cache is not on we hold onto all the unpacked buffers
2084+
// and we free them at the end
2085+
std::vector<FreeableBuffer> unpacked_buffers;
2086+
20682087
// External Ids for inputs and outputs
20692088
std::vector<uint32_t> input_ids;
20702089
std::vector<uint32_t> output_ids;
20712090
Error err = Error::Ok;
2072-
std::vector<FreeableBuffer> loaded_buffers_from_map;
20732091
for (auto value : *flatbuffer_graph->xvalues()) {
20742092
err = defineTensor(
20752093
subgraph.get(),
@@ -2081,7 +2099,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20812099
output_ids,
20822100
compile_allocator,
20832101
named_data_map,
2084-
loaded_buffers_from_map);
2102+
unpacked_buffers,
2103+
weights_cache);
20852104

20862105
if (err != Error::Ok) {
20872106
return err;
@@ -2103,20 +2122,34 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21032122

21042123
xnn_runtime_t runtime_ptr = nullptr;
21052124

2125+
// XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache
2126+
// just manages the unpacked weights until the runtime is created.
2127+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2128+
ET_CHECK_OR_RETURN_ERROR(
2129+
unpacked_buffers.size() == 0,
2130+
Internal,
2131+
"Weight Cache is enabled, which means unpacked buffers should be owned by the cache");
2132+
xnn_weights_cache_t weights_cache_ptr =
2133+
weights_cache->get_num_unpacked_data() > 0 ? weights_cache->get()
2134+
: nullptr;
2135+
#else
2136+
xnn_weights_cache_t weights_cache_ptr = nullptr;
2137+
#endif
2138+
21062139
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
21072140
ET_CHECK_OR_RETURN_ERROR(
21082141
workspace != nullptr, Internal, "Failed to initialize XNNPACK workspace");
21092142
status = xnn_create_runtime_v4(
21102143
subgraph.get(),
2111-
/*weight_cache=*/nullptr, // TODO - support weight cache
2144+
weights_cache_ptr,
21122145
workspace,
21132146
::executorch::extension::threadpool::get_pthreadpool(),
21142147
runtime_flags,
21152148
&runtime_ptr);
21162149
#else
21172150
status = xnn_create_runtime_v3(
21182151
subgraph.get(),
2119-
/*weight_cache=*/nullptr, // TODO - support weight cache
2152+
weights_cache_ptr,
21202153
::executorch::extension::threadpool::get_pthreadpool(),
21212154
runtime_flags,
21222155
&runtime_ptr);
@@ -2128,10 +2161,25 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21282161
"XNN Runtime creation failed with code: %s",
21292162
xnn_status_to_string(status));
21302163

2164+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2165+
auto packed_weights_names = weights_cache->finalize_for_runtime();
2166+
ET_CHECK_OR_RETURN_ERROR(
2167+
packed_weights_names.ok(),
2168+
Internal,
2169+
"Failed to finalize weights cache after creating the xnn runtime")
2170+
#else
2171+
for (auto& buffer : unpacked_buffers) {
2172+
buffer.Free();
2173+
}
2174+
Result<std::vector<std::string>> packed_weights_names =
2175+
std::vector<std::string>();
2176+
#endif
2177+
21312178
err = executor->initialize( // NOLINT: runtime_ptr is non-null
21322179
runtime_ptr,
21332180
std::move(input_ids),
2134-
std::move(output_ids));
2181+
std::move(output_ids),
2182+
std::move(packed_weights_names.get()));
21352183

21362184
return err;
21372185
};

backends/xnnpack/runtime/XNNCompiler.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
#pragma once
1010

1111
#include <executorch/backends/xnnpack/runtime/XNNExecutor.h>
12+
#include <executorch/backends/xnnpack/runtime/XNNWeightsCache.h>
1213
#include <executorch/runtime/platform/compiler.h>
13-
1414
#include <xnnpack.h>
15-
#include <memory>
16-
#include <vector>
1715

1816
namespace executorch {
1917
namespace backends {
@@ -29,9 +27,9 @@ class XNNCompiler {
2927
const void* buffer_pointer,
3028
size_t num_bytes,
3129
XNNExecutor* executor,
32-
executorch::runtime::MemoryAllocator* runtime_allocator,
33-
const executorch::runtime::NamedDataMap* named_data_map,
34-
xnn_workspace_t workspace);
30+
XNNWeightsCache* weights_cache,
31+
xnn_workspace_t workspace,
32+
const NamedDataMap* named_data_map);
3533
};
3634

3735
} // namespace delegate

backends/xnnpack/runtime/XNNExecutor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ using executorch::runtime::kTensorDimensionLimit;
3030
ET_NODISCARD Error XNNExecutor::initialize(
3131
xnn_runtime_t runtime,
3232
std::vector<uint32_t>&& input_ids,
33-
std::vector<uint32_t>&& output_ids) {
33+
std::vector<uint32_t>&& output_ids,
34+
std::vector<std::string>&& packed_data_names) {
3435
runtime_ = std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>(
3536
runtime, xnn_delete_runtime);
3637

@@ -51,6 +52,7 @@ ET_NODISCARD Error XNNExecutor::initialize(
5152
std::sort(output_ids_.begin(), output_ids_.end());
5253

5354
externals_.resize(input_ids_.size() + output_ids_.size());
55+
packed_data_names_ = std::move(packed_data_names);
5456

5557
return Error::Ok;
5658
}

backends/xnnpack/runtime/XNNExecutor.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class XNNExecutor {
3434
std::vector<uint32_t> input_ids_;
3535
std::vector<uint32_t> output_ids_;
3636
std::vector<xnn_external_value> externals_;
37+
std::vector<std::string> packed_data_names_;
3738

3839
public:
3940
XNNExecutor() = default;
@@ -46,6 +47,10 @@ class XNNExecutor {
4647
return output_ids_.size();
4748
}
4849

50+
inline std::vector<std::string> get_packed_data_names() {
51+
return packed_data_names_;
52+
}
53+
4954
/**
5055
* Initialize the XNNExecutor with a given runtime and input/output ids.
5156
* The input/output ids are expected to be sorted in order of their
@@ -54,7 +59,8 @@ class XNNExecutor {
5459
ET_NODISCARD executorch::runtime::Error initialize(
5560
xnn_runtime_t runtime,
5661
std::vector<uint32_t>&& input_ids,
57-
std::vector<uint32_t>&& output_ids);
62+
std::vector<uint32_t>&& output_ids,
63+
std::vector<std::string>&& packed_data_names);
5864

5965
/**
6066
* Prepares the arguments for runtime graph execution.

backends/xnnpack/runtime/XNNPACKBackend.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/backends/xnnpack/runtime/XNNCompiler.h>
10+
#include <executorch/backends/xnnpack/runtime/XNNWeightsCache.h>
1011
#include <executorch/runtime/backend/interface.h>
1112
#include <executorch/runtime/core/error.h>
1213
#include <executorch/runtime/core/evalue.h>
@@ -20,6 +21,7 @@
2021
namespace executorch {
2122
namespace backends {
2223

24+
using executorch::backends::xnnpack::delegate::XNNWeightsCache;
2325
using executorch::runtime::ArrayRef;
2426
using executorch::runtime::Backend;
2527
using executorch::runtime::BackendExecutionContext;
@@ -81,13 +83,18 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
8183
}
8284

8385
const NamedDataMap* named_data_map = context.get_named_data_map();
84-
85-
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
86-
// This is needed to serialize access to xnn_create_runtime which is not
8786
// thread safe. This can heppen when multiple threads call init() on
8887
// the same backend instance.
88+
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
8989
const std::lock_guard<std::mutex> lock(workspace_mutex_);
9090
#endif
91+
92+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
93+
const std::lock_guard<std::mutex> lock_weight_cache(weights_cache_mutex_);
94+
weights_cache_->initialize_for_runtime(
95+
context.get_runtime_allocator(), named_data_map);
96+
#endif
97+
9198
// Executor has been allocated but not constructed, ensure that runtime_ is
9299
// nullptr by constructing it in place here. NOTE: Since we use placement
93100
// new and since this type is not trivially destructible, we must call the
@@ -97,9 +104,9 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
97104
processed->data(),
98105
processed->size(),
99106
executor,
100-
context.get_runtime_allocator(),
101-
named_data_map,
102-
workspace_.get());
107+
weights_cache_.get(),
108+
workspace_.get(),
109+
named_data_map);
103110
// This backend does not need its processed data after compiling the model.
104111
processed->Free();
105112

@@ -125,6 +132,10 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
125132
const std::lock_guard<std::mutex> lock(workspace_mutex_);
126133
#endif
127134

135+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
136+
const std::lock_guard<std::mutex> lock_weights_cache(weights_cache_mutex_);
137+
#endif
138+
128139
// Prepare Inputs/Outputs and Propagate Input Shapes
129140
Error err = executor->prepare_args(args);
130141
if (err != Error::Ok) {
@@ -145,16 +156,24 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
145156

146157
void destroy(DelegateHandle* handle) const override {
147158
if (handle != nullptr) {
148-
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
149159
// This is needed to serialize access to xnn_delete_runtime which is not
150160
// thread safe. This can heppen when multiple threads call destroy() on
151161
// the same backend instance.
162+
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
152163
const std::lock_guard<std::mutex> lock(workspace_mutex_);
153164
#endif
165+
154166
auto executor = static_cast<xnnpack::delegate::XNNExecutor*>(handle);
167+
155168
#ifdef ENABLE_XNNPACK_PROFILING
156169
executor->print_avg_op_timings();
157170
#endif
171+
172+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
173+
const std::lock_guard<std::mutex> lock_weights_cache(
174+
weights_cache_mutex_);
175+
weights_cache_->delete_packed_data(executor->get_packed_data_names());
176+
#endif
158177
// XNNExecutor is not trivially destructible. Since this was constructed
159178
// manually in init(), we must destroy it manually here.
160179
executor->~XNNExecutor();
@@ -167,6 +186,15 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
167186
std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> workspace_{
168187
nullptr,
169188
&xnn_release_workspace};
189+
190+
// Weights cache is global to all delegate instances.
191+
mutable std::mutex weights_cache_mutex_;
192+
std::unique_ptr<XNNWeightsCache> weights_cache_ =
193+
std::make_unique<XNNWeightsCache>();
194+
195+
// Lock Hiearchy for Mutexes:
196+
// workspace_mutex_
197+
// weights_cache_mutex_
170198
};
171199

172200
namespace {

backends/xnnpack/targets.bzl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ def _get_preprocessor_flags():
66
Disable if someone explictly specified a config option,
77
else Enable otherwise
88
"""
9-
if native.read_config("executorch", "xnnpack_workspace_sharing", "0") == "0":
10-
return []
9+
preprocessor_flags = []
10+
if native.read_config("executorch", "xnnpack_workspace_sharing", "0") != "0":
11+
preprocessor_flags.append("-DENABLE_XNNPACK_SHARED_WORKSPACE")
12+
13+
if native.read_config("executorch", "xnnpack_weights_cache", "0") != "0":
14+
preprocessor_flags.append("-DENABLE_XNNPACK_WEIGHTS_CACHE")
1115

1216
# Enable if not disabled through config
13-
return ["-DENABLE_XNNPACK_SHARED_WORKSPACE"]
17+
return preprocessor_flags
1418

1519
def define_common_targets():
1620
runtime.cxx_library(

backends/xnnpack/test/runtime/test_xnnexecutor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ TEST(XNNExecutorTest, ArgumentWithTooManyDimensions) {
7474
},
7575
{
7676
1,
77-
}),
77+
},
78+
{}),
7879
Error::Ok);
7980
TensorFactory<executorch::aten::ScalarType::Int> tf;
8081
auto input_tensor = tf.make({1, 1, 1, 1, 1, 1, 1, 1, 1}, {42});

0 commit comments

Comments
 (0)