diff --git a/extension/runner_util/inputs.cpp b/extension/runner_util/inputs.cpp index c33716be679..11cd176b5d1 100644 --- a/extension/runner_util/inputs.cpp +++ b/extension/runner_util/inputs.cpp @@ -22,15 +22,36 @@ using executorch::runtime::TensorInfo; namespace executorch { namespace extension { -Result prepare_input_tensors(Method& method) { +Result prepare_input_tensors( + Method& method, + PrepareInputTensorsOptions options) { MethodMeta method_meta = method.method_meta(); size_t num_inputs = method_meta.num_inputs(); - size_t num_allocated = 0; + + // A large number of small allocations could exhaust the heap even if the + // total size is smaller than the limit. + ET_CHECK_OR_RETURN_ERROR( + num_inputs <= options.max_inputs, + InvalidProgram, + "Too many inputs: %zu > %zu", + num_inputs, + options.max_inputs); + + // Allocate memory for the inputs array void** inputs = (void**)malloc(num_inputs * sizeof(void*)); + ET_CHECK_OR_RETURN_ERROR( + inputs != nullptr, + MemoryAllocationFailed, + "malloc(%zd) failed", + num_inputs * sizeof(void*)); + // Allocate memory for each input tensor. + size_t total_size = 0; + size_t num_allocated = 0; for (size_t i = 0; i < num_inputs; i++) { auto tag = method_meta.input_tag(i); if (!tag.ok()) { + // The BufferCleanup will free the inputs when it goes out of scope. BufferCleanup cleanup({inputs, num_allocated}); return tag.error(); } @@ -40,10 +61,29 @@ Result prepare_input_tensors(Method& method) { } Result tensor_meta = method_meta.input_tensor_meta(i); if (!tensor_meta.ok()) { + BufferCleanup cleanup({inputs, num_allocated}); return tensor_meta.error(); } // This input is a tensor. Allocate a buffer for it. - void* data_ptr = malloc(tensor_meta->nbytes()); + size_t tensor_size = tensor_meta->nbytes(); + total_size += tensor_size; + if (total_size > options.max_total_allocation_size) { + ET_LOG( + Error, + "Allocating %zu bytes for input %zu would exceed " + "max_total_allocation_size %zu", + tensor_size, + i, + options.max_total_allocation_size); + BufferCleanup cleanup({inputs, num_allocated}); + return Error::InvalidProgram; + } + void* data_ptr = malloc(tensor_size); + if (data_ptr == nullptr) { + ET_LOG(Error, "malloc(%zu) failed for input %zu", tensor_size, i); + BufferCleanup cleanup({inputs, num_allocated}); + return Error::MemoryAllocationFailed; + } inputs[num_allocated++] = data_ptr; // Create the tensor and set it as the input. @@ -52,11 +92,11 @@ Result prepare_input_tensors(Method& method) { if (err != Error::Ok) { ET_LOG( Error, "Failed to prepare input %zu: 0x%" PRIx32, i, (uint32_t)err); - // The BufferCleanup will free the inputs when it goes out of scope. BufferCleanup cleanup({inputs, num_allocated}); return err; } } + return BufferCleanup({inputs, num_allocated}); } diff --git a/extension/runner_util/inputs.h b/extension/runner_util/inputs.h index b933bca8073..73722c0d7bf 100644 --- a/extension/runner_util/inputs.h +++ b/extension/runner_util/inputs.h @@ -51,18 +51,37 @@ class BufferCleanup final { executorch::runtime::Span buffers_; }; +/// Defines options for `prepare_input_tensors()`. +struct PrepareInputTensorsOptions { + /** + * The maximum total size in bytes of all input tensors. If the total size of + * all inputs exceeds this, an error is returned. This prevents allocating too + * much memory if the PTE file is malformed. + */ + size_t max_total_allocation_size = 1024 * 1024 * 1024; + + /** + * The maximum number of inputs to allocate. If the number of inputs exceeds + * this, an error is returned. This prevents allocating too much memory if the + * PTE file is malformed. + */ + size_t max_inputs = 1024; +}; + /** * Allocates input tensors for the provided Method, filling them with ones. Does * not modify inputs that are not Tensors. * * @param[in] method The Method that owns the inputs to prepare. + * @param[in] options Extra options for preparing the inputs. * * @returns On success, an object that owns any allocated tensor memory. It must * remain alive when calling `method->execute()`. * @returns An error on failure. */ executorch::runtime::Result prepare_input_tensors( - executorch::runtime::Method& method); + executorch::runtime::Method& method, + PrepareInputTensorsOptions options = {}); namespace internal { /** diff --git a/extension/runner_util/test/inputs_test.cpp b/extension/runner_util/test/inputs_test.cpp index 829c5265d56..7d6799fa9ab 100644 --- a/extension/runner_util/test/inputs_test.cpp +++ b/extension/runner_util/test/inputs_test.cpp @@ -28,6 +28,7 @@ using executorch::runtime::EValue; using executorch::runtime::MemoryAllocator; using executorch::runtime::MemoryManager; using executorch::runtime::Method; +using executorch::runtime::MethodMeta; using executorch::runtime::Program; using executorch::runtime::Result; using executorch::runtime::Span; @@ -100,6 +101,35 @@ TEST_F(InputsTest, Smoke) { // the pointers. } +TEST_F(InputsTest, ExceedingInputCountLimitFails) { + // The smoke test above demonstrated that we can prepare inputs with the + // default limits. It should fail if we lower the max below the number of + // actual inputs. + MethodMeta method_meta = method_->method_meta(); + size_t num_inputs = method_meta.num_inputs(); + ASSERT_GE(num_inputs, 1); + executorch::extension::PrepareInputTensorsOptions options; + options.max_inputs = num_inputs - 1; + + Result input_buffers = + prepare_input_tensors(*method_, options); + ASSERT_NE(input_buffers.error(), Error::Ok); +} + +TEST_F(InputsTest, ExceedingInputAllocationLimitFails) { + // The smoke test above demonstrated that we can prepare inputs with the + // default limits. It should fail if we lower the max below the actual + // allocation size. + executorch::extension::PrepareInputTensorsOptions options; + // The input tensors are float32, so 1 byte will always be smaller than any + // non-empty input tensor. + options.max_total_allocation_size = 1; + + Result input_buffers = + prepare_input_tensors(*method_, options); + ASSERT_NE(input_buffers.error(), Error::Ok); +} + TEST(BufferCleanupTest, Smoke) { // Returns the size of the buffer at index `i`. auto test_buffer_size = [](size_t i) {