From ed41d219c76ba75e37d17c64d2f7441b3419e726 Mon Sep 17 00:00:00 2001 From: lucylq Date: Tue, 4 Feb 2025 16:27:00 -0800 Subject: [PATCH] [executorch][runtime] Add NamedDataMap to method load Add NamedDataMap as an arg to: - Method - load_method - parseTensor Use NamedDataMap to resolve external tensors in parseTensor. Test that the PTE + PTD file run well inside method_test. Differential Revision: [D67127327](https://our.internmc.facebook.com/intern/diff/D67127327/) [ghstack-poisoned] --- extension/flat_tensor/test/targets.bzl | 4 +- runtime/executor/method.cpp | 10 ++- runtime/executor/method.h | 10 ++- runtime/executor/program.cpp | 6 +- runtime/executor/program.h | 3 +- runtime/executor/targets.bzl | 1 + runtime/executor/tensor_parser.h | 6 +- runtime/executor/tensor_parser_aten.cpp | 10 ++- runtime/executor/tensor_parser_exec_aten.cpp | 77 +++++++++++++++++++- runtime/executor/tensor_parser_portable.cpp | 7 +- runtime/executor/test/method_test.cpp | 33 +++++++++ runtime/executor/test/targets.bzl | 3 + test/models/targets.bzl | 2 +- 13 files changed, 152 insertions(+), 20 deletions(-) diff --git a/extension/flat_tensor/test/targets.bzl b/extension/flat_tensor/test/targets.bzl index ffa9b62c9f9..04e2e2b531e 100644 --- a/extension/flat_tensor/test/targets.bzl +++ b/extension/flat_tensor/test/targets.bzl @@ -35,8 +35,8 @@ def define_common_targets(is_fbcode=False): # The tests use this var to find the program file to load. This uses # an fbcode target path because the authoring/export tools # intentionally don't work in xplat (since they're host-only tools). - "ET_MODULE_LINEAR_PROGRAM": "$(location fbcode//executorch/test/models:exported_programs_with_data_separated[ModuleLinear.pte])", - "ET_MODULE_LINEAR_DATA": "$(location fbcode//executorch/test/models:exported_programs_with_data_separated[ModuleLinear.ptd])", + "ET_MODULE_LINEAR_PROGRAM": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])", + "ET_MODULE_LINEAR_DATA": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", } runtime.cxx_test( diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index c6fe98abcc8..4cf3177a314 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -414,7 +415,8 @@ Error Method::parse_values() { auto t = deserialization::parseTensor( program_, memory_manager_, - static_cast(val)); + static_cast(val), + named_data_map_); if (!t.ok()) { ET_LOG( Error, @@ -607,7 +609,8 @@ Result Method::load( executorch_flatbuffer::ExecutionPlan* s_plan, const Program* program, MemoryManager* memory_manager, - EventTracer* event_tracer) { + EventTracer* event_tracer, + const NamedDataMap* named_data_map) { MemoryAllocator* temp_allocator = memory_manager->temp_allocator(); if (temp_allocator == nullptr) { PlatformMemoryAllocator* platform_allocator = @@ -619,7 +622,8 @@ Result Method::load( new (platform_allocator) PlatformMemoryAllocator(); temp_allocator = platform_allocator; } - Method method(program, memory_manager, event_tracer, temp_allocator); + Method method( + program, memory_manager, event_tracer, temp_allocator, named_data_map); Error err = method.init(s_plan); if (err != Error::Ok) { diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 8b3330fb5a0..7ce0604724a 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -54,6 +55,7 @@ class Method final { program_(rhs.program_), memory_manager_(rhs.memory_manager_), temp_allocator_(rhs.temp_allocator_), + named_data_map_(rhs.named_data_map_), serialization_plan_(rhs.serialization_plan_), event_tracer_(rhs.event_tracer_), n_value_(rhs.n_value_), @@ -271,11 +273,13 @@ class Method final { const Program* program, MemoryManager* memory_manager, EventTracer* event_tracer, - MemoryAllocator* temp_allocator) + MemoryAllocator* temp_allocator, + const NamedDataMap* named_data_map) : step_state_(), program_(program), memory_manager_(memory_manager), temp_allocator_(temp_allocator), + named_data_map_(named_data_map), serialization_plan_(nullptr), event_tracer_(event_tracer), n_value_(0), @@ -291,7 +295,8 @@ class Method final { executorch_flatbuffer::ExecutionPlan* s_plan, const Program* program, MemoryManager* memory_manager, - EventTracer* event_tracer); + EventTracer* event_tracer, + const NamedDataMap* named_data_map); /** * Initialize the method from its serialized representation. @@ -317,6 +322,7 @@ class Method final { const Program* program_; MemoryManager* memory_manager_; MemoryAllocator* temp_allocator_; + const NamedDataMap* named_data_map_; executorch_flatbuffer::ExecutionPlan* serialization_plan_; EventTracer* event_tracer_; diff --git a/runtime/executor/program.cpp b/runtime/executor/program.cpp index b32832d7718..3c970fe169d 100644 --- a/runtime/executor/program.cpp +++ b/runtime/executor/program.cpp @@ -240,7 +240,8 @@ Result Program::get_method_name(size_t plan_index) const { Result Program::load_method( const char* method_name, MemoryManager* memory_manager, - EventTracer* event_tracer) const { + EventTracer* event_tracer, + const NamedDataMap* named_data_map) const { EXECUTORCH_SCOPE_PROF("Program::load_method"); internal::event_tracer_create_event_block(event_tracer, "Default"); internal::EventTracerProfileMethodScope event_tracer_scope = @@ -257,7 +258,8 @@ Result Program::load_method( if (!plan.ok()) { return plan.error(); } - return Method::load(plan.get(), this, memory_manager, event_tracer); + return Method::load( + plan.get(), this, memory_manager, event_tracer, named_data_map); } Result Program::method_meta(const char* method_name) const { diff --git a/runtime/executor/program.h b/runtime/executor/program.h index f7469eb2192..08d5e392768 100644 --- a/runtime/executor/program.h +++ b/runtime/executor/program.h @@ -132,7 +132,8 @@ class Program final { Result load_method( const char* method_name, MemoryManager* memory_manager, - EventTracer* event_tracer = nullptr) const; + EventTracer* event_tracer = nullptr, + const NamedDataMap* named_data_map = nullptr) const; /** * Gathers metadata for the named method. diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 158da5d1087..67163ed8789 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -79,6 +79,7 @@ def define_common_targets(): ":memory_manager", "//executorch/runtime/backend:interface", "//executorch/runtime/core:core", + "//executorch/runtime/core:named_data_map", "//executorch/runtime/core:evalue" + aten_suffix, "//executorch/runtime/core:event_tracer" + aten_suffix, "//executorch/runtime/core/exec_aten:lib" + aten_suffix, diff --git a/runtime/executor/tensor_parser.h b/runtime/executor/tensor_parser.h index 6b593afe7c0..5882f4606eb 100644 --- a/runtime/executor/tensor_parser.h +++ b/runtime/executor/tensor_parser.h @@ -21,7 +21,8 @@ namespace deserialization { ET_NODISCARD Result parseTensor( const Program* program, MemoryManager* memory_manager, - const executorch_flatbuffer::Tensor* s_tensor); + const executorch_flatbuffer::Tensor* s_tensor, + const NamedDataMap* named_data_map = nullptr); ET_NODISCARD Result> parseTensorList( const flatbuffers::Vector* tensor_indices, @@ -108,7 +109,8 @@ ET_NODISCARD Result getTensorDataPtr( const executorch_flatbuffer::Tensor* s_tensor, const Program* program, size_t nbytes, - HierarchicalAllocator* allocator); + HierarchicalAllocator* allocator, + const NamedDataMap* named_data_map = nullptr); } // namespace deserialization } // namespace runtime diff --git a/runtime/executor/tensor_parser_aten.cpp b/runtime/executor/tensor_parser_aten.cpp index d92e7e6eb90..ab9af3d0399 100644 --- a/runtime/executor/tensor_parser_aten.cpp +++ b/runtime/executor/tensor_parser_aten.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -31,7 +32,8 @@ void deleteNothing(void*) {} Result parseTensor( const Program* program, MemoryManager* memory_manager, - const executorch_flatbuffer::Tensor* s_tensor) { + const executorch_flatbuffer::Tensor* s_tensor, + const NamedDataMap* named_data_map) { EXECUTORCH_SCOPE_PROF("TensorParser::parseTensor"); ET_CHECK_OR_RETURN_ERROR( @@ -102,7 +104,11 @@ Result parseTensor( } else { // Now that we know how big the tensor is, find and assign its memory. Result data_ptr = getTensorDataPtr( - s_tensor, program, tensor.nbytes(), memory_manager->planned_memory()); + s_tensor, + program, + tensor.nbytes(), + memory_manager->planned_memory(), + named_data_map); if (!data_ptr.ok()) { ET_LOG( Error, diff --git a/runtime/executor/tensor_parser_exec_aten.cpp b/runtime/executor/tensor_parser_exec_aten.cpp index 4feae452995..c53c12b5cbe 100644 --- a/runtime/executor/tensor_parser_exec_aten.cpp +++ b/runtime/executor/tensor_parser_exec_aten.cpp @@ -19,6 +19,8 @@ namespace executorch { namespace runtime { namespace deserialization { +using executorch::aten::ScalarType; +using executorch::runtime::TensorLayout; // Provides access to private Program methods. class TensorParser final { public: @@ -113,7 +115,8 @@ ET_NODISCARD Result getTensorDataPtr( const executorch_flatbuffer::Tensor* s_tensor, const Program* program, size_t nbytes, - HierarchicalAllocator* allocator) { + HierarchicalAllocator* allocator, + const NamedDataMap* named_data_map) { auto data_buffer_idx = s_tensor->data_buffer_idx(); const executorch_flatbuffer::AllocationDetails* allocation_info = s_tensor->allocation_info(); @@ -132,8 +135,76 @@ ET_NODISCARD Result getTensorDataPtr( } return planned_ptr; - // Constant - } else if (data_buffer_idx > 0 && allocation_info == nullptr) { + } + // Constant, stored externally. + else if ( + allocation_info == nullptr && s_tensor->extra_tensor_info() != nullptr && + s_tensor->extra_tensor_info()->location() == + executorch_flatbuffer::TensorDataLocation::EXTERNAL) { + // Check that fqn is not null. + ET_CHECK_OR_RETURN_ERROR( + s_tensor->extra_tensor_info()->fully_qualified_name() != nullptr, + InvalidExternalData, + "Fully qualified name of external tensor is null"); + // Look up tensor in named data map. + Result tensor_layout_res = named_data_map->get_metadata( + s_tensor->extra_tensor_info()->fully_qualified_name()->c_str()); + if (!tensor_layout_res.ok()) { + return tensor_layout_res.error(); + } + const TensorLayout& tensor_layout = tensor_layout_res.get(); + + // Compatibility checking. + ET_CHECK_OR_RETURN_ERROR( + static_cast(s_tensor->scalar_type()) == + tensor_layout.scalar_type(), + InvalidExternalData, + "Scalar type mismatch. Expected %hhd, got %hhd.", + static_cast(s_tensor->scalar_type()), + static_cast(tensor_layout.scalar_type())); + ET_CHECK_OR_RETURN_ERROR( + nbytes == tensor_layout.nbytes(), + InvalidExternalData, + "Nbytes mismatch. Expected %zu, got %zu.", + nbytes, + tensor_layout.nbytes()); + int dim = s_tensor->sizes()->size(); + ET_CHECK_OR_RETURN_ERROR( + dim == tensor_layout.sizes().size(), + InvalidExternalData, + "Dim mismatch. Expected %d, got %zu.", + dim, + tensor_layout.sizes().size()); + for (int i = 0; i < dim; i++) { + ET_CHECK_OR_RETURN_ERROR( + s_tensor->sizes()->Get(i) == tensor_layout.sizes()[i], + InvalidExternalData, + "Sizes mismatch. Expected %d, got %d for size at index %d.", + s_tensor->sizes()->Get(i), + tensor_layout.sizes()[i], + i); + ET_CHECK_OR_RETURN_ERROR( + s_tensor->dim_order()->Get(i) == tensor_layout.dim_order()[i], + InvalidExternalData, + "Dim order mismatch. Expected %d, got %d for dim at index %d.", + s_tensor->dim_order()->Get(i), + tensor_layout.dim_order()[i], + i); + } + + Result data_res = named_data_map->get_data( + s_tensor->extra_tensor_info()->fully_qualified_name()->c_str()); + if (!data_res.ok()) { + return data_res.error(); + } + // The const_cast is 'ok' here because program and runtime should guarantee + // that this data is never modified. Temporary until we introduce the + // `get_and_persist_data` API from TODO(T214294528). + return const_cast(static_cast(data_res.get().data())); + } + + // Constant, stored in PTE file. + else if (data_buffer_idx > 0 && allocation_info == nullptr) { auto const_data = program->get_constant_buffer_data(data_buffer_idx, nbytes); if (!const_data.ok()) { diff --git a/runtime/executor/tensor_parser_portable.cpp b/runtime/executor/tensor_parser_portable.cpp index 414961e0ff3..a53295470fc 100644 --- a/runtime/executor/tensor_parser_portable.cpp +++ b/runtime/executor/tensor_parser_portable.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -27,7 +28,8 @@ using torch::executor::TensorImpl; Result parseTensor( const Program* program, MemoryManager* memory_manager, - const executorch_flatbuffer::Tensor* s_tensor) { + const executorch_flatbuffer::Tensor* s_tensor, + const NamedDataMap* named_data_map) { EXECUTORCH_SCOPE_PROF("TensorParser::parseTensor"); auto method_allocator = memory_manager->method_allocator(); @@ -146,7 +148,8 @@ Result parseTensor( s_tensor, program, tensor_impl->nbytes(), - memory_manager->planned_memory()); + memory_manager->planned_memory(), + named_data_map); if (!data_ptr.ok()) { ET_LOG( Error, diff --git a/runtime/executor/test/method_test.cpp b/runtime/executor/test/method_test.cpp index 8ef4cfcb369..ef9ca0c76b2 100644 --- a/runtime/executor/test/method_test.cpp +++ b/runtime/executor/test/method_test.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,7 @@ using namespace ::testing; using executorch::aten::ArrayRef; +using executorch::extension::DataMap; using executorch::extension::prepare_input_tensors; using executorch::runtime::Error; using executorch::runtime::EValue; @@ -52,6 +54,21 @@ class MethodTest : public ::testing::Test { {module_name, std::make_unique(std::move(program.get()))}); } + void load_data_map(const char* path, const char* module_name) { + // Create a loader for the serialized data map. + Result loader = FileDataLoader::from(path); + ASSERT_EQ(loader.error(), Error::Ok); + loaders_.insert( + {module_name, + std::make_unique(std::move(loader.get()))}); + + Result data_map = DataMap::load(loaders_[module_name].get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + data_maps_.insert( + {module_name, std::make_unique(std::move(data_map.get()))}); + } + void SetUp() override { executorch::runtime::runtime_init(); @@ -63,6 +80,10 @@ class MethodTest : public ::testing::Test { load_program( std::getenv("DEPRECATED_ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH"), "linear_constant_buffer"); + + load_program( + std::getenv("ET_MODULE_LINEAR_PROGRAM_PATH"), "linear_program"); + load_data_map(std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear_data"); } private: @@ -71,6 +92,7 @@ class MethodTest : public ::testing::Test { protected: std::unordered_map> programs_; + std::unordered_map> data_maps_; }; TEST_F(MethodTest, MoveTest) { @@ -303,6 +325,17 @@ TEST_F(MethodTest, ConstantBufferTest) { ASSERT_EQ(err, Error::Ok); } +TEST_F(MethodTest, ProgramDataSeparationTest) { + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = programs_["linear_program"]->load_method( + "forward", &mmm.get(), nullptr, data_maps_["linear_data"].get()); + ASSERT_EQ(method.error(), Error::Ok); + + // Can execute the method. + Error err = method->execute(); + ASSERT_EQ(err, Error::Ok); +} + /* * TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of * the portable op lib diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index 72923e9868f..c5c50844e87 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -109,6 +109,8 @@ def define_common_targets(is_fbcode = False): "ET_MODULE_LINEAR_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear.pte])", "ET_MODULE_MULTI_ENTRY_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleMultipleEntry.pte])", "ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])", + "ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])", + "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", } runtime.cxx_test( @@ -135,6 +137,7 @@ def define_common_targets(is_fbcode = False): ":managed_memory_manager", "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/flat_tensor:data_map", "//executorch/extension/runner_util:inputs", "//executorch/kernels/portable:generated_lib", ], diff --git a/test/models/targets.bzl b/test/models/targets.bzl index e2cd0f264b3..51ad10f0f01 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -92,7 +92,7 @@ def define_common_targets(): ) runtime.genrule( - name = "exported_programs_with_data_separated", + name = "exported_program_and_data", cmd = "$(exe :export_program) --modules ModuleLinear --external-constants --outdir $OUT", outs = { "ModuleLinear.pte": ["ModuleLinear.pte"],