diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a76b3e21558..d9e8c206f46 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -198,7 +198,7 @@ set(TRT_LIB TensorRT::NvInfer) get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH) set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty) -if(BINDING_TYPE STREQUAL "pybind") +if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP) add_subdirectory(${3RDPARTY_DIR}/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11) endif() @@ -217,7 +217,7 @@ include_directories( ${3RDPARTY_DIR}/cutlass/tools/util/include ${3RDPARTY_DIR}/NVTX/include ${3RDPARTY_DIR}/json/include) -if(BINDING_TYPE STREQUAL "pybind") +if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP) include_directories(${3RDPARTY_DIR}/pybind11/include) endif() if(BINDING_TYPE STREQUAL "nanobind") diff --git a/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h b/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h index 13bde6d07a5..fa43d084b27 100644 --- a/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h +++ b/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h @@ -168,7 +168,7 @@ class RuntimeBuffers public: //! Additional buffers depending on model type - std::unique_ptr transformerBuffers; + std::shared_ptr transformerBuffers; std::unique_ptr rnnStateBuffers; //! Encoder-Decoder diff --git a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp index 691fb9c7efd..e8b71d065f3 100644 --- a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp @@ -84,7 +84,7 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, if (modelConfig.isTransformerBased()) { - transformerBuffers = std::make_unique(maxBatchSize, maxBeamWidth, maxAttentionWindowVec, + transformerBuffers = std::make_shared(maxBatchSize, maxBeamWidth, maxAttentionWindowVec, maxAttentionWindow, sinkTokenLen, runtime, modelConfig, worldConfig); } if (modelConfig.isRnnBased()) diff --git a/cpp/tensorrt_llm/nanobind/CMakeLists.txt b/cpp/tensorrt_llm/nanobind/CMakeLists.txt index d2e7eac20c2..3d570f024d7 100755 --- a/cpp/tensorrt_llm/nanobind/CMakeLists.txt +++ b/cpp/tensorrt_llm/nanobind/CMakeLists.txt @@ -3,7 +3,23 @@ set(TRTLLM_NB_MODULE ${TRTLLM_NB_MODULE} PARENT_SCOPE) -set(SRCS ../runtime/ipcNvlsMemory.cu bindings.cpp) +set(SRCS + batch_manager/algorithms.cpp + batch_manager/bindings.cpp + batch_manager/buffers.cpp + batch_manager/cacheTransceiver.cpp + batch_manager/kvCacheManager.cpp + batch_manager/llmRequest.cpp + executor/bindings.cpp + executor/executor.cpp + executor/executorConfig.cpp + executor/request.cpp + runtime/bindings.cpp + testing/modelSpecBinding.cpp + runtime/moeBindings.cpp + userbuffers/bindings.cpp + ../runtime/ipcNvlsMemory.cu + bindings.cpp) include_directories(${PROJECT_SOURCE_DIR}/include) @@ -14,20 +30,29 @@ set_property(TARGET ${TRTLLM_NB_MODULE} PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_directories(${TRTLLM_NB_MODULE} PUBLIC "${TORCH_INSTALL_PREFIX}/lib") +if(ENABLE_NVSHMEM) + target_link_libraries(${TRTLLM_NB_MODULE} PUBLIC nvshmem::nvshmem_host + nvshmem::nvshmem_device) +endif() + target_link_libraries( ${TRTLLM_NB_MODULE} - PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG} ${NO_AS_NEEDED_FLAG} - ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python) - + PUBLIC ${SHARED_TARGET} + ${UNDEFINED_FLAG} + ${NO_AS_NEEDED_FLAG} + ${Python3_LIBRARIES} + ${TORCH_LIBRARIES} + torch_python + ${CUDA_NVML_LIB}) target_compile_definitions( ${TRTLLM_NB_MODULE} PUBLIC TRTLLM_NB_MODULE=${TRTLLM_NB_MODULE} - NB_DETAILED_ERROR_MESSAGES=1) + PYBIND11_DETAILED_ERROR_MESSAGES=1) if(NOT WIN32) set_target_properties( ${TRTLLM_NB_MODULE} PROPERTIES LINK_FLAGS - "-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" + "-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' -Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib/stubs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" ) endif() diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp new file mode 100644 index 00000000000..637401555e8 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp @@ -0,0 +1,178 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "algorithms.h" +#include "tensorrt_llm/batch_manager/allocateKvCache.h" +#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h" +#include "tensorrt_llm/batch_manager/capacityScheduler.h" +#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h" +#include "tensorrt_llm/batch_manager/handleContextLogits.h" +#include "tensorrt_llm/batch_manager/handleGenerationLogits.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/batch_manager/logitsPostProcessor.h" +#include "tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h" +#include "tensorrt_llm/batch_manager/medusaBuffers.h" +#include "tensorrt_llm/batch_manager/microBatchScheduler.h" +#include "tensorrt_llm/batch_manager/pauseRequests.h" +#include "tensorrt_llm/batch_manager/peftCacheManager.h" +#include "tensorrt_llm/batch_manager/runtimeBuffers.h" +#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/decoderState.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace nb = nanobind; + +namespace tr = tensorrt_llm::runtime; +using namespace tensorrt_llm::batch_manager; + +void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_& m) +{ + nb::class_(m, CapacityScheduler::name) + .def(nb::init(), + nb::arg("max_num_requests"), nb::arg("capacity_scheduler_policy"), nb::arg("has_kv_cache_manager"), + nb::arg("two_step_lookahead") = false, nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT, + nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE) + .def("__call__", &CapacityScheduler::operator(), nb::arg("active_requests"), + nb::arg("kv_cache_manager") = nullptr, nb::arg("peft_cache_manager") = nullptr, + nb::arg("cross_kv_cache_manager") = nullptr) + .def("name", [](CapacityScheduler const&) { return CapacityScheduler::name; }); + + nb::class_(m, MicroBatchScheduler::name) + .def(nb::init, std::optional, LlmRequestState, + LlmRequestState>(), + nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt, + nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT, + nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE) + .def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"), + nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime")) + .def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; }); + + nb::class_(m, PauseRequests::name) + .def(nb::init(), nb::arg("max_input_len")) + .def("__call__", &PauseRequests::operator(), nb::arg("requests_to_pause"), nb::arg("inflight_req_ids"), + nb::arg("req_ids_to_pause"), nb::arg("pause_flagged"), nb::arg("seq_slot_manager"), + nb::arg("kv_cache_manager") = std::nullopt, nb::arg("cross_kv_cache_manager") = std::nullopt, + nb::arg("peft_cache_manager") = std::nullopt) + .def("name", [](PauseRequests const&) { return PauseRequests::name; }); + + nb::class_(m, AssignReqSeqSlots::name) + .def(nb::init<>()) + .def("__call__", &AssignReqSeqSlots::operator(), nb::arg("seq_slot_manager"), nb::arg("context_requests"), + nb::arg("generation_requests")) + .def("name", [](AssignReqSeqSlots const&) { return AssignReqSeqSlots::name; }); + + nb::class_(m, AllocateKvCache::name) + .def(nb::init<>()) + .def("__call__", &AllocateKvCache::operator(), nb::arg("kv_cache_manager"), nb::arg("context_requests"), + nb::arg("generation_requests"), nb::arg("model_config"), nb::arg("cross_kv_cache_manager") = std::nullopt) + .def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; }); + + nb::class_(m, HandleContextLogits::name) + .def(nb::init<>()) + .def( + "__call__", + [](HandleContextLogits const& self, DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests, + at::Tensor const& logits, std::vector const& numContextLogitsVec, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + OptionalRef medusaBuffers = std::nullopt) + { + return self(inputBuffers, contextRequests, tr::TorchView::of(logits), numContextLogitsVec, modelConfig, + manager, medusaBuffers); + }, + nb::arg("decoder_input_buffers"), nb::arg("context_requests"), nb::arg("logits"), + nb::arg("num_context_logits"), nb::arg("model_config"), nb::arg("buffer_manager"), + nb::arg("medusa_buffers") = std::nullopt) + .def("name", [](HandleContextLogits const&) { return HandleContextLogits::name; }); + + nb::class_(m, HandleGenerationLogits::name) + .def(nb::init<>()) + .def( + "__call__", + [](HandleGenerationLogits const& self, DecoderInputBuffers& inputBuffers, + RequestVector const& generationRequests, at::Tensor const& logits, tr::SizeType32 logitsIndex, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + OptionalRef genRuntimeBuffers = std::nullopt, + OptionalRef medusaBuffers = std::nullopt) + { + self(inputBuffers, generationRequests, tr::TorchView::of(logits), logitsIndex, modelConfig, manager, + genRuntimeBuffers, medusaBuffers); + }, + nb::arg("decoder_input_buffers"), nb::arg("generation_requests"), nb::arg("logits"), + nb::arg("logits_index"), nb::arg("model_config"), nb::arg("buffer_manager"), + nb::arg("gen_runtime_buffers") = std::nullopt, nb::arg("medusa_buffers") = std::nullopt) + .def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; }); + + nb::class_(m, MakeDecodingBatchInputOutput::name) + .def(nb::init<>()) + .def("__call__", &MakeDecodingBatchInputOutput::operator(), nb::arg("context_requests"), + nb::arg("generation_requests"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), + nb::arg("model_config"), nb::arg("max_num_sequences"), nb::arg("fused_runtime_buffers") = std::nullopt) + .def("name", [](MakeDecodingBatchInputOutput const&) { return MakeDecodingBatchInputOutput::name; }); + + nb::class_(m, LogitsPostProcessor::name) + .def(nb::init<>()) + .def("__call__", &LogitsPostProcessor::operator(), nb::arg("context_requests"), nb::arg("generation_requests"), + nb::arg("replicate_logits_post_processor"), nb::arg("decoder_buffers"), nb::arg("world_config"), + nb::arg("runtime"), nb::arg("logits_post_processor_batched") = std::nullopt) + .def("name", [](LogitsPostProcessor const&) { return LogitsPostProcessor::name; }); + + nb::class_(m, CreateNewDecoderRequests::name) + .def(nb::init(), nb::arg("speculative_decoding_fast_logits"), + nb::arg("is_leader_in_orch_mode"), nb::arg("is_normalize_log_probs")) + .def( + "__call__", + [](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig, + executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, + tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType, + DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, + tensorrt_llm::runtime::CudaStream const& runtimeStream, + tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, + SizeType32 beamWidth, OptionalRef medusaBuffers = std::nullopt) + { + auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig, + worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState, + runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers); + + return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs), + std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)}; + }, + nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"), + nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"), + nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"), + nb::arg("max_sequence_length"), nb::arg("beam_width"), nb::arg("medusa_buffers") = std::nullopt) + .def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; }); + + nb::class_(m, UpdateDecoderBuffers::name) + .def(nb::init<>()) + .def("__call__", &UpdateDecoderBuffers::operator(), nb::arg("model_config"), nb::arg("decoder_output_buffers"), + nb::arg("copy_buffer_manager"), nb::arg("decoder_state"), nb::arg("return_log_probs"), + nb::arg("decoder_finish_event")) + .def("name", [](UpdateDecoderBuffers const&) { return UpdateDecoderBuffers::name; }); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.h b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.h new file mode 100644 index 00000000000..cac81d73f27 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager::algorithms +{ + +void initBindings(nb::module_& m); + +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp new file mode 100644 index 00000000000..d44a957aad9 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -0,0 +1,525 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include "tensorrt_llm/batch_manager/common.h" +#include "tensorrt_llm/batch_manager/decoderBuffers.h" +#include "tensorrt_llm/batch_manager/medusaBuffers.h" +#include "tensorrt_llm/batch_manager/microBatchScheduler.h" +#include "tensorrt_llm/batch_manager/peftCacheManager.h" +#include "tensorrt_llm/batch_manager/rnnStateManager.h" +#include "tensorrt_llm/batch_manager/runtimeBuffers.h" +#include "tensorrt_llm/batch_manager/sequenceSlotManager.h" +#include "tensorrt_llm/nanobind/common/bindTypes.h" +#include "tensorrt_llm/runtime/gptDecoderBatched.h" +#include "tensorrt_llm/runtime/runtimeKernels.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tb = tensorrt_llm::batch_manager; +namespace tle = tensorrt_llm::executor; +namespace tr = tensorrt_llm::runtime; + +using namespace tensorrt_llm::runtime; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +void initBindings(nb::module_& m) +{ + using GenLlmReq = tb::GenericLlmRequest; + + // Create and register exceptions in module scope + nb::exception(m, "PeftTaskNotCachedException"); + nb::exception(m, "LoraCacheFullException"); + + // Register with no captures + nb::register_exception_translator( + [](std::exception_ptr const& p, void*) + { + try + { + if (p) + std::rethrow_exception(p); + } + catch (const tb::PeftTaskNotCachedException& e) + { + PyErr_SetString(nb::type().ptr(), e.what()); + } + catch (const tr::LoraCacheFullException& e) + { + PyErr_SetString(nb::type().ptr(), e.what()); + } + }); + + PybindUtils::bindSet(m, "ReqIdsSet"); + + nb::enum_(m, "LlmRequestType") + .value("LLMREQUEST_TYPE_CONTEXT_AND_GENERATION", tb::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION) + .value("LLMREQUEST_TYPE_CONTEXT_ONLY", tb::LLMREQUEST_TYPE_CONTEXT_ONLY) + .value("LLMREQUEST_TYPE_GENERATION_ONLY", tb::LLMREQUEST_TYPE_GENERATION_ONLY) + .export_values(); + + nb::class_(m, "ContextChunkingConfig") + .def(nb::init(), nb::arg("chunking_policy"), + nb::arg("chunk_unit_size")) + .def_rw("chunking_policy", &tb::batch_scheduler::ContextChunkingConfig::chunkingPolicy) + .def_rw("chunk_unit_size", &tb::batch_scheduler::ContextChunkingConfig::chunkUnitSize); + + nb::class_(m, "GenericLlmRequest") + .def("set_exclude_input_from_output", &GenLlmReq::setExcludeInputFromOutput, nb::arg("exclude")) + .def("get_num_tokens", &GenLlmReq::getNumTokens, nb::arg("beam")) + .def_prop_ro("max_beam_num_tokens", &GenLlmReq::getMaxBeamNumTokens) + .def("get_token", &GenLlmReq::getToken, nb::arg("beam"), nb::arg("pos")) + .def("get_tokens", nb::overload_cast(&GenLlmReq::getTokens, nb::const_), nb::arg("beam")) + .def("get_tokens", nb::overload_cast<>(&GenLlmReq::getTokens, nb::const_)) + .def("get_last_tokens", nb::overload_cast(&GenLlmReq::getLastTokens), nb::arg("beam")) + .def("get_last_tokens", nb::overload_cast<>(&GenLlmReq::getLastTokens)) + .def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, nb::arg("for_next_iteration") = false) + .def_prop_ro("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens) + .def("add_new_token", &GenLlmReq::addNewToken, nb::arg("token"), nb::arg("beam")) + .def("add_new_tokens", &GenLlmReq::addNewTokens, nb::arg("beam_tokens")) + .def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens) + .def("set_generated_tokens", &GenLlmReq::setGeneratedTokens, nb::arg("generated_beam_tokens")) + .def("pause", &GenLlmReq::pause, nb::arg("max_input_len")) + .def_prop_rw("max_sent_token_len", &GenLlmReq::getMaxSentTokenLen, &GenLlmReq::setMaxSentTokenLen) + .def_prop_ro("prompt_embedding_table", &GenLlmReq::getPromptEmbeddingTable) + .def_prop_ro("multimodal_embedding", &GenLlmReq::getMultimodalEmbedding) + .def_prop_ro("mrope_rotary_cos_sin", &GenLlmReq::getMropeRotaryCosSin) + .def_prop_ro("bad_words_list", &GenLlmReq::getBadWordsList) + .def_prop_rw("draft_logits", &GenLlmReq::getDraftLogits, &GenLlmReq::setDraftLogits) + .def_prop_ro("embedding_bias", &GenLlmReq::getEmbeddingBias) + .def_prop_rw("lora_config", &GenLlmReq::getLoraConfig, &GenLlmReq::setLoraConfig) + .def_prop_rw("lora_weights", &GenLlmReq::getLoraWeights, &GenLlmReq::setLoraWeights) + .def_prop_ro("stop_words_list", &GenLlmReq::getStopWordsList) + .def_prop_ro("context_logits", &GenLlmReq::getContextLogitsHost) + .def_prop_ro("generation_logits", &GenLlmReq::getGenerationLogitsHost) + .def_prop_ro("prompt_vocab_size", &GenLlmReq::getPromptVocabSize) + .def_prop_ro("mrope_position_deltas", &GenLlmReq::getMropePositionDeltas) + .def_prop_ro("lora_task_id", &GenLlmReq::getLoraTaskId) + .def_prop_ro("lookahead_config", &GenLlmReq::getLookaheadConfig) + .def_prop_rw("context_chunk_size", &GenLlmReq::getContextChunkSize, &GenLlmReq::setContextChunkSize) + .def_prop_rw("decoding_iter", &GenLlmReq::getDecodingIter, &GenLlmReq::setDecodingIter) + .def_rw("request_id", &GenLlmReq::mRequestId) + .def_rw("prompt_len", &GenLlmReq::mPromptLen) + .def_rw("max_new_tokens", &GenLlmReq::mMaxNewTokens) + .def_rw("sampling_config", &GenLlmReq::mSamplingConfig) + .def_prop_rw("state", &GenLlmReq::getState, &GenLlmReq::setState) + .def_prop_rw("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming) + .def_rw("end_id", &GenLlmReq::mEndId) + .def_rw("pad_id", &GenLlmReq::mPadId) + .def_rw("seq_slot", &GenLlmReq::mSeqSlot) + .def_prop_ro("return_log_probs", &GenLlmReq::returnLogProbs) + .def_prop_ro("return_context_logits", &GenLlmReq::getReturnContextLogits) + .def_prop_ro("return_generation_logits", &GenLlmReq::getReturnGenerationLogits) + .def_prop_ro("log_probs", nb::overload_cast<>(&GenLlmReq::getLogProbs, nb::const_)) + .def("get_log_probs", nb::overload_cast(&GenLlmReq::getLogProbs, nb::const_)) + .def("set_log_probs", &GenLlmReq::setLogProbs, nb::arg("log_probs"), nb::arg("beam")) + .def("set_return_encoder_output", &GenLlmReq::setReturnEncoderOutput, nb::arg("return_encoder_output")) + .def("get_return_encoder_output", &GenLlmReq::getReturnEncoderOutput) + .def("priority", nb::overload_cast<>(&GenLlmReq::priority, nb::const_)) + .def("set_priority", nb::overload_cast(&GenLlmReq::setPriority)) + .def_prop_ro("cum_log_probs", &GenLlmReq::getCumLogProbs) + .def("set_cum_log_prob", &GenLlmReq::setCumLogProb, nb::arg("cum_log_prob"), nb::arg("beam")) + .def("update_num_tokens_per_iteration", &GenLlmReq::updateNumTokensPerIteration, + nb::arg("num_tokens_per_iteration"), nb::arg("model_config")) + .def_prop_ro("orig_prompt_len", &GenLlmReq::getOrigPromptLen) + .def("has_draft_tokens", &GenLlmReq::hasDraftTokens) + .def("move_to_next_context_chunk", &GenLlmReq::moveToNextContextChunk) + .def_prop_ro("is_last_context_chunk", &GenLlmReq::isLastContextChunk) + .def_prop_ro("is_first_context_chunk", &GenLlmReq::isFirstContextChunk) + .def_prop_ro("context_remaining_length", &GenLlmReq::getContextRemainingLength) + .def_prop_ro("context_logits", &GenLlmReq::getContextLogitsHost) + .def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens) + .def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam")) + .def_prop_ro("is_finished", &GenLlmReq::isFinished) + .def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) + .def_prop_rw( + "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) + .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) + .def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) + .def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams) + .def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest) + .def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) + .def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState) + .def_prop_ro("is_context_finished", &GenLlmReq::isContextFinished) + .def_prop_ro("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState) + .def_prop_ro("is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete) + .def_prop_ro( + "is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress) + .def_prop_ro("is_context_init_state", &GenLlmReq::isContextInitState) + .def_prop_ro("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState) + .def_prop_ro("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState) + .def_prop_ro("is_disagg_context_complete_state", &GenLlmReq::isDisaggContextCompleteState) + .def_prop_ro("stage", &GenLlmReq::getRequestStage) + .def_prop_ro("kv_cache_transfer_time_ms", &GenLlmReq::getKvCacheTransferTimeMS) + .def_prop_ro("kv_cache_size", &GenLlmReq::getKvCacheSize) + .def_prop_ro("avg_decoded_tokens_per_iter", &GenLlmReq::getAvgDecodedTokensPerIter) + .def_prop_ro("alloc_total_blocks", &GenLlmReq::getAllocTotalBlocksPerRequest) + .def_prop_ro("alloc_new_blocks", &GenLlmReq::getAllocNewBlocksPerRequest) + .def("alloc_context_logits", &GenLlmReq::allocContextLogitsHost, nb::arg("vocab_size"), nb::arg("logit_dtype")) + .def_prop_ro("reused_blocks", &GenLlmReq::getReusedBlocksPerRequest) + .def_prop_ro("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest) + .def_prop_ro("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest) + .def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType) + .def_prop_ro("multimodal_hashes", + [](GenLlmReq& self) + { + std::optional>> hashes = std::nullopt; + if (self.getMultimodalHashes()) + { + hashes = *self.getMultimodalHashes().value(); + } + return hashes; + }) + .def_prop_ro("multimodal_positions", + [](GenLlmReq& self) + { + std::optional> positions = std::nullopt; + if (self.getMultimodalPositions()) + { + positions = *self.getMultimodalPositions().value(); + } + return positions; + }) + .def_prop_ro("multimodal_lengths", + [](GenLlmReq& self) + { + std::optional> lengths = std::nullopt; + if (self.getMultimodalLengths()) + { + lengths = *self.getMultimodalLengths().value(); + } + return lengths; + }) + .def_prop_ro("position_ids", + [](GenLlmReq& self) + { + std::optional> positionIds = std::nullopt; + if (self.getPositionIds()) + { + positionIds = *self.getPositionIds().value(); + } + return positionIds; + }) + .def_prop_rw( + "draft_tokens", + [](GenLlmReq& self) + { + std::optional draftTokens = std::nullopt; + if (self.hasDraftTokens()) + { + draftTokens = *self.getDraftTokens(); + } + return draftTokens; + }, + [](GenLlmReq& self, std::optional const& draftTokens) + { + if (draftTokens) + { + self.setDraftTokens(std::make_shared(draftTokens.value())); + } + }) + .def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) + .def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); + + nb::class_(m, "LlmRequest", nb::dynamic_attr()) + .def( + "__init__", + [](tb::LlmRequest* self, tb::LlmRequest::RequestIdType request_id, + tb::LlmRequest::SizeType32 max_new_tokens, std::vector input_tokens, + runtime::SamplingConfig sampling_config, bool is_streaming, + std::optional end_id, std::optional pad_id, + std::optional embedding_bias, std::optional bad_words_list, + std::optional stop_words_list, + std::optional> position_ids, + std::optional prompt_embedding_table, + std::optional prompt_vocab_size, + std::optional>> multimodal_hashes, + std::optional> multimodal_positions, + std::optional> multimodal_lengths, + std::optional multimodal_embedding, std::optional mrope_rotary_cos_sin, + std::optional mrope_position_deltas, + std::optional lora_task_id, std::optional lora_weights, + std::optional lora_config, + std::optional lookahead_config, + std::optional kv_cache_retention_config, bool return_log_probs, + bool return_context_logits, bool return_generation_logits, + std::optional draft_tokens, std::optional draft_logits, + bool exclude_input_from_output, + std::optional logits_post_processor, + bool apply_logits_post_processor_batched, std::optional encoder_input_tokens, + bool return_encoder_output, std::optional client_id, + executor::PriorityType priority, std::optional encoder_input_features, + std::optional encoder_output_length, + std::optional cross_attention_mask, tb::LlmRequestType llm_request_type, + std::optional input_token_extra_ids, + tb::LlmRequest::SizeType32 num_return_sequences, std::optional eagle_config, + std::optional skip_cross_attn_blocks, bool return_perf_metrics, + std::optional guided_decoding_params, + std::optional language_adapter_uid, + std::optional allotted_time_ms, + std::optional context_phase_params) + { + auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) + { + std::optional tensorPtr = std::nullopt; + if (atTensor) + { + tensorPtr = tr::TorchView::of(atTensor.value()); + if (unsqueeze) + { + (*tensorPtr)->unsqueeze(0); + } + } + return tensorPtr; + }; + + auto embedding_bias_tensor_ptr = makeOptionalTensor(embedding_bias, true); + auto bad_words_list_tensor_ptr = makeOptionalTensor(bad_words_list, true); + auto stop_words_list_tensor_ptr = makeOptionalTensor(stop_words_list, true); + auto prompt_embedding_table_tensor_ptr = makeOptionalTensor(prompt_embedding_table); + auto multimodal_embedding_tensor_ptr = makeOptionalTensor(multimodal_embedding); + auto lora_weights_tensor_ptr = makeOptionalTensor(lora_weights); + auto mrope_rotary_cos_sin_tensor_ptr = makeOptionalTensor(mrope_rotary_cos_sin); + auto lora_config_tensor_ptr = makeOptionalTensor(lora_config); + auto draft_logits_tensor_ptr = makeOptionalTensor(draft_logits); + auto encoder_input_features_tensor_ptr = makeOptionalTensor(encoder_input_features); + auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask); + auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks); + + // 49 parameters + new (self) tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming, + end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr, stop_words_list_tensor_ptr, + position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size, multimodal_hashes, + multimodal_positions, multimodal_lengths, multimodal_embedding_tensor_ptr, + mrope_rotary_cos_sin_tensor_ptr, mrope_position_deltas, lora_task_id, lora_weights_tensor_ptr, + lora_config_tensor_ptr, lookahead_config, kv_cache_retention_config, return_log_probs, + return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr, + exclude_input_from_output, logits_post_processor, apply_logits_post_processor_batched, + encoder_input_tokens, return_encoder_output, client_id, priority, encoder_input_features_tensor_ptr, + encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, + num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, + guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params}; + }, + nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"), + nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt, + nb::arg("embedding_bias") = std::nullopt, nb::arg("bad_words_list") = std::nullopt, + nb::arg("stop_words_list") = std::nullopt, nb::arg("position_ids") = std::nullopt, + nb::arg("prompt_embedding_table") = std::nullopt, nb::arg("prompt_vocab_size") = std::nullopt, + nb::arg("multimodal_hashes") = std::nullopt, nb::arg("multimodal_positions") = std::nullopt, + nb::arg("multimodal_lengths") = std::nullopt, nb::arg("multimodal_embedding") = std::nullopt, + nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt, + nb::arg("lora_task_id") = std::nullopt, nb::arg("lora_weights") = std::nullopt, + nb::arg("lora_config") = std::nullopt, nb::arg("lookahead_config") = std::nullopt, + nb::arg("kv_cache_retention_config") = std::nullopt, nb::arg("return_log_probs") = false, + nb::arg("return_context_logits") = false, nb::arg("return_generation_logits") = false, + nb::arg("draft_tokens") = std::nullopt, nb::arg("draft_logits") = std::nullopt, + nb::arg("exclude_input_from_output") = false, nb::arg("logits_post_processor") = std::nullopt, + nb::arg("apply_logits_post_processor_batched") = false, nb::arg("encoder_input_tokens") = std::nullopt, + nb::arg("return_encoder_output") = false, nb::arg("client_id") = std::nullopt, + nb::arg("priority") = executor::Request::kDefaultPriority, nb::arg("encoder_input_features") = std::nullopt, + nb::arg("encoder_output_len") = std::nullopt, nb::arg("cross_attention_mask") = std::nullopt, + nb::arg("llm_request_type") = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, + nb::arg("input_token_extra_ids") = std::nullopt, nb::arg("num_return_sequences") = 1, + nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt, + nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt, + nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt, + nb::arg("context_phase_params") = std::nullopt) + .def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"), + nb::arg("max_draft_len"), nb::arg("vocab_size_padded"), nb::arg("max_endocer_input_len") = std::nullopt, + nb::arg("enable_kv_cache_reuse") = false) + .def("create_response", &tb::LlmRequest::createResponse, nb::arg("use_fast_logits") = false, + nb::arg("mpi_world_rank") = 0) + .def("create_result", &tb::LlmRequest::createResult, nb::arg("use_fast_logits") = false, + nb::arg("mpi_world_rank") = 0) + .def("create_serialized_result", + [](tb::LlmRequest& self, bool use_fast_logits = false, int mpi_world_rank = 0) + { + std::vector serialized_result; + bool is_final = false; + self.createSerializedResult(serialized_result, is_final, use_fast_logits, mpi_world_rank); + return std::make_tuple(nb::bytes(serialized_result.data(), serialized_result.size()), is_final); + }) + .def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, nb::arg("manager")) + .def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, nb::arg("manager")) + .def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason")) + .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) + .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")); + + nb::class_(m, "SequenceSlotManager") + .def(nb::init(), nb::arg("max_num_slots"), + nb::arg("max_sequence_idle_microseconds")) + .def("get_sequence_slot", &tb::SequenceSlotManager::getSequenceSlot, nb::arg("start_flag"), + nb::arg("sequence_id")) + .def("free_sequence_slot", &tb::SequenceSlotManager::freeSequenceSlot, nb::arg("sequence_id")) + .def("free_idle_sequence_slots", &tb::SequenceSlotManager::freeIdleSequenceSlots); + + nb::class_(m, "RnnStateManager") + .def(nb::init(), + nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")); + + nb::class_(m, "DecoderInputBuffers") + .def(nb::init(), + nb::arg("max_num_sequences"), nb::arg("max_batch_size"), nb::arg("max_tokens_per_engine_step"), + nb::arg("manager")) + .def_rw("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots) + .def_rw("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice) + .def_rw("fill_values", &tb::DecoderInputBuffers::fillValues) + .def_rw("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice) + .def_rw("inputs_ids", &tb::DecoderInputBuffers::inputsIds) + .def_rw("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots) + .def_rw("logits", &tb::DecoderInputBuffers::logits); + + nb::class_(m, "DecoderOutputBuffers") + .def_rw("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost) + .def_rw("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost) + .def_prop_ro("new_output_tokens_host", + [](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); }) + .def_rw("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost) + .def_rw("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost) + .def_rw("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost); + + nb::class_(m, "SlotDecoderBuffers") + .def(nb::init(), + nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager")) + .def_rw("output_ids", &tb::SlotDecoderBuffers::outputIds) + .def_rw("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost) + .def_rw("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost) + .def_rw("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs) + .def_rw("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost) + .def_rw("log_probs", &tb::SlotDecoderBuffers::logProbs) + .def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost) + .def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost); + + nb::class_(m, "MedusaBuffers") + .def(nb::init(), + nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager"), nb::arg("model_config"), + nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("runtime")); + + m.def( + "add_new_tokens_to_requests", + [](std::vector>& requests, + std::vector const& tokens, int beam_idx) + { + TLLM_CHECK_WITH_INFO(requests.size() == tokens.size(), "Expected the same number of requests and tokens."); + + for (int i = 0; i < requests.size(); ++i) + { + requests[i]->addNewToken(tokens[i], beam_idx); + } + }, + nb::arg("requests"), nb::arg("tokens"), nb::arg("beam_idx"), + "Add new tokens to multiple LLM requests. The tokens vector should contain tokens for beam beam_idx of all " + "requests in order."); + + m.def( + "make_decoding_batch_input", + [](std::vector>& contextRequests, + std::vector>& genRequests, tr::ITensor::SharedPtr logits, int beamWidth, + std::vector const& numContextLogitsPrefixSum, tb::DecoderInputBuffers const& decoderInputBuffers, + runtime::decoder::DecoderState& decoderState, tr::BufferManager const& manager) + { + std::vector activeSlots; + std::vector generationSteps; + std::vector> logitsVec = {{}}; + + for (int i = 0; i < contextRequests.size(); ++i) + { + if (contextRequests[i]->isLastContextChunk()) + { + activeSlots.push_back(*contextRequests[i]->mSeqSlot); + generationSteps.push_back(contextRequests[i]->getDecodingIter()); + auto contextLogitsOffset = numContextLogitsPrefixSum[i + 1] - 1; + tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, contextLogitsOffset, 1); + + if (beamWidth > 1) + { + // Tile logits of context requests + auto const logitsShape = logitsView->getShape(); + auto const logitsType = logitsView->getDataType(); + auto decoderLogits = manager.gpu(ITensor::makeShape({beamWidth, logitsShape.d[1]}), logitsType); + tensorrt_llm::runtime::kernels::tileTensor( + *decoderLogits, *logitsView, beamWidth, manager.getStream()); + decoderLogits->unsqueeze(0); + logitsVec[0].push_back(std::move(decoderLogits)); + } + else + { + logitsView->unsqueeze(1); + logitsVec[0].push_back(std::move(logitsView)); + } + } + } + + auto genLogitsOffset = numContextLogitsPrefixSum.back(); + for (int i = 0; i < genRequests.size(); ++i) + { + if (genRequests[i]->isGenerationInProgressState()) + { + activeSlots.push_back(*genRequests[i]->mSeqSlot); + generationSteps.push_back(genRequests[i]->getDecodingIter()); + + auto logitsOffset = genLogitsOffset + i * beamWidth; + auto numberOfLogits = beamWidth; + tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, logitsOffset, numberOfLogits); + logitsView->unsqueeze(0); + logitsVec[0].push_back(std::move(logitsView)); + } + } + + auto& batchSlots = decoderInputBuffers.forwardBatchSlots; + batchSlots[0]->resize(activeSlots.size()); + auto batchSlotsRange = tr::BufferRange(*batchSlots[0]); + for (int i = 0; i < activeSlots.size(); ++i) + { + batchSlotsRange[i] = activeSlots[i]; + } + + auto decodingInput = std::make_unique(logitsVec, 1); + decodingInput->batchSlots = batchSlots; + + auto const maxBeamWidth = decoderState.getMaxBeamWidth(); + if (maxBeamWidth > 1) + { + // For Variable-Beam-Width-Search + decoderState.getJointDecodingInput().generationSteps = generationSteps; + } + + return decodingInput; + }, + nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("logits"), nb::arg("beam_width"), + nb::arg("num_context_logits_prefix_sum"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), + nb::arg("buffer_manager"), "Make decoding batch input."); +} + +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.h b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.h new file mode 100644 index 00000000000..3d5a0f5d5b2 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.h @@ -0,0 +1,28 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +void initBindings(nb::module_& m); + +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp new file mode 100644 index 00000000000..b6edcca1c24 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp @@ -0,0 +1,108 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "buffers.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/batch_manager/runtimeBuffers.h" +#include "tensorrt_llm/batch_manager/transformerBuffers.h" + +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tb = tensorrt_llm::batch_manager; +namespace tr = tensorrt_llm::runtime; + +using tr::SizeType32; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +void Buffers::initBindings(nb::module_& m) +{ + nb::class_(m, "TransformerBuffers") + .def(nb::init const&, SizeType32, SizeType32, + runtime::TllmRuntime const&, runtime::ModelConfig const&, runtime::WorldConfig const&>(), + nb::arg("max_batch_size"), nb::arg("max_beam_width"), nb::arg("max_attention_window_vec"), + nb::arg("max_attention_window"), nb::arg("sink_token_len"), nb::arg("runtime"), nb::arg("model_config"), + nb::arg("world_config")) + .def("reshape", &tb::TransformerBuffers::reshape, nb::arg("num_sequences"), nb::arg("num_input_tokens")) + .def("reshape_kv_tensors", &tb::TransformerBuffers::reshapeKvTensors, nb::arg("max_batch_size"), + nb::arg("max_beam_width"), nb::arg("max_blocks_per_seq"), nb::arg("kv_cache_type"), nb::arg("num_pools"), + nb::arg("buffer_manager")) + .def("get_buffers", &tb::TransformerBuffers::getBuffers, nb::arg("input_buffers"), nb::arg("output_buffers"), + nb::arg("model_config")) + .def("copy_position_ids", &tb::TransformerBuffers::copyPositionIds, nb::arg("runtime"), + nb::arg("position_ids_host"), nb::arg("is_chat_glm"), nb::arg("decoder_position_ids")) + .def("copy_kv_block_offsets", &tb::TransformerBuffers::copyKvBlockOffsets, nb::arg("context_requests"), + nb::arg("gen_requests"), nb::arg("kv_cache_manager"), nb::arg("cross_kv_cache_manager"), + nb::arg("buffer_manager")) + .def("copy_cache_indirection", &tb::TransformerBuffers::copyCacheIndirection, nb::arg("gen_requests"), + nb::arg("decoder_cache_indirection_output"), nb::arg("runtime")) + .def_rw("past_key_value_lengths", &tb::TransformerBuffers::pastKeyValueLengths) + .def_rw("position_ids", &tb::TransformerBuffers::positionIds) + .def_rw("max_attention_windows", &tb::TransformerBuffers::maxAttentionWindows) + .def_rw("sink_token_lengths", &tb::TransformerBuffers::sinkTokenLengths) + .def_rw("cache_indirection", &tb::TransformerBuffers::cacheIndirection) + .def_rw("kv_cache_block_offsets_host", &tb::TransformerBuffers::kvCacheBlockOffsetsHost) + .def_rw("kv_cache_block_offsets_device", &tb::TransformerBuffers::kvCacheBlockOffsetsDevice) + .def_rw("cross_kv_cache_block_pool_pointers", &tb::TransformerBuffers::crossKvCacheBlockPoolPointers) + .def_rw("cross_kv_cache_block_offsets_host", &tb::TransformerBuffers::crossKvCacheBlockOffsetsHost) + .def_rw("cross_kv_cache_block_offsets_device", &tb::TransformerBuffers::crossKvCacheBlockOffsetsDevice) + .def_rw("cache_indir_batched_copy_src_offsets", &tb::TransformerBuffers::cacheIndirBatchedCopySrcOffsets) + .def_rw("cache_indir_batched_copy_dst_offsets", &tb::TransformerBuffers::cacheIndirBatchedCopyDstOffsets) + .def_rw("cache_indir_batched_copy_sizes", &tb::TransformerBuffers::cacheIndirBatchedCopySizes) + .def_rw("fill_values_alt", &tb::TransformerBuffers::fillValuesAlt) + .def_rw("fill_values_alt_device", &tb::TransformerBuffers::fillValuesAltDevice) + .def_rw("seq_slots_alt", &tb::TransformerBuffers::seqSlotsAlt) + .def_rw("seq_slots_alt_device", &tb::TransformerBuffers::seqSlotsAltDevice); + + nb::class_(m, "RuntimeBuffers") + .def(nb::init const&, SizeType32, SizeType32, + runtime::TllmRuntime const&, runtime::ModelConfig const&, runtime::WorldConfig const&, + executor::DecodingConfig const&, bool, std::optional>(), + nb::arg("max_batch_size"), nb::arg("max_beam_width"), nb::arg("max_attention_window_vec"), + nb::arg("max_attention_window"), nb::arg("sink_token_len"), nb::arg("runtime"), nb::arg("model_config"), + nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("gather_generation_logits"), + nb::arg("max_num_tokens") = std::nullopt) + .def_prop_rw( + "transformer_buffers", [](tb::RuntimeBuffers& self) { return self.transformerBuffers; }, + [](tb::RuntimeBuffers& self, std::shared_ptr val) + { self.transformerBuffers = val; }) + .def_rw("num_context_logits", &tb::RuntimeBuffers::numContextLogits) + .def_rw("cache_indir_decoder_io_batched_copy_src_offsets", + &tb::RuntimeBuffers::cacheIndirDecoderIOBatchedCopySrcOffsets) + .def_rw("cache_indir_decoder_io_batched_copy_dst_offsets", + &tb::RuntimeBuffers::cacheIndirDecoderIOBatchedCopyDstOffsets) + .def_rw("cache_indir_decoder_io_batched_copy_sizes", &tb::RuntimeBuffers::cacheIndirDecoderIOBatchedCopySizes) + .def_rw("logits", &tb::RuntimeBuffers::logits) + .def_rw("seq_slots", &tb::RuntimeBuffers::seqSlots) + .def_rw("seq_slots_device", &tb::RuntimeBuffers::seqSlotsDevice) + .def_rw("cache_indir_decoder_io_batched_copy_src_offsets_slice_device", + &tb::RuntimeBuffers::mCacheIndirDecoderIOBatchedCopySrcOffsetsSliceDevice) + .def_rw("cache_indir_decoder_io_batched_copy_dst_offsets_slice_device", + &tb::RuntimeBuffers::mCacheIndirDecoderIOBatchedCopyDstOffsetsSliceDevice) + .def_rw("cache_indir_decoder_io_batched_copy_copy_sizes_device", + &tb::RuntimeBuffers::mCacheIndirDecoderIOBatchedCopyCopySizesDevice); +} +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h new file mode 100644 index 00000000000..34df07e4073 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager +{ +class Buffers +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp new file mode 100644 index 00000000000..abac6d17ed8 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp @@ -0,0 +1,110 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cacheTransceiver.h" +#include "tensorrt_llm/batch_manager/cacheTransceiver.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include +#include +#include +#include +#include +#include +#include + +using SizeType32 = tensorrt_llm::runtime::SizeType32; + +namespace tb = tensorrt_llm::batch_manager; +namespace nb = nanobind; + +namespace +{ + +class PyCacheTransceiver : public tb::BaseCacheTransceiver +{ +public: + // using BaseCacheTransceiver::BaseCacheTransceiver; // Inherit constructors + NB_TRAMPOLINE(tb::BaseCacheTransceiver, 6); + + void respondAndSendAsync(tb::LlmRequest* llmRequest) override + { + NB_OVERRIDE_PURE(respondAndSendAsync, llmRequest); + } + + void requestAndReceiveSync(tb::LlmRequest* llmRequest) override + { + NB_OVERRIDE_PURE(requestAndReceiveSync, llmRequest); + } + + void requestAndReceiveAsync(tb::LlmRequest* llmRequest) override + { + NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest); + } + + void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override + { + NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum); + } + + void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override + { + NB_OVERRIDE_PURE(checkGenTransferStatus, atLeastRequestNum); + } + + bool checkGenTransferComplete() const override + { + NB_OVERRIDE_PURE(checkGenTransferComplete); + } +}; +} // namespace + +void tb::CacheTransceiverBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "BaseCacheTransceiver") + .def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync) + .def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync) + .def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync) + .def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus) + .def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus) + .def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete); + + nb::enum_(m, "CommType") + .value("UNKNOWN", tb::CacheTransceiver::CommType::UNKNOWN) + .value("MPI", tb::CacheTransceiver::CommType::MPI) + .value("UCX", tb::CacheTransceiver::CommType::UCX) + .value("NIXL", tb::CacheTransceiver::CommType::NIXL); + + nb::enum_(m, "AttentionType") + .value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT) + .value("MLA", executor::kv_cache::CacheState::AttentionType::kMLA); + + nb::class_(m, "CacheTransceiver") + .def(nb::init, SizeType32, SizeType32, runtime::WorldConfig, nvinfer1::DataType, + executor::kv_cache::CacheState::AttentionType, std::optional>(), + nb::arg("cache_manager"), nb::arg("comm_type"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), + nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("dtype"), nb::arg("attention_type"), + nb::arg("cache_transceiver_config") = std::nullopt); + + nb::class_(m, "CacheTransBufferManager") + .def(nb::init>(), nb::arg("cache_manager"), + nb::arg("max_num_tokens") = std::nullopt) + .def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize, + nb::arg("max_num_tokens") = std::nullopt); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h new file mode 100644 index 00000000000..90fc63d4fde --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +namespace nb = nanobind; + +namespace tensorrt_llm::batch_manager +{ +class CacheTransceiverBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp new file mode 100644 index 00000000000..f1c398d31f0 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -0,0 +1,478 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kvCacheManager.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/batch_manager/peftCacheManager.h" +#include "tensorrt_llm/nanobind/common/bindTypes.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tb = tensorrt_llm::batch_manager; +namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager; +namespace tr = tensorrt_llm::runtime; +namespace nb = nanobind; +using BlockKey = tbk::BlockKey; +using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; +using SizeType32 = tensorrt_llm::runtime::SizeType32; +using TokenIdType = tensorrt_llm::runtime::TokenIdType; +using VecTokens = std::vector; +using CudaStreamPtr = std::shared_ptr; + +namespace +{ +std::optional from_torch(std::optional torchPtr) +{ + if (torchPtr) + { + return tr::TorchView::of(torchPtr.value()); + } + return std::nullopt; +} + +class PyKvCacheManager : public tbk::BaseKVCacheManager +{ +public: + NB_TRAMPOLINE(tbk::BaseKVCacheManager, 28); + + // using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors + void allocatePools(bool useUvm = false) override + { + NB_OVERRIDE_PURE(allocatePools, useUvm); + } + + void releasePools() override + { + NB_OVERRIDE_PURE(releasePools); + } + + void startScheduling() override + { + NB_OVERRIDE_PURE(startScheduling); + } + + SizeType32 getTokensPerBlock() const override + { + NB_OVERRIDE_PURE(getTokensPerBlock); + } + + SizeType32 getMaxNumBlocks() const override + { + NB_OVERRIDE_PURE(getMaxNumBlocks); + } + + SizeType32 getNumPools() const override + { + NB_OVERRIDE_PURE(getNumPools); + } + + tbk::KvCacheStats getKvCacheStats() const override + { + NB_OVERRIDE_PURE(getKvCacheStats); + } + + void addToken(tb::LlmRequest::RequestIdType requestId) override + { + NB_OVERRIDE_PURE(addToken, requestId); + } + + void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override + { + NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); + } + + void removeSequence(tb::LlmRequest::RequestIdType requestId, + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override + { + NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest); + } + + tbk::GenerationRequest const& getSequence(tb::LlmRequest::RequestIdType requestId) const override + { + NB_OVERRIDE_PURE(getSequence, requestId); + } + + void schedulingRemoveSequence(tb::LlmRequest::RequestIdType requestId) override + { + NB_OVERRIDE_PURE(schedulingRemoveSequence, requestId); + } + + tensorrt_llm::runtime::ITensor::SharedPtr getBlockPoolPointers() const override + { + NB_OVERRIDE_PURE(getBlockPoolPointers); + } + + tensorrt_llm::runtime::ITensor::SharedPtr getLayerToPoolMapping() const override + { + NB_OVERRIDE_PURE(getLayerToPoolMapping); + } + + void getBlockOffsetsOfBatch(tensorrt_llm::runtime::ITensor& output, SizeType32 firstBatchSlotIdx, + SizeType32 batchSize, SizeType32 beamWidth) const override + { + NB_OVERRIDE_PURE(getBlockOffsetsOfBatch, output, firstBatchSlotIdx, batchSize, beamWidth); + } + + SizeType32 copyBlockOffsets(tensorrt_llm::runtime::ITensor& output, SizeType32 outputSlotOffset, + tb::LlmRequest::RequestIdType requestId) const override + { + NB_OVERRIDE_PURE(copyBlockOffsets, output, outputSlotOffset, requestId); + } + + bool isEnableBlockReuse() const override + { + NB_OVERRIDE_PURE(isEnableBlockReuse); + } + + void rewindKVCache(tb::LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override + { + NB_OVERRIDE_PURE(rewindKVCache, requestId, rewindLengths); + } + + bool isCrossKv() const override + { + NB_OVERRIDE_PURE(isCrossKv); + } + + std::optional findNewContextBlock( + VecUniqueTokens const& uniqueTokens, tb::LlmRequest const& llmRequest) const override + { + NB_OVERRIDE_PURE(findNewContextBlock, uniqueTokens, llmRequest); + } + + void storeContextBlocks(tb::LlmRequest const& llmRequest) override + { + NB_OVERRIDE_PURE(storeContextBlocks, llmRequest); + } + + std::vector> const& getCacheBlockIds( + tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override + { + NB_OVERRIDE_PURE(getCacheBlockIds, requestId, windowSize); + } + + std::vector>> getBatchCacheBlockIds( + std::vector const& requestIds, SizeType32 windowSize) const override + { + NB_OVERRIDE_PURE(getBatchCacheBlockIds, requestIds, windowSize); + } + + std::vector getNewlyAllocatedBlockIds( + tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override + { + NB_OVERRIDE_PURE(getNewlyAllocatedBlockIds, requestId, windowSize); + } + + SizeType32 getUsedNumBlocks() const override + { + NB_OVERRIDE_PURE(getUsedNumBlocks); + } + + SizeType32 getNumFreeBlocks() const override + { + NB_OVERRIDE_PURE(getNumFreeBlocks); + } + + tbk::BlockManager const& getBlockManager() const override + { + NB_OVERRIDE_PURE(getBlockManager); + } + + std::deque getLatestEvents( + std::optional timeout = std::nullopt) const override + { + NB_OVERRIDE_PURE(getLatestEvents, timeout); + } + + tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override + { + NB_OVERRIDE_PURE(getPrimaryPool, layer_idx); + } + + SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override + { + NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx); + } + + void refreshBlocks() override + { + NB_OVERRIDE_PURE(refreshBlocks); + } + + void flushIterationEvents() override + { + NB_OVERRIDE_PURE(flushIterationEvents); + } +}; + +// TODO: Deduplicate executor bindings KvCacheStats +class PyBasePeftCacheManager : public tb::BasePeftCacheManager +{ +public: + ~PyBasePeftCacheManager() override = default; + + NB_TRAMPOLINE(tb::BasePeftCacheManager, 8); + + void addRequestPeft(tb::BasePeftCacheManager::LlmRequestPtr llmRequest, bool tryGpuCache = true) override + { + NB_OVERRIDE_PURE(addRequestPeft, llmRequest, tryGpuCache); + } + + tb::BasePeftCacheManager::PeftTable ensureBatch(tb::RequestVector const& contextRequests, + tb::RequestVector const& generationRequests, bool resetGpuCache = false) override + { + NB_OVERRIDE_PURE(ensureBatch, contextRequests, generationRequests, resetGpuCache); + } + + void resetDeviceCache() override + { + NB_OVERRIDE_PURE(resetDeviceCache); + } + + void markRequestDone(tb::LlmRequest const& llmReq, bool pause = false) override + { + NB_OVERRIDE_PURE(markRequestDone, llmReq, pause); + } + + tr::SizeType32 getMaxDevicePages() const override + { + NB_OVERRIDE_PURE(getMaxDevicePages); + } + + tr::SizeType32 getMaxHostPages() const override + { + NB_OVERRIDE_PURE(getMaxHostPages); + } + + tr::SizeType32 determineNumPages(std::shared_ptr llmRequest) const override + { + NB_OVERRIDE_PURE(determineNumPages, llmRequest); + } + + bool enabled() const override + { + NB_OVERRIDE_PURE(enabled); + } +}; +} // namespace + +void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "KvCacheStats") + .def(nb::init<>()) + .def_rw("max_num_blocks", &tbk::KvCacheStats::maxNumBlocks) + .def_rw("free_num_blocks", &tbk::KvCacheStats::freeNumBlocks) + .def_rw("used_num_blocks", &tbk::KvCacheStats::usedNumBlocks) + .def_rw("tokens_per_block", &tbk::KvCacheStats::toksPerBlock) + .def_rw("alloc_total_blocks", &tbk::KvCacheStats::allocTotalBlocks) + .def_rw("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks) + .def_rw("reused_blocks", &tbk::KvCacheStats::reusedBlocks) + .def_rw("missed_blocks", &tbk::KvCacheStats::missedBlocks) + .def_rw("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate) + .def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize); + + nb::class_(m, "TempAttentionWindowInputs") + .def(nb::init<>()) + .def_rw("paged_context_fmha", &tbk::TempAttentionWindowInputs::pagedContextFMHA) + .def_rw("max_input_len", &tbk::TempAttentionWindowInputs::maxInputLen) + .def_rw("max_num_tokens", &tbk::TempAttentionWindowInputs::maxNumTokens); + + nb::class_(m, "BlockKey") + .def(nb::init<>()) + .def(nb::init>(), nb::arg("tokens"), + nb::arg("lora_task_id") = std::nullopt) + .def(nb::init, VecUniqueTokens const&>(), nb::arg("uses_extra_ids"), + nb::arg("lora_task_id"), nb::arg("unique_tokens")) + .def_ro("uses_extra_ids", &tbk::BlockKey::usesExtraIds) + .def_ro("lora_task_id", &tbk::BlockKey::loraTaskId) + .def_ro("unique_tokens", &tbk::BlockKey::uniqueTokens); + + nb::class_(m, "BlockKeyHasher") + .def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0); + + nb::class_(m, "KVCacheEventManager") + .def(nb::init(), nb::arg("max_kv_event_entries")); + + nb::class_(m, "BaseKVCacheManager") + .def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, nb::arg("config"), + nb::arg("is_cross_attention"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"), + nb::arg("window_size_to_layers"), nb::arg("allotted_primary_mem_bytes"), + nb::arg("allotted_secondary_mem_bytes"), nb::arg("extra_cost_memory"), nb::arg("kv_factor")) + .def("allocate_pools", &BaseKVCacheManager::allocatePools) + .def("release_pools", &BaseKVCacheManager::releasePools) + .def("start_scheduling", &BaseKVCacheManager::startScheduling) + .def_prop_ro("tokens_per_block", &BaseKVCacheManager::getTokensPerBlock) + .def_prop_ro("max_num_blocks", &BaseKVCacheManager::getMaxNumBlocks) + .def_prop_ro("num_pools", &BaseKVCacheManager::getNumPools) + .def("get_kv_cache_stats", &BaseKVCacheManager::getKvCacheStats) + .def_prop_ro("max_blocks_per_seq", + [](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; }) + .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep) + .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion) + .def("add_token", &BaseKVCacheManager::addToken) + .def("add_sequence", &BaseKVCacheManager::addSequence) + .def("remove_sequence", &BaseKVCacheManager::removeSequence) + .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence) + .def("get_block_pool_pointers", + [](tbk::BaseKVCacheManager& self) + { + std::optional block_pool_pointers{std::nullopt}; + auto tensor = self.getBlockPoolPointers(); + if (tensor) + { + std::shared_ptr _tensor = std::move(tensor); + block_pool_pointers = tr::Torch::tensor(_tensor); + } + return block_pool_pointers; + }) + .def("get_layer_to_pool_mapping", + [](tbk::BaseKVCacheManager& self) + { + std::optional layer_to_pool_mapping{std::nullopt}; + auto tensor = self.getLayerToPoolMapping(); + if (tensor) + { + std::shared_ptr _tensor = std::move(tensor); + layer_to_pool_mapping = tr::Torch::tensor(_tensor); + } + return layer_to_pool_mapping; + }) + .def("get_primary_pool_data", + [](tbk::BaseKVCacheManager& self, SizeType32 layer_idx) -> at::Tensor + { + auto pool = tr::Torch::tensor(self.getPrimaryPool(layer_idx)); + auto pool_layer_idx = self.getPoolLayerIdx(layer_idx); + return pool.index({torch::indexing::Slice(), pool_layer_idx}); + }) + .def("get_block_offsets_of_batch", + [](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, + SizeType32 beamWidth) + { + auto _output = from_torch(output); + TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); + self.getBlockOffsetsOfBatch(*(_output.value()), firstBatchSlotIdx, batchSize, beamWidth); + }) + .def("copy_block_offsets", + [](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 outputSlotOffset, + tb::LlmRequest::RequestIdType requestId) + { + auto _output = from_torch(output); + TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); + auto maxBlockCount = self.copyBlockOffsets(*(_output.value()), outputSlotOffset, requestId); + return maxBlockCount; + }) + .def("copy_batch_block_offsets", + [](tbk::BaseKVCacheManager& self, at::Tensor output, + std::vector const& requestIds, SizeType32 const beamWidth, + SizeType32 const offset) + { + auto _output = from_torch(output); + TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); + for (size_t i = 0; i < requestIds.size(); ++i) + { + self.copyBlockOffsets(*(_output.value()), i * beamWidth + offset, requestIds[i]); + } + }) + .def( + "get_latest_events", + [](tbk::BaseKVCacheManager& self, std::optional timeout_ms = std::nullopt) + { + if (timeout_ms) + { + return self.getLatestEvents(std::chrono::milliseconds(static_cast(*timeout_ms))); + } + return self.getLatestEvents(std::nullopt); + }, + nb::arg("timeout_ms") = std::nullopt) + .def_prop_ro("enable_block_reuse", &BaseKVCacheManager::isEnableBlockReuse) + .def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache) + .def_prop_ro("cross_kv", &BaseKVCacheManager::isCrossKv) + .def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks) + .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds) + .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds) + .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds) + .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents); + + nb::bind_vector>>(m, "CacheBlockIds"); + + nb::enum_(m, "CacheType") + .value("SELF", tbk::CacheType::kSELF) + .value("CROSS", tbk::CacheType::kCROSS) + .value("SELFKONLY", tbk::CacheType::kSELFKONLY); + + nb::class_(m, "KVCacheManager") + .def(nb::init const&, SizeType32, SizeType32, + std::map> const&, SizeType32, SizeType32, + std::vector const&, std::optional const&, + nvinfer1::DataType, SizeType32, int64_t, std::optional, bool, bool, + tbk::CacheType, std::optional, + std::shared_ptr, bool, bool>(), + nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"), + nb::arg("blocks_per_window"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), + nb::arg("max_attention_window_vec"), nb::arg("temp_attention_window_inputs").none(), nb::arg("dtype"), + nb::arg("sink_token_length"), nb::arg("stream"), nb::arg("max_sequence_length").none(), + nb::arg("enable_block_reuse") = false, nb::arg("onboard_blocks") = true, + nb::arg("cache_type") = tbk::CacheType::kSELF, nb::arg("secondary_offload_min_priority") = std::nullopt, + nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true, + nb::arg("copy_on_partial_reuse") = true); +} + +void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "BasePeftCacheManager") + .def("add_request_peft", &tb::BasePeftCacheManager::addRequestPeft, nb::arg("request"), + nb::arg("try_gpu_cache") = true) + .def( + "ensure_batch", + [](tb::BasePeftCacheManager& self, tb::RequestVector const& contextRequests, + tb::RequestVector const& generationRequests, bool resetGpuCache) + { + nb::gil_scoped_release release; + return self.ensureBatch(contextRequests, generationRequests, resetGpuCache); + }, + nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("reset_gpu_cache") = false) + .def("reset_device_cache", &tb::BasePeftCacheManager::resetDeviceCache) + .def("mark_request_done", &tb::BasePeftCacheManager::markRequestDone, nb::arg("request"), + nb::arg("pause") = false) + .def_prop_ro("max_device_pages", &tb::BasePeftCacheManager::getMaxDevicePages) + .def_prop_ro("max_host_pages", &tb::BasePeftCacheManager::getMaxHostPages) + .def("determine_num_pages", &tb::BasePeftCacheManager::determineNumPages, nb::arg("request")) + .def_prop_ro("enabled", &tb::BasePeftCacheManager::enabled); + + nb::class_(m, "PeftCacheManager") + .def(nb::init(), + nb::arg("config"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")); + + nb::class_(m, "NoOpPeftCacheManager").def(nb::init<>()); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h new file mode 100644 index 00000000000..786c0d391df --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h @@ -0,0 +1,39 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::batch_manager::kv_cache_manager +{ +class KVCacheManagerBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::batch_manager::kv_cache_manager + +namespace tensorrt_llm::batch_manager +{ +class BasePeftCacheManagerBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp new file mode 100644 index 00000000000..d8f45cb865f --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -0,0 +1,131 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "llmRequest.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/nanobind/common/bindTypes.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchUtils.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include + +#include + +namespace tb = tensorrt_llm::batch_manager; +namespace tr = tensorrt_llm::runtime; +namespace tle = tensorrt_llm::executor; + +using namespace tensorrt_llm::nanobind::batch_manager; + +using LlmRequestPtr = std::shared_ptr; +using RequestList = std::list; + +namespace +{ + +std::optional from_torch(std::optional torchPtr) +{ + if (torchPtr) + { + return tr::TorchView::of(torchPtr.value()); + } + return std::nullopt; +} + +} // namespace + +std::optional LlmRequest::callbackAdapter( + std::optional callback) +{ + if (!callback) + { + return std::nullopt; + } + + return [callback](RequestIdType reqId, tr::ITensor::SharedPtr& tensor, tb::LlmRequest::BeamTokens const& tokens, + tr::BufferManager::CudaStreamPtr stream, std::optional clientId) + { + at::Tensor atTensor = tr::Torch::tensor(tensor); + callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap(), clientId); + }; +} + +std::shared_ptr LlmRequest::toTrtLlm() const +{ + + auto const draftTokens = std::make_shared>(*mDraftTokens.get()); + auto const optDraftTokens = std::optional>>(draftTokens); + auto const encoderInputTokens = mEncoderTokens.has_value() + ? std::make_shared>(*mEncoderTokens.value().get()) + : nullptr; + auto const optEncoderInputTokens = std::optional>>(encoderInputTokens); + // 49 parameters + return std::make_shared( // + mRequestId, // + mMaxNewTokens, // + std::make_shared>(mTokens.at(0)), // + mSamplingConfig, // + mIsStreaming, // + mEndId, // + mPadId, // + from_torch(mEmbeddingBias), // + from_torch(mBadWordsList), // + from_torch(mStopWordsList), // + mPositionIds, // + from_torch(mPromptEmbeddingTable), // + mPromptVocabSize, // + mMultimodalHashes, // + mMultimodalPositions, // + mMultimodalLengths, // + from_torch(mMultimodalEmbedding), // + from_torch(mMropeRotaryCosSin), // + mMropePositionDeltas, // + mLoraTaskId, // + from_torch(mLoraWeights), // + from_torch(mLoraConfig), // + mLookaheadConfig, // + mKvCacheRetentionConfig, // + mReturnLogProbs, // + mReturnContextLogits, // + mReturnGenerationLogits, // + optDraftTokens, // + from_torch(mDraftLogits), // + mExcludeInputFromOutput, // + callbackAdapter(mLogitsPostProcessor), // + mApplyLogitsPostProcessorBatched, // + optEncoderInputTokens, // + mReturnEncoderOutput, // + mClientId, // + mPriority, // + from_torch(mEncoderInputFeatures), // + mEncoderOutputLength, // + from_torch(mCrossAttentionMask), // + getLlmRequestType(), // + std::nullopt, // inputTokenExtraIds + mNumReturnSequences, // + mEagleConfig, // + from_torch(mSkipCrossAttnBlocks), // + false, // returnPerfMetrics + mGuidedDecodingParams, // + mLanguageAdapterUid, // + mAllottedTimeMs, // + mContextPhaseParams // + ); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h new file mode 100644 index 00000000000..624dc55112d --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -0,0 +1,160 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/batch_manager/llmRequest.h" + +#include +#include +#include +#include +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +namespace tb = tensorrt_llm::batch_manager; + +/* Unfortunately, torch's default nanobind bindings don't know about c10::cuda::CUDAStream, + * so we have to pass the more generic c10::Stream, and convert it back to a full-fledged + * torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py + */ +class LlmRequest : public tb::GenericLlmRequest +{ +public: + using Base = GenericLlmRequest; + using TensorPtr = Base::TensorPtr; + using SizeType32 = Base::SizeType32; + using TokenIdType = Base::TokenIdType; + using RequestIdType = Base::RequestIdType; + using LoraTaskIdType = Base::LoraTaskIdType; + using VecLogProbs = Base::VecLogProbs; + using BeamTokens = Base::BeamTokens; + using VecTokens = Base::VecTokens; + using VecTokenExtraIds = Base::VecTokenExtraIds; + using LogitsPostProcessor = Base::LogitsPostProcessor; + + // 49 parameters + LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector inputTokens, + runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, + std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, + std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, + std::optional> positionIds = std::nullopt, + std::optional promptEmbeddingTable = std::nullopt, + std::optional promptVocabSize = std::nullopt, + std::optional>> multimodalHashes = std::nullopt, + std::optional> multimodalPositions = std::nullopt, + std::optional> multimodalLengths = std::nullopt, + std::optional multimodalEmbedding = std::nullopt, + std::optional mropeRotaryCosSin = std::nullopt, + std::optional mropePositionDeltas = std::nullopt, + std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt, + std::optional loraConfig = std::nullopt, + std::optional lookaheadConfig = std::nullopt, + std::optional kvCacheRetentionConfig = std::nullopt, + bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, + std::optional draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, + bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt, + bool applyLogitsPostProcessorBatched = false, std::optional encoderInputTokens = std::nullopt, + bool returnEncoderOutput = false, std::optional clientId = std::nullopt, + executor::PriorityType priority = executor::Request::kDefaultPriority, + std::optional encoderInputFeatures = std::nullopt, + std::optional encoderOutputLength = std::nullopt, + std::optional crossAttentionMask = std::nullopt, + tb::LlmRequestType llmRequestType = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, + std::optional inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1, + std::optional eagleConfig = std::nullopt, + std::optional skipCrossAttnBlocks = std::nullopt, bool returnPerfMetrics = false, + std::optional guidedDecodingParams = std::nullopt, + std::optional languageAdapterUid = std::nullopt, + std::optional allottedTimeMs = std::nullopt, + std::optional const& contextPhaseParams = std::nullopt) + : Base(requestId, // + maxNewTokens, // + std::make_shared>(std::move(inputTokens)), // + samplingConfig, // + isStreaming, // + endId, // + padId, // + embeddingBias, // + badWordsList, // + stopWordsList, // + positionIds.has_value() ? std::make_shared>(std::move(positionIds.value())) // + : std::optional>>(std::nullopt), // + promptEmbeddingTable, // + promptVocabSize, // + multimodalHashes.has_value() + ? std::make_optional( + std::make_shared>>(std::move(multimodalHashes.value()))) // + : std::optional>>>(std::nullopt), // + multimodalPositions.has_value() + ? std::make_shared>(std::move(multimodalPositions.value())) // + : std::optional>>(std::nullopt), // + multimodalLengths.has_value() + ? std::make_shared>(std::move(multimodalLengths.value())) // + : std::optional>>(std::nullopt), // + multimodalEmbedding, // + mropeRotaryCosSin, // + mropePositionDeltas, // + loraTaskId, // + loraWeights, // + loraConfig, // + lookaheadConfig, // + kvCacheRetentionConfig, // + returnLogProbs, // + returnContextLogits, // + returnGenerationLogits, // + draftTokens.has_value() ? std::make_shared(std::move(draftTokens.value())) // + : std::make_shared(), // + draftLogits, // + excludeInputFromOutput, // + logitsPostProcessor, // + applyLogitsPostProcessorBatched, // + encoderInputTokens ? std::make_optional(std::make_shared(std::move(*encoderInputTokens))) // + : std::optional>(std::nullopt), // + returnEncoderOutput, // + clientId, // + priority, // + encoderInputFeatures, // + encoderOutputLength, // + crossAttentionMask, // + llmRequestType, // + inputTokenExtraIds // + ? std::make_optional(std::make_shared(std::move(*inputTokenExtraIds))) // + : std::optional>(std::nullopt), // + numReturnSequences, // + eagleConfig, // + skipCrossAttnBlocks, // + returnPerfMetrics, // + guidedDecodingParams, // + languageAdapterUid, // + allottedTimeMs, // + contextPhaseParams // + ) + { + } + + static std::optional callbackAdapter( + std::optional callback); + + [[nodiscard]] std::shared_ptr toTrtLlm() const; +}; + +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index adc82587433..dd01d21cced 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,14 +15,483 @@ * limitations under the License. */ +#include "tensorrt_llm/nanobind/common/customCasters.h" #include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h" +#include "tensorrt_llm/common/quantization.h" +#include "tensorrt_llm/nanobind/batch_manager/algorithms.h" +#include "tensorrt_llm/nanobind/batch_manager/bindings.h" +#include "tensorrt_llm/nanobind/batch_manager/buffers.h" +#include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h" +#include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h" +#include "tensorrt_llm/nanobind/executor/bindings.h" +#include "tensorrt_llm/nanobind/runtime/bindings.h" +#include "tensorrt_llm/nanobind/testing/modelSpecBinding.h" +#include "tensorrt_llm/nanobind/userbuffers/bindings.h" +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/gptJsonConfig.h" +#include "tensorrt_llm/runtime/ipcNvlsMemory.h" +#include "tensorrt_llm/runtime/memoryCounters.h" +#include "tensorrt_llm/runtime/samplingConfig.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" + +namespace nb = nanobind; +namespace tb = tensorrt_llm::batch_manager; +namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager; +namespace tpb = tensorrt_llm::nanobind::batch_manager; +namespace tc = tensorrt_llm::common; +namespace tr = tensorrt_llm::runtime; +namespace tle = tensorrt_llm::executor; +using SizeType32 = tr::SizeType32; +using TokenIdType = tr::TokenIdType; +template +using OptVec = std::optional>; #if not defined(TRTLLM_NB_MODULE) #error "TRTLLM_NB_MODULE must be defined" #endif +namespace +{ +tr::SamplingConfig makeSamplingConfig(std::vector const& configs) +{ + return tr::SamplingConfig(configs); +} +} // namespace + NB_MODULE(TRTLLM_NB_MODULE, m) { m.doc() = "TensorRT-LLM Python bindings for C++ runtime"; m.attr("binding_type") = "nanobind"; + nb::set_leak_warnings(false); + + // Create MpiComm binding first since it's used in the executor bindings + nb::class_(m, "MpiComm") + .def_static("rank", + []() + { + auto& session = tensorrt_llm::mpi::MpiComm::session(); + return session.tensorrt_llm::mpi::MpiComm::getRank(); + }) + .def_static("size", + []() + { + auto& session = tensorrt_llm::mpi::MpiComm::session(); + return session.tensorrt_llm::mpi::MpiComm::getSize(); + }) + .def_static("local_size", + []() + { + auto& session = tensorrt_llm::mpi::MpiComm::localSession(); + return session.tensorrt_llm::mpi::MpiComm::getSize(); + }) + .def_static("local_init", []() { tensorrt_llm::mpi::MpiComm::localSession(); }) + .def_static("set_raw_mpi_session_by_fortran_handle", + [](int64_t fortran_handle) { tensorrt_llm::mpi::MpiComm::setRawSessionByFortran(fortran_handle); }) + .def_static("split", + [](size_t color, size_t rank) + { + auto& world = tensorrt_llm::mpi::MpiComm::world(); + tensorrt_llm::mpi::MpiComm::setSession(world.split(color, rank)); + }); + + nb::class_(m, "CudaStream") + .def( + "__init__", + [](tr::CudaStream* self, nb::object py_stream) + { + cudaStream_t stream = reinterpret_cast(nb::cast(py_stream)); + new (self) tr::CudaStream{stream}; + }, + nb::arg("stream_ptr")) + .def("get_device", &tr::CudaStream::getDevice); + + // Create submodule for executor bindings. + auto mExecutor = m.def_submodule("executor", "Executor bindings"); + auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime"); + auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings"); + auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings"); + auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings"); + + tensorrt_llm::nanobind::executor::initBindings(mExecutor); + tensorrt_llm::nanobind::runtime::initBindingsEarly(mInternalRuntime); + + auto buildInfo = m.def_submodule("BuildInfo"); + buildInfo.attr("ENABLE_MULTI_DEVICE") = nb::int_(ENABLE_MULTI_DEVICE); + + nb::class_(m, "PeftCacheManagerConfig") + .def(nb::init, std::optional, std::optional>(), + nb::arg("num_host_module_layer") = 0, nb::arg("num_device_module_layer") = 0, + nb::arg("optimal_adapter_size") = 8, nb::arg("max_adapter_size") = 64, nb::arg("num_put_workers") = 1, + nb::arg("num_ensure_workers") = 1, nb::arg("num_copy_streams") = 1, + nb::arg("max_pages_per_block_host") = 24, nb::arg("max_pages_per_block_device") = 8, + nb::arg("device_cache_percent") = std::nullopt, nb::arg("host_cache_size") = std::nullopt, + nb::arg("lora_prefetch_dir") = std::nullopt) + .def_rw("num_host_module_layer", &tb::PeftCacheManagerConfig::numHostModuleLayer) + .def_rw("num_device_module_layer", &tb::PeftCacheManagerConfig::numDeviceModuleLayer) + .def_rw("optimal_adapter_size", &tb::PeftCacheManagerConfig::optimalAdapterSize) + .def_rw("max_adapter_size", &tb::PeftCacheManagerConfig::maxAdapterSize) + .def_rw("num_put_workers", &tb::PeftCacheManagerConfig::numPutWorkers) + .def_rw("num_ensure_workers", &tb::PeftCacheManagerConfig::numEnsureWorkers) + .def_rw("num_copy_streams", &tb::PeftCacheManagerConfig::numCopyStreams) + .def_rw("max_pages_per_block_host", &tb::PeftCacheManagerConfig::maxPagesPerBlockHost) + .def_rw("max_pages_per_block_device", &tb::PeftCacheManagerConfig::maxPagesPerBlockDevice) + .def_rw("device_cache_percent", &tb::PeftCacheManagerConfig::deviceCachePercent) + .def_rw("host_cache_size", &tb::PeftCacheManagerConfig::hostCacheSize) + .def_rw("lora_prefetch_dir", &tb::PeftCacheManagerConfig::loraPrefetchDir); + + nb::enum_(m, "DataType") + .value("FLOAT", nvinfer1::DataType::kFLOAT) + .value("HALF", nvinfer1::DataType::kHALF) + .value("INT8", nvinfer1::DataType::kINT8) + .value("INT32", nvinfer1::DataType::kINT32) + .value("BOOL", nvinfer1::DataType::kBOOL) + .value("UINT8", nvinfer1::DataType::kUINT8) + .value("FP8", nvinfer1::DataType::kFP8) + .value("BF16", nvinfer1::DataType::kBF16) + .value("INT64", nvinfer1::DataType::kINT64) + .export_values(); + + nb::enum_(m, "GptModelVariant") + .value("GPT", tr::ModelConfig::ModelVariant::kGpt) + .value("GLM", tr::ModelConfig::ModelVariant::kGlm) + .value("CHATGLM", tr::ModelConfig::ModelVariant::kChatGlm) + .value("MAMBA", tr::ModelConfig::ModelVariant::kMamba) + .value("RECURRENTGEMMA", tr::ModelConfig::ModelVariant::kRecurrentGemma); + + nb::enum_(m, "KVCacheType") + .value("CONTINUOUS", tr::ModelConfig::KVCacheType::kCONTINUOUS) + .value("PAGED", tr::ModelConfig::KVCacheType::kPAGED) + .value("DISABLED", tr::ModelConfig::KVCacheType::kDISABLED) + .def("from_string", tr::ModelConfig::KVCacheTypeFromString); + + nb::enum_(m, "LayerType") + .value("ATTENTION", tr::ModelConfig::LayerType::kATTENTION) + .value("RECURRENT", tr::ModelConfig::LayerType::kRECURRENT); + + nb::enum_(m, "LoraModuleType") + .value("INVALID", tr::LoraModule::ModuleType::kINVALID) + .value("ATTN_QKV", tr::LoraModule::ModuleType::kATTN_QKV) + .value("ATTN_Q", tr::LoraModule::ModuleType::kATTN_Q) + .value("ATTN_K", tr::LoraModule::ModuleType::kATTN_K) + .value("ATTN_V", tr::LoraModule::ModuleType::kATTN_V) + .value("ATTN_DENSE", tr::LoraModule::ModuleType::kATTN_DENSE) + .value("MLP_H_TO_4H", tr::LoraModule::ModuleType::kMLP_H_TO_4H) + .value("MLP_4H_TO_H", tr::LoraModule::ModuleType::kMLP_4H_TO_H) + .value("MLP_GATE", tr::LoraModule::ModuleType::kMLP_GATE) + .value("CROSS_ATTN_QKV", tr::LoraModule::ModuleType::kCROSS_ATTN_QKV) + .value("CROSS_ATTN_Q", tr::LoraModule::ModuleType::kCROSS_ATTN_Q) + .value("CROSS_ATTN_K", tr::LoraModule::ModuleType::kCROSS_ATTN_K) + .value("CROSS_ATTN_V", tr::LoraModule::ModuleType::kCROSS_ATTN_V) + .value("CROSS_ATTN_DENSE", tr::LoraModule::ModuleType::kCROSS_ATTN_DENSE) + .value("MOE_H_TO_4H", tr::LoraModule::ModuleType::kMOE_H_TO_4H) + .value("MOE_4H_TO_H", tr::LoraModule::ModuleType::kMOE_4H_TO_H) + .value("MOE_GATE", tr::LoraModule::ModuleType::kMOE_GATE) + .value("MOE_ROUTER", tr::LoraModule::ModuleType::kMOE_ROUTER) + .value("MLP_ROUTER", tr::LoraModule::ModuleType::kMLP_ROUTER) + .value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP); + + nb::class_(m, "LoraModule") + .def(nb::init(), + nb::arg("module_type"), nb::arg("in_dim"), nb::arg("out_dim"), nb::arg("in_dim_first"), + nb::arg("out_dim_first"), nb::arg("in_tp_split_dim"), nb::arg("out_tp_split_dim")) + .def_prop_ro("module_type", &tr::LoraModule::name) + .def_prop_ro("in_dim", &tr::LoraModule::inDim) + .def_prop_ro("out_dim", &tr::LoraModule::outDim) + .def_prop_ro("in_dim_first", &tr::LoraModule::inDimFirst) + .def_prop_ro("out_dim_first", &tr::LoraModule::outDimFirst) + .def_prop_ro("in_tp_split_dim", &tr::LoraModule::inTpSplitDim) + .def_prop_ro("out_tp_split_dim", &tr::LoraModule::outTpSplitDim) + .def_static("create_lora_modules", &tr::LoraModule::createLoraModules, nb::arg("lora_module_names"), + nb::arg("hidden_size"), nb::arg("mlp_hidden_size"), nb::arg("num_attention_heads"), + nb::arg("num_kv_attention_heads"), nb::arg("attention_head_size"), nb::arg("tp_size") = 1, + nb::arg("num_experts") = 0); + + nb::class_(m, "QuantMode") + .def_static("none", &tc::QuantMode::none) + .def_static("int4_weights", &tc::QuantMode::int4Weights) + .def_static("int8_weights", &tc::QuantMode::int8Weights) + .def_static("activations", &tc::QuantMode::activations) + .def_static("per_channel_scaling", &tc::QuantMode::perChannelScaling) + .def_static("per_token_scaling", &tc::QuantMode::perTokenScaling) + .def_static("per_group_scaling", &tc::QuantMode::perGroupScaling) + .def_static("int8_kv_cache", &tc::QuantMode::int8KvCache) + .def_static("fp8_kv_cache", &tc::QuantMode::fp8KvCache) + .def_static("fp8_qdq", &tc::QuantMode::fp8Qdq) + .def_prop_ro("value", &tc::QuantMode::value) + .def("is_set", &tc::QuantMode::isSet, nb::arg("mode")) + .def_prop_ro("has_int4_weights", &tc::QuantMode::hasInt4Weights) + .def_prop_ro("has_int8_weights", &tc::QuantMode::hasInt8Weights) + .def_prop_ro("has_activations", &tc::QuantMode::hasActivations) + .def_prop_ro("has_per_channel_scaling", &tc::QuantMode::hasPerChannelScaling) + .def_prop_ro("has_per_token_scaling", &tc::QuantMode::hasPerTokenScaling) + .def_prop_ro("has_per_group_scaling", &tc::QuantMode::hasPerGroupScaling) + .def_prop_ro("has_static_activation_scaling", &tc::QuantMode::hasStaticActivationScaling) + .def_prop_ro("has_int8_kv_cache", &tc::QuantMode::hasInt8KvCache) + .def_prop_ro("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache) + .def_prop_ro("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq) + .def_prop_ro("has_nvfp4", &tc::QuantMode::hasNvfp4) + .def_prop_ro("has_w4a8_mxfp4_fp8", &tc::QuantMode::hasW4a8Mxfp4Fp8) + .def_prop_ro("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant) + .def_static("from_description", &tc::QuantMode::fromDescription, nb::arg("quantize_weights"), + nb::arg("quantize_activations"), nb::arg("per_token"), nb::arg("per_channel"), nb::arg("per_group"), + nb::arg("use_int4_weights"), nb::arg("use_int8_kv_cache"), nb::arg("use_fp8_kv_kache"), + nb::arg("use_fp8_qdq"), nb::arg("use_fp8_rowwise"), nb::arg("use_w4a8_qserve"), nb::arg("use_nvfp4"), + nb::arg("use_fp8_block_scales"), nb::arg("use_w4a8_mxfp4_fp8")) + .def_static("use_smooth_quant", &tc::QuantMode::useSmoothQuant, nb::arg("per_token") = false, + nb::arg("per_channel") = false) + .def_static("use_weight_only", &tc::QuantMode::useWeightOnly, nb::arg("use_int4_weights") = false, + nb::arg("per_group") = false) + .def_static("from_quant_algo", &tc::QuantMode::fromQuantAlgo, nb::arg("quant_algo") = nb::none(), + nb::arg("kv_cache_quant_algo") = nb::none()) + .def(nb::self + nb::self) + .def(nb::self += nb::self) + .def(nb::self - nb::self) + .def(nb::self -= nb::self) + .def(nb::self == nb::self) + .def(nb::self != nb::self); + + nb::class_(m, "ModelConfig") + .def(nb::init(), + nb::arg("vocab_size"), nb::arg("num_layers"), nb::arg("num_attention_layers"), nb::arg("num_rnn_layers"), + nb::arg("num_heads"), nb::arg("hidden_size"), nb::arg("data_type")) + .def_prop_ro("vocab_size", &tr::ModelConfig::getVocabSize) + .def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, nb::arg("world_size")) + .def("num_layers", &tr::ModelConfig::getNbLayers, nb::arg("pipeline_parallelism") = 1, + nb::arg("pipeline_parallelism_rank") = 0) + .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, nb::arg("pipeline_parallelism") = 1, + nb::arg("pipeline_parallelism_rank") = 0) + .def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, nb::arg("pipeline_parallelism") = 1, + nb::arg("pipeline_parallelism_rank") = 0) + .def("num_kv_heads", &tr::ModelConfig::getNbKvHeads, nb::arg("layer_idx")) + .def("set_num_kv_heads", &tr::ModelConfig::setNbKvHeads, nb::arg("num_kv_heads")) + .def_prop_ro("num_heads", &tr::ModelConfig::getNbHeads) + .def_prop_ro("hidden_size", &tr::ModelConfig::getHiddenSize) + .def_prop_ro("size_per_head", &tr::ModelConfig::getSizePerHead) + .def_prop_ro("data_type", &tr::ModelConfig::getDataType) + .def_prop_ro("speculative_decoding_mode", &tr::ModelConfig::getSpeculativeDecodingMode) + .def_prop_rw("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead) + .def_prop_rw( + "num_kv_heads_per_layer", &tr::ModelConfig::getNumKvHeadsPerLayer, &tr::ModelConfig::setNumKvHeadsPerLayer) + .def_prop_rw("use_gpt_attention_plugin", + nb::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, nb::const_), + nb::overload_cast(&tr::ModelConfig::useGptAttentionPlugin)) + .def_prop_rw("use_packed_input", nb::overload_cast<>(&tr::ModelConfig::usePackedInput, nb::const_), + nb::overload_cast(&tr::ModelConfig::usePackedInput)) + .def_prop_rw("kv_cache_type", nb::overload_cast<>(&tr::ModelConfig::getKVCacheType, nb::const_), + nb::overload_cast(&tr::ModelConfig::setKVCacheType)) + .def_prop_rw("tokens_per_block", &tr::ModelConfig::getTokensPerBlock, &tr::ModelConfig::setTokensPerBlock) + .def_prop_rw("quant_mode", &tr::ModelConfig::getQuantMode, &tr::ModelConfig::setQuantMode) + .def_prop_ro("supports_inflight_batching", &tr::ModelConfig::supportsInflightBatching) + .def_prop_rw("max_batch_size", &tr::ModelConfig::getMaxBatchSize, &tr::ModelConfig::setMaxBatchSize) + .def_prop_rw("max_beam_width", &tr::ModelConfig::getMaxBeamWidth, &tr::ModelConfig::setMaxBeamWidth) + .def_prop_rw("max_input_len", &tr::ModelConfig::getMaxInputLen, &tr::ModelConfig::setMaxInputLen) + .def_prop_rw("max_seq_len", &tr::ModelConfig::getMaxSequenceLen, &tr::ModelConfig::setMaxSequenceLen) + .def_prop_rw("max_num_tokens", &tr::ModelConfig::getMaxNumTokens, &tr::ModelConfig::setMaxNumTokens) + .def_prop_rw("max_prompt_embedding_table_size", &tr::ModelConfig::getMaxPromptEmbeddingTableSize, + &tr::ModelConfig::setMaxPromptEmbeddingTableSize) + .def_prop_ro("use_prompt_tuning", &tr::ModelConfig::usePromptTuning) + .def_prop_ro("use_mrope", &tr::ModelConfig::useMrope) + .def_prop_rw("use_lora_plugin", nb::overload_cast<>(&tr::ModelConfig::useLoraPlugin, nb::const_), + nb::overload_cast(&tr::ModelConfig::useLoraPlugin)) + .def_prop_rw("layer_types", &tr::ModelConfig::getLayerTypes, &tr::ModelConfig::setLayerTypes) + .def_prop_rw("compute_context_logits", nb::overload_cast<>(&tr::ModelConfig::computeContextLogits, nb::const_), + nb::overload_cast(&tr::ModelConfig::computeContextLogits)) + .def_prop_rw("compute_generation_logits", + nb::overload_cast<>(&tr::ModelConfig::computeGenerationLogits, nb::const_), + nb::overload_cast(&tr::ModelConfig::computeGenerationLogits)) + .def_prop_rw("model_variant", &tr::ModelConfig::getModelVariant, &tr::ModelConfig::setModelVariant) + .def_prop_rw("use_cross_attention", &tr::ModelConfig::useCrossAttention, &tr::ModelConfig::setUseCrossAttention) + .def_prop_rw("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules) + .def_prop_rw("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank) + .def_prop_rw("mlp_hidden_size", &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize) + .def_prop_rw("size_per_head", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead); + + nb::class_(m, "WorldConfig") + .def(nb::init> const&, bool>(), + nb::arg("tensor_parallelism") = 1, nb::arg("pipeline_parallelism") = 1, nb::arg("context_parallelism") = 1, + nb::arg("rank") = 0, nb::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, + nb::arg("device_ids") = nb::none(), nb::arg("enable_attention_dp") = false) + .def_prop_ro("size", &tr::WorldConfig::getSize) + .def_prop_ro("tensor_parallelism", &tr::WorldConfig::getTensorParallelism) + .def_prop_ro("pipeline_parallelism", &tr::WorldConfig::getPipelineParallelism) + .def_prop_ro("context_parallelism", &tr::WorldConfig::getContextParallelism) + .def_prop_ro("is_tensor_parallel", &tr::WorldConfig::isTensorParallel) + .def_prop_ro("is_pipeline_parallel", &tr::WorldConfig::isPipelineParallel) + .def_prop_ro("is_context_parallel", &tr::WorldConfig::isContextParallel) + .def_prop_ro("rank", &tr::WorldConfig::getRank) + .def_prop_ro("local_rank", &tr::WorldConfig::getLocalRank) + .def_prop_ro("node_rank", &tr::WorldConfig::getNodeRank) + .def_prop_ro("gpus_per_node", &tr::WorldConfig::getGpusPerNode) + .def_prop_ro("gpus_per_group", &tr::WorldConfig::getGpusPerGroup) + .def_prop_ro("device", &tr::WorldConfig::getDevice) + .def_prop_ro("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank) + .def_prop_ro("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank) + .def_prop_ro("context_parallel_rank", &tr::WorldConfig::getContextParallelRank) + .def_prop_ro("enable_attention_dp", &tr::WorldConfig::enableAttentionDP) + .def_static("mpi", + nb::overload_cast, std::optional, + std::optional, std::optional> const&, bool>(&tr::WorldConfig::mpi), + nb::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, nb::arg("tensor_parallelism") = nb::none(), + nb::arg("pipeline_parallelism") = nb::none(), nb::arg("context_parallelism") = nb::none(), + nb::arg("device_ids") = nb::none(), nb::arg("enable_attention_dp") = false); + + auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> nb::tuple + { + return nb::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty, + config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed, + config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty, + config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP, + config.beamWidthArray); + }; + auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t) -> tr::SamplingConfig + { + assert(t.size() == 19); + + tr::SamplingConfig config; + config.beamWidth = nb::cast(t[0]); + config.temperature = nb::cast>(t[1]); + config.minLength = nb::cast>(t[2]); + config.repetitionPenalty = nb::cast>(t[3]); + config.presencePenalty = nb::cast>(t[4]); + config.frequencyPenalty = nb::cast>(t[5]); + config.topK = nb::cast>(t[6]); + config.topP = nb::cast>(t[7]); + config.randomSeed = nb::cast>(t[8]); + config.topPDecay = nb::cast>(t[9]); + config.topPMin = nb::cast>(t[10]); + config.topPResetIds = nb::cast>(t[11]); + config.beamSearchDiversityRate = nb::cast>(t[12]); + config.lengthPenalty = nb::cast>(t[13]); + config.earlyStopping = nb::cast>(t[14]); + config.noRepeatNgramSize = nb::cast>(t[15]); + config.numReturnSequences = nb::cast(t[16]); + config.minP = nb::cast>(t[17]); + config.beamWidthArray = nb::cast>>(t[18]); + + return config; + }; + + nb::class_(m, "SamplingConfig") + .def(nb::init(), nb::arg("beam_width") = 1) + .def(nb::init>(), + nb::arg("executor_sample_config"), nb::arg("external_draft_tokens_config") = std::nullopt) + .def_rw("beam_width", &tr::SamplingConfig::beamWidth) + .def_rw("temperature", &tr::SamplingConfig::temperature) + .def_rw("min_length", &tr::SamplingConfig::minLength) + .def_rw("repetition_penalty", &tr::SamplingConfig::repetitionPenalty) + .def_rw("presence_penalty", &tr::SamplingConfig::presencePenalty) + .def_rw("frequency_penalty", &tr::SamplingConfig::frequencyPenalty) + .def_rw("top_k", &tr::SamplingConfig::topK) + .def_rw("top_p", &tr::SamplingConfig::topP) + .def_rw("random_seed", &tr::SamplingConfig::randomSeed) + .def_rw("top_p_decay", &tr::SamplingConfig::topPDecay) + .def_rw("top_p_min", &tr::SamplingConfig::topPMin) + .def_rw("top_p_reset_ids", &tr::SamplingConfig::topPResetIds) + .def_rw("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate) + .def_rw("length_penalty", &tr::SamplingConfig::lengthPenalty) + .def_rw("early_stopping", &tr::SamplingConfig::earlyStopping) + .def_rw("no_repeat_ngram_size", &tr::SamplingConfig::noRepeatNgramSize) + .def_rw("num_return_sequences", &tr::SamplingConfig::numReturnSequences) + .def_rw("min_p", &tr::SamplingConfig::minP) + .def_rw("beam_width_array", &tr::SamplingConfig::beamWidthArray) + .def_rw("normalize_log_probs", &tr::SamplingConfig::normalizeLogProbs) + .def("__getstate__", SamplingConfigGetState) + .def("__setstate__", SamplingConfigSetState) + .def("__eq__", &tr::SamplingConfig::operator==); + + nb::bind_vector>(m, "SamplingConfigVector"); + + m.def("make_sampling_config", &makeSamplingConfig, nb::arg("configs")); + + nb::class_(m, "GptJsonConfig") + .def(nb::init>(), + nb::arg("name"), nb::arg("version"), nb::arg("precision"), nb::arg("tensor_parallelism"), + nb::arg("pipeline_parallelism"), nb::arg("context_parallelism"), nb::arg("gpus_per_node"), + nb::arg("model_config"), nb::arg("runtime_defaults") = nb::none()) + .def_static("parse", nb::overload_cast(&tr::GptJsonConfig::parse), nb::arg("json")) + .def_static( + "parse_file", nb::overload_cast(&tr::GptJsonConfig::parse), nb::arg("path")) + .def_prop_ro("model_config", &tr::GptJsonConfig::getModelConfig) + .def_prop_ro("name", &tr::GptJsonConfig::getName) + .def_prop_ro("version", &tr::GptJsonConfig::getVersion) + .def_prop_ro("precision", &tr::GptJsonConfig::getPrecision) + .def_prop_ro("tensor_parallelism", &tr::GptJsonConfig::getTensorParallelism) + .def_prop_ro("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism) + .def_prop_ro("context_parallelism", &tr::GptJsonConfig::getContextParallelism) + .def_prop_ro("gpus_per_node", &tr::GptJsonConfig::getGpusPerNode) + .def_prop_ro("world_size", &tr::GptJsonConfig::getWorldSize) + .def_prop_ro("runtime_defaults", &tr::GptJsonConfig::getRuntimeDefaults) + .def("engine_filename", + nb::overload_cast( + &tr::GptJsonConfig::engineFilename, nb::const_), + nb::arg("world_config"), nb::arg("model")) + .def("engine_filename", + nb::overload_cast(&tr::GptJsonConfig::engineFilename, nb::const_), + nb::arg("world_config")); + + nb::enum_(m, "LlmRequestState") + .value("UNKNOWN", tb::LlmRequestState::kUNKNOWN) + .value("ENCODER_INIT", tb::LlmRequestState::kENCODER_INIT) + .value("CONTEXT_INIT", tb::LlmRequestState::kCONTEXT_INIT) + .value("GENERATION_IN_PROGRESS", tb::LlmRequestState::kGENERATION_IN_PROGRESS) + .value("GENERATION_TO_COMPLETE", tb::LlmRequestState::kGENERATION_TO_COMPLETE) + .value("GENERATION_COMPLETE", tb::LlmRequestState::kGENERATION_COMPLETE) + .value("DISAGG_GENERATION_INIT", tb::LlmRequestState::kDISAGG_GENERATION_INIT) + .value("DISAGG_CONTEXT_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS) + .value("DISAGG_CONTEXT_COMPLETE", tb::LlmRequestState::kDISAGG_CONTEXT_COMPLETE) + .value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS) + .value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE) + .value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); + + nb::class_(m, "MemoryCounters") + .def_static("instance", &tr::MemoryCounters::getInstance, nb::rv_policy::reference) + .def_prop_ro("gpu", &tr::MemoryCounters::getGpu) + .def_prop_ro("cpu", &tr::MemoryCounters::getCpu) + .def_prop_ro("pinned", &tr::MemoryCounters::getPinned) + .def_prop_ro("uvm", &tr::MemoryCounters::getUVM); + + tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime); + tensorrt_llm::nanobind::testing::initBindings(mInternalTesting); + tpb::initBindings(mInternalBatchManager); + tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager); + tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager); + tb::CacheTransceiverBindings::initBindings(mInternalBatchManager); + tpb::Buffers::initBindings(mInternalBatchManager); + + auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings"); + tpb::algorithms::initBindings(mInternalAlgorithms); + + auto mUserbuffers = mInternal.def_submodule("userbuffers", "User buffers internal bindings"); + tensorrt_llm::kernels::userbuffers::UserBufferBindings::initBindings(mUserbuffers); + + // NVLS allocators + nb::class_(m, "IpcNvlsHandle") + .def(nb::init<>()) + .def_rw("uc_ptr", &tr::IpcNvlsHandle::uc_ptr) + .def_rw("mc_ptr", &tr::IpcNvlsHandle::mc_ptr) + .def_rw("size", &tr::IpcNvlsHandle::size) + .def("get_ipc_ptrs", + [](tr::IpcNvlsHandle& self) { return reinterpret_cast(self.ipc_uc_ptrs.data()); }); + + m.def("ipc_nvls_allocate", &tr::ipcNvlsAllocate, nb::rv_policy::reference); + m.def("ipc_nvls_free", &tr::ipcNvlsFree); + m.def("ipc_nvls_supported", &tr::ipcNvlsSupported); } diff --git a/cpp/tensorrt_llm/nanobind/common/bindTypes.h b/cpp/tensorrt_llm/nanobind/common/bindTypes.h new file mode 100644 index 00000000000..5cd714e458a --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/common/bindTypes.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace PybindUtils +{ + +namespace nb = nanobind; + +template +void bindList(nb::module_& m, std::string const& name) +{ + nb::class_(m, name.c_str()) + .def(nb::init<>()) + .def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); }) + .def("pop_back", [](T& lst) { lst.pop_back(); }) + .def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); }) + .def("pop_front", [](T& lst) { lst.pop_front(); }) + .def("__len__", [](T const& lst) { return lst.size(); }) + .def( + "__iter__", [](T& lst) { return nb::make_iterator(nb::type(), "iterator", lst.begin(), lst.end()); }, + nb::keep_alive<0, 1>()) + .def("__getitem__", + [](T const& lst, size_t index) + { + if (index >= lst.size()) + throw nb::index_error(); + auto it = lst.begin(); + std::advance(it, index); + return *it; + }) + .def("__setitem__", + [](T& lst, size_t index, const typename T::value_type& value) + { + if (index >= lst.size()) + throw nb::index_error(); + auto it = lst.begin(); + std::advance(it, index); + *it = value; + }); +} + +template +void bindSet(nb::module_& m, std::string const& name) +{ + nb::class_(m, name.c_str()) + .def(nb::init<>()) + .def("clear", &T::clear) + .def("size", &T::size) + .def("insert", [](T& s, typename T::value_type const& value) { s.insert(value); }) + .def("erase", nb::overload_cast(&T::erase)) + .def("__len__", [](T const& lst) { return lst.size(); }) + .def("__contains__", [](T const& s, typename T::value_type x) { return s.find(x) != s.end(); }) + .def( + "__iter__", [](T& s) { return nb::make_iterator(nb::type(), "iterator", s.begin(), s.end()); }, + nb::keep_alive<0, 1>()) + .def("__eq__", [](T const& s, T const& other) { return s == other; }) + .def("__getstate__", + [](T const& v) + { + /* Return a tuple that fully encodes the state of the object */ + return nb::make_tuple(std::vector(v.begin(), v.end())); + }) + .def("__setstate__", + [](T& v, nb::tuple const& t) + { + if (t.size() != 1) + throw std::runtime_error("Invalid state!"); + /* Create a new C++ instance */ + T s; + /* Assign any additional state */ + auto state_list = nb::cast>(t[0]); + for (auto& item : state_list) + { + s.insert(item); + } + return s; + }); +} + +} // namespace PybindUtils diff --git a/cpp/tensorrt_llm/nanobind/common/customCasters.h b/cpp/tensorrt_llm/nanobind/common/customCasters.h new file mode 100644 index 00000000000..7cfa07d249a --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/common/customCasters.h @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/batch_manager/common.h" +#include "tensorrt_llm/batch_manager/decoderBuffers.h" +#include "tensorrt_llm/common/optionalRef.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/request.h" +#include "tensorrt_llm/runtime/samplingConfig.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Pybind requires to have a central include in order for type casters to work. +// Opaque bindings add a type caster, so they have the same requirement. +// See the warning in https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html + +// Opaque bindings +NB_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet) +NB_MAKE_OPAQUE(std::vector) +NB_MAKE_OPAQUE(std::vector) +NB_MAKE_OPAQUE(std::vector) +NB_MAKE_OPAQUE(std::vector>) + +namespace nb = nanobind; + +// Custom casters +namespace NB_NAMESPACE +{ + +namespace detail +{ + +template +struct type_caster> +{ + using Type = std::deque; + NB_TYPE_CASTER(Type, const_name("List")); + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept + { + sequence seq(src, nanobind::detail::borrow_t{}); + value.clear(); + make_caster caster; + for (auto const& item : seq) + { + if (!caster.from_python(item, flags, cleanup)) + return false; + value.push_back(caster.operator T&()); + } + return true; + } + + static handle from_cpp(Type const& deque, rv_policy policy, cleanup_list* cleanup) noexcept + { + nb::list list; + + for (auto const& item : deque) + { + nb::object py_item = steal(make_caster::from_cpp(item, policy, cleanup)); + if (!py_item) + return {}; + list.append(py_item); + } + return list.release(); + } +}; + +template +struct type_caster> +{ + using value_conv = make_caster; + + NB_TYPE_CASTER(tensorrt_llm::common::OptionalRef, value_conv::Name); + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) + { + if (src.is_none()) + { + // If the Python object is None, create an empty OptionalRef + value = tensorrt_llm::common::OptionalRef(); + return true; + } + + value_conv conv; + if (!conv.from_python(src, flags, cleanup)) + return false; + + // Create an OptionalRef with a reference to the converted value + value = tensorrt_llm::common::OptionalRef(conv); + return true; + } + + static handle from_cpp(tensorrt_llm::common::OptionalRef const& src, rv_policy policy, cleanup_list* cleanup) + { + if (!src.has_value()) + return none().release(); + + return value_conv::from_cpp(*src, policy, cleanup); + } +}; + +template +struct PathCaster +{ + +private: + static PyObject* unicode_from_fs_native(std::string const& w) + { + return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size())); + } + + static PyObject* unicode_from_fs_native(std::wstring const& w) + { + return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size())); + } + +public: + static handle from_cpp(T const& path, rv_policy, cleanup_list* cleanup) + { + if (auto py_str = unicode_from_fs_native(path.native())) + { + return module_::import_("pathlib").attr("Path")(steal(py_str), cleanup).release(); + } + return nullptr; + } + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) + { + PyObject* native = nullptr; + if constexpr (std::is_same_v) + { + if (PyUnicode_FSConverter(src.ptr(), &native) != 0) + { + if (auto* c_str = PyBytes_AsString(native)) + { + // AsString returns a pointer to the internal buffer, which + // must not be free'd. + value = c_str; + } + } + } + else if constexpr (std::is_same_v) + { + if (PyUnicode_FSDecoder(src.ptr(), &native) != 0) + { + if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr)) + { + // AsWideCharString returns a new string that must be free'd. + value = c_str; // Copies the string. + PyMem_Free(c_str); + } + } + } + Py_XDECREF(native); + if (PyErr_Occurred()) + { + PyErr_Clear(); + return false; + } + return true; + } + + NB_TYPE_CASTER(T, const_name("os.PathLike")); +}; + +template <> +class type_caster +{ +public: + NB_TYPE_CASTER(tensorrt_llm::executor::StreamPtr, const_name("int")); + + bool from_python([[maybe_unused]] handle src, uint8_t flags, cleanup_list* cleanup) + { + auto stream_ptr = nanobind::cast(src); + value = std::make_shared(reinterpret_cast(stream_ptr)); + + return true; + } + + static handle from_cpp( + tensorrt_llm::executor::StreamPtr const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + // Return cudaStream_t as integer. + return PyLong_FromVoidPtr(src->get()); + } +}; + +template <> +struct type_caster +{ +public: + NB_TYPE_CASTER(tensorrt_llm::executor::Tensor, const_name("torch.Tensor")); + + // Convert PyObject(torch.Tensor) -> tensorrt_llm::executor::Tensor + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) + { + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) + { + at::Tensor const& t = THPVariable_Unpack(obj); + value = tensorrt_llm::executor::detail::ofITensor(tensorrt_llm::runtime::TorchView::of(t)); + return true; + } + return false; + } + + // Convert tensorrt_llm::executor::Tensor -> PyObject(torch.Tensor) + static handle from_cpp( + tensorrt_llm::executor::Tensor const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(tensorrt_llm::executor::detail::toITensor(src))); + } +}; + +template <> +struct type_caster +{ +public: + NB_TYPE_CASTER(tensorrt_llm::runtime::ITensor::SharedPtr, const_name("torch.Tensor")); + + // Convert PyObject(torch.Tensor) -> tensorrt_llm::runtime::ITensor::SharedPtr + bool from_python(handle src, uint8_t, cleanup_list*) + { + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) + { + at::Tensor const& t = THPVariable_Unpack(obj); + value = std::move(tensorrt_llm::runtime::TorchView::of(t)); + return true; + } + return false; + } + + // Convert tensorrt_llm::runtime::ITensor::SharedPtr -> PyObject(torch.Tensor) + static handle from_cpp( + tensorrt_llm::runtime::ITensor::SharedPtr const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + if (src == nullptr) + { + return none().release(); + } + return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(src)); + } +}; + +template <> +struct type_caster +{ +public: + NB_TYPE_CASTER(tensorrt_llm::runtime::ITensor::SharedConstPtr, const_name("torch.Tensor")); + + // Convert PyObject(torch.Tensor) -> tensorrt_llm::runtime::ITensor::SharedConstPtr + bool from_python(handle src, uint8_t, cleanup_list*) + { + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) + { + at::Tensor const& t = THPVariable_Unpack(obj); + value = std::move(tensorrt_llm::runtime::TorchView::of(t)); + return true; + } + return false; + } + + // Convert tensorrt_llm::runtime::ITensor::SharedConstPtr -> PyObject(torch.Tensor) + static handle from_cpp( + tensorrt_llm::runtime::ITensor::SharedConstPtr const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + if (src == nullptr) + { + return none().release(); + } + return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor( + reinterpret_cast(src))); + } +}; + +template <> +struct type_caster +{ + NB_TYPE_CASTER(at::Tensor, const_name("torch.Tensor")); + + bool from_python(nb::handle src, uint8_t, cleanup_list*) noexcept + { + nb::object capsule = nb::getattr(src, "__dlpack__")(); + DLManagedTensor* dl_managed = static_cast(PyCapsule_GetPointer(capsule.ptr(), "dltensor")); + PyCapsule_SetDestructor(capsule.ptr(), nullptr); + value = at::fromDLPack(dl_managed).alias(); + return true; + } + + static handle from_cpp(at::Tensor tensor, rv_policy, cleanup_list*) noexcept + { + DLManagedTensor* dl_managed = at::toDLPack(tensor); + if (!dl_managed) + return nullptr; + + nanobind::object capsule = nb::steal(PyCapsule_New(dl_managed, "dltensor", + [](PyObject* obj) + { + DLManagedTensor* dl = static_cast(PyCapsule_GetPointer(obj, "dltensor")); + dl->deleter(dl); + })); + if (!capsule.is_valid()) + { + dl_managed->deleter(dl_managed); + return nullptr; + } + nanobind::module_ torch = nanobind::module_::import_("torch"); + nanobind::object result = torch.attr("from_dlpack")(capsule); + capsule.release(); + return result.release(); + } +}; +} // namespace detail +} // namespace NB_NAMESPACE diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp new file mode 100644 index 00000000000..d3f482df899 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp @@ -0,0 +1,263 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include "executor.h" +#include "executorConfig.h" +#include "request.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/types.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; +using SizeType32 = tle::SizeType32; + +namespace tensorrt_llm::nanobind::executor +{ + +template +void instantiateEventDiff(nb::module_& m, std::string const& name) +{ + nb::class_>(m, ("KVCacheEventDiff" + name).c_str()) + .def_ro("old_value", &tle::KVCacheEventDiff::oldValue) + .def_ro("new_value", &tle::KVCacheEventDiff::newValue); +} + +void initBindings(nb::module_& m) +{ + m.attr("__version__") = tle::version(); + nb::enum_(m, "ModelType") + .value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY) + .value("ENCODER_ONLY", tle::ModelType::kENCODER_ONLY) + .value("ENCODER_DECODER", tle::ModelType::kENCODER_DECODER); + + auto decodingModeGetstate = [](tle::DecodingMode const& self) { return nb::make_tuple(self.getState()); }; + auto decodingModeSetstate = [](tle::DecodingMode& self, nb::tuple const& state) + { + if (state.size() != 1) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::DecodingMode(nb::cast(state[0])); + }; + nb::class_(m, "DecodingMode") + .def("Auto", &tle::DecodingMode::Auto) + .def("TopK", &tle::DecodingMode::TopK) + .def("TopP", &tle::DecodingMode::TopP) + .def("TopKTopP", &tle::DecodingMode::TopKTopP) + .def("BeamSearch", &tle::DecodingMode::BeamSearch) + .def("Medusa", &tle::DecodingMode::Medusa) + .def("Lookahead", &tle::DecodingMode::Lookahead) + .def("ExplicitDraftTokens", &tle::DecodingMode::ExplicitDraftTokens) + .def("Eagle", &tle::DecodingMode::Eagle) + .def("isAuto", &tle::DecodingMode::isAuto) + .def("isTopK", &tle::DecodingMode::isTopK) + .def("isTopP", &tle::DecodingMode::isTopP) + .def("isTopKorTopP", &tle::DecodingMode::isTopKorTopP) + .def("isTopKandTopP", &tle::DecodingMode::isTopKandTopP) + .def("isBeamSearch", &tle::DecodingMode::isBeamSearch) + .def("isMedusa", &tle::DecodingMode::isMedusa) + .def("isLookahead", &tle::DecodingMode::isLookahead) + .def("isExplicitDraftTokens", &tle::DecodingMode::isExplicitDraftTokens) + .def("isEagle", &tle::DecodingMode::isEagle) + .def("useVariableBeamWidthSearch", &tle::DecodingMode::useVariableBeamWidthSearch) + .def_prop_ro("name", &tle::DecodingMode::getName) + .def("__getstate__", decodingModeGetstate) + .def("__setstate__", decodingModeSetstate); + + nb::enum_(m, "CapacitySchedulerPolicy") + .value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION) + .value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) + .value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH); + + nb::enum_(m, "ContextChunkingPolicy") + .value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS) + .value("FIRST_COME_FIRST_SERVED", tle::ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED); + + nb::enum_(m, "CommunicationType").value("MPI", tle::CommunicationType::kMPI); + + nb::enum_(m, "CommunicationMode") + .value("LEADER", tle::CommunicationMode::kLEADER) + .value("ORCHESTRATOR", tle::CommunicationMode::kORCHESTRATOR); + + nb::class_(m, "KvCacheStats") + .def(nb::init<>()) + .def_rw("max_num_blocks", &tle::KvCacheStats::maxNumBlocks) + .def_rw("free_num_blocks", &tle::KvCacheStats::freeNumBlocks) + .def_rw("used_num_blocks", &tle::KvCacheStats::usedNumBlocks) + .def_rw("tokens_per_block", &tle::KvCacheStats::tokensPerBlock) + .def_rw("alloc_total_blocks", &tle::KvCacheStats::allocTotalBlocks) + .def_rw("alloc_new_blocks", &tle::KvCacheStats::allocNewBlocks) + .def_rw("reused_blocks", &tle::KvCacheStats::reusedBlocks) + .def_rw("missed_blocks", &tle::KvCacheStats::missedBlocks) + .def_rw("cache_hit_rate", &tle::KvCacheStats::cacheHitRate); + + nb::class_(m, "StaticBatchingStats") + .def(nb::init<>()) + .def_rw("num_scheduled_requests", &tle::StaticBatchingStats::numScheduledRequests) + .def_rw("num_context_requests", &tle::StaticBatchingStats::numContextRequests) + .def_rw("num_ctx_tokens", &tle::StaticBatchingStats::numCtxTokens) + .def_rw("num_gen_tokens", &tle::StaticBatchingStats::numGenTokens) + .def_rw("empty_gen_slots", &tle::StaticBatchingStats::emptyGenSlots); + + nb::class_(m, "InflightBatchingStats") + .def(nb::init<>()) + .def_rw("num_scheduled_requests", &tle::InflightBatchingStats::numScheduledRequests) + .def_rw("num_context_requests", &tle::InflightBatchingStats::numContextRequests) + .def_rw("num_gen_requests", &tle::InflightBatchingStats::numGenRequests) + .def_rw("num_paused_requests", &tle::InflightBatchingStats::numPausedRequests) + .def_rw("num_ctx_tokens", &tle::InflightBatchingStats::numCtxTokens) + .def_rw("micro_batch_id", &tle::InflightBatchingStats::microBatchId) + .def_rw("avg_num_decoded_tokens_per_iter", &tle::InflightBatchingStats::avgNumDecodedTokensPerIter); + + nb::class_(m, "SpecDecodingStats") + .def(nb::init<>()) + .def_rw("num_draft_tokens", &tle::SpecDecodingStats::numDraftTokens) + .def_rw("num_accepted_tokens", &tle::SpecDecodingStats::numAcceptedTokens) + .def_rw("num_requests_with_draft_tokens", &tle::SpecDecodingStats::numRequestsWithDraftTokens) + .def_rw("acceptance_length", &tle::SpecDecodingStats::acceptanceLength) + .def_rw("iter_latency_ms", &tle::SpecDecodingStats::iterLatencyMS) + .def_rw("draft_overhead", &tle::SpecDecodingStats::draftOverhead); + + nb::class_(m, "IterationStats") + .def(nb::init<>()) + .def_rw("timestamp", &tle::IterationStats::timestamp) + .def_rw("iter", &tle::IterationStats::iter) + .def_rw("iter_latency_ms", &tle::IterationStats::iterLatencyMS) + .def_rw("new_active_requests_queue_latency_ms", &tle::IterationStats::newActiveRequestsQueueLatencyMS) + .def_rw("num_new_active_requests", &tle::IterationStats::numNewActiveRequests) + .def_rw("num_active_requests", &tle::IterationStats::numActiveRequests) + .def_rw("num_queued_requests", &tle::IterationStats::numQueuedRequests) + .def_rw("num_completed_requests", &tle::IterationStats::numCompletedRequests) + .def_rw("max_num_active_requests", &tle::IterationStats::maxNumActiveRequests) + .def_rw("gpu_mem_usage", &tle::IterationStats::gpuMemUsage) + .def_rw("cpu_mem_usage", &tle::IterationStats::cpuMemUsage) + .def_rw("pinned_mem_usage", &tle::IterationStats::pinnedMemUsage) + .def_rw("kv_cache_stats", &tle::IterationStats::kvCacheStats) + .def_rw("cross_kv_cache_stats", &tle::IterationStats::crossKvCacheStats) + .def_rw("static_batching_stats", &tle::IterationStats::staticBatchingStats) + .def_rw("inflight_batching_stats", &tle::IterationStats::inflightBatchingStats) + .def_rw("specdec_stats", &tle::IterationStats::specDecodingStats) + .def("to_json_str", + [](tle::IterationStats const& iterationStats) + { return tle::JsonSerialization::toJsonStr(iterationStats); }); + + nb::class_(m, "DebugTensorsPerIteration") + .def(nb::init<>()) + .def_rw("iter", &tle::DebugTensorsPerIteration::iter) + .def_rw("debug_tensors", &tle::DebugTensorsPerIteration::debugTensors); + + nb::enum_(m, "RequestStage") + .value("QUEUED", tle::RequestStage::kQUEUED) + .value("ENCODER_IN_PROGRESS", tle::RequestStage::kENCODER_IN_PROGRESS) + .value("CONTEXT_IN_PROGRESS", tle::RequestStage::kCONTEXT_IN_PROGRESS) + .value("GENERATION_IN_PROGRESS", tle::RequestStage::kGENERATION_IN_PROGRESS) + .value("GENERATION_COMPLETE", tle::RequestStage::kGENERATION_COMPLETE); + + nb::class_(m, "DisServingRequestStats") + .def(nb::init<>()) + .def_rw("kv_cache_transfer_ms", &tle::DisServingRequestStats::kvCacheTransferMS) + .def_rw("kv_cache_size", &tle::DisServingRequestStats::kvCacheSize); + + nb::class_(m, "RequestStats") + .def(nb::init<>()) + .def_rw("id", &tle::RequestStats::id) + .def_rw("stage", &tle::RequestStats::stage) + .def_rw("context_prefill_position", &tle::RequestStats::contextPrefillPosition) + .def_rw("num_generated_tokens", &tle::RequestStats::numGeneratedTokens) + .def_rw("avg_num_decoded_tokens_per_iter", &tle::RequestStats::avgNumDecodedTokensPerIter) + .def_rw("scheduled", &tle::RequestStats::scheduled) + .def_rw("paused", &tle::RequestStats::paused) + .def_rw("dis_serving_stats", &tle::RequestStats::disServingStats) + .def_rw("alloc_total_blocks_per_request", &tle::RequestStats::allocTotalBlocksPerRequest) + .def_rw("alloc_new_blocks_per_request", &tle::RequestStats::allocNewBlocksPerRequest) + .def_rw("reused_blocks_per_request", &tle::RequestStats::reusedBlocksPerRequest) + .def_rw("missed_blocks_per_request", &tle::RequestStats::missedBlocksPerRequest) + .def_rw("kv_cache_hit_rate_per_request", &tle::RequestStats::kvCacheHitRatePerRequest) + .def("to_json_str", + [](tle::RequestStats const& iterationStats) { return tle::JsonSerialization::toJsonStr(iterationStats); }); + + nb::class_(m, "RequestStatsPerIteration") + .def(nb::init<>()) + .def_rw("iter", &tle::RequestStatsPerIteration::iter) + .def_rw("request_stats", &tle::RequestStatsPerIteration::requestStats) + .def("to_json_str", + [](tle::RequestStatsPerIteration const& iterationStats) + { return tle::JsonSerialization::toJsonStr(iterationStats); }); + + nb::module_ executor_kv_cache = m.def_submodule("kv_cache", "Executor KV Cache Manager"); + + nb::class_(executor_kv_cache, "KVCacheCreatedData") + .def_ro("num_blocks_per_cache_level", &tle::KVCacheCreatedData::numBlocksPerCacheLevel); + + nb::class_(executor_kv_cache, "UniqueToken") + .def_ro("token_id", &tensorrt_llm::runtime::UniqueToken::tokenId) + .def_ro("token_extra_id", &tensorrt_llm::runtime::UniqueToken::tokenExtraId); + + nb::class_(executor_kv_cache, "KVCacheStoredBlockData") + .def_ro("block_hash", &tle::KVCacheStoredBlockData::blockHash) + .def_ro("tokens", &tle::KVCacheStoredBlockData::tokens) + .def_ro("lora_id", &tle::KVCacheStoredBlockData::loraId) + .def_ro("cache_level", &tle::KVCacheStoredBlockData::cacheLevel) + .def_ro("priority", &tle::KVCacheStoredBlockData::priority); + + nb::class_(executor_kv_cache, "KVCacheStoredData") + .def_ro("parent_hash", &tle::KVCacheStoredData::parentHash) + .def_ro("blocks", &tle::KVCacheStoredData::blocks); + + nb::class_(executor_kv_cache, "KVCacheRemovedData") + .def_ro("block_hashes", &tle::KVCacheRemovedData::blockHashes); + + instantiateEventDiff(executor_kv_cache, "Int"); + + nb::class_(executor_kv_cache, "KVCacheUpdatedData") + .def_ro("block_hash", &tle::KVCacheUpdatedData::blockHash) + .def_ro("cache_level", &tle::KVCacheUpdatedData::cacheLevel) + .def_ro("priority", &tle::KVCacheUpdatedData::priority); + + nb::class_(executor_kv_cache, "KVCacheEvent") + .def_ro("event_id", &tle::KVCacheEvent::eventId) + .def_ro("data", &tle::KVCacheEvent::data) + .def_ro("window_size", &tle::KVCacheEvent::windowSize); + + nb::class_(executor_kv_cache, "KVCacheEventManager") + .def( + "get_latest_events", + [](tle::KVCacheEventManager& self, std::optional timeout_ms = std::nullopt) + { + if (timeout_ms) + { + return self.getLatestEvents(std::chrono::milliseconds(static_cast(*timeout_ms))); + } + return self.getLatestEvents(std::nullopt); + }, + nb::arg("timeout_ms") = std::nullopt); + + tensorrt_llm::nanobind::executor::initRequestBindings(m); + tensorrt_llm::nanobind::executor::initConfigBindings(m); + tensorrt_llm::nanobind::executor::Executor::initBindings(m); +} + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.h b/cpp/tensorrt_llm/nanobind/executor/bindings.h new file mode 100644 index 00000000000..4df52c2d34e --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::executor +{ + +// Register bindings for executor API. +void initBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executor.cpp b/cpp/tensorrt_llm/nanobind/executor/executor.cpp new file mode 100644 index 00000000000..59c7d2a3dc1 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/executor.cpp @@ -0,0 +1,241 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "executor.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; + +namespace nanobind::detail +{ + +template <> +struct dtype_traits +{ + static constexpr dlpack::dtype value{ + (uint8_t) dlpack::dtype_code::Float, // type code + 16, // size in bits + 1 // lanes (simd), usually set to 1 + }; + static constexpr auto name = const_name("float16"); +}; +} // namespace nanobind::detail + +namespace +{ +// todo: Properly support FP8 and BF16 and verify functionality +tle::Tensor numpyToTensor(nb::ndarray const& array) +{ + auto npDtype = array.dtype(); + char kind = '\0'; + switch (npDtype.code) + { + case static_cast(nb::dlpack::dtype_code::Int): + kind = 'i'; // signed integer + break; + case static_cast(nb::dlpack::dtype_code::UInt): + kind = 'u'; // unsigned integer + break; + case static_cast(nb::dlpack::dtype_code::Float): + kind = 'f'; // floating point + break; + case static_cast(nb::dlpack::dtype_code::Bfloat): + kind = 'f'; // brain floating point (treat as float kind) + break; + case static_cast(nb::dlpack::dtype_code::Complex): + kind = 'c'; // complex + break; + default: + kind = 'V'; // void/other + break; + } + tle::DataType dtype; + if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kFP16; + } + else if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kFP32; + } + else if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kINT8; + } + else if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kINT32; + } + else if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kINT64; + } + else if (kind == 'V' && array.itemsize() == 1) + { + dtype = tle::DataType::kFP8; + } + else if (kind == 'V' && array.itemsize() == 2) + { + dtype = tle::DataType::kBF16; + } + else + { + TLLM_THROW("Unsupported numpy dtype."); + } + + // todo: improve the following code + std::vector dims; + dims.reserve(array.ndim()); + for (size_t i = 0; i < array.ndim(); ++i) + { + dims.push_back(static_cast(array.shape(i))); + } + tle::Shape shape(dims.data(), dims.size()); + + return tle::Tensor::of(dtype, const_cast(array.data()), shape); +} + +} // namespace + +namespace tensorrt_llm::nanobind::executor +{ + +Executor::Executor( + std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig) +{ + mExecutor = std::make_unique(modelPath, modelType, executorConfig); +} + +Executor::Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath, + tle::ModelType modelType, tle::ExecutorConfig const& executorConfig) +{ + mExecutor = std::make_unique(encoderModelPath, decoderModelPath, modelType, executorConfig); +} + +Executor::Executor(nb::bytes const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig, std::optional managedWeights) +{ + uint8_t const* data = static_cast(engineBuffer.data()); + size_t size = engineBuffer.size(); + std::optional> managedWeightsMap = std::nullopt; + if (managedWeights.has_value() && !managedWeights.value().empty()) + { + managedWeightsMap = std::map(); + for (auto const& [rawName, rawArray] : managedWeights.value()) + { + std::string name = nb::cast(rawName); + nb::ndarray array = nb::cast>(rawArray); + managedWeightsMap->emplace(name, numpyToTensor(array)); + } + } + mExecutor = std::make_unique( + tle::BufferView(data, size), jsonConfigStr, modelType, executorConfig, managedWeightsMap); +} + +Executor::Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, + std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig) +{ + uint8_t const* encoderData = reinterpret_cast(encoderEngineBuffer.data()); + size_t encoderSize = encoderEngineBuffer.size(); + uint8_t const* decoderData = reinterpret_cast(decoderEngineBuffer.data()); + size_t decoderSize = decoderEngineBuffer.size(); + mExecutor = std::make_unique(tle::BufferView(encoderData, encoderSize), encoderJsonConfigStr, + tle::BufferView(decoderData, decoderSize), decoderJsonConfigStr, modelType, executorConfig); +} + +nb::object Executor::enter() +{ + TLLM_CHECK(static_cast(mExecutor)); + return nb::cast(this); +} + +void Executor::exit( + [[maybe_unused]] nb::handle type, [[maybe_unused]] nb::handle value, [[maybe_unused]] nb::handle traceback) +{ + shutdown(); + mExecutor = nullptr; +} + +void Executor::shutdown() +{ + // NOTE: we must release the GIL here. Executor has spawned a thread for the execution loop. That thread must be + // able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so + // we release it now. Note that we shouldn't do anything related to python objects after that. + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + nb::gil_scoped_release release; + mExecutor->shutdown(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +void Executor::initBindings(nb::module_& m) +{ + nb::class_(m, "Executor") + .def(nb::init(), + nb::arg("model_path"), nb::arg("model_type"), nb::arg("executor_config")) + .def(nb::init(), + nb::arg("encoder_model_path"), nb::arg("decoder_model_path"), nb::arg("model_type"), + nb::arg("executor_config")) + .def(nb::init(), + nb::arg("engine_buffer"), nb::arg("json_config_str"), nb::arg("model_type"), nb::arg("executor_config"), + nb::arg("managed_weights") = nb::dict()) + .def(nb::init(), + nb::arg("encoder_engine_buffer"), nb::arg("encoder_json_config_str"), nb::arg("decoder_engine_buffer"), + nb::arg("decoder_json_config_str"), nb::arg("model_type"), nb::arg("executor_config")) + .def("shutdown", &Executor::shutdown) + .def("__enter__", &Executor::enter) + .def("__exit__", &Executor::exit) + .def("enqueue_request", &Executor::enqueueRequest, nb::arg("request")) + .def("enqueue_requests", &Executor::enqueueRequests, nb::arg("requests")) + .def("await_responses", + nb::overload_cast const&>(&Executor::awaitResponses), + nb::arg("timeout") = nb::none()) + .def("await_responses", + nb::overload_cast const&>( + &Executor::awaitResponses), + nb::arg("id"), nb::arg("timeout") = nb::none()) + .def("await_responses", + nb::overload_cast const&, std::optional const&>( + &Executor::awaitResponses), + nb::arg("ids"), nb::arg("timeout") = nb::none()) + .def("get_num_responses_ready", &Executor::getNumResponsesReady, nb::arg("id") = nb::none()) + .def("cancel_request", &Executor::cancelRequest, nb::arg("id") = nb::none()) + .def("get_latest_iteration_stats", &Executor::getLatestIterationStats) + .def("get_latest_request_stats", &Executor::getLatestRequestStats) + .def("get_latest_debug_tensors", &Executor::getLatestDebugTensors) + .def("can_enqueue_requests", &Executor::canEnqueueRequests) + .def("get_kv_cache_event_manager", &Executor::getKVCacheEventManager); +} + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executor.h b/cpp/tensorrt_llm/nanobind/executor/executor.h new file mode 100644 index 00000000000..22c24abb4bf --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/executor.h @@ -0,0 +1,129 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/types.h" +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; + +namespace tensorrt_llm::nanobind::executor +{ + +class Executor +{ +public: + Executor( + std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); + + Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath, + tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); + + Executor(nb::bytes const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig, std::optional managedWeights); + + Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, + std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig); + + nb::object enter(); + void exit( + [[maybe_unused]] nb::handle type, [[maybe_unused]] nb::handle value, [[maybe_unused]] nb::handle traceback); + void shutdown(); + + [[nodiscard]] tle::IdType enqueueRequest(tle::Request const& request) + { + return mExecutor->enqueueRequest(request); + } + + [[nodiscard]] std::vector enqueueRequests(std::vector const& requests) + { + return mExecutor->enqueueRequests(requests); + } + + [[nodiscard]] std::vector awaitResponses( + std::optional const& timeout = std::nullopt) + { + // Await responses blocks until a response is received. Release GIL so that it can be ran in a background + // thread. + nb::gil_scoped_release release; + return mExecutor->awaitResponses(timeout); + } + + [[nodiscard]] std::vector awaitResponses( + tle::IdType const& requestId, std::optional const& timeout = std::nullopt) + { + // Await responses blocks until a response is received. Release GIL so that it can be ran in a background + // thread. + nb::gil_scoped_release release; + return mExecutor->awaitResponses(requestId, timeout); + } + + [[nodiscard]] std::vector> awaitResponses(std::vector const& requestIds, + std::optional const& timeout = std::nullopt) + { + // Await responses blocks until a response is received. Release GIL so that it can be ran in a background + // thread. + nb::gil_scoped_release release; + return mExecutor->awaitResponses(requestIds, timeout); + } + + [[nodiscard]] tle::SizeType32 getNumResponsesReady(std::optional const& requestId = std::nullopt) const + { + return mExecutor->getNumResponsesReady(requestId); + } + + void cancelRequest(tle::IdType requestId) + { + mExecutor->cancelRequest(requestId); + } + + std::deque getLatestIterationStats() + { + return mExecutor->getLatestIterationStats(); + } + + std::deque getLatestRequestStats() + { + return mExecutor->getLatestRequestStats(); + } + + std::deque getLatestDebugTensors() + { + return mExecutor->getLatestDebugTensors(); + } + + [[nodiscard]] bool canEnqueueRequests() const + { + return mExecutor->canEnqueueRequests(); + } + + [[nodiscard]] std::optional> getKVCacheEventManager() const + { + return mExecutor->getKVCacheEventManager(); + } + + static void initBindings(nb::module_& m); + +private: + std::unique_ptr mExecutor; +}; + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp new file mode 100644 index 00000000000..c2d9fe25dff --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -0,0 +1,616 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "executorConfig.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/types.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; +using SizeType32 = tle::SizeType32; +using RuntimeDefaults = tensorrt_llm::runtime::RuntimeDefaults; + +namespace tensorrt_llm::nanobind::executor +{ + +void initConfigBindings(nb::module_& m) +{ + nb::enum_(m, "BatchingType") + .value("STATIC", tle::BatchingType::kSTATIC) + .value("INFLIGHT", tle::BatchingType::kINFLIGHT); + + auto dynamicBatchConfigGetstate = [](tle::DynamicBatchConfig const& self) + { + return nb::make_tuple(self.getEnableBatchSizeTuning(), self.getEnableMaxNumTokensTuning(), + self.getDynamicBatchMovingAverageWindow(), self.getBatchSizeTable()); + }; + auto dynamicBatchConfigSetstate = [](tle::DynamicBatchConfig& self, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::DynamicBatchConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast>>(state[3])); + }; + nb::class_(m, "DynamicBatchConfig") + .def(nb::init(), nb::arg("enable_batch_size_tuning"), + nb::arg("enable_max_num_tokens_tuning"), nb::arg("dynamic_batch_moving_average_window")) + .def_prop_ro("enable_batch_size_tuning", &tle::DynamicBatchConfig::getEnableBatchSizeTuning) + .def_prop_ro("enable_max_num_tokens_tuning", &tle::DynamicBatchConfig::getEnableMaxNumTokensTuning) + .def_prop_ro( + "dynamic_batch_moving_average_window", &tle::DynamicBatchConfig::getDynamicBatchMovingAverageWindow) + .def("__getstate__", dynamicBatchConfigGetstate) + .def("__setstate__", dynamicBatchConfigSetstate); + + auto schedulerConfigSetstate = [](tle::SchedulerConfig& self, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::SchedulerConfig(nb::cast(state[0]), + nb::cast>(state[1]), + nb::cast>(state[2])); + }; + auto schedulerConfigGetstate = [](tle::SchedulerConfig const& self) + { + return nb::make_tuple( + self.getCapacitySchedulerPolicy(), self.getContextChunkingPolicy(), self.getDynamicBatchConfig()); + }; + nb::class_(m, "SchedulerConfig") + .def(nb::init, + std::optional>(), + nb::arg("capacity_scheduler_policy") = tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT, + nb::arg("context_chunking_policy") = nb::none(), nb::arg("dynamic_batch_config") = nb::none()) + .def_prop_ro("capacity_scheduler_policy", &tle::SchedulerConfig::getCapacitySchedulerPolicy) + .def_prop_ro("context_chunking_policy", &tle::SchedulerConfig::getContextChunkingPolicy) + .def_prop_ro("dynamic_batch_config", &tle::SchedulerConfig::getDynamicBatchConfig) + .def("__getstate__", schedulerConfigGetstate) + .def("__setstate__", schedulerConfigSetstate); + + nb::class_(m, "RuntimeDefaults") + .def(nb::init>, std::optional>(), + nb::arg("max_attention_window") = nb::none(), nb::arg("sink_token_length") = nb::none()) + .def_ro("max_attention_window", &RuntimeDefaults::maxAttentionWindowVec) + .def_ro("sink_token_length", &RuntimeDefaults::sinkTokenLength); + + auto kvCacheConfigGetstate = [](tle::KvCacheConfig const& self) + { + return nb::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(), + self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), + self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), + self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm()); + }; + auto kvCacheConfigSetstate = [](tle::KvCacheConfig& self, nb::tuple const& state) + { + if (state.size() != 13) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::KvCacheConfig(nb::cast(state[0]), nb::cast>(state[1]), + nb::cast>>(state[2]), nb::cast>(state[3]), + nb::cast>(state[4]), nb::cast>(state[5]), + nb::cast(state[6]), nb::cast>(state[7]), + nb::cast>(state[8]), nb::cast(state[9]), + nb::cast(state[10]), nb::cast(state[11]), nb::cast(state[12])); + }; + nb::class_(m, "KvCacheConfig") + .def(nb::init const&, std::optional> const&, + std::optional const&, std::optional const&, std::optional const&, bool, + std::optional const&, std::optional, size_t const&, bool, bool, bool, + std::optional const&>(), + nb::arg("enable_block_reuse") = true, nb::arg("max_tokens") = nb::none(), + nb::arg("max_attention_window") = nb::none(), nb::arg("sink_token_length") = nb::none(), + nb::arg("free_gpu_memory_fraction") = nb::none(), nb::arg("host_cache_size") = nb::none(), + nb::arg("onboard_blocks") = true, nb::arg("cross_kv_cache_fraction") = nb::none(), + nb::arg("secondary_offload_min_priority") = nb::none(), nb::arg("event_buffer_max_size") = 0, nb::kw_only(), + nb::arg("enable_partial_reuse") = true, nb::arg("copy_on_partial_reuse") = true, nb::arg("use_uvm") = false, + nb::arg("runtime_defaults") = nb::none()) + .def_prop_rw( + "enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse) + .def_prop_rw("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens) + .def_prop_rw("max_attention_window", &tle::KvCacheConfig::getMaxAttentionWindowVec, + &tle::KvCacheConfig::setMaxAttentionWindowVec) + .def_prop_rw( + "sink_token_length", &tle::KvCacheConfig::getSinkTokenLength, &tle::KvCacheConfig::setSinkTokenLength) + .def_prop_rw("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction, + &tle::KvCacheConfig::setFreeGpuMemoryFraction) + .def_prop_rw("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize) + .def_prop_rw("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks) + .def_prop_rw("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction, + &tle::KvCacheConfig::setCrossKvCacheFraction) + .def_prop_rw("secondary_offload_min_priority", &tle::KvCacheConfig::getSecondaryOffloadMinPriority, + &tle::KvCacheConfig::setSecondaryOffloadMinPriority) + .def_prop_rw("event_buffer_max_size", &tle::KvCacheConfig::getEventBufferMaxSize, + &tle::KvCacheConfig::setEventBufferMaxSize) + .def_prop_rw("enable_partial_reuse", &tle::KvCacheConfig::getEnablePartialReuse, + &tle::KvCacheConfig::setEnablePartialReuse) + .def_prop_rw("copy_on_partial_reuse", &tle::KvCacheConfig::getCopyOnPartialReuse, + &tle::KvCacheConfig::setCopyOnPartialReuse) + .def_prop_rw("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm) + .def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults) + .def("__getstate__", kvCacheConfigGetstate) + .def("__setstate__", kvCacheConfigSetstate); + + nb::class_(m, "OrchestratorConfig") + .def(nb::init, bool>(), nb::arg("is_orchestrator") = true, + nb::arg("worker_executable_path") = "", nb::arg("orch_leader_comm").none() = nullptr, + nb::arg("spawn_processes") = true) + .def_prop_rw( + "is_orchestrator", &tle::OrchestratorConfig::getIsOrchestrator, &tle::OrchestratorConfig::setIsOrchestrator) + .def_prop_rw("worker_executable_path", &tle::OrchestratorConfig::getWorkerExecutablePath, + &tle::OrchestratorConfig::setWorkerExecutablePath) + .def_prop_rw("orch_leader_comm", &tle::OrchestratorConfig::getOrchLeaderComm, + &tle::OrchestratorConfig::setOrchLeaderComm) + .def_prop_rw("spawn_processes", &tle::OrchestratorConfig::getSpawnProcesses, + &tle::OrchestratorConfig::setSpawnProcesses); + + auto parallelConfigGetstate = [](tle::ParallelConfig const& self) + { + return nb::make_tuple(self.getCommunicationType(), self.getCommunicationMode(), self.getDeviceIds(), + self.getParticipantIds(), self.getOrchestratorConfig(), self.getNumNodes()); + }; + auto parallelConfigSetstate = [](tle::ParallelConfig& self, nb::tuple const& state) + { + if (state.size() != 6) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::ParallelConfig(nb::cast(state[0]), + nb::cast(state[1]), nb::cast>>(state[2]), + nb::cast>>(state[3]), + nb::cast>(state[4]), nb::cast>(state[5])); + }; + nb::class_(m, "ParallelConfig") + .def(nb::init> const&, + std::optional> const&, std::optional const&, + std::optional const&>(), + nb::arg("communication_type") = tle::CommunicationType::kMPI, + nb::arg("communication_mode") = tle::CommunicationMode::kLEADER, nb::arg("device_ids") = nb::none(), + nb::arg("participant_ids") = nb::none(), nb::arg("orchestrator_config") = nb::none(), + nb::arg("num_nodes") = nb::none()) + .def_prop_rw("communication_type", &tle::ParallelConfig::getCommunicationType, + &tle::ParallelConfig::setCommunicationType) + .def_prop_rw("communication_mode", &tle::ParallelConfig::getCommunicationMode, + &tle::ParallelConfig::setCommunicationMode) + .def_prop_rw("device_ids", &tle::ParallelConfig::getDeviceIds, &tle::ParallelConfig::setDeviceIds) + .def_prop_rw( + "participant_ids", &tle::ParallelConfig::getParticipantIds, &tle::ParallelConfig::setParticipantIds) + .def_prop_rw("orchestrator_config", &tle::ParallelConfig::getOrchestratorConfig, + &tle::ParallelConfig::setOrchestratorConfig) + .def_prop_rw("num_nodes", &tle::ParallelConfig::getNumNodes, &tle::ParallelConfig::setNumNodes) + .def("__getstate__", parallelConfigGetstate) + .def("__setstate__", parallelConfigSetstate); + + auto peftCacheConfigSetstate = [](tle::PeftCacheConfig& self, nb::tuple const& state) + { + if (state.size() != 11) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::PeftCacheConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), + nb::cast(state[5]), nb::cast(state[6]), nb::cast(state[7]), + nb::cast(state[8]), nb::cast>(state[9]), + nb::cast>(state[10])); + }; + auto peftCacheConfigGetstate = [](tle::PeftCacheConfig const& self) + { + return nb::make_tuple(self.getNumHostModuleLayer(), self.getNumDeviceModuleLayer(), + self.getOptimalAdapterSize(), self.getMaxAdapterSize(), self.getNumPutWorkers(), self.getNumEnsureWorkers(), + self.getNumCopyStreams(), self.getMaxPagesPerBlockHost(), self.getMaxPagesPerBlockDevice(), + self.getDeviceCachePercent(), self.getHostCacheSize()); + }; + nb::class_(m, "PeftCacheConfig") + .def(nb::init const&, std::optional const&, + std::optional const&>(), + nb::arg("num_host_module_layer") = 0, nb::arg("num_device_module_layer") = 0, + nb::arg("optimal_adapter_size") = 8, nb::arg("max_adapter_size") = 64, nb::arg("num_put_workers") = 1, + nb::arg("num_ensure_workers") = 1, nb::arg("num_copy_streams") = 1, + nb::arg("max_pages_per_block_host") = 24, nb::arg("max_pages_per_block_device") = 8, + nb::arg("device_cache_percent") = nb::none(), nb::arg("host_cache_size") = nb::none(), + nb::arg("lora_prefetch_dir") = nb::none()) + .def_prop_ro("num_host_module_layer", &tle::PeftCacheConfig::getNumHostModuleLayer) + .def_prop_ro("num_device_module_layer", &tle::PeftCacheConfig::getNumDeviceModuleLayer) + .def_prop_ro("optimal_adapter_size", &tle::PeftCacheConfig::getOptimalAdapterSize) + .def_prop_ro("max_adapter_size", &tle::PeftCacheConfig::getMaxAdapterSize) + .def_prop_ro("num_put_workers", &tle::PeftCacheConfig::getNumPutWorkers) + .def_prop_ro("num_ensure_workers", &tle::PeftCacheConfig::getNumEnsureWorkers) + .def_prop_ro("num_copy_streams", &tle::PeftCacheConfig::getNumCopyStreams) + .def_prop_ro("max_pages_per_block_host", &tle::PeftCacheConfig::getMaxPagesPerBlockHost) + .def_prop_ro("max_pages_per_block_device", &tle::PeftCacheConfig::getMaxPagesPerBlockDevice) + .def_prop_ro("device_cache_percent", &tle::PeftCacheConfig::getDeviceCachePercent) + .def_prop_ro("host_cache_size", &tle::PeftCacheConfig::getHostCacheSize) + .def_prop_ro("lora_prefetch_dir", &tle::PeftCacheConfig::getLoraPrefetchDir) + .def("__getstate__", peftCacheConfigGetstate) + .def("__setstate__", peftCacheConfigSetstate); + + auto decodingConfigGetstate = [](tle::DecodingConfig const& self) + { + return nb::make_tuple( + self.getDecodingMode(), self.getLookaheadDecodingConfig(), self.getMedusaChoices(), self.getEagleConfig()); + }; + auto decodingConfigSetstate = [](tle::DecodingConfig& self, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::DecodingConfig(nb::cast>(state[0]), // DecodingMode + nb::cast>(state[1]), // LookaheadDecodingConfig + nb::cast>(state[2]), // MedusaChoices + nb::cast>(state[3]) // EagleConfig + ); + }; + nb::class_(m, "DecodingConfig") + .def(nb::init, std::optional, + std::optional, std::optional>(), + nb::arg("decoding_mode") = nb::none(), nb::arg("lookahead_decoding_config") = nb::none(), + nb::arg("medusa_choices") = nb::none(), nb::arg("eagle_config") = nb::none()) + .def_prop_rw("decoding_mode", &tle::DecodingConfig::getDecodingMode, &tle::DecodingConfig::setDecodingMode) + .def_prop_rw("lookahead_decoding_config", &tle::DecodingConfig::getLookaheadDecodingConfig, + &tle::DecodingConfig::setLookaheadDecodingConfig) + .def_prop_rw("medusa_choices", &tle::DecodingConfig::getMedusaChoices, &tle::DecodingConfig::setMedusaChoices) + .def_prop_rw("eagle_config", &tle::DecodingConfig::getEagleConfig, &tle::DecodingConfig::setEagleConfig) + .def("__getstate__", decodingConfigGetstate) + .def("__setstate__", decodingConfigSetstate); + + auto debugConfigGetstate = [](tle::DebugConfig const& self) + { + return nb::make_tuple(self.getDebugInputTensors(), self.getDebugOutputTensors(), self.getDebugTensorNames(), + self.getDebugTensorsMaxIterations()); + }; + auto debugConfigSetstate = [](tle::DebugConfig& self, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::DebugConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast>(state[2]), nb::cast(state[3])); + }; + nb::class_(m, "DebugConfig") + .def(nb::init, SizeType32>(), nb::arg("debug_input_tensors") = false, + nb::arg("debug_output_tensors") = false, nb::arg("debug_tensor_names") = nb::none(), + nb::arg("debug_tensors_max_iterations") = false) + .def_prop_rw( + "debug_input_tensors", &tle::DebugConfig::getDebugInputTensors, &tle::DebugConfig::setDebugInputTensors) + .def_prop_rw( + "debug_output_tensors", &tle::DebugConfig::getDebugOutputTensors, &tle::DebugConfig::setDebugOutputTensors) + .def_prop_rw( + "debug_tensor_names", &tle::DebugConfig::getDebugTensorNames, &tle::DebugConfig::setDebugTensorNames) + .def_prop_rw("debug_tensors_max_iterations", &tle::DebugConfig::getDebugTensorsMaxIterations, + &tle::DebugConfig::setDebugTensorsMaxIterations) + .def("__getstate__", debugConfigGetstate) + .def("__setstate__", debugConfigSetstate); + + auto logitsPostProcessorConfigGetstate = [](tle::LogitsPostProcessorConfig const& self) + { return nb::make_tuple(self.getProcessorMap(), self.getProcessorBatched(), self.getReplicate()); }; + + auto logitsPostProcessorConfigSetstate = [](tle::LogitsPostProcessorConfig& self, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid LogitsPostProcessorConfig state!"); + } + new (&self) tle::LogitsPostProcessorConfig(nb::cast>(state[0]), + nb::cast>(state[1]), nb::cast(state[2])); + }; + + nb::class_(m, "LogitsPostProcessorConfig") + .def(nb::init, std::optional, + bool>(), + nb::arg("processor_map") = nb::none(), nb::arg("processor_batched") = nb::none(), + nb::arg("replicate") = true) + .def_prop_rw("processor_map", &tle::LogitsPostProcessorConfig::getProcessorMap, + &tle::LogitsPostProcessorConfig::setProcessorMap) + .def_prop_rw("processor_batched", &tle::LogitsPostProcessorConfig::getProcessorBatched, + &tle::LogitsPostProcessorConfig::setProcessorBatched) + .def_prop_rw( + "replicate", &tle::LogitsPostProcessorConfig::getReplicate, &tle::LogitsPostProcessorConfig::setReplicate) + .def("__getstate__", logitsPostProcessorConfigGetstate) + .def("__setstate__", logitsPostProcessorConfigSetstate); + + auto extendedRuntimePerfKnobConfigSetstate = [](tle::ExtendedRuntimePerfKnobConfig& self, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!"); + } + new (&self) tle::ExtendedRuntimePerfKnobConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[2])); + }; + auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self) + { + return nb::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc(), self.getCudaGraphMode(), + self.getCudaGraphCacheSize()); + }; + nb::class_(m, "ExtendedRuntimePerfKnobConfig") + .def( + nb::init(), nb::arg("multi_block_mode") = true, nb::arg("enable_context_fmha_fp32_acc") = false) + .def_prop_rw("multi_block_mode", &tle::ExtendedRuntimePerfKnobConfig::getMultiBlockMode, + &tle::ExtendedRuntimePerfKnobConfig::setMultiBlockMode) + .def_prop_rw("enable_context_fmha_fp32_acc", &tle::ExtendedRuntimePerfKnobConfig::getEnableContextFMHAFP32Acc, + &tle::ExtendedRuntimePerfKnobConfig::setEnableContextFMHAFP32Acc) + .def_prop_rw("cuda_graph_mode", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphMode, + &tle::ExtendedRuntimePerfKnobConfig::setCudaGraphMode) + .def_prop_rw("cuda_graph_cache_size", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphCacheSize, + &tle::ExtendedRuntimePerfKnobConfig::setCudaGraphCacheSize) + .def("__getstate__", extendedRuntimePerfKnobConfigGetstate) + .def("__setstate__", extendedRuntimePerfKnobConfigSetstate); + + auto SpeculativeDecodingConfigGetState + = [](tle::SpeculativeDecodingConfig const& self) { return nb::make_tuple(self.fastLogits); }; + auto SpeculativeDecodingConfigSetState = [](tle::SpeculativeDecodingConfig& self, nb::tuple const& state) + { + if (state.size() != 1) + { + throw std::runtime_error("Invalid SpeculativeDecodingConfig state!"); + } + new (&self) tle::SpeculativeDecodingConfig(nb::cast(state[0])); + }; + nb::class_(m, "SpeculativeDecodingConfig") + .def(nb::init(), nb::arg("fast_logits") = false) + .def_rw("fast_logits", &tle::SpeculativeDecodingConfig::fastLogits) + .def("__getstate__", SpeculativeDecodingConfigGetState) + .def("__setstate__", SpeculativeDecodingConfigSetState); + + // Guided decoding config + auto pyGuidedDecodingConfig = nb::class_(m, "GuidedDecodingConfig"); + + nb::enum_(pyGuidedDecodingConfig, "GuidedDecodingBackend") + .value("XGRAMMAR", tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR) + .value("LLGUIDANCE", tle::GuidedDecodingConfig::GuidedDecodingBackend::kLLGUIDANCE); + + auto guidedDecodingConfigGetstate = [](tle::GuidedDecodingConfig const& self) { + return nb::make_tuple( + self.getBackend(), self.getEncodedVocab(), self.getTokenizerStr(), self.getStopTokenIds()); + }; + auto guidedDecodingConfigSetstate = [](tle::GuidedDecodingConfig& self, nb::tuple state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid GuidedDecodingConfig state!"); + } + new (&self) tle::GuidedDecodingConfig(nb::cast(state[0]), + nb::cast>>(state[1]), nb::cast>(state[2]), + nb::cast>>(state[3])); + }; + + pyGuidedDecodingConfig + .def(nb::init>, + std::optional, std::optional>>(), + nb::arg("backend"), nb::arg("encoded_vocab") = nb::none(), nb::arg("tokenizer_str") = nb::none(), + nb::arg("stop_token_ids") = nb::none()) + .def_prop_rw("backend", &tle::GuidedDecodingConfig::getBackend, &tle::GuidedDecodingConfig::setBackend) + .def_prop_rw( + "encoded_vocab", &tle::GuidedDecodingConfig::getEncodedVocab, &tle::GuidedDecodingConfig::setEncodedVocab) + .def_prop_rw( + "tokenizer_str", &tle::GuidedDecodingConfig::getTokenizerStr, &tle::GuidedDecodingConfig::setTokenizerStr) + .def_prop_rw( + "stop_token_ids", &tle::GuidedDecodingConfig::getStopTokenIds, &tle::GuidedDecodingConfig::setStopTokenIds) + .def("__getstate__", guidedDecodingConfigGetstate) + .def("__setstate__", guidedDecodingConfigSetstate); + + auto cacheTransceiverConfigGetstate + = [](tle::CacheTransceiverConfig const& self) { return nb::make_tuple(self.getMaxNumTokens()); }; + auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state) + { + if (state.size() != 1) + { + throw std::runtime_error("Invalid CacheTransceiverConfig state!"); + } + new (&self) tle::CacheTransceiverConfig(nb::cast>(state[0])); + }; + + nb::class_(m, "CacheTransceiverConfig") + .def(nb::init>(), nb::arg("max_num_tokens") = nb::none()) + .def_prop_rw("max_num_tokens", &tle::CacheTransceiverConfig::getMaxNumTokens, + &tle::CacheTransceiverConfig::setMaxNumTokens) + .def("__getstate__", cacheTransceiverConfigGetstate) + .def("__setstate__", cacheTransceiverConfigSetstate); + + auto executorConfigGetState = [](nb::object const& self) + { + auto& c = nb::cast(self); + // Return a tuple containing C++ data and the Python __dict__ + auto cpp_states = nb::make_tuple(c.getMaxBeamWidth(), c.getSchedulerConfig(), c.getKvCacheConfig(), + c.getEnableChunkedContext(), c.getNormalizeLogProbs(), c.getIterStatsMaxIterations(), + c.getRequestStatsMaxIterations(), c.getBatchingType(), c.getMaxBatchSize(), c.getMaxNumTokens(), + c.getParallelConfig(), c.getPeftCacheConfig(), c.getLogitsPostProcessorConfig(), c.getDecodingConfig(), + c.getUseGpuDirectStorage(), c.getGpuWeightsPercent(), c.getMaxQueueSize(), + c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(), + c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(), + c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(), + c.getPromptTableOffloading(), c.getEnableTrtOverlap()); + auto pickle_tuple = nb::make_tuple(cpp_states, nb::getattr(self, "__dict__")); + return pickle_tuple; + }; + + auto executorConfigSetState = [](nb::object self, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid state!"); + } + + auto cpp_states = nb::cast(state[0]); + if (cpp_states.size() != 28) + { + throw std::runtime_error("Invalid cpp_states!"); + } + + // Restore C++ data + tle::ExecutorConfig* cpp_self = nb::inst_ptr(self); + new (cpp_self) tle::ExecutorConfig( // + nb::cast(cpp_states[0]), // MaxBeamWidth + nb::cast(cpp_states[1]), // SchedulerConfig + nb::cast(cpp_states[2]), // KvCacheConfig + nb::cast(cpp_states[3]), // EnableChunkedContext + nb::cast(cpp_states[4]), // NormalizeLogProbs + nb::cast(cpp_states[5]), // IterStatsMaxIterations + nb::cast(cpp_states[6]), // RequestStatsMaxIterations + nb::cast(cpp_states[7]), // BatchingType + nb::cast>(cpp_states[8]), // MaxBatchSize + nb::cast>(cpp_states[9]), // MaxNumTokens + nb::cast>(cpp_states[10]), // ParallelConfig + nb::cast>(cpp_states[11]), // PeftCacheConfig + nb::cast>(cpp_states[12]), // LogitsPostProcessorConfig + nb::cast>(cpp_states[13]), // DecodingConfig + nb::cast(cpp_states[14]), // UseGpuDirectStorage + nb::cast(cpp_states[15]), // GpuWeightsPercent + nb::cast>(cpp_states[16]), // MaxQueueSize + nb::cast(cpp_states[17]), // ExtendedRuntimePerfKnobConfig + nb::cast>(cpp_states[18]), // DebugConfig + nb::cast(cpp_states[19]), // RecvPollPeriodMs + nb::cast(cpp_states[20]), // MaxSeqIdleMicroseconds + nb::cast>(cpp_states[21]), // SpecDecConfig + nb::cast>(cpp_states[22]), // GuidedDecodingConfig + nb::cast>>(cpp_states[23]), // AdditionalModelOutputs + nb::cast>(cpp_states[24]), // CacheTransceiverConfig + nb::cast(cpp_states[25]), // GatherGenerationLogits + nb::cast(cpp_states[26]), // PromptTableOffloading + nb::cast(cpp_states[27]) // EnableTrtOverlap + ); + + // Restore Python data + auto py_state = nb::cast(state[1]); + self.attr("__dict__").attr("update")(py_state); + + nb::inst_mark_ready(self); + }; + + nb::class_(m, "ExecutorConfig", nb::dynamic_attr()) + .def(nb::init< // + SizeType32, // MaxBeamWidth + tle::SchedulerConfig const&, // SchedulerConfig + tle::KvCacheConfig const&, // KvCacheConfig + bool, // EnableChunkedContext + bool, // NormalizeLogProbs + SizeType32, // IterStatsMaxIterations + SizeType32, // RequestStatsMaxIterations + tle::BatchingType, // BatchingType + std::optional, // MaxBatchSize + std::optional, // MaxNumTokens + std::optional, // ParallelConfig + tle::PeftCacheConfig const&, // PeftCacheConfig + std::optional, // LogitsPostProcessorConfig + std::optional, // DecodingConfig + bool, // UseGpuDirectStorage + float, // GpuWeightsPercent + std::optional, // MaxQueueSize + tle::ExtendedRuntimePerfKnobConfig const&, // ExtendedRuntimePerfKnobConfig + std::optional, // DebugConfig + SizeType32, // RecvPollPeriodMs + uint64_t, // MaxSeqIdleMicroseconds + std::optional, // SpecDecConfig + std::optional, // GuidedDecodingConfig + std::optional>, // AdditionalModelOutputs + std::optional, // CacheTransceiverConfig + bool, // GatherGenerationLogits + bool, // PromptTableOffloading + bool // EnableTrtOverlap + >(), + nb::arg("max_beam_width") = 1, nb::arg("scheduler_config") = tle::SchedulerConfig(), + nb::arg("kv_cache_config") = tle::KvCacheConfig(), nb::arg("enable_chunked_context") = false, + nb::arg("normalize_log_probs") = true, + nb::arg("iter_stats_max_iterations") = tle::ExecutorConfig::kDefaultIterStatsMaxIterations, + nb::arg("request_stats_max_iterations") = tle::ExecutorConfig::kDefaultRequestStatsMaxIterations, + nb::arg("batching_type") = tle::BatchingType::kINFLIGHT, nb::arg("max_batch_size") = nb::none(), + nb::arg("max_num_tokens") = nb::none(), nb::arg("parallel_config") = nb::none(), + nb::arg("peft_cache_config") = tle::PeftCacheConfig(), nb::arg("logits_post_processor_config") = nb::none(), + nb::arg("decoding_config") = nb::none(), nb::arg("use_gpu_direct_storage") = false, + nb::arg("gpu_weights_percent") = 1.0, nb::arg("max_queue_size") = nb::none(), + nb::arg("extended_runtime_perf_knob_config") = tle::ExtendedRuntimePerfKnobConfig(), + nb::arg("debug_config") = nb::none(), nb::arg("recv_poll_period_ms") = 0, + nb::arg("max_seq_idle_microseconds") = tle::ExecutorConfig::kDefaultMaxSeqIdleMicroseconds, + nb::arg("spec_dec_config") = nb::none(), nb::arg("guided_decoding_config") = nb::none(), + nb::arg("additional_model_outputs") = nb::none(), nb::arg("cache_transceiver_config") = nb::none(), + nb::arg("gather_generation_logits") = false, nb::arg("mm_embedding_offloading") = false, + nb::arg("enable_trt_overlap") = false) + .def_prop_rw("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth) + .def_prop_rw("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize) + .def_prop_rw("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens) + .def_prop_rw( + "scheduler_config", &tle::ExecutorConfig::getSchedulerConfigRef, &tle::ExecutorConfig::setSchedulerConfig) + .def_prop_rw( + "kv_cache_config", &tle::ExecutorConfig::getKvCacheConfigRef, &tle::ExecutorConfig::setKvCacheConfig) + .def_prop_rw("enable_chunked_context", &tle::ExecutorConfig::getEnableChunkedContext, + &tle::ExecutorConfig::setEnableChunkedContext) + .def_prop_rw("normalize_log_probs", &tle::ExecutorConfig::getNormalizeLogProbs, + &tle::ExecutorConfig::setNormalizeLogProbs) + .def_prop_rw("iter_stats_max_iterations", &tle::ExecutorConfig::getIterStatsMaxIterations, + &tle::ExecutorConfig::setIterStatsMaxIterations) + .def_prop_rw("request_stats_max_iterations", &tle::ExecutorConfig::getRequestStatsMaxIterations, + &tle::ExecutorConfig::setRequestStatsMaxIterations) + .def_prop_rw("batching_type", &tle::ExecutorConfig::getBatchingType, &tle::ExecutorConfig::setBatchingType) + .def_prop_rw( + "parallel_config", &tle::ExecutorConfig::getParallelConfig, &tle::ExecutorConfig::setParallelConfig) + .def_prop_rw( + "peft_cache_config", &tle::ExecutorConfig::getPeftCacheConfig, &tle::ExecutorConfig::setPeftCacheConfig) + .def_prop_rw("logits_post_processor_config", &tle::ExecutorConfig::getLogitsPostProcessorConfig, + &tle::ExecutorConfig::setLogitsPostProcessorConfig) + .def_prop_rw( + "decoding_config", &tle::ExecutorConfig::getDecodingConfig, &tle::ExecutorConfig::setDecodingConfig) + .def_prop_rw("use_gpu_direct_storage", &tle::ExecutorConfig::getUseGpuDirectStorage, + &tle::ExecutorConfig::setUseGpuDirectStorage) + .def_prop_rw("gpu_weights_percent", &tle::ExecutorConfig::getGpuWeightsPercent, + &tle::ExecutorConfig::setGpuWeightsPercent) + .def_prop_rw("max_queue_size", &tle::ExecutorConfig::getMaxQueueSize, &tle::ExecutorConfig::setMaxQueueSize) + .def_prop_rw("extended_runtime_perf_knob_config", &tle::ExecutorConfig::getExtendedRuntimePerfKnobConfig, + &tle::ExecutorConfig::setExtendedRuntimePerfKnobConfig) + .def_prop_rw("debug_config", &tle::ExecutorConfig::getDebugConfig, &tle::ExecutorConfig::setDebugConfig) + .def_prop_rw( + "recv_poll_period_ms", &tle::ExecutorConfig::getRecvPollPeriodMs, &tle::ExecutorConfig::setRecvPollPeriodMs) + .def_prop_rw("max_seq_idle_microseconds", &tle::ExecutorConfig::getMaxSeqIdleMicroseconds, + &tle::ExecutorConfig::setMaxSeqIdleMicroseconds) + .def_prop_rw("spec_dec_config", &tle::ExecutorConfig::getSpecDecConfig, &tle::ExecutorConfig::setSpecDecConfig) + .def_prop_rw("guided_decoding_config", &tle::ExecutorConfig::getGuidedDecodingConfig, + &tle::ExecutorConfig::setGuidedDecodingConfig) + .def_prop_rw("additional_model_outputs", &tle::ExecutorConfig::getAdditionalModelOutputs, + &tle::ExecutorConfig::setAdditionalModelOutputs) + .def_prop_rw("cache_transceiver_config", &tle::ExecutorConfig::getCacheTransceiverConfig, + &tle::ExecutorConfig::setCacheTransceiverConfig) + .def_prop_rw("gather_generation_logits", &tle::ExecutorConfig::getGatherGenerationLogits, + &tle::ExecutorConfig::setGatherGenerationLogits) + .def_prop_rw("mm_embedding_offloading", &tle::ExecutorConfig::getPromptTableOffloading, + &tle::ExecutorConfig::setPromptTableOffloading) + .def_prop_rw( + "enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap) + .def("__getstate__", executorConfigGetState) + .def("__setstate__", executorConfigSetState); +} + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.h b/cpp/tensorrt_llm/nanobind/executor/executorConfig.h new file mode 100644 index 00000000000..5b63e7c5a3e --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.h @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::executor +{ + +// Register bindings for executor API. +void initConfigBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp new file mode 100644 index 00000000000..9c3d34aa8fd --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -0,0 +1,935 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "request.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/serializeUtils.h" +#include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/executor/types.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/cudaStream.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; +using Tensor = tle::Tensor; +using SizeType32 = tle::SizeType32; +using FloatType = tle::FloatType; +using VecTokens = tle::VecTokens; +using IdType = tle::IdType; +using VecTokenExtraIds = tle::VecTokenExtraIds; + +namespace tensorrt_llm::nanobind::executor +{ + +void initRequestBindings(nb::module_& m) +{ + nb::enum_(m, "RequestType") + .value("REQUEST_TYPE_CONTEXT_AND_GENERATION", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION) + .value("REQUEST_TYPE_CONTEXT_ONLY", tle::RequestType::REQUEST_TYPE_CONTEXT_ONLY) + .value("REQUEST_TYPE_GENERATION_ONLY", tle::RequestType::REQUEST_TYPE_GENERATION_ONLY); + + nb::enum_(m, "FinishReason") + .value("NOT_FINISHED", tle::FinishReason::kNOT_FINISHED) + .value("END_ID", tle::FinishReason::kEND_ID) + .value("STOP_WORDS", tle::FinishReason::kSTOP_WORDS) + .value("LENGTH", tle::FinishReason::kLENGTH) + .value("TIMED_OUT", tle::FinishReason::kTIMED_OUT) + .value("CANCELLED", tle::FinishReason::kCANCELLED); + + nb::enum_(m, "KvCacheTransferMode") + .value("DRAM", tle::KvCacheTransferMode::DRAM) + .value("GDS", tle::KvCacheTransferMode::GDS) + .value("POSIX_DEBUG_FALLBACK", tle::KvCacheTransferMode::POSIX_DEBUG_FALLBACK); + + auto samplingConfigGetstate = [](tle::SamplingConfig const& self) + { + return nb::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(), + self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(), + self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(), + self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(), + self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray()); + }; + auto samplingConfigSetstate = [](tle::SamplingConfig& samplingConfig, nb::tuple const& state) + { + if (state.size() != 19) + { + throw std::runtime_error("Invalid SamplingConfig state!"); + } + new (&samplingConfig) tle::SamplingConfig(nb::cast(state[0]), // BeamWidth + nb::cast>(state[1]), // TopK + nb::cast>(state[2]), // TopP + nb::cast>(state[3]), // TopPMin + nb::cast>(state[4]), // TopPResetIds + nb::cast>(state[5]), // TopPDecay + nb::cast>(state[6]), // Seed + nb::cast>(state[7]), // Temperature + nb::cast>(state[8]), // MinTokens + nb::cast>(state[9]), // BeamSearchDiversityRate + nb::cast>(state[10]), // RepetitionPenalty + nb::cast>(state[11]), // PresencePenalty + nb::cast>(state[12]), // FrequencyPenalty + nb::cast>(state[13]), // LengthPenalty + nb::cast>(state[14]), // EarlyStopping + nb::cast>(state[15]), // NoRepeatNgramSize + nb::cast>(state[16]), // NumReturnSequences + nb::cast>(state[17]), // MinP + nb::cast>>(state[18]) // BeamWidthArray + ); + }; + nb::class_(m, "SamplingConfig") + .def(nb::init const&, // beamWidth + std::optional const&, // topP + std::optional const&, // topPMin + std::optional const&, // topPResetIds + std::optional const&, // topPDecay + std::optional const&, // seed + std::optional const&, // temperature + std::optional const&, // minTokens + std::optional const&, // beamSearchDiversityRate + std::optional const&, // repetitionPenalty + std::optional const&, // presencePenalty + std::optional const&, // frequencyPenalty + std::optional const&, // lengthPenalty + std::optional const&, // earlyStopping + std::optional const&, // noRepeatNgramSize + std::optional const&, // numReturnSequences + std::optional const&, // minP + std::optional> const& // beamWidthArray + >(), + // clang-format off + nb::arg("beam_width") = 1, + nb::kw_only(), + nb::arg("top_k") = nb::none(), + nb::arg("top_p") = nb::none(), + nb::arg("top_p_min") = nb::none(), + nb::arg("top_p_reset_ids") = nb::none(), + nb::arg("top_p_decay") = nb::none(), + nb::arg("seed") = nb::none(), + nb::arg("temperature") = nb::none(), + nb::arg("min_tokens") = nb::none(), + nb::arg("beam_search_diversity_rate") = nb::none(), + nb::arg("repetition_penalty") = nb::none(), + nb::arg("presence_penalty") = nb::none(), + nb::arg("frequency_penalty") = nb::none(), + nb::arg("length_penalty") = nb::none(), + nb::arg("early_stopping") = nb::none(), + nb::arg("no_repeat_ngram_size") = nb::none(), + nb::arg("num_return_sequences") = nb::none(), + nb::arg("min_p") = nb::none(), + nb::arg("beam_width_array") = nb::none()) // clang-format on + .def_prop_rw("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth) + .def_prop_rw("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK) + .def_prop_rw("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP) + .def_prop_rw("top_p_min", &tle::SamplingConfig::getTopPMin, &tle::SamplingConfig::setTopPMin) + .def_prop_rw("top_p_reset_ids", &tle::SamplingConfig::getTopPResetIds, &tle::SamplingConfig::setTopPResetIds) + .def_prop_rw("top_p_decay", &tle::SamplingConfig::getTopPDecay, &tle::SamplingConfig::setTopPDecay) + .def_prop_rw("seed", &tle::SamplingConfig::getSeed, &tle::SamplingConfig::setSeed) + .def_prop_rw("temperature", &tle::SamplingConfig::getTemperature, &tle::SamplingConfig::setTemperature) + .def_prop_rw("min_tokens", &tle::SamplingConfig::getMinTokens, &tle::SamplingConfig::setMinTokens) + .def_prop_rw("beam_search_diversity_rate", &tle::SamplingConfig::getBeamSearchDiversityRate, + &tle::SamplingConfig::setBeamSearchDiversityRate) + .def_prop_rw("repetition_penalty", &tle::SamplingConfig::getRepetitionPenalty, + &tle::SamplingConfig::setRepetitionPenalty) + .def_prop_rw("presence_penalty", &tle::SamplingConfig::getPresencePenalty, + [](tle::SamplingConfig& self, std::optional v) { self.setPresencePenalty(v); }) + .def_prop_rw( + "frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty) + .def_prop_rw("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty) + .def_prop_rw("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping) + .def_prop_rw("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize, + &tle::SamplingConfig::setNoRepeatNgramSize) + .def_prop_rw("num_return_sequences", &tle::SamplingConfig::getNumReturnSequences, + &tle::SamplingConfig::setNumReturnSequences) + .def_prop_rw("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP) + .def_prop_rw( + "beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray) + .def("__getstate__", samplingConfigGetstate) + .def("__setstate__", samplingConfigSetstate); + + auto additionalModelOutputGetstate + = [](tle::AdditionalModelOutput const& self) { return nb::make_tuple(self.name, self.gatherContext); }; + auto additionalModelOutputSetstate = [](tle::AdditionalModelOutput& additionalModelOutput, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid AdditionalModelOutput state!"); + } + new (&additionalModelOutput) + tle::AdditionalModelOutput(nb::cast(state[0]), nb::cast(state[1])); + }; + nb::class_(m, "AdditionalModelOutput") + .def(nb::init(), nb::arg("name"), nb::arg("gather_context") = false) + .def_rw("name", &tle::AdditionalModelOutput::name) + .def_rw("gather_context", &tle::AdditionalModelOutput::gatherContext) + .def("__getstate__", additionalModelOutputGetstate) + .def("__setstate__", additionalModelOutputSetstate); + + auto outputConfigGetstate = [](tle::OutputConfig const& self) + { + return nb::make_tuple(self.returnLogProbs, self.returnContextLogits, self.returnGenerationLogits, + self.excludeInputFromOutput, self.returnEncoderOutput, self.returnPerfMetrics, self.additionalModelOutputs); + }; + auto outputConfigSetstate = [](tle::OutputConfig& outputConfig, nb::tuple const& state) + { + if (state.size() != 7) + { + throw std::runtime_error("Invalid OutputConfig state!"); + } + new (&outputConfig) tle::OutputConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), nb::cast(state[5]), + nb::cast>>(state[6])); + }; + nb::class_(m, "OutputConfig") + .def(nb::init>>(), + nb::arg("return_log_probs").none() = false, nb::arg("return_context_logits") = false, + nb::arg("return_generation_logits") = false, nb::arg("exclude_input_from_output") = false, + nb::arg("return_encoder_output") = false, nb::arg("return_perf_metrics") = false, + nb::arg("additional_model_outputs") = nb::none()) + .def_rw("return_log_probs", &tle::OutputConfig::returnLogProbs) + .def_rw("return_context_logits", &tle::OutputConfig::returnContextLogits) + .def_rw("return_generation_logits", &tle::OutputConfig::returnGenerationLogits) + .def_rw("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput) + .def_rw("return_encoder_output", &tle::OutputConfig::returnEncoderOutput) + .def_rw("return_perf_metrics", &tle::OutputConfig::returnPerfMetrics) + .def_rw("additional_model_outputs", &tle::OutputConfig::additionalModelOutputs) + .def("__getstate__", outputConfigGetstate) + .def("__setstate__", outputConfigSetstate); + + auto externalDraftTokensConfigGetstate = [](tle::ExternalDraftTokensConfig const& self) + { return nb::make_tuple(self.getTokens(), self.getLogits(), self.getAcceptanceThreshold()); }; + auto externalDraftTokensConfigSetstate + = [](tle::ExternalDraftTokensConfig& externalDraftTokensConfig, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid ExternalDraftTokensConfig state!"); + } + new (&externalDraftTokensConfig) tle::ExternalDraftTokensConfig(nb::cast(state[0]), + nb::cast>(state[1]), nb::cast>(state[2])); + }; + nb::class_(m, "ExternalDraftTokensConfig") + .def(nb::init, std::optional const&, std::optional>(), + nb::arg("tokens"), nb::arg("logits") = nb::none(), nb::arg("acceptance_threshold") = nb::none(), + nb::arg("fast_logits") = nb::none()) + .def_prop_ro("tokens", &tle::ExternalDraftTokensConfig::getTokens) + .def_prop_ro("logits", &tle::ExternalDraftTokensConfig::getLogits) + .def_prop_ro("acceptance_threshold", &tle::ExternalDraftTokensConfig::getAcceptanceThreshold) + .def("__getstate__", externalDraftTokensConfigGetstate) + .def("__setstate__", externalDraftTokensConfigSetstate) + .def_prop_ro("fast_logits", &tle::ExternalDraftTokensConfig::getFastLogits); + + auto promptTuningConfigGetstate = [](tle::PromptTuningConfig const& self) + { return nb::make_tuple(self.getEmbeddingTable(), self.getInputTokenExtraIds()); }; + auto promptTuningConfigSetstate = [](tle::PromptTuningConfig& promptTuningConfig, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid PromptTuningConfig state!"); + } + new (&promptTuningConfig) + tle::PromptTuningConfig(nb::cast(state[0]), nb::cast>(state[1])); + }; + nb::class_(m, "PromptTuningConfig") + .def(nb::init>(), nb::arg("embedding_table"), + nb::arg("input_token_extra_ids") = nb::none()) + .def_prop_ro("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable) + .def_prop_ro("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds) + .def("__getstate__", promptTuningConfigGetstate) + .def("__setstate__", promptTuningConfigSetstate); + + auto loraConfigGetstate = [](tle::LoraConfig const& self) + { return nb::make_tuple(self.getTaskId(), self.getWeights(), self.getConfig()); }; + auto loraConfigSetstate = [](tle::LoraConfig& loraConfig, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid LoraConfig state!"); + } + new (&loraConfig) tle::LoraConfig(nb::cast(state[0]), nb::cast>(state[1]), + nb::cast>(state[2])); + }; + nb::class_(m, "LoraConfig") + .def(nb::init, std::optional>(), nb::arg("task_id"), + nb::arg("weights") = nb::none(), nb::arg("config") = nb::none()) + .def_prop_ro("task_id", &tle::LoraConfig::getTaskId) + .def_prop_ro("weights", &tle::LoraConfig::getWeights) + .def_prop_ro("config", &tle::LoraConfig::getConfig) + .def("__getstate__", loraConfigGetstate) + .def("__setstate__", loraConfigSetstate); + + auto multimodalInputGetstate = [](tle::MultimodalInput const& self) + { return nb::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths()); }; + auto multimodalInputSetstate = [](tle::MultimodalInput& multimodalInput, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid MultimodalInput state!"); + } + new (&multimodalInput) tle::MultimodalInput(nb::cast>>(state[0]), + nb::cast>(state[1]), nb::cast>(state[2])); + }; + nb::class_(m, "MultimodalInput") + .def(nb::init>, std::vector, std::vector>(), + nb::arg("multimodal_hashes"), nb::arg("multimodal_positions"), nb::arg("multimodal_lengths")) + .def_prop_ro("multimodal_hashes", &tle::MultimodalInput::getMultimodalHashes) + .def_prop_ro("multimodal_positions", &tle::MultimodalInput::getMultimodalPositions) + .def_prop_ro("multimodal_lengths", &tle::MultimodalInput::getMultimodalLengths) + .def("__getstate__", multimodalInputGetstate) + .def("__setstate__", multimodalInputSetstate); + + auto MropeConfigGetstate = [](tle::MropeConfig const& self) + { return nb::make_tuple(self.getMRopeRotaryCosSin(), self.getMRopePositionDeltas()); }; + auto MropeConfigSetstate = [](tle::MropeConfig& mropeConfig, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid MropeConfig state!"); + } + new (&mropeConfig) tle::MropeConfig(nb::cast(state[0]), nb::cast(state[1])); + }; + nb::class_(m, "MropeConfig") + .def(nb::init(), nb::arg("mrope_rotary_cos_sin"), nb::arg("mrope_position_deltas")) + .def_prop_ro("mrope_rotary_cos_sin", &tle::MropeConfig::getMRopeRotaryCosSin) + .def_prop_ro("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas) + .def("__getstate__", MropeConfigGetstate) + .def("__setstate__", MropeConfigSetstate); + + auto lookaheadDecodingConfigGetstate = [](tle::LookaheadDecodingConfig const& self) + { return nb::make_tuple(self.getWindowSize(), self.getNgramSize(), self.getVerificationSetSize()); }; + auto lookaheadDecodingConfigSetstate + = [](tle::LookaheadDecodingConfig& lookaheadDecodingConfig, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid LookaheadDecodingConfig state!"); + } + new (&lookaheadDecodingConfig) tle::LookaheadDecodingConfig( + nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])); + }; + nb::class_(m, "LookaheadDecodingConfig") + .def(nb::init(), nb::arg("max_window_size"), nb::arg("max_ngram_size"), + nb::arg("max_verification_set_size")) + .def_prop_ro("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize) + .def_prop_ro("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize) + .def_prop_ro("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize) + .def("calculate_speculative_resource", &tle::LookaheadDecodingConfig::calculateSpeculativeResource) + .def_static( + "calculate_speculative_resource_tuple", &tle::LookaheadDecodingConfig::calculateSpeculativeResourceTuple) + .def("__getstate__", lookaheadDecodingConfigGetstate) + .def("__setstate__", lookaheadDecodingConfigSetstate) + .def_static("get_default_lookahead_decoding_window", + []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingWindow; }) + .def_static("get_default_lookahead_decoding_ngram", + []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingNgram; }) + .def_static("get_default_lookahead_decoding_verification_set", + []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingVerificationSet; }); + + auto TokenRangeRetentionConfigGetstate = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig const& self) + { return nb::make_tuple(self.tokenStart, self.tokenEnd, self.priority, self.durationMs); }; + auto TokenRangeRetentionConfigSetstate + = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig& tokenRangeRetentionConfig, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid state!"); + } + new (&tokenRangeRetentionConfig) tle::KvCacheRetentionConfig::TokenRangeRetentionConfig( + nb::cast(state[0]), nb::cast>(state[1]), + nb::cast(state[2]), nb::cast>(state[3])); + }; + auto kvCacheRetentionConfigGetstate = [](tle::KvCacheRetentionConfig const& self) + { + return nb::make_tuple(self.getTokenRangeRetentionConfigs(), self.getDecodeRetentionPriority(), + self.getDecodeDurationMs(), self.getTransferMode(), self.getDirectory()); + }; + auto kvCacheRetentionConfigSetstate + = [](tle::KvCacheRetentionConfig& kvCacheRetentionConfig, nb::tuple const& state) + { + if (state.size() != 5) + { + throw std::runtime_error("Invalid state!"); + } + new (&kvCacheRetentionConfig) tle::KvCacheRetentionConfig( + nb::cast>(state[0]), + nb::cast(state[1]), nb::cast>(state[2]), + nb::cast(state[3]), nb::cast>(state[4])); + }; + + auto kvCacheRetentionConfig = nb::class_(m, "KvCacheRetentionConfig"); + + nb::class_( + kvCacheRetentionConfig, "TokenRangeRetentionConfig") + .def(nb::init, tle::RetentionPriority, + std::optional>(), + nb::arg("token_start"), nb::arg("token_end"), nb::arg("priority"), nb::arg("duration_ms") = nb::none()) + .def_rw("token_start", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenStart) + .def_rw("token_end", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenEnd) + .def_rw("priority", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::priority) + .def_rw("duration_ms", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::durationMs) + .def("__getstate__", TokenRangeRetentionConfigGetstate) + .def("__setstate__", TokenRangeRetentionConfigSetstate) + .def("__eq__", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::operator==); + + // There's a circular dependency between the declaration of the TokenRangeRetentionPriority and + // KvCacheRetentionConfig bindings. Defer definition of the KvCacheRetentionConfig bindings until the + // TokenRangeRetentionPriority bindings have been defined. + kvCacheRetentionConfig + .def(nb::init, tle::RetentionPriority, + std::optional, tle::KvCacheTransferMode, std::optional>(), + nb::arg("token_range_retention_configs"), + nb::arg("decode_retention_priority") = tle::KvCacheRetentionConfig::kDefaultRetentionPriority, + nb::arg("decode_duration_ms") = nb::none(), nb::arg("transfer_mode") = tle::KvCacheTransferMode::DRAM, + nb::arg("directory") = nb::none()) + .def_prop_ro("token_range_retention_configs", &tle::KvCacheRetentionConfig::getTokenRangeRetentionConfigs) + .def_prop_ro("decode_retention_priority", &tle::KvCacheRetentionConfig::getDecodeRetentionPriority) + .def_prop_ro("decode_duration_ms", &tle::KvCacheRetentionConfig::getDecodeDurationMs) + .def_prop_ro("transfer_mode", &tle::KvCacheRetentionConfig::getTransferMode) + .def_prop_ro("directory", &tle::KvCacheRetentionConfig::getDirectory) + .def("__getstate__", kvCacheRetentionConfigGetstate) + .def("__setstate__", kvCacheRetentionConfigSetstate) + .def("__eq__", &tle::KvCacheRetentionConfig::operator==); + + auto ContextPhaseParamsGetState = [](tle::ContextPhaseParams const& self) + { + if (self.getState() != nullptr) + { + auto serializedState = self.getSerializedState(); + return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), + nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens()); + } + return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens()); + }; + + auto ContextPhaseParamsSetState = [](tle::ContextPhaseParams& contextPhaseParams, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid ContextPhaseParams state!"); + } + if (!state[2].is_none()) + { + auto opaque_state = nb::cast(state[2]); + auto opaque_state_str_view = std::string_view(opaque_state.c_str(), opaque_state.size()); + new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), + nb::cast(state[1]), + std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), + nb::cast>(state[3])); + } + new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), + nb::cast(state[1]), nb::cast>(state[3])); + }; + + nb::class_(m, "ContextPhaseParams") + .def("__init__", + [](tle::ContextPhaseParams const& self, VecTokens const& first_gen_tokens, + tle::ContextPhaseParams::RequestIdType req_id, std::optional const& opaque_state, + std::optional const& draft_tokens) + { + if (opaque_state) + { + auto opaque_state_str_view + = std::string_view(opaque_state.value().c_str(), opaque_state.value().size()); + return std::make_unique(first_gen_tokens, req_id, + std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); + } + return std::make_unique(first_gen_tokens, req_id, draft_tokens); + }) + .def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); }) + .def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); }) + .def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId) + .def_prop_ro("opaque_state", + [](tle::ContextPhaseParams const& self) + { + std::optional opaque_state{std::nullopt}; + if (self.getState() != nullptr) + { + auto serializedState = self.getSerializedState(); + opaque_state = nb::bytes(serializedState.data(), serializedState.size()); + } + return opaque_state; + }) + .def("__getstate__", ContextPhaseParamsGetState) + .def("__setstate__", ContextPhaseParamsSetState); + + auto EagleDecodingConfigGetstate = [](tle::EagleConfig const& self) + { + return nb::make_tuple(self.getEagleChoices(), self.isGreedySampling(), self.getPosteriorThreshold(), + self.useDynamicTree(), self.getDynamicTreeMaxTopK()); + }; + auto EagleDecodingConfigSetstate = [](tle::EagleConfig& eagleConfig, nb::tuple const& state) + { + if (state.size() != 5) + { + throw std::runtime_error("Invalid EagleConfig state!"); + } + new (&eagleConfig) tle::EagleConfig(nb::cast>(state[0]), + nb::cast(state[1]), nb::cast>(state[2]), nb::cast(state[3]), + nb::cast>(state[4])); + }; + nb::class_(m, "EagleConfig") + .def(nb::init, bool, std::optional, bool, std::optional>(), + nb::arg("eagle_choices") = nb::none(), nb::arg("greedy_sampling") = true, + nb::arg("posterior_threshold") = nb::none(), nb::arg("use_dynamic_tree") = false, + nb::arg("dynamic_tree_max_topK") = nb::none()) + .def_prop_ro("eagle_choices", &tle::EagleConfig::getEagleChoices) + .def_prop_ro("greedy_sampling", &tle::EagleConfig::isGreedySampling) + .def_prop_ro("posterior_threshold", &tle::EagleConfig::getPosteriorThreshold) + .def_prop_ro("use_dynamic_tree", &tle::EagleConfig::useDynamicTree) + .def_prop_ro("dynamic_tree_max_topK", &tle::EagleConfig::getDynamicTreeMaxTopK) + .def("__getstate__", EagleDecodingConfigGetstate) + .def("__setstate__", EagleDecodingConfigSetstate); + + // Guided decoding params + auto pyGuidedDecodingParams = nb::class_(m, "GuidedDecodingParams"); + + nb::enum_(pyGuidedDecodingParams, "GuideType") + .value("JSON", tle::GuidedDecodingParams::GuideType::kJSON) + .value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA) + .value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX) + .value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR) + .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); + + auto guidedDecodingParamsGetstate + = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide()); }; + + auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& guidedDecodingParams, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid GuidedDecodingParams state!"); + } + new (&guidedDecodingParams) tle::GuidedDecodingParams( + nb::cast(state[0]), nb::cast>(state[1])); + }; + + pyGuidedDecodingParams + .def(nb::init>(), nb::arg("guide_type"), + nb::arg("guide") = nb::none()) + .def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType) + .def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide) + .def("__getstate__", guidedDecodingParamsGetstate) + .def("__setstate__", guidedDecodingParamsSetstate); + + auto requestGetstate = [](tle::Request const& self) + { + return nb::make_tuple(self.getInputTokenIds(), self.getMaxTokens(), self.getStreaming(), + self.getSamplingConfig(), self.getOutputConfig(), self.getEndId(), self.getPadId(), self.getPositionIds(), + self.getBadWords(), self.getStopWords(), self.getEmbeddingBias(), self.getExternalDraftTokensConfig(), + self.getPromptTuningConfig(), self.getMultimodalInput(), self.getMultimodalEmbedding(), + self.getMropeConfig(), self.getLoraConfig(), self.getLookaheadConfig(), self.getKvCacheRetentionConfig(), + self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(), + self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), + self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), + self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), + self.getGuidedDecodingParams()); + }; + auto requestSetstate = [](tle::Request& request, nb::tuple const& state) + { + if (state.size() != 33) + { + throw std::runtime_error("Invalid Request state!"); + } + new (&request) tle::Request(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), + nb::cast>(state[5]), nb::cast>(state[6]), + nb::cast>>(state[7]), + nb::cast>>(state[8]), + nb::cast>>(state[9]), nb::cast>(state[10]), + nb::cast>(state[11]), + nb::cast>(state[12]), + nb::cast>(state[13]), nb::cast>(state[14]), + nb::cast>(state[15]), nb::cast>(state[16]), + nb::cast>(state[17]), + nb::cast>(state[18]), + nb::cast>(state[19]), + nb::cast>(state[20]), nb::cast>(state[21]), + nb::cast>(state[22]), nb::cast(state[23]), + nb::cast(state[24]), nb::cast(state[25]), + nb::cast>(state[26]), + nb::cast>(state[27]), nb::cast>(state[28]), + nb::cast>(state[29]), 1, nb::cast>(state[30]), + nb::cast>(state[31]), + nb::cast>(state[32])); + }; + + nb::class_ request(m, "Request", nb::dynamic_attr()); + request + .def(nb::init const&, // endId + std::optional const&, // padId + std::optional>, // positionIds + std::optional>, // badWords + std::optional>, // stopWords + std::optional, // embeddingBias + std::optional, // externalDraftTokensConfig + std::optional, // pTuningConfig + std::optional, // multimodalInput + std::optional, // multimodalEmbedding + std::optional, // mRopeConfig + std::optional, // loraConfig + std::optional, // lookaheadConfig + std::optional, // kvCacheRetentionConfig + std::optional, // logitsPostProcessorName + std::optional, // logitsPostProcessor + std::optional, // encoderInputTokenIds + std::optional, // clientId + bool, // returnAllGeneratedTokens + tle::PriorityType, // priority + tle::RequestType, // type + std::optional, // contextPhaseParams + std::optional, // encoderInputFeatures + std::optional, // encoderOutputLength + std::optional, // crossAttentionMask + SizeType32, // numReturnSequences + std::optional, // eagleConfig + std::optional, // skipCrossAttnBlocks + std::optional, // guidedDecodingParams + std::optional, // languageAdapterUid + std::optional // allottedTimeMs + >(), + // clang-format off + nb::arg("input_token_ids"), + nb::arg("max_tokens"), + nb::kw_only(), + nb::arg("streaming") = false, + nb::arg("sampling_config") = tle::SamplingConfig(), + nb::arg("output_config") = tle::OutputConfig(), + nb::arg("end_id") = nb::none(), + nb::arg("pad_id") = nb::none(), + nb::arg("position_ids") = nb::none(), + nb::arg("bad_words") = nb::none(), + nb::arg("stop_words") = nb::none(), + nb::arg("embedding_bias") = nb::none(), + nb::arg("external_draft_tokens_config") = nb::none(), + nb::arg("prompt_tuning_config") = nb::none(), + nb::arg("multimodal_input") = nb::none(), + nb::arg("multimodal_embedding") = nb::none(), + nb::arg("mrope_config") = nb::none(), + nb::arg("lora_config") = nb::none(), + nb::arg("lookahead_config") = nb::none(), + nb::arg("kv_cache_retention_config") = nb::none(), + nb::arg("logits_post_processor_name") = nb::none(), + nb::arg("logits_post_processor") = nb::none(), + nb::arg("encoder_input_token_ids") = nb::none(), + nb::arg("client_id") = nb::none(), + nb::arg("return_all_generated_tokens") = false, + nb::arg("priority") = tle::Request::kDefaultPriority, + nb::arg("type") = tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION, + nb::arg("context_phase_params") = nb::none(), + nb::arg("encoder_input_features") = nb::none(), + nb::arg("encoder_output_length") = nb::none(), + nb::arg("cross_attention_mask") = nb::none(), + nb::arg("num_return_sequences") = 1, + nb::arg("eagle_config") = nb::none(), + nb::arg("skip_cross_attn_blocks") = nb::none(), + nb::arg("guided_decoding_params") = nb::none(), + nb::arg("language_adapter_uid") = nb::none(), + nb::arg("allotted_time_ms") = nb::none() + ) // clang-format on + .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) + .def_prop_ro("max_tokens", &tle::Request::getMaxTokens) + .def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) + .def_prop_rw("sampling_config", &tle::Request::getSamplingConfig, &tle::Request::setSamplingConfig) + .def_prop_rw("output_config", &tle::Request::getOutputConfig, &tle::Request::setOutputConfig) + .def_prop_rw("end_id", &tle::Request::getEndId, &tle::Request::setEndId) + .def_prop_rw("pad_id", &tle::Request::getPadId, &tle::Request::setPadId) + .def_prop_rw("position_ids", &tle::Request::getPositionIds, &tle::Request::setPositionIds) + .def_prop_rw("bad_words", &tle::Request::getBadWords, &tle::Request::setBadWords) + .def_prop_rw("stop_words", &tle::Request::getStopWords, &tle::Request::setStopWords) + .def_prop_rw("embedding_bias", &tle::Request::getEmbeddingBias, &tle::Request::setEmbeddingBias) + .def_prop_rw("external_draft_tokens_config", &tle::Request::getExternalDraftTokensConfig, + &tle::Request::setExternalDraftTokensConfig) + .def_prop_rw("prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig) + .def_prop_rw("multimodal_input", &tle::Request::getMultimodalInput, &tle::Request::setMultimodalInput) + .def_prop_rw( + "multimodal_embedding", &tle::Request::getMultimodalEmbedding, &tle::Request::setMultimodalEmbedding) + .def_prop_rw("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig) + .def_prop_rw("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig) + .def_prop_rw("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig) + .def_prop_rw("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig, + &tle::Request::setKvCacheRetentionConfig) + .def_prop_rw("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName, + &tle::Request::setLogitsPostProcessorName) + .def_prop_rw( + "logits_post_processor", &tle::Request::getLogitsPostProcessor, &tle::Request::setLogitsPostProcessor) + .def_prop_rw( + "encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds) + .def_prop_rw("client_id", &tle::Request::getClientId, &tle::Request::setClientId) + .def_prop_rw("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens, + &tle::Request::setReturnAllGeneratedTokens) + .def_prop_rw("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType) + .def_prop_rw( + "encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures) + .def_prop_rw("cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask) + .def_prop_rw("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig) + .def_prop_rw( + "skip_cross_attn_blocks", &tle::Request::getSkipCrossAttnBlocks, &tle::Request::setSkipCrossAttnBlocks) + .def_prop_rw( + "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) + .def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) + .def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) + .def("__getstate__", requestGetstate) + .def("__setstate__", requestSetstate); + request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName; + + nb::class_(m, "SpeculativeDecodingFastLogitsInfo") + .def(nb::init<>()) + .def_rw("draft_request_id", &tle::SpeculativeDecodingFastLogitsInfo::draftRequestId) + .def_rw("draft_participant_id", &tle::SpeculativeDecodingFastLogitsInfo::draftParticipantId) + .def("to_tensor", &tle::SpeculativeDecodingFastLogitsInfo::toTensor); + + auto requestPerfMetrics = nb::class_(m, "RequestPerfMetrics"); + + auto timingMetricsGetstate = [](tle::RequestPerfMetrics::TimingMetrics const& self) + { + return nb::make_tuple(self.arrivalTime, self.firstScheduledTime, self.firstTokenTime, self.lastTokenTime, + self.kvCacheTransferStart, self.kvCacheTransferEnd, self.kvCacheSize); + }; + auto timingMetricsSetstate = [](tle::RequestPerfMetrics::TimingMetrics& timingMetrics, nb::tuple const& state) + { + if (state.size() != 7) + { + throw std::runtime_error("Invalid TimingMetrics state!"); + } + new (&timingMetrics) + tle::RequestPerfMetrics::TimingMetrics{nb::cast(state[0]), + nb::cast(state[1]), + nb::cast(state[2]), + nb::cast(state[3]), + nb::cast(state[4]), + nb::cast(state[5]), nb::cast(state[6])}; + }; + nb::class_(m, "TimingMetrics") + .def(nb::init<>()) + .def_rw("arrival_time", &tle::RequestPerfMetrics::TimingMetrics::arrivalTime) + .def_rw("first_scheduled_time", &tle::RequestPerfMetrics::TimingMetrics::firstScheduledTime) + .def_rw("first_token_time", &tle::RequestPerfMetrics::TimingMetrics::firstTokenTime) + .def_rw("last_token_time", &tle::RequestPerfMetrics::TimingMetrics::lastTokenTime) + .def_rw("kv_cache_transfer_start", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferStart) + .def_rw("kv_cache_transfer_end", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferEnd) + .def_rw("kv_cache_size", &tle::RequestPerfMetrics::TimingMetrics::kvCacheSize) + .def("__getstate__", timingMetricsGetstate) + .def("__setstate__", timingMetricsSetstate); + + auto kvCacheMetricsGetstate = [](tle::RequestPerfMetrics::KvCacheMetrics const& self) + { + return nb::make_tuple(self.numTotalAllocatedBlocks, self.numNewAllocatedBlocks, self.numReusedBlocks, + self.numMissedBlocks, self.kvCacheHitRate); + }; + auto kvCacheMetricsSetstate = [](tle::RequestPerfMetrics::KvCacheMetrics& kvCacheMetrics, nb::tuple const& state) + { + if (state.size() != 5) + { + throw std::runtime_error("Invalid KvCacheMetrics state!"); + } + new (&kvCacheMetrics) + tle::RequestPerfMetrics::KvCacheMetrics{nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4])}; + }; + nb::class_(m, "KvCacheMetrics") + .def(nb::init<>()) + .def_rw("num_total_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numTotalAllocatedBlocks) + .def_rw("num_new_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numNewAllocatedBlocks) + .def_rw("num_reused_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numReusedBlocks) + .def_rw("num_missed_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numMissedBlocks) + .def_rw("kv_cache_hit_rate", &tle::RequestPerfMetrics::KvCacheMetrics::kvCacheHitRate) + .def("__getstate__", kvCacheMetricsGetstate) + .def("__setstate__", kvCacheMetricsSetstate); + + auto speculativeDecodingMetricsGetstate = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics const& self) + { return nb::make_tuple(self.acceptanceRate, self.totalAcceptedDraftTokens, self.totalDraftTokens); }; + auto speculativeDecodingMetricsSetstate + = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics& speculativeDecodingMetrics, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid SpeculativeDecodingMetrics state!"); + } + new (&speculativeDecodingMetrics) tle::RequestPerfMetrics::SpeculativeDecodingMetrics{ + nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])}; + }; + + nb::class_(m, "SpeculativeDecodingMetrics") + .def(nb::init<>()) + .def_rw("acceptance_rate", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::acceptanceRate) + .def_rw("total_accepted_draft_tokens", + &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalAcceptedDraftTokens) + .def_rw("total_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalDraftTokens) + .def("__getstate__", speculativeDecodingMetricsGetstate) + .def("__setstate__", speculativeDecodingMetricsSetstate); + + auto requestPerfMetricsGetstate = [](tle::RequestPerfMetrics const& self) + { + return nb::make_tuple(self.timingMetrics, self.kvCacheMetrics, self.speculativeDecoding, self.firstIter, + self.lastIter, self.iter); + }; + auto requestPerfMetricsSetstate = [](tle::RequestPerfMetrics& requestPerfMetrics, nb::tuple const& state) + { + if (state.size() != 6) + { + throw std::runtime_error("Invalid RequestPerfMetrics state!"); + } + new (&requestPerfMetrics) tle::RequestPerfMetrics{nb::cast(state[0]), + nb::cast(state[1]), + nb::cast(state[2]), + nb::cast>(state[3]), + nb::cast>(state[4]), + nb::cast>(state[5])}; + }; + + // There's a circular dependency between the declaration of the TimingMetrics and RequestPerfMetrics bindings. + // Defer definition of the RequestPerfMetrics bindings until the TimingMetrics have been defined. + requestPerfMetrics.def(nb::init<>()) + .def_rw("timing_metrics", &tle::RequestPerfMetrics::timingMetrics) + .def_rw("kv_cache_metrics", &tle::RequestPerfMetrics::kvCacheMetrics) + .def_rw("speculative_decoding", &tle::RequestPerfMetrics::speculativeDecoding) + .def_rw("first_iter", &tle::RequestPerfMetrics::firstIter) + .def_rw("last_iter", &tle::RequestPerfMetrics::lastIter) + .def_rw("iter", &tle::RequestPerfMetrics::iter) + .def("__getstate__", requestPerfMetricsGetstate) + .def("__setstate__", requestPerfMetricsSetstate); + + nb::class_(m, "AdditionalOutput") + .def("__init__ ", + [](tle::AdditionalOutput const& self, std::string const& name, tle::Tensor const& output) + { return std::make_unique(name, output); }) + .def_rw("name", &tle::AdditionalOutput::name) + .def_rw("output", &tle::AdditionalOutput::output); + + auto resultSetstate = [](tle::Result& result, nb::tuple const& state) + { + if (state.size() != 13) + { + throw std::runtime_error("Invalid Request state!"); + } + new (&result) tle::Result(); + result.isFinal = nb::cast(state[0]); + result.outputTokenIds = nb::cast>(state[1]); + result.cumLogProbs = nb::cast>>(state[2]); + result.logProbs = nb::cast>>>(state[3]); + result.contextLogits = nb::cast>(state[4]); + result.generationLogits = nb::cast>(state[5]); + result.encoderOutput = nb::cast>(state[6]); + result.finishReasons = nb::cast>(state[7]); + result.sequenceIndex = nb::cast(state[8]); + result.isSequenceFinal = nb::cast(state[9]); + result.decodingIter = nb::cast(state[10]); + result.contextPhaseParams = nb::cast>(state[11]); + result.requestPerfMetrics = nb::cast>(state[12]); + }; + + auto resultGetstate = [](tle::Result const& self) + { + return nb::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits, + self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal, + self.decodingIter, self.contextPhaseParams, self.requestPerfMetrics); + }; + + nb::class_(m, "Result") + .def(nb::init<>()) + .def_rw("is_final", &tle::Result::isFinal) + .def_rw("output_token_ids", &tle::Result::outputTokenIds) + .def_rw("cum_log_probs", &tle::Result::cumLogProbs) + .def_rw("log_probs", &tle::Result::logProbs) + .def_rw("context_logits", &tle::Result::contextLogits) + .def_rw("generation_logits", &tle::Result::generationLogits) + .def_rw("spec_dec_fast_logits_info", &tle::Result::specDecFastLogitsInfo) + .def_rw("encoder_output", &tle::Result::encoderOutput) + .def_rw("finish_reasons", &tle::Result::finishReasons) + .def_rw("sequence_index", &tle::Result::sequenceIndex) + .def_rw("is_sequence_final", &tle::Result::isSequenceFinal) + .def_rw("decoding_iter", &tle::Result::decodingIter) + .def_rw("context_phase_params", &tle::Result::contextPhaseParams) + .def_rw("request_perf_metrics", &tle::Result::requestPerfMetrics) + .def_rw("additional_outputs", &tle::Result::additionalOutputs) + .def("__getstate__", resultGetstate) + .def("__setstate__", resultSetstate); + + m.def("deserialize_result", + [](nb::bytes& x) + { + std::string str(x.c_str(), x.size()); + std::istringstream is(str); + return tle::serialize_utils::deserialize(is); + }); + + auto responseGetstate = [](tle::Response const& self) + { return nb::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); }; + + auto responseSetstate = [](tle::Response& response, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid Request state!"); + } + new (&response) tle::Response( + nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])); + }; + + nb::class_(m, "Response") + .def(nb::init>(), nb::arg("request_id"), nb::arg("error_msg"), + nb::arg("client_id") = std::nullopt) + .def(nb::init>(), nb::arg("request_id"), nb::arg("result"), + nb::arg("client_id") = std::nullopt) + .def_prop_ro("request_id", &tle::Response::getRequestId) + .def_prop_ro("client_id", &tle::Response::getClientId) + .def("has_error", &tle::Response::hasError) + .def_prop_ro("error_msg", &tle::Response::getErrorMsg) + .def_prop_ro("result", &tle::Response::getResult) + .def("clear_context_logits", + [](tle::Response& self) + { + if (!self.hasError()) + { + auto& result = const_cast(self.getResult()); + result.contextLogits.reset(); + } + }) + .def("clear_generation_logits", + [](tle::Response& self) + { + if (!self.hasError()) + { + auto& result = const_cast(self.getResult()); + result.generationLogits.reset(); + } + }) + .def("__getstate__", responseGetstate) + .def("__setstate__", responseSetstate); +} + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/request.h b/cpp/tensorrt_llm/nanobind/executor/request.h new file mode 100644 index 00000000000..5a5cf9acbee --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/request.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::executor +{ + +// Register bindings for executor API. +void initRequestBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp new file mode 100644 index 00000000000..f3be85bbbf2 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -0,0 +1,388 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include "moeBindings.h" +#include "tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h" +#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h" +#include "tensorrt_llm/kernels/customAllReduceKernels.h" +#include "tensorrt_llm/kernels/delayStream.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/cudaEvent.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/decoderState.h" +#include "tensorrt_llm/runtime/decodingInput.h" +#include "tensorrt_llm/runtime/decodingOutput.h" +#include "tensorrt_llm/runtime/gptDecoder.h" +#include "tensorrt_llm/runtime/gptDecoderBatched.h" +#include "tensorrt_llm/runtime/iBuffer.h" +#include "tensorrt_llm/runtime/iGptDecoderBatched.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/ipcUtils.h" +#include "tensorrt_llm/runtime/lookaheadBuffers.h" +#include "tensorrt_llm/runtime/loraCache.h" +#include "tensorrt_llm/runtime/mcastGPUBuffer.h" +#include "tensorrt_llm/runtime/request.h" +#include "tensorrt_llm/runtime/speculativeDecodingMode.h" +#include "tensorrt_llm/runtime/tllmRuntime.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace tr = tensorrt_llm::runtime; +namespace te = tensorrt_llm::executor; + +class PyIGptDecoder : public tr::IGptDecoder +{ +public: + NB_TRAMPOLINE(tr::IGptDecoder, 5); + + void setup(tr::SamplingConfig const& samplingConfig, size_t batchSize, + tr::DecodingInput::TensorConstPtr const& batchSlots, + std::optional const& output = std::nullopt, + std::optional explicitDraftTokensDType = std::nullopt, + std::optional> const& lookaheadPrompt = std::nullopt, + std::optional> const& lookaheadAlgoConfigs = std::nullopt) override + { + NB_OVERRIDE_PURE(setup, samplingConfig, batchSize, batchSlots, output, explicitDraftTokensDType, + lookaheadPrompt, lookaheadAlgoConfigs); + } + + void forwardAsync(tr::DecodingOutput& output, tr::DecodingInput const& input) override + { + NB_OVERRIDE_PURE(forwardAsync, output, input); + } + + void forwardSync(tr::DecodingOutput& output, tr::DecodingInput const& input) override + { + NB_OVERRIDE_PURE(forwardSync, output, input); + } + + tr::SamplingConfig const& getSamplingConfig() override + { + NB_OVERRIDE_PURE(getSamplingConfig); + } + + void disableLookahead(std::optional const& samplingConfig, tr::SizeType32 batchSize, + tr::DecodingInput::TensorConstPtr batchSlots) override + { + NB_OVERRIDE_PURE(disableLookahead, samplingConfig, batchSize, batchSlots); + } +}; + +namespace tensorrt_llm::nanobind::runtime +{ + +void initBindings(nb::module_& m) +{ + + nb::class_(m, "TaskLayerModuleConfig") + .def(nb::init<>()) + .def_rw("page_id", &tr::LoraCache::TaskLayerModuleConfig::pageId) + .def_rw("slot_idx", &tr::LoraCache::TaskLayerModuleConfig::slotIdx) + .def_rw("in_size", &tr::LoraCache::TaskLayerModuleConfig::inSize) + .def_rw("out_size", &tr::LoraCache::TaskLayerModuleConfig::outSize) + .def_rw("module_id", &tr::LoraCache::TaskLayerModuleConfig::moduleId) + .def_rw("layer_id", &tr::LoraCache::TaskLayerModuleConfig::layerId) + .def_rw("adapter_size", &tr::LoraCache::TaskLayerModuleConfig::adapterSize) + .def_rw("num_slots", &tr::LoraCache::TaskLayerModuleConfig::numSlots) + .def_rw("weights_in_pointer", &tr::LoraCache::TaskLayerModuleConfig::weightsInPointer) + .def_rw("weights_out_pointer", &tr::LoraCache::TaskLayerModuleConfig::weightsOutPointer) + .def_rw("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer) + .def(nb::self == nb::self); + + nb::class_(m, "BufferManager") + .def(nb::init(), nb::arg("stream"), nb::arg("trim_pool") = false) + .def_prop_ro("stream", &tr::BufferManager::getStream); + + nb::class_(m, "TllmRuntime") + .def( + "__init__", + [](tr::TllmRuntime* self, std::filesystem::path engine_path, float gpu_weights_percent = 1.0f, + bool use_shape_inference = true) + { + // Using default logger by passing nullptr + new (self) + tr::TllmRuntime(tr::RawEngine(engine_path), nullptr, gpu_weights_percent, use_shape_inference); + }, + nb::arg("engine_path"), nb::arg("gpu_weights_percent") = 1.0f, nb::arg("use_shape_inference") = true) + .def( + "__init__", + [](tr::TllmRuntime* self, nb::ndarray engine_buffer, float gpu_weights_percent = 1.0f, + bool use_shape_inference = true) + { + if (engine_buffer.ndim() != 1) + throw std::runtime_error("Expected 1-D array for engine buffer"); + new (self) tr::TllmRuntime(tr::RawEngine(engine_buffer.data(), engine_buffer.size()), nullptr, + gpu_weights_percent, use_shape_inference); + }, + nb::arg("engine_buffer"), nb::arg("gpu_weights_percent") = 1.0f, nb::arg("use_shape_inference") = true) + .def_prop_ro("num_contexts", &tr::TllmRuntime::getNbContexts) + .def_prop_ro("num_profiles", &tr::TllmRuntime::getNbProfiles) + .def("get_opt_profile_id", &tr::TllmRuntime::getOptProfileId, nb::arg("num_tokens"), nb::arg("split_points")) + .def("clear_contexts", &tr::TllmRuntime::clearContexts) + .def("execute_context", &tr::TllmRuntime::executeContext, nb::arg("context_id")) + .def_prop_ro("stream_ptr", &tr::TllmRuntime::getStreamPtr) + .def_prop_ro("buffer_manager", + static_cast(&tr::TllmRuntime::getBufferManager)) + .def("set_layer_profiler", &tr::TllmRuntime::setLayerProfiler) + .def("has_layer_profiler", &tr::TllmRuntime::hasLayerProfiler, nb::arg("context_id")) + .def_prop_ro("layer_profiler_info", &tr::TllmRuntime::getLayerProfileInfo) + .def("report_to_profiler", &tr::TllmRuntime::reportToProfiler, nb::arg("context_id")) + .def_prop_ro("logits_dtype_from_engine", + [](tr::TllmRuntime& self) { return self.getEngine().getTensorDataType("logits"); }); + + nb::class_(m, "Request") + .def(nb::init, + std::optional>(), + nb::arg("ids"), nb::arg("input_len"), nb::arg("max_new_tokens") = std::nullopt, + nb::arg("end_id") = std::nullopt) + .def_rw("ids", &tr::decoder_batch::Request::ids) + .def_rw("input_len", &tr::decoder_batch::Request::inputLen) + .def_rw("max_new_tokens", &tr::decoder_batch::Request::maxNewTokens) + .def_rw("end_id", &tr::decoder_batch::Request::endId) + .def_rw("draft_logits", &tr::decoder_batch::Request::draftLogits) + .def_rw("embedding_bias", &tr::decoder_batch::Request::embeddingBias) + .def_rw("bad_words_list", &tr::decoder_batch::Request::badWordsList) + .def_rw("stop_words_list", &tr::decoder_batch::Request::stopWordsList) + .def_rw("generated_tokens_per_engine_step", &tr::decoder_batch::Request::generatedTokensPerEngineStep) + .def_rw("medusa_paths", &tr::decoder_batch::Request::medusaPaths) + .def_rw("medusa_tree_ids", &tr::decoder_batch::Request::medusaTreeIds) + .def_rw("lookahead_runtime_config", &tr::decoder_batch::Request::lookaheadRuntimeConfig); + nb::bind_vector>(m, "RequestVector"); + + nb::class_(m, "DecoderBatchInput") + .def(nb::init>, tr::SizeType32>(), nb::arg("logits"), + nb::arg("max_decoding_engine_tokens")) + .def(nb::init>(), nb::arg("logits")) + .def_rw("logits", &tr::decoder_batch::Input::logits) + .def_rw("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps) + .def_rw("batch_slots", &tr::decoder_batch::Input::batchSlots); + + nb::class_(m, "LookaheadDecodingBuffers") + .def(nb::init(), nb::arg("max_num_sequences"), + nb::arg("max_tokens_per_step"), nb::arg("buffer_manager")) + .def_rw("generation_lengths", &tr::LookaheadDecodingBuffers::generationLengths) + .def_rw("position_offsets", &tr::LookaheadDecodingBuffers::positionOffsets) + .def_rw("packed_masks", &tr::LookaheadDecodingBuffers::packedMasks) + .def_rw("position_ids", &tr::LookaheadDecodingBuffers::positionIds); + + nb::class_(m, "ExplicitDraftTokensBuffersInputs") + .def("create", &tr::ExplicitDraftTokensBuffers::Inputs::create, nb::arg("max_num_sequences"), + nb::arg("runtime"), nb::arg("model_config"), nb::arg("world_config")) + .def_rw("temperatures", &tr::ExplicitDraftTokensBuffers::Inputs::temperatures) + .def_rw("position_ids_base", &tr::ExplicitDraftTokensBuffers::Inputs::positionIdsBase) + .def_rw("generation_lengths", &tr::ExplicitDraftTokensBuffers::Inputs::generationLengths) + .def_rw("random_data_sample", &tr::ExplicitDraftTokensBuffers::Inputs::randomDataSample) + .def_rw("random_data_validation", &tr::ExplicitDraftTokensBuffers::Inputs::randomDataValidation) + .def_rw("draft_tokens", &tr::ExplicitDraftTokensBuffers::Inputs::draftTokens) + .def_rw("draft_indices", &tr::ExplicitDraftTokensBuffers::Inputs::draftIndices) + .def_rw("draft_probs", &tr::ExplicitDraftTokensBuffers::Inputs::draftProbs) + .def_rw("packed_masks", &tr::ExplicitDraftTokensBuffers::Inputs::packedMasks) + .def_rw("position_ids", &tr::ExplicitDraftTokensBuffers::Inputs::positionIds) + .def_rw("max_gen_length_host", &tr::ExplicitDraftTokensBuffers::Inputs::maxGenLengthHost) + .def_rw("generation_lengths_host", &tr::ExplicitDraftTokensBuffers::Inputs::generationLengthsHost); + + nb::class_(m, "DecodingInput"); + nb::class_(m, "DecodingOutput"); + + nb::class_(m, "CudaEvent") + .def(nb::init(), nb::arg("flags") = cudaEventDisableTiming) + .def("synchronize", &tr::CudaEvent::synchronize); + + nb::class_(m, "IGptDecoder") + .def( + "setup", + [](tr::IGptDecoder& self, tr::SamplingConfig const& samplingConfig, size_t batchSize, + at::Tensor const& batchSlots, std::optional const& output = std::nullopt, + std::optional explicitDraftTokensDType = std::nullopt, + std::optional> const& lookaheadPrompt = std::nullopt, + std::optional> const& lookaheadAlgoConfigs = std::nullopt) + { + auto tensorPtrBatchSlots = tr::TorchView::of(batchSlots); + self.setup(samplingConfig, batchSize, std::move(tensorPtrBatchSlots), output, explicitDraftTokensDType, + lookaheadPrompt, lookaheadAlgoConfigs); + }, + nb::arg("sampling_config"), nb::arg("batch_size"), nb::arg("batch_slots"), nb::arg("output") = std::nullopt, + nb::arg("explicit_draft_tokens_d_type") = std::nullopt, nb::arg("lookahead_prompt") = std::nullopt, + nb::arg("lookahead_algo_configs") = std::nullopt); + + nb::class_(m, "DecoderState") + .def(nb::init<>()) + .def("setup", &tr::decoder::DecoderState::setup, nb::arg("max_batch_size"), nb::arg("max_beam_width"), + nb::arg("max_attention_window"), nb::arg("sink_token_length"), nb::arg("max_sequence_length"), + nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")) + .def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, nb::arg("max_batch_size"), + nb::arg("max_beam_width"), nb::arg("max_attention_window"), nb::arg("buffer_manager")) + .def("setup_speculative_decoding", &tr::decoder::DecoderState::setupSpeculativeDecoding, + nb::arg("speculative_decoding_mode"), nb::arg("max_tokens_per_engine_step"), nb::arg("dtype"), + nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")) + .def_prop_ro("joint_decoding_input", &tr::decoder::DecoderState::getJointDecodingInput) + .def_prop_ro("joint_decoding_output", &tr::decoder::DecoderState::getJointDecodingOutput) + .def_prop_ro("cache_indirection_input", &tr::decoder::DecoderState::getCacheIndirectionInput) + .def_prop_ro("cache_indirection_output", &tr::decoder::DecoderState::getCacheIndirectionOutput) + .def_prop_ro( + "sequence_lengths", nb::overload_cast<>(&tr::decoder::DecoderState::getSequenceLengths, nb::const_)) + .def("get_sequence_lengths", + nb::overload_cast(&tr::decoder::DecoderState::getSequenceLengths, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("all_new_tokens", &tr::decoder::DecoderState::getAllNewTokens) + .def_prop_ro("finished_sum", &tr::decoder::DecoderState::getFinishedSum) + .def_prop_ro("finish_reasons", &tr::decoder::DecoderState::getFinishReasons) + .def_prop_ro("ids", nb::overload_cast<>(&tr::decoder::DecoderState::getIds, nb::const_)) + .def("get_ids", nb::overload_cast(&tr::decoder::DecoderState::getIds, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("gathered_ids", nb::overload_cast<>(&tr::decoder::DecoderState::getGatheredIds, nb::const_)) + .def("get_gathered_ids", + nb::overload_cast(&tr::decoder::DecoderState::getGatheredIds, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("parent_ids", &tr::decoder::DecoderState::getParentIds) + .def_prop_ro("cum_log_probs", nb::overload_cast<>(&tr::decoder::DecoderState::getCumLogProbs, nb::const_)) + .def("get_cum_log_probs", + nb::overload_cast(&tr::decoder::DecoderState::getCumLogProbs, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("log_probs", nb::overload_cast<>(&tr::decoder::DecoderState::getLogProbs, nb::const_)) + .def("get_log_probs", nb::overload_cast(&tr::decoder::DecoderState::getLogProbs, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("next_draft_tokens", &tr::decoder::DecoderState::getNextDraftTokens) + .def_prop_ro("prev_draft_tokens_lengths", &tr::decoder::DecoderState::getPrevDraftTokensLengths) + .def_prop_ro("next_draft_tokens_lengths", &tr::decoder::DecoderState::getNextDraftTokensLengths) + .def_prop_ro("accepted_lengths_cum_sum", &tr::decoder::DecoderState::getAcceptedLengthsCumSum) + .def_prop_ro("accepted_packed_paths", &tr::decoder::DecoderState::getAcceptedPackedPaths) + .def_prop_ro("finished_steps", &tr::decoder::DecoderState::getFinishedSteps) + .def_prop_ro("max_beam_width", &tr::decoder::DecoderState::getMaxBeamWidth) + .def_prop_ro("max_sequence_length", &tr::decoder::DecoderState::getMaxSequenceLength) + .def_prop_ro("max_decoding_decoder_tokens", &tr::decoder::DecoderState::getMaxDecodingDecoderTokens) + .def_prop_ro("max_decoding_engine_tokens", &tr::decoder::DecoderState::getMaxDecodingEngineTokens) + .def_prop_ro("num_decoding_engine_tokens", + nb::overload_cast<>(&tr::decoder::DecoderState::getNumDecodingEngineTokens, nb::const_)) + .def("get_num_decoding_engine_tokens", + nb::overload_cast(&tr::decoder::DecoderState::getNumDecodingEngineTokens, nb::const_), + nb::arg("batch_idx")) + .def("set_num_decoding_engine_tokens", &tr::decoder::DecoderState::setNumDecodingEngineTokens, + nb::arg("batch_idx"), nb::arg("num_tokens")) + .def_prop_ro("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode) + .def_prop_rw("generation_steps", &tr::decoder::DecoderState::getGenerationSteps, + &tr::decoder::DecoderState::setGenerationSteps); + + nb::class_(m, "GptDecoderBatched") + .def(nb::init(), nb::arg("stream")) + .def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_batch_size"), + nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config")) + .def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("output"), nb::arg("input")) + .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference) + .def("finalize", &tr::GptDecoderBatched::finalize, nb::arg("decoder_state"), nb::arg("batch_idx"), + nb::arg("sampling_config"), nb::arg("streaming")) + .def_prop_ro( + "decoder_stream", + [](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); }, + nb::rv_policy::reference); + + m.def( + "lamport_initialize_all", + [](intptr_t buffer_0, intptr_t buffer_1, intptr_t buffer_2, size_t size) + { + tr::lamportInitializeAll(reinterpret_cast(buffer_0), reinterpret_cast(buffer_1), + reinterpret_cast(buffer_2), size); + }, + "Lamport initialize all buffers"); + m.def( + "lamport_initialize", + [](intptr_t buffer, size_t size) + { tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast(buffer), size, 0); }, + "Lmaport initialize buffer"); + m.def( + "delay_kernel", + [](int64_t delay_micro_secs, nb::object py_stream) + { + // Get the raw stream handle from PyTorch stream object + auto stream_ptr = nb::cast(py_stream.attr("cuda_stream")); + cudaStream_t stream = reinterpret_cast(stream_ptr); + tensorrt_llm::kernels::invokeDelayStreamKernel(delay_micro_secs, stream); + }, + "Delay kernel launch on the default stream"); + m.def( + "max_workspace_size_lowprecision", + [](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); }, + "Calculate the maximum workspace size needed for low precision all-reduce operations"); + + nb::class_(m, "McastGPUBuffer") + .def(nb::init()) + .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer) + .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer); + + nb::enum_(m, "AllReduceFusionOp") + .value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE) + .value("RESIDUAL_RMS_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM) + .value("LAST_PROCESS_FOR_UB", tensorrt_llm::kernels::AllReduceFusionOp::LAST_PROCESS_FOR_UB) + .value("RESIDUAL_RMS_PREPOST_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM) + .value("RESIDUAL_RMS_NORM_QUANT_FP8", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8) + .value("RESIDUAL_RMS_NORM_QUANT_NVFP4", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4) + .value("RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4", + tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4) + .value("RESIDUAL_RMS_NORM_OUT_QUANT_FP8", + tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8); + + nb::enum_(m, "AllReduceStrategy") + .value("NCCL", tensorrt_llm::kernels::AllReduceStrategyType::NCCL) + .value("MIN_LATENCY", tensorrt_llm::kernels::AllReduceStrategyType::MIN_LATENCY) + .value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO) + .value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB) + .value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT) + .value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT); + + // Initialize MoeLoadBalancer bindings + initMoeBindings(m); +} + +void initBindingsEarly(nb::module_& m) +{ + nb::class_(m, "SpeculativeDecodingMode") + .def(nb::init(), nb::arg("state")) + .def_static("NoneType", &tr::SpeculativeDecodingMode::None) + .def_static("DraftTokensExternal", &tr::SpeculativeDecodingMode::DraftTokensExternal) + .def_static("Medusa", &tr::SpeculativeDecodingMode::Medusa) + .def_static("Eagle", &tr::SpeculativeDecodingMode::Eagle) + .def_static("LookaheadDecoding", &tr::SpeculativeDecodingMode::LookaheadDecoding) + .def_static("ExplicitDraftTokens", &tr::SpeculativeDecodingMode::ExplicitDraftTokens) + .def_prop_ro("is_none", &tr::SpeculativeDecodingMode::isNone) + .def_prop_ro("is_draft_tokens_external", &tr::SpeculativeDecodingMode::isDraftTokensExternal) + .def_prop_ro("is_medusa", &tr::SpeculativeDecodingMode::isMedusa) + .def_prop_ro("is_eagle", &tr::SpeculativeDecodingMode::isEagle) + .def_prop_ro("is_lookahead_decoding", &tr::SpeculativeDecodingMode::isLookaheadDecoding) + .def_prop_ro("is_explicit_draft_tokens", &tr::SpeculativeDecodingMode::isExplicitDraftTokens) + .def_prop_ro("updates_position_ids", &tr::SpeculativeDecodingMode::updatesPositionIds) + .def_prop_ro("requires_attention_mask", &tr::SpeculativeDecodingMode::requiresAttentionMask) + .def_prop_ro("predicts_draft_tokens", &tr::SpeculativeDecodingMode::predictsDraftTokens) + .def_prop_ro("needs_kv_cache_rewind", &tr::SpeculativeDecodingMode::needsKVCacheRewind) + .def_prop_ro("variable_draft_length", &tr::SpeculativeDecodingMode::variableDraftLength) + .def_prop_ro("has_draft_logits", &tr::SpeculativeDecodingMode::hasDraftLogits) + .def_prop_ro("needs_decoder_prologue", &tr::SpeculativeDecodingMode::needsDecoderPrologue); +} +} // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.h b/cpp/tensorrt_llm/nanobind/runtime/bindings.h new file mode 100644 index 00000000000..410dac80b05 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.h @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::runtime +{ + +void initBindings(nb::module_& m); +void initBindingsEarly(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/nanobind/runtime/moeBindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/moeBindings.cpp new file mode 100644 index 00000000000..c26fa84b661 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/runtime/moeBindings.cpp @@ -0,0 +1,124 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moeBindings.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.h" +#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h" +#include +#include +#include + +namespace nb = nanobind; +namespace tr = tensorrt_llm::runtime; +namespace tk = tensorrt_llm::kernels; + +namespace tensorrt_llm::nanobind::runtime +{ + +void pyDoReplication(tk::MoeLoadBalanceMetaInfo const& metaInfo, std::vector& expertLoadFactor, + tr::MoePlacementCpuInfo* cpuPlacement) +{ + TLLM_CHECK_WITH_INFO( + metaInfo.expertCount == expertLoadFactor.size(), "expert_count and expert_load_factor size mismatch"); + tr::doReplication(metaInfo, expertLoadFactor.data(), cpuPlacement); +}; + +void pyDoPlacement(tk::MoeLoadBalanceMetaInfo const& metaInfo, std::vector& expertLoadFactor, + tr::MoePlacementCpuInfo* cpuPlacement) +{ + TLLM_CHECK_WITH_INFO( + metaInfo.expertCount == expertLoadFactor.size(), "expert_count and expert_load_factor size mismatch"); + tr::doPlacement(metaInfo, expertLoadFactor.data(), cpuPlacement); +}; + +void initMoeBindings(nb::module_& m) +{ + // Bind MoeWeight struct + nb::class_(m, "MoeWeight") + .def(nb::init<>()) + .def_prop_rw("weight_ptr", &tr::MoeWeight::getWeightPtr, &tr::MoeWeight::setWeightPtr) + .def_rw("height", &tr::MoeWeight::mHeight) + .def_rw("width", &tr::MoeWeight::mWidth) + .def_rw("pitch", &tr::MoeWeight::mPitch) + .def("__repr__", + [](tr::MoeWeight const& self) + { + return ""; + }); + + // Bind MoeLoadBalanceMetaInfo struct + nb::class_(m, "MoeLoadBalanceMetaInfo") + .def(nb::init(), nb::arg("expert_count"), nb::arg("top_k"), nb::arg("ep_rank"), + nb::arg("ep_size"), nb::arg("slot_count_per_rank")) + .def_rw("expert_count", &tk::MoeLoadBalanceMetaInfo::expertCount) + .def_rw("top_k", &tk::MoeLoadBalanceMetaInfo::topK) + .def_rw("ep_rank", &tk::MoeLoadBalanceMetaInfo::epRank) + .def_rw("ep_size", &tk::MoeLoadBalanceMetaInfo::epSize) + .def_rw("slot_count_per_rank", &tk::MoeLoadBalanceMetaInfo::slotCountPerRank); + + // Bind MoePlacementCpuInfo struct + nb::class_(m, "MoePlacementCpuInfo") + .def(nb::init<>()) + .def_rw("expert_replica_count", &tr::MoePlacementCpuInfo::expertReplicaCount) + .def_rw("rank_expert_ids", &tr::MoePlacementCpuInfo::rankExpertIds); + + // Bind SingleLayerMoeLoadBalancer class + nb::class_(m, "SingleLayerMoeLoadBalancer") + .def("add_single_weight_slot", &tr::SingleLayerMoeLoadBalancer::addSingleWeightSlot, nb::arg("slot_id"), + nb::arg("name"), nb::arg("weight_slot"), "Add a single weight slot for a specific slot ID") + .def("add_single_host_weight", &tr::SingleLayerMoeLoadBalancer::addSingleHostWeight, nb::arg("expert_id"), + nb::arg("name"), nb::arg("host_weight"), "Add a single host weight for a specific expert ID") + .def("set_initial_weight_assignments", &tr::SingleLayerMoeLoadBalancer::setInitialWeightAssignments, + nb::arg("initial_weight_assignments"), "Set initial weight assignments for each slot") + .def("get_pointer", &tr::SingleLayerMoeLoadBalancer::getSelfPtr, + "Get the pointer of the SingleLayerMoeLoadBalancer") + .def("get_layer_id", &tr::SingleLayerMoeLoadBalancer::getLayerId, + "Get the layer id of the SingleLayerMoeLoadBalancer"); + + // Bind MoeLoadBalancer class + nb::class_(m, "MoeLoadBalancer") + .def(nb::init(), nb::arg("ep_rank"), nb::arg("ep_size"), nb::arg("layer_updates_per_iter"), + "Initialize the MoeLoadBalancer with the specified expert parallel rank, size, and update frequency") + .def("set_use_gpu_memcpy", &tr::MoeLoadBalancer::setUseGpuMemcpy, nb::arg("use_gpu_memcpy"), + "Set whether to use GPU memcpy for weight updates") + .def("add_layer", &tr::MoeLoadBalancer::AddLayer, nb::arg("expert_count"), nb::arg("top_k"), + nb::arg("slot_count_per_rank"), "Add a new MOE layer to the load balancer") + .def("finalize_model", &tr::MoeLoadBalancer::finalizeModel, + "Finalize the model structure, must be called after all layers are added") + .def("set_warm_up_iter_count", &tr::MoeLoadBalancer::setWarmUpIterCount, nb::arg("iter_count"), + "Set the number of warm-up iterations") + .def("start_iter", &tr::MoeLoadBalancer::startIter, nb::arg("iter_id"), nb::arg("enable_statistic"), + nb::arg("enable_update_weights"), "Start a new iteration with the given ID and settings") + .def("end_iter", &tr::MoeLoadBalancer::endIter, nb::arg("iter_id"), "End the iteration with the given ID") + .def("shutdown", &tr::MoeLoadBalancer::shutdown, "Shutdown the load balancer and clean up resources"); + + m.def("is_host_accessible_device_memory_supported", &tr::HostAccessibleDeviceAllocator::isSupported, + "If current system support host accessible device memory"); + + // Bind do_replication function for testing + m.def("do_replication", &pyDoReplication, nb::arg("meta_info"), nb::arg("expert_load_factor"), + nb::arg("cpu_placement"), "Do replication"); + + // Bind do_placement function for testing + m.def("do_placement", &pyDoPlacement, nb::arg("meta_info"), nb::arg("expert_load_factor"), nb::arg("cpu_placement"), + "Do placement"); +} + +} // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/nanobind/runtime/moeBindings.h b/cpp/tensorrt_llm/nanobind/runtime/moeBindings.h new file mode 100644 index 00000000000..73b9a3ceec8 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/runtime/moeBindings.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::runtime +{ + +void initMoeBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.cpp b/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.cpp new file mode 100644 index 00000000000..caef94c5def --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.cpp @@ -0,0 +1,87 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "modelSpecBinding.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/testing/modelSpec.h" + +#include + +namespace nb = nanobind; +using tensorrt_llm::testing::ModelSpec; +using tensorrt_llm::testing::KVCacheType; +using tensorrt_llm::testing::QuantMethod; +using tensorrt_llm::testing::OutputContentType; + +namespace tensorrt_llm::nanobind::testing +{ + +void initBindings(nb::module_& m) +{ + nb::enum_(m, "QuantMethod", nb::is_arithmetic(), "Quantization Method") + .value("NONE", QuantMethod::kNONE, "No Quantization") + .value("SMOOTH_QUANT", QuantMethod::kSMOOTH_QUANT, "Smooth Quantization"); + + nb::enum_(m, "OutputContentType", nb::is_arithmetic(), "Output Content Type") + .value("NONE", OutputContentType::kNONE, "No Output Content") + .value("CONTEXT_LOGITS", OutputContentType::kCONTEXT_LOGITS, "Context Logits") + .value("GENERATION_LOGITS", OutputContentType::kGENERATION_LOGITS, "Generation Logits") + .value("LOG_PROBS", OutputContentType::kLOG_PROBS, "Log Probs") + .value("CUM_LOG_PROBS", OutputContentType::kCUM_LOG_PROBS, "Cumulative Log"); + + nb::class_(m, "ModelSpec") + .def(nb::init()) + .def("use_gpt_plugin", &ModelSpec::useGptAttentionPlugin, nb::rv_policy::reference_internal) + .def("use_packed_input", &ModelSpec::usePackedInput, nb::rv_policy::reference_internal) + .def("set_kv_cache_type", &ModelSpec::setKVCacheType, nb::rv_policy::reference_internal) + .def("use_decoder_per_request", &ModelSpec::useDecoderPerRequest, nb::rv_policy::reference_internal) + .def("use_tensor_parallelism", &ModelSpec::useTensorParallelism, nb::rv_policy::reference_internal) + .def("use_pipeline_parallelism", &ModelSpec::usePipelineParallelism, nb::rv_policy::reference_internal) + .def("use_context_parallelism", &ModelSpec::useContextParallelism, nb::rv_policy::reference_internal) + .def("set_draft_tokens", &ModelSpec::setDraftTokens, nb::rv_policy::reference_internal) + .def("use_accept_by_logits", &ModelSpec::useAcceptByLogits, nb::rv_policy::reference_internal) + .def("use_mamba_plugin", &ModelSpec::useMambaPlugin, nb::rv_policy::reference_internal) + .def("gather_logits", &ModelSpec::gatherLogits, nb::rv_policy::reference_internal) + .def("replace_logits", &ModelSpec::replaceLogits, nb::rv_policy::reference_internal) + .def("return_log_probs", &ModelSpec::returnLogProbs, nb::rv_policy::reference_internal) + .def("smoke_test", &ModelSpec::smokeTest, nb::rv_policy::reference_internal) + .def("use_medusa", &ModelSpec::useMedusa, nb::rv_policy::reference_internal) + .def("use_eagle", &ModelSpec::useEagle, nb::rv_policy::reference_internal) + .def("use_lookahead_decoding", &ModelSpec::useLookaheadDecoding, nb::rv_policy::reference_internal) + .def("use_explicit_draft_tokens_decoding", &ModelSpec::useExplicitDraftTokensDecoding, + nb::rv_policy::reference_internal) + .def("use_draft_tokens_external_decoding", &ModelSpec::useDraftTokensExternalDecoding, + nb::rv_policy::reference_internal) + .def("use_logits", &ModelSpec::useLogits) + .def("use_multiple_profiles", &ModelSpec::useMultipleProfiles, nb::rv_policy::reference_internal) + .def("set_max_input_length", &ModelSpec::setMaxInputLength, nb::rv_policy::reference_internal) + .def("set_max_output_length", &ModelSpec::setMaxOutputLength, nb::rv_policy::reference_internal) + .def("set_quant_method", &ModelSpec::setQuantMethod, nb::rv_policy::reference_internal) + .def("use_lora_plugin", &ModelSpec::useLoraPlugin, nb::rv_policy::reference_internal) + .def("get_input_file", &ModelSpec::getInputFile) + .def("get_model_path", &ModelSpec::getModelPath) + .def("get_results_file", &ModelSpec::getResultsFile) + .def("get_generation_logits_file", &ModelSpec::getGenerationLogitsFile) + .def("get_context_logits_file", &ModelSpec::getContextLogitsFile) + .def("get_cum_log_probs_file", &ModelSpec::getCumLogProbsFile) + .def("get_log_probs_file", &ModelSpec::getLogProbsFile) + .def("enable_context_fmha_fp32_acc", &ModelSpec::enableContextFMHAFp32Acc, nb::rv_policy::reference_internal) + .def("get_enable_context_fmha_fp32_acc", &ModelSpec::getEnableContextFMHAFp32Acc) + .def("__copy__", [](ModelSpec const& self) { return ModelSpec(self); }); +} + +} // namespace tensorrt_llm::nanobind::testing diff --git a/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.h b/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.h new file mode 100644 index 00000000000..1aababc6ff8 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::testing +{ + +void initBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::testing diff --git a/cpp/tensorrt_llm/nanobind/userbuffers/bindings.cpp b/cpp/tensorrt_llm/nanobind/userbuffers/bindings.cpp new file mode 100644 index 00000000000..82e0d0a1f0c --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/userbuffers/bindings.cpp @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include "tensorrt_llm/kernels/userbuffers/ub_interface.h" +#include "tensorrt_llm/kernels/userbuffers/userbuffersManager.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include + +namespace nb = nanobind; +namespace tub = tensorrt_llm::runtime::ub; + +namespace tensorrt_llm::kernels::userbuffers +{ + +void UserBufferBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "UBBuffer") + .def_ro("size", &tub::UBBuffer::size) + .def_prop_ro("addr", [](tub::UBBuffer& self) { return reinterpret_cast(self.addr); }) + .def_ro("handle", &tub::UBBuffer::handle) + .def("invalid", &tub::UBBuffer::invalid); + + m.def("ub_initialize", [](int tp_size) { tub::ub_initialize(tp_size); }); + m.def("ub_is_initialized", &tub::ub_is_initialized); + m.def("ub_allocate", [](size_t bytes) { return tub::ub_allocate(bytes); }); + m.def("ub_deallocate", [](intptr_t addr) { return tub::ub_deallocate(reinterpret_cast(addr)); }); + m.def("ub_get", &tub::ub_get); + m.def("ub_supported", &tub::ub_supported); + + m.def("initialize_userbuffers_manager", &tub::initialize_userbuffers_manager); +} +} // namespace tensorrt_llm::kernels::userbuffers diff --git a/cpp/tensorrt_llm/nanobind/userbuffers/bindings.h b/cpp/tensorrt_llm/nanobind/userbuffers/bindings.h new file mode 100644 index 00000000000..15728bf6c1d --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/userbuffers/bindings.h @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::kernels::userbuffers +{ +class UserBufferBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::kernels::userbuffers diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 1a5841d4b7a..962071c4857 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -170,7 +170,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .value("CONTINUOUS", tr::ModelConfig::KVCacheType::kCONTINUOUS) .value("PAGED", tr::ModelConfig::KVCacheType::kPAGED) .value("DISABLED", tr::ModelConfig::KVCacheType::kDISABLED) - .def(py::init(&tr::ModelConfig::KVCacheTypeFromString)); + .def("from_string", &tr::ModelConfig::KVCacheTypeFromString); py::enum_(m, "LayerType") .value("ATTENTION", tr::ModelConfig::LayerType::kATTENTION) diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index d09157e1a8b..a8f6aaef73d 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -244,7 +244,17 @@ void initBindings(pybind11::module_& m) py::class_>( executor_kv_cache, "KVCacheEventManager") - .def("get_latest_events", &tle::KVCacheEventManager::getLatestEvents, py::arg("timeout") = std::nullopt); + .def( + "get_latest_events", + [](tle::KVCacheEventManager& self, std::optional timeout_ms = std::nullopt) + { + if (timeout_ms) + { + return self.getLatestEvents(std::chrono::milliseconds(static_cast(*timeout_ms))); + } + return self.getLatestEvents(std::nullopt); + }, + py::arg("timeout_ms") = std::nullopt); tensorrt_llm::pybind::executor::initRequestBindings(m); tensorrt_llm::pybind::executor::initConfigBindings(m); diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index 71a0b4af724..a413b8c9e67 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -336,7 +336,7 @@ void initConfigBindings(pybind11::module_& m) throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!"); } return tle::ExtendedRuntimePerfKnobConfig( - state[0].cast(), state[1].cast(), state[2].cast(), state[2].cast()); + state[0].cast(), state[1].cast(), state[2].cast(), state[3].cast()); }; auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self) { diff --git a/examples/models/core/llama/summarize_long.py b/examples/models/core/llama/summarize_long.py index 9f127bc32a6..cee2e07fdd5 100644 --- a/examples/models/core/llama/summarize_long.py +++ b/examples/models/core/llama/summarize_long.py @@ -97,7 +97,7 @@ def TRTLLaMA(args, config): quantization_config = pretrained_config['quantization'] build_config = config['build_config'] - kv_cache_type = KVCacheType(build_config['kv_cache_type']) + kv_cache_type = KVCacheType.from_string(build_config['kv_cache_type']) plugin_config = build_config['plugin_config'] dtype = pretrained_config['dtype'] diff --git a/examples/models/core/qwen2audio/run.py b/examples/models/core/qwen2audio/run.py index e0d495a67f8..93e161c7e08 100644 --- a/examples/models/core/qwen2audio/run.py +++ b/examples/models/core/qwen2audio/run.py @@ -122,7 +122,8 @@ def get_model(self): num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_heads) if "kv_cache_type" in config["build_config"]: - kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"]) + kv_cache_type = KVCacheType.from_string( + config["build_config"]["kv_cache_type"]) else: kv_cache_type = KVCacheType.CONTINUOUS diff --git a/examples/models/core/qwenvl/run.py b/examples/models/core/qwenvl/run.py index a04c2b142e3..06ce341a9a0 100644 --- a/examples/models/core/qwenvl/run.py +++ b/examples/models/core/qwenvl/run.py @@ -118,7 +118,8 @@ def get_model(self): num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_heads) if "kv_cache_type" in config["build_config"]: - kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"]) + kv_cache_type = KVCacheType.from_string( + config["build_config"]["kv_cache_type"]) else: kv_cache_type = KVCacheType.CONTINUOUS diff --git a/jenkins/Build.groovy b/jenkins/Build.groovy index bb8fd7816ce..77e12ee5100 100644 --- a/jenkins/Build.groovy +++ b/jenkins/Build.groovy @@ -47,6 +47,12 @@ CONFIG_LINUX_AARCH64 = "linux_aarch64" @Field def CONFIG_LINUX_AARCH64_LLVM = "linux_aarch64_LLVM" +@Field +def CONFIG_LINUX_X86_64_NANOBIND = "linux_x86_64_Nanobind" + +@Field +def CONFIG_LINUX_AARCH64_NANOBIND = "linux_aarch64_Nanobind" + @Field def BUILD_CONFIGS = [ // Vanilla TARNAME is used for packaging in runLLMPackage @@ -56,6 +62,11 @@ def BUILD_CONFIGS = [ (TARNAME) : "TensorRT-LLM.tar.gz", (WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real", ], + (CONFIG_LINUX_X86_64_NANOBIND) : [ + (WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks", + (TARNAME) : "nanobind-TensorRT-LLM.tar.gz", + (WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real", + ], (CONFIG_LINUX_X86_64_SINGLE_DEVICE) : [ (WHEEL_EXTRA_ARGS) : "--extra-cmake-vars ENABLE_MULTI_DEVICE=0 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars ENABLE_UCX=0 --micro_benchmarks", (TARNAME) : "single-device-TensorRT-LLM.tar.gz", @@ -71,6 +82,11 @@ def BUILD_CONFIGS = [ (TARNAME) : "TensorRT-LLM-GH200.tar.gz", (WHEEL_ARCHS): "90-real;100-real;120-real", ], + (CONFIG_LINUX_AARCH64_NANOBIND): [ + (WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars WARNING_IS_ERROR=ON", + (TARNAME) : "nanobind-TensorRT-LLM-GH200.tar.gz", + (WHEEL_ARCHS): "90-real;100-real;120-real", + ], (CONFIG_LINUX_AARCH64_LLVM) : [ (WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_CUDA_HOST_COMPILER=clang -DCMAKE_LINKER_TYPE=LLD", (TARNAME) : "llvm-TensorRT-LLM-GH200.tar.gz", @@ -523,6 +539,8 @@ def launchStages(pipeline, cpu_arch, enableFailFast, globalVars) pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64 : CONFIG_LINUX_X86_64_VANILLA), "Build TRT-LLM LLVM": [LLM_DOCKER_IMAGE] + prepareLLMBuild( pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_LLVM : CONFIG_LINUX_X86_64_LLVM), + "Build TRT-LLM Nanobind": [LLM_DOCKER_IMAGE] + prepareLLMBuild( + pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_NANOBIND : CONFIG_LINUX_X86_64_NANOBIND), ] if (cpu_arch == X86_64_TRIPLE) { diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 6f6ae7c1186..35e7140ebda 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -64,6 +64,9 @@ def LLVM_CONFIG = "LLVM" @Field LINUX_AARCH64_CONFIG = "linux_aarch64" +@Field +def NANOBIND_CONFIG = "Nanobind" + @Field def BUILD_CONFIGS = [ // Vanilla TARNAME is used for packaging in runLLMPackage @@ -71,6 +74,7 @@ def BUILD_CONFIGS = [ (SINGLE_DEVICE_CONFIG) : [(TARNAME) : "single-device-TensorRT-LLM.tar.gz"], (LLVM_CONFIG) : [(TARNAME) : "llvm-TensorRT-LLM.tar.gz"], (LINUX_AARCH64_CONFIG) : [(TARNAME) : "TensorRT-LLM-GH200.tar.gz"], + (NANOBIND_CONFIG) : [(TARNAME) : "nanobind-TensorRT-LLM.tar.gz"], ] // TODO: Move common variables to an unified location @@ -1724,6 +1728,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) "A10-TensorRT-4": ["a10", "l0_a10", 4, 6], "A10-TensorRT-5": ["a10", "l0_a10", 5, 6], "A10-TensorRT-6": ["a10", "l0_a10", 6, 6], + "A10-Nanobind": ["a10", "l0_a10_nanobind", 1, 1], "A30-Triton-1": ["a30", "l0_a30", 1, 1], "A30-PyTorch-1": ["a30", "l0_a30", 1, 2], "A30-PyTorch-2": ["a30", "l0_a30", 2, 2], @@ -1800,6 +1805,9 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) if (key.contains("llvm")) { config = LLVM_CONFIG } + if (key.contains("Nanobind")) { + config = NANOBIND_CONFIG + } runLLMTestlistOnPlatform(pipeline, values[0], values[1], config, key.contains("Perf"), key, values[2], values[3]) }]]} fullSet = parallelJobs.keySet() diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index e2dc543ac42..11d528a853d 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -593,7 +593,7 @@ def from_dict(cls, config, plugin_config=None): defaults.get('max_prompt_embedding_table_size')) if "kv_cache_type" in config and config["kv_cache_type"] is not None: - kv_cache_type = KVCacheType(config.pop('kv_cache_type')) + kv_cache_type = KVCacheType.from_string(config.pop('kv_cache_type')) else: kv_cache_type = None gather_context_logits = config.pop( diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index a47e1485b71..e6b55f6e040 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -38,6 +38,23 @@ from tensorrt_llm.quantization.mode import QuantAlgo +def enum_type(enum_class): + + def parse_enum(value): + if isinstance(value, enum_class): + return value + + if isinstance(value, str): + return enum_class.from_string(value) + + valid_values = [e.name for e in enum_class] + raise argparse.ArgumentTypeError( + f"Invalid value '{value}' of type {type(value).__name__}. Expected one of {valid_values}" + ) + + return parse_enum + + def parse_arguments(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -131,7 +148,7 @@ def parse_arguments(): parser.add_argument( '--kv_cache_type', default=argparse.SUPPRESS, - type=KVCacheType, + type=enum_type(KVCacheType), help= "Set KV cache type (continuous, paged, or disabled). For disabled case, KV cache is disabled and only context phase is allowed." ) diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index 486c58f6d15..a9f0fe8de40 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -86,7 +86,7 @@ def _builder_to_model_config(config: dict) -> Tuple[ModelConfig, dict]: dtype = builder_config['precision'] tp_size = builder_config['tensor_parallel'] pp_size = builder_config.get('pipeline_parallel', 1) - kv_cache_type = KVCacheType(builder_config.get('kv_cache_type')) + kv_cache_type = KVCacheType.from_string(builder_config.get('kv_cache_type')) world_size = tp_size * pp_size assert world_size == mpi_world_size(), \ f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({mpi_world_size()})' diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 2f63ab45f3a..5799ea27945 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -190,3 +190,18 @@ l0_a10: tests: - stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] - stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] +l0_a10_nanobind: +- condition: + ranges: + system_gpu_count: + gte: 1 + lte: 1 + wildcards: + gpu: + - '*a10*' + linux_distribution_name: ubuntu* + terms: + stage: pre_merge + backend: tensorrt + tests: + - unittest/bindings diff --git a/tests/unittest/bindings/test_bindings_ut.py b/tests/unittest/bindings/test_bindings_ut.py index 774accb080f..6fd46040b66 100644 --- a/tests/unittest/bindings/test_bindings_ut.py +++ b/tests/unittest/bindings/test_bindings_ut.py @@ -5,6 +5,7 @@ from pathlib import Path import numpy as np +import pytest import torch from utils.runtime_defaults import assert_runtime_defaults_are_parsed_correctly @@ -309,6 +310,8 @@ def parse_runtime_defaults(defaults_dict: dict | None = None): strict_keys=strict_keys) +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_llm_request(): beam_width = 2 sampling_config = _tb.SamplingConfig(beam_width) @@ -418,6 +421,8 @@ def test_Mpicomm(): assert size2 == session_size +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_SamplingConfig_pickle(): config = _tb.SamplingConfig() config.beam_width = 5 @@ -497,6 +502,8 @@ def test_KvCache_events_binding(): torch.cuda.empty_cache() +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_ReqIdsSet_pickle(): ids = _tb.internal.batch_manager.ReqIdsSet() ids1 = pickle.loads(pickle.dumps(ids)) diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index 5d9460ffef0..56f71df079d 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -14,6 +14,7 @@ from binding_test_utils import * from pydantic import BaseModel +import tensorrt_llm.bindings as _tb import tensorrt_llm.bindings.executor as trtllm import tensorrt_llm.version as trtllm_version from tensorrt_llm.models.modeling_utils import PretrainedConfig @@ -484,6 +485,8 @@ def test_get_num_responses_ready(streaming: bool, assert executor.get_num_responses_ready() == num_expected_responses +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") @pytest.mark.parametrize("batching_type", [trtllm.BatchingType.INFLIGHT]) @pytest.mark.parametrize("streaming", [False, True]) @pytest.mark.parametrize("beam_width", [1]) @@ -688,6 +691,8 @@ def verify_output(beam_tokens, test_data, given_input_lengths): verify_output(tokens, test_data, given_input_lengths) +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") @pytest.mark.parametrize("streaming", [False, True]) @pytest.mark.parametrize("beam_width", [1]) def test_finish_reason(streaming: bool, beam_width: int, model_files, @@ -1112,6 +1117,8 @@ def test_spec_dec_fast_logits_info(): assert fast_logits_info.draft_participant_id == 5 +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_result(): result = trtllm.Result() result.is_final = True @@ -1149,6 +1156,8 @@ def test_result(): assert (additional_output.output == torch.ones(1, 4, 100)).all() +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_result_pickle(): result = trtllm.Result() result.is_final = True @@ -1495,6 +1504,8 @@ def test_eagle_config(): assert getattr(config, k) == v +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_eagle_config_pickle(): config = trtllm.EagleConfig([[0, 0], [0, 1]], False, 0.5) config_copy = pickle.loads(pickle.dumps(config)) @@ -1867,6 +1878,8 @@ def logits_post_processor(req_id: int, logits: torch.Tensor, assert tokens[-max_tokens:] == [42] * max_tokens +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_logits_post_processor_batched(model_files, model_path): # Define the logits post-processor callback @@ -2141,6 +2154,8 @@ def test_request_perf_metrics_kv_cache(model_path): assert kv_cache_metrics.kv_cache_hit_rate == 1.0 +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") @pytest.mark.parametrize("exclude_input_from_output", [False, True]) def test_request_perf_metrics_draft(model_path_draft_tokens_external, exclude_input_from_output: bool): @@ -2221,7 +2236,7 @@ def test_kv_event_stream_timeout(model_path): assert len(events) == 1 start = datetime.datetime.now() - events = cache_manager.get_latest_events(datetime.timedelta(seconds=1)) + events = cache_manager.get_latest_events(1000) end = datetime.datetime.now() # Make sure that it actually waited assert abs(end - start) > datetime.timedelta(milliseconds=900)