diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 6c534b8d560..ec01323edc7 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -302,5 +302,15 @@ runtime::Error Module::set_output( output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index); } +ET_NODISCARD inline runtime::Result Module::get_method( + const std::string& method_name) { + ET_CHECK_OR_RETURN_ERROR( + methods_.count(method_name) > 0, + InvalidArgument, + "no such method in program: %s", + method_name.c_str()); + return methods_[method_name].method.get(); +} + } // namespace extension } // namespace executorch diff --git a/extension/module/module.h b/extension/module/module.h index 73c7328ee0a..8cedb79c06e 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -493,6 +493,16 @@ class Module { std::unique_ptr data_map_; protected: + /** + * Get a method by method name. + * + * @param[in] method_name The name of the method to get. + * + * @returns A Result object containing either a pointer to the requested + * method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result get_method( + const std::string& method_name); std::unordered_map methods_; friend class ExecuTorchJni;