11
11
#include < executorch/backends/xnnpack/serialization/schema_generated.h>
12
12
#include < executorch/extension/threadpool/threadpool.h>
13
13
#include < executorch/runtime/executor/pte_data_map.h>
14
+ #include < string>
14
15
#include < unordered_map>
16
+ #include < vector>
15
17
16
18
#pragma clang diagnostic ignored "-Wmissing-prototypes"
17
19
#pragma clang diagnostic ignored "-Wglobal-constructors"
@@ -167,7 +169,8 @@ const uint8_t* getConstantDataPtr(
167
169
GraphPtr flatbuffer_graph,
168
170
const uint8_t * constant_data_ptr,
169
171
const NamedDataMap* named_data_map,
170
- std::vector<FreeableBuffer>& loaded_buffers_from_map) {
172
+ std::vector<FreeableBuffer>& freeable_buffers,
173
+ XNNWeightsCache* weights_cache) {
171
174
auto buffer_idx = tensor_value->constant_buffer_idx ();
172
175
if (buffer_idx) {
173
176
if (!constant_data_ptr) {
@@ -187,6 +190,15 @@ const uint8_t* getConstantDataPtr(
187
190
return constant_data_ptr + offset;
188
191
} else {
189
192
const std::string& data_name = constant_data_offset->named_key ()->str ();
193
+ #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
194
+ Result<const uint8_t *> data_ptr =
195
+ weights_cache->load_unpacked_data (data_name);
196
+ if (!data_ptr.ok ()) {
197
+ ET_LOG (Error, " Failed to load weights from cache" );
198
+ return nullptr ;
199
+ }
200
+ return data_ptr.get ();
201
+ #else
190
202
Result<FreeableBuffer> buffer =
191
203
named_data_map->get_data (data_name.c_str ());
192
204
if (!buffer.ok ()) {
@@ -198,8 +210,9 @@ const uint8_t* getConstantDataPtr(
198
210
}
199
211
const uint8_t * data_ptr =
200
212
static_cast <const uint8_t *>(buffer.get ().data ());
201
- loaded_buffers_from_map .push_back (std::move (buffer.get ()));
213
+ freeable_buffers .push_back (std::move (buffer.get ()));
202
214
return data_ptr;
215
+ #endif
203
216
}
204
217
}
205
218
}
@@ -222,7 +235,8 @@ Error defineTensor(
222
235
std::vector<uint32_t >& output_ids,
223
236
CompileAllocator& allocator,
224
237
const NamedDataMap* named_data_map,
225
- std::vector<FreeableBuffer>& loaded_buffers_from_map) {
238
+ std::vector<FreeableBuffer>& freeable_buffers,
239
+ XNNWeightsCache* weights_cache) {
226
240
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr ;
227
241
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr ;
228
242
@@ -264,7 +278,8 @@ Error defineTensor(
264
278
flatbuffer_graph,
265
279
constant_data_ptr,
266
280
named_data_map,
267
- loaded_buffers_from_map);
281
+ freeable_buffers,
282
+ weights_cache);
268
283
269
284
xnn_status status;
270
285
// The type we might have to convert to
@@ -1999,9 +2014,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
1999
2014
const void * buffer_pointer,
2000
2015
size_t num_bytes,
2001
2016
XNNExecutor* executor,
2002
- MemoryAllocator* runtime_allocator ,
2003
- const NamedDataMap* named_data_map ,
2004
- xnn_workspace_t workspace ) {
2017
+ XNNWeightsCache* weights_cache ,
2018
+ xnn_workspace_t workspace ,
2019
+ const NamedDataMap* named_data_map ) {
2005
2020
Result<XNNHeader> header = XNNHeader::Parse (buffer_pointer, num_bytes);
2006
2021
const uint8_t * flatbuffer_data = nullptr ;
2007
2022
const uint8_t * constant_data = nullptr ;
@@ -2065,11 +2080,14 @@ ET_NODISCARD Error XNNCompiler::compileModel(
2065
2080
// Invalid ids do not need to be remapped
2066
2081
remapped_ids.emplace (XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID);
2067
2082
2083
+ // If weight cache is not on we hold onto all the unpacked buffers
2084
+ // and we free them at the end
2085
+ std::vector<FreeableBuffer> unpacked_buffers;
2086
+
2068
2087
// External Ids for inputs and outputs
2069
2088
std::vector<uint32_t > input_ids;
2070
2089
std::vector<uint32_t > output_ids;
2071
2090
Error err = Error::Ok;
2072
- std::vector<FreeableBuffer> loaded_buffers_from_map;
2073
2091
for (auto value : *flatbuffer_graph->xvalues ()) {
2074
2092
err = defineTensor (
2075
2093
subgraph.get (),
@@ -2081,7 +2099,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
2081
2099
output_ids,
2082
2100
compile_allocator,
2083
2101
named_data_map,
2084
- loaded_buffers_from_map);
2102
+ unpacked_buffers,
2103
+ weights_cache);
2085
2104
2086
2105
if (err != Error::Ok) {
2087
2106
return err;
@@ -2103,20 +2122,34 @@ ET_NODISCARD Error XNNCompiler::compileModel(
2103
2122
2104
2123
xnn_runtime_t runtime_ptr = nullptr ;
2105
2124
2125
+ // XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache
2126
+ // just manages the unpacked weights until the runtime is created.
2127
+ #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2128
+ ET_CHECK_OR_RETURN_ERROR (
2129
+ unpacked_buffers.size () == 0 ,
2130
+ Internal,
2131
+ " Weight Cache is enabled, which means unpacked buffers should be owned by the cache" );
2132
+ xnn_weights_cache_t weights_cache_ptr =
2133
+ weights_cache->get_num_unpacked_data () > 0 ? weights_cache->get ()
2134
+ : nullptr ;
2135
+ #else
2136
+ xnn_weights_cache_t weights_cache_ptr = nullptr ;
2137
+ #endif
2138
+
2106
2139
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
2107
2140
ET_CHECK_OR_RETURN_ERROR (
2108
2141
workspace != nullptr , Internal, " Failed to initialize XNNPACK workspace" );
2109
2142
status = xnn_create_runtime_v4 (
2110
2143
subgraph.get (),
2111
- /* weight_cache= */ nullptr , // TODO - support weight cache
2144
+ weights_cache_ptr,
2112
2145
workspace,
2113
2146
::executorch::extension::threadpool::get_pthreadpool (),
2114
2147
runtime_flags,
2115
2148
&runtime_ptr);
2116
2149
#else
2117
2150
status = xnn_create_runtime_v3 (
2118
2151
subgraph.get (),
2119
- /* weight_cache= */ nullptr , // TODO - support weight cache
2152
+ weights_cache_ptr,
2120
2153
::executorch::extension::threadpool::get_pthreadpool (),
2121
2154
runtime_flags,
2122
2155
&runtime_ptr);
@@ -2128,10 +2161,25 @@ ET_NODISCARD Error XNNCompiler::compileModel(
2128
2161
" XNN Runtime creation failed with code: %s" ,
2129
2162
xnn_status_to_string (status));
2130
2163
2164
+ #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2165
+ auto packed_weights_names = weights_cache->finalize_for_runtime ();
2166
+ ET_CHECK_OR_RETURN_ERROR (
2167
+ packed_weights_names.ok (),
2168
+ Internal,
2169
+ " Failed to finalize weights cache after creating the xnn runtime" )
2170
+ #else
2171
+ for (auto & buffer : unpacked_buffers) {
2172
+ buffer.Free ();
2173
+ }
2174
+ Result<std::vector<std::string>> packed_weights_names =
2175
+ std::vector<std::string>();
2176
+ #endif
2177
+
2131
2178
err = executor->initialize ( // NOLINT: runtime_ptr is non-null
2132
2179
runtime_ptr,
2133
2180
std::move (input_ids),
2134
- std::move (output_ids));
2181
+ std::move (output_ids),
2182
+ std::move (packed_weights_names.get ()));
2135
2183
2136
2184
return err;
2137
2185
};
0 commit comments