diff --git a/src/libtorch.cc b/src/libtorch.cc index 331b554..1297115 100644 --- a/src/libtorch.cc +++ b/src/libtorch.cc @@ -25,7 +25,9 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include + #include + #include "libtorch_utils.h" #include "triton/backend/backend_common.h" #include "triton/backend/backend_input_collector.h" @@ -59,6 +61,13 @@ namespace triton { namespace backend { namespace pytorch { +// BackendConfiguration. Allows users to set parameters that apply accross +// models +struct BackendConfiguration { + BackendConfiguration() : gpu_memory_fraction_(1.0) {} + float gpu_memory_fraction_; +}; + // // ModelState // @@ -104,6 +113,15 @@ class ModelState : public BackendModel { bool EnabledWeightSharing() { return enable_weight_sharing_; } const std::vector& ModelOutputs() { return output_names_; } + void SetMemoryFraction(float fraction) + { + c10::cuda::CUDACachingAllocator::init(1); + c10::cuda::CUDACachingAllocator::setMemoryFraction(fraction, 0); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Memory Fraction: ") + std::to_string(fraction)).c_str()); + } + private: ModelState(TRITONBACKEND_Model* triton_model); TRITONSERVER_Error* AutoCompleteConfig(); @@ -136,6 +154,9 @@ class ModelState : public BackendModel { // Defaults to (false, false). std::pair enable_nvfuser_pair_; + // Config settings that apply across models + BackendConfiguration* backend_config_; + // Model mapping for shared TorchScript model across all instances on the // same device. The key is a pair of isGPU and device index. std::map< @@ -180,7 +201,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}), enable_jit_profiling_pair_({false, true}), enable_jit_executor_pair_({false, true}), - enable_nvfuser_pair_({false, false}) + enable_nvfuser_pair_({false, false}), + backend_config_(nullptr) { output_names_.clear(); @@ -198,6 +220,15 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) io.MemberAsString("name", &io_name, &io_name_len)); output_names_.emplace_back(io_name); } + + TRITONBACKEND_Backend* backend; + THROW_IF_BACKEND_MODEL_ERROR( + TRITONBACKEND_ModelBackend(triton_model, &backend)); + void* vstate; + THROW_IF_BACKEND_MODEL_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); + backend_config_ = reinterpret_cast(vstate); + + SetMemoryFraction(backend_config_->gpu_memory_fraction_); } TRITONSERVER_Error* @@ -1772,8 +1803,8 @@ ModelInstanceState::SetInputTensors( // The input must be in contiguous CPU/GPU memory. std::vector> alloc_perference; if (device_.is_cpu()) { - alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0}, - {TRITONSERVER_MEMORY_CPU, 0}}; + alloc_perference = { + {TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; } else { alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}}; } @@ -2043,6 +2074,40 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) .c_str()); } + TRITONSERVER_Message* backend_config_message; + RETURN_IF_ERROR( + TRITONBACKEND_BackendConfig(backend, &backend_config_message)); + + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson( + backend_config_message, &buffer, &byte_size)); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("backend configuration:\n") + buffer).c_str()); + + triton::common::TritonJson::Value backend_config; + if (byte_size != 0) { + RETURN_IF_ERROR(backend_config.Parse(buffer, byte_size)); + } + + std::unique_ptr lconfig(new BackendConfiguration()); + triton::common::TritonJson::Value cmdline; + if (backend_config.Find("cmdline", &cmdline)) { + triton::common::TritonJson::Value value; + std::string value_str; + if (cmdline.Find("gpu-memory-fraction", &value)) { + RETURN_IF_ERROR(value.AsString(&value_str)); + double lvalue; + RETURN_IF_ERROR(ParseDoubleValue(value_str, &lvalue)); + lconfig->gpu_memory_fraction_ = lvalue; + } + } + RETURN_IF_ERROR(TRITONBACKEND_BackendSetState( + backend, reinterpret_cast(lconfig.get()))); + + lconfig.release(); + return nullptr; // success }