Skip to content

Commit 304c567

Browse files
authored
[ExecuTorch][#10375] Add extension.BundledModule to Wrap extension.Module with Bundled Program Logic (#10744)
1 parent 1656854 commit 304c567

12 files changed

+519
-23
lines changed

devtools/bundled_program/bundled_program.cpp

+18-3
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,16 @@ ET_NODISCARD Error load_bundled_input(
260260
if (!method_test.ok()) {
261261
return method_test.error();
262262
}
263-
263+
auto test_cases = method_test.get()->test_cases();
264+
ET_CHECK_OR_RETURN_ERROR(
265+
testset_idx < test_cases->size(),
266+
InvalidArgument,
267+
"testset_idx %zu is out of range [0, %u]",
268+
testset_idx,
269+
test_cases->size());
264270
auto bundled_inputs =
265-
method_test.get()->test_cases()->Get(testset_idx)->inputs();
271+
test_cases->Get(static_cast<flatbuffers::uoffset_t>(testset_idx))
272+
->inputs();
266273

267274
for (size_t input_idx = 0; input_idx < method.inputs_size(); input_idx++) {
268275
auto bundled_input = bundled_inputs->GetMutableObject(input_idx);
@@ -359,8 +366,16 @@ ET_NODISCARD Error verify_method_outputs(
359366
return method_test.error();
360367
}
361368

369+
auto test_cases = method_test.get()->test_cases();
370+
ET_CHECK_OR_RETURN_ERROR(
371+
testset_idx < test_cases->size(),
372+
InvalidArgument,
373+
"testset_idx %zu is out of range [0, %u]",
374+
testset_idx,
375+
test_cases->size());
362376
auto bundled_expected_outputs =
363-
method_test.get()->test_cases()->Get(testset_idx)->expected_outputs();
377+
test_cases->Get(static_cast<flatbuffers::uoffset_t>(testset_idx))
378+
->expected_outputs();
364379

365380
if (bundled_expected_outputs->size() == 0) {
366381
// No bundled expected outputs, so we can't verify the method outputs.

devtools/bundled_program/schema/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def define_common_targets():
7474
visibility = [
7575
"//executorch/devtools/bundled_program/...",
7676
"//executorch/extension/pybindings/...",
77+
"//executorch/extension/module/...",
7778
],
7879
exported_headers = {
7980
OUTPUT_BUNDLED_HEADER: ":{}[{}]".format(BUNDLED_GEN_RULE_NAME, OUTPUT_BUNDLED_HEADER),

extension/module/bundled_module.cpp

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/module/bundled_module.h>
10+
11+
#include <executorch/devtools/bundled_program/bundled_program.h>
12+
#include <executorch/devtools/bundled_program/schema/bundled_program_schema_generated.h>
13+
#include <executorch/extension/data_loader/buffer_data_loader.h>
14+
#include <executorch/extension/data_loader/file_data_loader.h>
15+
16+
namespace executorch {
17+
namespace extension {
18+
19+
namespace {
20+
std::unique_ptr<BufferDataLoader> program_data_loader(
21+
const void* bundled_program_ptr) {
22+
auto bundled_program =
23+
bundled_program_flatbuffer::GetBundledProgram(bundled_program_ptr);
24+
// the program inside the bundled program
25+
auto program = bundled_program->program();
26+
return std::make_unique<BufferDataLoader>(program->data(), program->size());
27+
}
28+
} // namespace
29+
30+
BundledModule::BundledModule(
31+
const void* bundled_program_ptr,
32+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
33+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
34+
std::unique_ptr<runtime::EventTracer> event_tracer,
35+
std::unique_ptr<runtime::DataLoader> data_map_loader)
36+
: Module(
37+
program_data_loader(bundled_program_ptr),
38+
std::move(memory_allocator),
39+
std::move(temp_allocator),
40+
std::move(event_tracer),
41+
std::move(data_map_loader)),
42+
bundled_program_ptr_(bundled_program_ptr) {}
43+
44+
runtime::Result<std::unique_ptr<BundledModule>> BundledModule::from_file(
45+
const std::string& file_path,
46+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
47+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
48+
std::unique_ptr<runtime::EventTracer> event_tracer,
49+
std::unique_ptr<runtime::DataLoader> data_map_loader) {
50+
auto data_loader_result = FileDataLoader::from(file_path.c_str());
51+
if (!data_loader_result.ok()) {
52+
return data_loader_result.error();
53+
}
54+
55+
auto file_size_result = data_loader_result->size();
56+
if (!file_size_result.ok()) {
57+
return file_size_result.error();
58+
}
59+
60+
size_t file_size = file_size_result.get();
61+
auto file_data = std::make_unique<uint8_t[]>(file_size);
62+
auto buffer_result =
63+
data_loader_result->load_into(0, file_size, {}, file_data.get());
64+
if (buffer_result != runtime::Error::Ok) {
65+
return buffer_result;
66+
}
67+
68+
// Pass ownership of the data to BundledModule
69+
auto bm = std::make_unique<BundledModule>(
70+
file_data.release(),
71+
std::move(memory_allocator),
72+
std::move(temp_allocator),
73+
std::move(event_tracer),
74+
std::move(data_map_loader));
75+
76+
bm->is_loaded_from_file_ = true;
77+
78+
return bm;
79+
}
80+
81+
runtime::Result<std::vector<runtime::EValue>> BundledModule::execute(
82+
const std::string& method_name,
83+
const size_t testset_idx) {
84+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
85+
auto& method = methods_.at(method_name).method;
86+
87+
ET_CHECK_OK_OR_RETURN_ERROR(
88+
executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
89+
*method, bundled_program_ptr_, testset_idx));
90+
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
91+
92+
const auto outputs_size = method->outputs_size();
93+
std::vector<runtime::EValue> outputs(outputs_size);
94+
ET_CHECK_OK_OR_RETURN_ERROR(
95+
method->get_outputs(outputs.data(), outputs_size));
96+
97+
return outputs;
98+
}
99+
100+
runtime::Error BundledModule::verify_method_outputs(
101+
const std::string& method_name,
102+
const size_t testset_idx,
103+
double rtol,
104+
double atol) {
105+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
106+
auto& method = methods_.at(method_name).method;
107+
return executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs(
108+
*method, bundled_program_ptr_, testset_idx, rtol, atol);
109+
}
110+
111+
} // namespace extension
112+
} // namespace executorch

extension/module/bundled_module.h

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/module/module.h>
12+
13+
namespace executorch {
14+
namespace extension {
15+
16+
/**
17+
* A facade class for loading bundled programs and executing methods within
18+
* them.
19+
*/
20+
class BundledModule : public Module {
21+
public:
22+
/**
23+
* Constructs an instance with the bundled program buffer pointer.
24+
*
25+
* This constructor reads the program from bundled program buffer to load the
26+
* module with data loader. The bundled program pointer is preserved so that
27+
* the portion outside of program is accessible.
28+
*
29+
* @param[in] bundled_program_ptr A DataLoader used for loading program data.
30+
* @param[in] memory_allocator A MemoryAllocator used for memory management.
31+
* @param[in] temp_allocator A MemoryAllocator to use when allocating
32+
* temporary data during kernel or delegate execution.
33+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
34+
* @param[in] data_map_loader A DataLoader used for loading external weights.
35+
*/
36+
explicit BundledModule(
37+
const void* bundled_program_ptr,
38+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
39+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
40+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
41+
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
42+
43+
// Disallow copying
44+
BundledModule(const BundledModule&) = delete;
45+
BundledModule& operator=(const BundledModule&) = delete;
46+
// Disallow copying
47+
BundledModule(BundledModule&&) = delete;
48+
BundledModule& operator=(BundledModule&&) = delete;
49+
// Default destructor
50+
~BundledModule() {
51+
if (is_loaded_from_file_) {
52+
delete[] static_cast<const uint8_t*>(bundled_program_ptr_);
53+
}
54+
}
55+
56+
/**
57+
* Constructs an instance by loading a bundled program from a file with
58+
* specified memory locking behavior.
59+
*
60+
* @param[in] file_path The path to the ExecuTorch bundled program file to
61+
* load.
62+
* @param[in] memory_allocator A MemoryAllocator used for memory management.
63+
* @param[in] temp_allocator A MemoryAllocator to use when allocating
64+
* temporary data during kernel or delegate execution.
65+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
66+
* @param[in] data_map_loader A DataLoader used for loading external weights.
67+
*/
68+
ET_NODISCARD static runtime::Result<std::unique_ptr<BundledModule>> from_file(
69+
const std::string& file_path,
70+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
71+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
72+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
73+
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
74+
75+
using Module::execute;
76+
77+
/**
78+
* Execute a specific method with the input value at the given `testset_idx`
79+
* from the bundle to the method. Loads the program and method before
80+
* executing if needed.
81+
*
82+
* This function is a wrapper of `load_bundled_input` in `bundled_program`.
83+
*
84+
* @param[in] method_name The name of the method to execute.
85+
* @param[in] testset_idx The index of the input value to be passed to the
86+
* method.
87+
*
88+
* @returns Return Error::Ok on a successful load, or the error happens during
89+
* execution.
90+
*/
91+
ET_NODISCARD
92+
runtime::Result<std::vector<runtime::EValue>> execute(
93+
const std::string& method_name,
94+
const size_t testset_idx);
95+
96+
/**
97+
* Verify the output of a specific method with the expected output from the
98+
* program bundle at the given `testset_idx`.
99+
*
100+
* This function is a wrapper of `verify_method_outputs` in `bundled_program`.
101+
*
102+
* @param[in] method_name The name of the method to extract outputs from.
103+
* @param[in] testset_idx The index of expected output needs to be compared.
104+
* @param[in] rtol Relative tolerance used for data comparsion.
105+
* @param[in] atol Absolute tolerance used for data comparsion.
106+
*
107+
* @returns Return Error::Ok if two outputs match, or the error happens during
108+
* execution.
109+
*/
110+
ET_NODISCARD
111+
runtime::Error verify_method_outputs(
112+
const std::string& method_name,
113+
const size_t testset_idx,
114+
double rtol = 1e-5,
115+
double atol = 1e-8);
116+
117+
private:
118+
const void* bundled_program_ptr_;
119+
bool is_loaded_from_file_ = false;
120+
};
121+
122+
} // namespace extension
123+
} // namespace executorch

extension/module/module.cpp

-10
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,5 @@ runtime::Error Module::set_output(
302302
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
303303
}
304304

305-
ET_NODISCARD inline runtime::Result<Method*> Module::get_method(
306-
const std::string& method_name) {
307-
ET_CHECK_OR_RETURN_ERROR(
308-
methods_.count(method_name) > 0,
309-
InvalidArgument,
310-
"no such method in program: %s",
311-
method_name.c_str());
312-
return methods_[method_name].method.get();
313-
}
314-
315305
} // namespace extension
316306
} // namespace executorch

extension/module/module.h

-10
Original file line numberDiff line numberDiff line change
@@ -491,16 +491,6 @@ class Module {
491491
std::unique_ptr<NamedDataMap> data_map_;
492492

493493
protected:
494-
/**
495-
* Get a method by method name.
496-
*
497-
* @param[in] method_name The name of the method to get.
498-
*
499-
* @returns A Result object containing either a pointer to the requested
500-
* method or an error to indicate failure.
501-
*/
502-
ET_NODISCARD inline runtime::Result<Method*> get_method(
503-
const std::string& method_name);
504494
std::unordered_map<std::string, MethodHolder> methods_;
505495

506496
friend class ExecuTorchJni;

extension/module/targets.bzl

+22
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,25 @@ def define_common_targets():
3131
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
3232
],
3333
)
34+
35+
runtime.cxx_library(
36+
name = "bundled_module" + aten_suffix,
37+
srcs = [
38+
"bundled_module.cpp",
39+
],
40+
exported_headers = [
41+
"bundled_module.h",
42+
],
43+
visibility = [
44+
"@EXECUTORCH_CLIENTS",
45+
],
46+
deps = [
47+
"//executorch/extension/data_loader:buffer_data_loader",
48+
"//executorch/extension/data_loader:file_data_loader",
49+
"//executorch/devtools/bundled_program:runtime" + aten_suffix,
50+
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
51+
],
52+
exported_deps = [
53+
"//executorch/extension/module:module" + aten_suffix,
54+
],
55+
)

0 commit comments

Comments
 (0)