From c6cf993c996fbe8c75abfaa991139dc1ef0df412 Mon Sep 17 00:00:00 2001 From: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Date: Mon, 21 Jul 2025 08:28:54 -0700 Subject: [PATCH] fix: bindings unit tests for nanobind Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> --- .../nanobind/batch_manager/bindings.cpp | 2 +- .../nanobind/batch_manager/kvCacheManager.cpp | 13 +- cpp/tensorrt_llm/nanobind/bindings.cpp | 9 +- cpp/tensorrt_llm/nanobind/common/bindTypes.h | 39 +----- .../nanobind/common/customCasters.h | 123 +++++------------- .../nanobind/executor/executor.cpp | 64 ++++----- .../nanobind/executor/request.cpp | 51 +++++--- cpp/tensorrt_llm/pybind/bindings.cpp | 5 +- tests/unittest/bindings/test_bindings_ut.py | 8 -- .../bindings/test_executor_bindings.py | 57 +++++--- 10 files changed, 157 insertions(+), 214 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index e4ba7b05382..fb0153f5ff8 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -79,7 +79,7 @@ void initBindings(nb::module_& m) } }); - PybindUtils::bindSet(m, "ReqIdsSet"); + NanobindUtils::bindSet(m, "ReqIdsSet"); nb::enum_(m, "LlmRequestType") .value("LLMREQUEST_TYPE_CONTEXT_AND_GENERATION", tb::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 6028db86ff9..74049eaf96b 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -48,6 +48,9 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; using TokenIdType = tensorrt_llm::runtime::TokenIdType; using VecTokens = std::vector; using CudaStreamPtr = std::shared_ptr; +using CacheBlockIds = std::vector>; + +NB_MAKE_OPAQUE(CacheBlockIds); namespace { @@ -424,7 +427,15 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents); - nb::bind_vector>>(m, "CacheBlockIds"); + nb::bind_vector(m, "CacheBlockIds") + .def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); }) + .def("__setstate__", + [](CacheBlockIds& self, nb::tuple const& t) + { + if (t.size() != 1) + throw std::runtime_error("Invalid state!"); + new (&self) CacheBlockIds(nb::cast>>(t[0])); + }); nb::enum_(m, "CacheType") .value("SELF", tbk::CacheType::kSELF) diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index 470ddeb546a..43a985658dd 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -359,9 +359,12 @@ NB_MODULE(TRTLLM_NB_MODULE, m) config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP, config.beamWidthArray); }; - auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t) -> tr::SamplingConfig + auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t) { - assert(t.size() == 19); + if (t.size() != 19) + { + throw std::runtime_error("Invalid SamplingConfig state!"); + } tr::SamplingConfig config; config.beamWidth = nb::cast(t[0]); @@ -384,7 +387,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) config.minP = nb::cast>(t[17]); config.beamWidthArray = nb::cast>>(t[18]); - return config; + new (&self) tr::SamplingConfig(config); }; nb::class_(m, "SamplingConfig") diff --git a/cpp/tensorrt_llm/nanobind/common/bindTypes.h b/cpp/tensorrt_llm/nanobind/common/bindTypes.h index 5cd714e458a..6312907b88f 100644 --- a/cpp/tensorrt_llm/nanobind/common/bindTypes.h +++ b/cpp/tensorrt_llm/nanobind/common/bindTypes.h @@ -21,44 +21,11 @@ #include #include -namespace PybindUtils +namespace NanobindUtils { namespace nb = nanobind; -template -void bindList(nb::module_& m, std::string const& name) -{ - nb::class_(m, name.c_str()) - .def(nb::init<>()) - .def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); }) - .def("pop_back", [](T& lst) { lst.pop_back(); }) - .def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); }) - .def("pop_front", [](T& lst) { lst.pop_front(); }) - .def("__len__", [](T const& lst) { return lst.size(); }) - .def( - "__iter__", [](T& lst) { return nb::make_iterator(nb::type(), "iterator", lst.begin(), lst.end()); }, - nb::keep_alive<0, 1>()) - .def("__getitem__", - [](T const& lst, size_t index) - { - if (index >= lst.size()) - throw nb::index_error(); - auto it = lst.begin(); - std::advance(it, index); - return *it; - }) - .def("__setitem__", - [](T& lst, size_t index, const typename T::value_type& value) - { - if (index >= lst.size()) - throw nb::index_error(); - auto it = lst.begin(); - std::advance(it, index); - *it = value; - }); -} - template void bindSet(nb::module_& m, std::string const& name) { @@ -93,8 +60,8 @@ void bindSet(nb::module_& m, std::string const& name) { s.insert(item); } - return s; + new (&v) T(s); }); } -} // namespace PybindUtils +} // namespace NanobindUtils diff --git a/cpp/tensorrt_llm/nanobind/common/customCasters.h b/cpp/tensorrt_llm/nanobind/common/customCasters.h index 7cfa07d249a..2739ccd569e 100644 --- a/cpp/tensorrt_llm/nanobind/common/customCasters.h +++ b/cpp/tensorrt_llm/nanobind/common/customCasters.h @@ -38,6 +38,7 @@ #include #include #include +#include // Pybind requires to have a central include in order for type casters to work. // Opaque bindings add a type caster, so they have the same requirement. @@ -48,7 +49,6 @@ NB_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet) NB_MAKE_OPAQUE(std::vector) NB_MAKE_OPAQUE(std::vector) NB_MAKE_OPAQUE(std::vector) -NB_MAKE_OPAQUE(std::vector>) namespace nb = nanobind; @@ -128,70 +128,6 @@ struct type_caster> } }; -template -struct PathCaster -{ - -private: - static PyObject* unicode_from_fs_native(std::string const& w) - { - return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size())); - } - - static PyObject* unicode_from_fs_native(std::wstring const& w) - { - return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size())); - } - -public: - static handle from_cpp(T const& path, rv_policy, cleanup_list* cleanup) - { - if (auto py_str = unicode_from_fs_native(path.native())) - { - return module_::import_("pathlib").attr("Path")(steal(py_str), cleanup).release(); - } - return nullptr; - } - - bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) - { - PyObject* native = nullptr; - if constexpr (std::is_same_v) - { - if (PyUnicode_FSConverter(src.ptr(), &native) != 0) - { - if (auto* c_str = PyBytes_AsString(native)) - { - // AsString returns a pointer to the internal buffer, which - // must not be free'd. - value = c_str; - } - } - } - else if constexpr (std::is_same_v) - { - if (PyUnicode_FSDecoder(src.ptr(), &native) != 0) - { - if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr)) - { - // AsWideCharString returns a new string that must be free'd. - value = c_str; // Copies the string. - PyMem_Free(c_str); - } - } - } - Py_XDECREF(native); - if (PyErr_Occurred()) - { - PyErr_Clear(); - return false; - } - return true; - } - - NB_TYPE_CASTER(T, const_name("os.PathLike")); -}; - template <> class type_caster { @@ -311,34 +247,45 @@ struct type_caster bool from_python(nb::handle src, uint8_t, cleanup_list*) noexcept { - nb::object capsule = nb::getattr(src, "__dlpack__")(); - DLManagedTensor* dl_managed = static_cast(PyCapsule_GetPointer(capsule.ptr(), "dltensor")); - PyCapsule_SetDestructor(capsule.ptr(), nullptr); - value = at::fromDLPack(dl_managed).alias(); - return true; + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) + { + value = THPVariable_Unpack(obj); + return true; + } + return false; } - static handle from_cpp(at::Tensor tensor, rv_policy, cleanup_list*) noexcept + static handle from_cpp(at::Tensor src, rv_policy, cleanup_list*) noexcept { - DLManagedTensor* dl_managed = at::toDLPack(tensor); - if (!dl_managed) - return nullptr; - - nanobind::object capsule = nb::steal(PyCapsule_New(dl_managed, "dltensor", - [](PyObject* obj) - { - DLManagedTensor* dl = static_cast(PyCapsule_GetPointer(obj, "dltensor")); - dl->deleter(dl); - })); - if (!capsule.is_valid()) + return THPVariable_Wrap(src); + } +}; + +template +struct type_caster>> +{ + using VectorType = std::vector>; + + NB_TYPE_CASTER(VectorType, const_name("List[") + make_caster::Name + const_name("]")); + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept + { + // Not needed for our use case since we only convert C++ to Python + return false; + } + + static handle from_cpp(VectorType const& src, rv_policy policy, cleanup_list* cleanup) noexcept + { + + std::vector result; + result.reserve(src.size()); + for (auto const& ref : src) { - dl_managed->deleter(dl_managed); - return nullptr; + result.push_back(ref.get()); } - nanobind::module_ torch = nanobind::module_::import_("torch"); - nanobind::object result = torch.attr("from_dlpack")(capsule); - capsule.release(); - return result.release(); + + return make_caster>::from_cpp(result, policy, cleanup); } }; } // namespace detail diff --git a/cpp/tensorrt_llm/nanobind/executor/executor.cpp b/cpp/tensorrt_llm/nanobind/executor/executor.cpp index 59c7d2a3dc1..5b916c4b184 100644 --- a/cpp/tensorrt_llm/nanobind/executor/executor.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/executor.cpp @@ -52,58 +52,37 @@ struct dtype_traits namespace { -// todo: Properly support FP8 and BF16 and verify functionality -tle::Tensor numpyToTensor(nb::ndarray const& array) +tle::Tensor numpyToTensor(nb::object const& object) { - auto npDtype = array.dtype(); - char kind = '\0'; - switch (npDtype.code) - { - case static_cast(nb::dlpack::dtype_code::Int): - kind = 'i'; // signed integer - break; - case static_cast(nb::dlpack::dtype_code::UInt): - kind = 'u'; // unsigned integer - break; - case static_cast(nb::dlpack::dtype_code::Float): - kind = 'f'; // floating point - break; - case static_cast(nb::dlpack::dtype_code::Bfloat): - kind = 'f'; // brain floating point (treat as float kind) - break; - case static_cast(nb::dlpack::dtype_code::Complex): - kind = 'c'; // complex - break; - default: - kind = 'V'; // void/other - break; - } + std::string dtype_name = nb::cast(object.attr("dtype").attr("name")); + nb::object metadata = object.attr("dtype").attr("metadata"); + tle::DataType dtype; - if (npDtype == nb::dtype()) + if (dtype_name == "float16") { dtype = tle::DataType::kFP16; } - else if (npDtype == nb::dtype()) + else if (dtype_name == "float32") { dtype = tle::DataType::kFP32; } - else if (npDtype == nb::dtype()) + else if (dtype_name == "int8") { dtype = tle::DataType::kINT8; } - else if (npDtype == nb::dtype()) + else if (dtype_name == "int32") { dtype = tle::DataType::kINT32; } - else if (npDtype == nb::dtype()) + else if (dtype_name == "int64") { dtype = tle::DataType::kINT64; } - else if (kind == 'V' && array.itemsize() == 1) + else if (dtype_name == "void8" && !metadata.is_none() && nb::cast(metadata["dtype"]) == "float8") { dtype = tle::DataType::kFP8; } - else if (kind == 'V' && array.itemsize() == 2) + else if (dtype_name == "void16" && !metadata.is_none() && nb::cast(metadata["dtype"]) == "bfloat16") { dtype = tle::DataType::kBF16; } @@ -112,16 +91,21 @@ tle::Tensor numpyToTensor(nb::ndarray const& array) TLLM_THROW("Unsupported numpy dtype."); } - // todo: improve the following code + nb::object array_interface = object.attr("__array_interface__"); + nb::object shape_obj = array_interface["shape"]; std::vector dims; - dims.reserve(array.ndim()); - for (size_t i = 0; i < array.ndim(); ++i) + dims.reserve(nb::len(shape_obj)); + + for (size_t i = 0; i < nb::len(shape_obj); ++i) { - dims.push_back(static_cast(array.shape(i))); + dims.push_back(nb::cast(shape_obj[i])); } - tle::Shape shape(dims.data(), dims.size()); - return tle::Tensor::of(dtype, const_cast(array.data()), shape); + nb::object data_obj = array_interface["data"]; + uintptr_t addr = nb::cast(data_obj[0]); + void* data_ptr = reinterpret_cast(addr); + tle::Shape shape(dims.data(), dims.size()); + return tle::Tensor::of(dtype, data_ptr, shape); } } // namespace @@ -153,8 +137,8 @@ Executor::Executor(nb::bytes const& engineBuffer, std::string const& jsonConfigS for (auto const& [rawName, rawArray] : managedWeights.value()) { std::string name = nb::cast(rawName); - nb::ndarray array = nb::cast>(rawArray); - managedWeightsMap->emplace(name, numpyToTensor(array)); + nb::object array_obj = nb::cast(rawArray); + managedWeightsMap->emplace(name, numpyToTensor(array_obj)); } } mExecutor = std::make_unique( diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index 9c3d34aa8fd..e2ed1fb2d19 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -445,13 +445,18 @@ void initRequestBindings(nb::module_& m) std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), nb::cast>(state[3])); } - new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), - nb::cast(state[1]), nb::cast>(state[3])); + else + { + new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), + nb::cast(state[1]), + nb::cast>(state[3])); + } }; nb::class_(m, "ContextPhaseParams") - .def("__init__", - [](tle::ContextPhaseParams const& self, VecTokens const& first_gen_tokens, + .def( + "__init__", + [](tle::ContextPhaseParams& self, VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, std::optional const& opaque_state, std::optional const& draft_tokens) { @@ -459,11 +464,16 @@ void initRequestBindings(nb::module_& m) { auto opaque_state_str_view = std::string_view(opaque_state.value().c_str(), opaque_state.value().size()); - return std::make_unique(first_gen_tokens, req_id, + new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); } - return std::make_unique(first_gen_tokens, req_id, draft_tokens); - }) + else + { + new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, draft_tokens); + } + }, + nb::arg("first_gen_tokens"), nb::arg("req_id"), nb::arg("opaque_state").none(), + nb::arg("draft_tokens").none()) .def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); }) .def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); }) .def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId) @@ -486,14 +496,14 @@ void initRequestBindings(nb::module_& m) return nb::make_tuple(self.getEagleChoices(), self.isGreedySampling(), self.getPosteriorThreshold(), self.useDynamicTree(), self.getDynamicTreeMaxTopK()); }; - auto EagleDecodingConfigSetstate = [](tle::EagleConfig& eagleConfig, nb::tuple const& state) + auto EagleDecodingConfigSetstate = [](tle::EagleConfig& self, nb::tuple const& state) { if (state.size() != 5) { throw std::runtime_error("Invalid EagleConfig state!"); } - new (&eagleConfig) tle::EagleConfig(nb::cast>(state[0]), - nb::cast(state[1]), nb::cast>(state[2]), nb::cast(state[3]), + new (&self) tle::EagleConfig(nb::cast>(state[0]), nb::cast(state[1]), + nb::cast>(state[2]), nb::cast(state[3]), nb::cast>(state[4])); }; nb::class_(m, "EagleConfig") @@ -522,13 +532,13 @@ void initRequestBindings(nb::module_& m) auto guidedDecodingParamsGetstate = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide()); }; - auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& guidedDecodingParams, nb::tuple const& state) + auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& self, nb::tuple const& state) { if (state.size() != 2) { throw std::runtime_error("Invalid GuidedDecodingParams state!"); } - new (&guidedDecodingParams) tle::GuidedDecodingParams( + new (&self) tle::GuidedDecodingParams( nb::cast(state[0]), nb::cast>(state[1])); }; @@ -553,13 +563,13 @@ void initRequestBindings(nb::module_& m) self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), self.getGuidedDecodingParams()); }; - auto requestSetstate = [](tle::Request& request, nb::tuple const& state) + auto requestSetstate = [](tle::Request& self, nb::tuple const& state) { if (state.size() != 33) { throw std::runtime_error("Invalid Request state!"); } - new (&request) tle::Request(nb::cast(state[0]), nb::cast(state[1]), + new (&self) tle::Request(nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), nb::cast>(state[5]), nb::cast>(state[6]), nb::cast>>(state[7]), @@ -797,13 +807,13 @@ void initRequestBindings(nb::module_& m) return nb::make_tuple(self.timingMetrics, self.kvCacheMetrics, self.speculativeDecoding, self.firstIter, self.lastIter, self.iter); }; - auto requestPerfMetricsSetstate = [](tle::RequestPerfMetrics& requestPerfMetrics, nb::tuple const& state) + auto requestPerfMetricsSetstate = [](tle::RequestPerfMetrics& self, nb::tuple const& state) { if (state.size() != 6) { throw std::runtime_error("Invalid RequestPerfMetrics state!"); } - new (&requestPerfMetrics) tle::RequestPerfMetrics{nb::cast(state[0]), + new (&self) tle::RequestPerfMetrics{nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2]), nb::cast>(state[3]), @@ -824,19 +834,17 @@ void initRequestBindings(nb::module_& m) .def("__setstate__", requestPerfMetricsSetstate); nb::class_(m, "AdditionalOutput") - .def("__init__ ", - [](tle::AdditionalOutput const& self, std::string const& name, tle::Tensor const& output) - { return std::make_unique(name, output); }) + .def(nb::init(), nb::arg("name"), nb::arg("output")) .def_rw("name", &tle::AdditionalOutput::name) .def_rw("output", &tle::AdditionalOutput::output); - auto resultSetstate = [](tle::Result& result, nb::tuple const& state) + auto resultSetstate = [](tle::Result& self, nb::tuple const& state) { if (state.size() != 13) { throw std::runtime_error("Invalid Request state!"); } - new (&result) tle::Result(); + tle::Result result; result.isFinal = nb::cast(state[0]); result.outputTokenIds = nb::cast>(state[1]); result.cumLogProbs = nb::cast>>(state[2]); @@ -850,6 +858,7 @@ void initRequestBindings(nb::module_& m) result.decodingIter = nb::cast(state[10]); result.contextPhaseParams = nb::cast>(state[11]); result.requestPerfMetrics = nb::cast>(state[12]); + new (&self) tle::Result(result); }; auto resultGetstate = [](tle::Result const& self) diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 962071c4857..a004c872a7f 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -355,7 +355,10 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) }; auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig { - assert(t.size() == 19); + if (t.size() != 19) + { + throw std::runtime_error("Invalid SamplingConfig state!"); + } tr::SamplingConfig config; config.beamWidth = t[0].cast(); diff --git a/tests/unittest/bindings/test_bindings_ut.py b/tests/unittest/bindings/test_bindings_ut.py index 6fd46040b66..e12fd52cb4b 100644 --- a/tests/unittest/bindings/test_bindings_ut.py +++ b/tests/unittest/bindings/test_bindings_ut.py @@ -5,7 +5,6 @@ from pathlib import Path import numpy as np -import pytest import torch from utils.runtime_defaults import assert_runtime_defaults_are_parsed_correctly @@ -310,8 +309,6 @@ def parse_runtime_defaults(defaults_dict: dict | None = None): strict_keys=strict_keys) -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") def test_llm_request(): beam_width = 2 sampling_config = _tb.SamplingConfig(beam_width) @@ -421,8 +418,6 @@ def test_Mpicomm(): assert size2 == session_size -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") def test_SamplingConfig_pickle(): config = _tb.SamplingConfig() config.beam_width = 5 @@ -447,7 +442,6 @@ def test_SamplingConfig_pickle(): config.beam_width_array = [[2, 3, 4, 5]] config1 = pickle.loads(pickle.dumps(config)) - assert config1 == config @@ -502,8 +496,6 @@ def test_KvCache_events_binding(): torch.cuda.empty_cache() -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") def test_ReqIdsSet_pickle(): ids = _tb.internal.batch_manager.ReqIdsSet() ids1 = pickle.loads(pickle.dumps(ids)) diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index 08082584cda..c59e69fa38f 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -14,9 +14,9 @@ from binding_test_utils import * from pydantic import BaseModel -import tensorrt_llm.bindings as _tb import tensorrt_llm.bindings.executor as trtllm import tensorrt_llm.version as trtllm_version +from tensorrt_llm._utils import torch_to_numpy from tensorrt_llm.models.modeling_utils import PretrainedConfig _sys.path.append(_os.path.join(_os.path.dirname(__file__), '..')) @@ -67,6 +67,40 @@ def test_executor_from_memory(model_files, model_path): trtllm.ModelType.DECODER_ONLY, executor_config) +def test_executor_with_managed_weights(model_files, model_path): + """Test executor constructor with standard dtypes in managed weights.""" + + executor_config = trtllm.ExecutorConfig( + 1, kv_cache_config=trtllm.KvCacheConfig(free_gpu_memory_fraction=0.5)) + engine_buffer = open(model_path / "rank0.engine", mode="rb").read() + json_config_str = open(model_path / "config.json", 'r').read() + + managed_weights = { + "weight_float32": + np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + "weight_int32": + np.array([[1, 2], [3, 4]], dtype=np.int32), + "weight_int64": + np.array([[1, 2], [3, 4]], dtype=np.int64), + "weight_int8": + np.array([[1, 2], [3, 4]], dtype=np.int8), + "weight_fp16": + np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float16), + "weight_bf16": + torch_to_numpy( + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.bfloat16)), + "weight_fp8": + torch_to_numpy( + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float8_e4m3fn)), + } + + executor = trtllm.Executor(engine_buffer, json_config_str, + trtllm.ModelType.DECODER_ONLY, executor_config, + managed_weights) + + assert executor.can_enqueue_requests() == True + + def test_executor_invalid_ctor(): executor_config = trtllm.ExecutorConfig( 1, kv_cache_config=trtllm.KvCacheConfig(free_gpu_memory_fraction=0.5)) @@ -485,8 +519,6 @@ def test_get_num_responses_ready(streaming: bool, assert executor.get_num_responses_ready() == num_expected_responses -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") @pytest.mark.parametrize("batching_type", [trtllm.BatchingType.INFLIGHT]) @pytest.mark.parametrize("streaming", [False, True]) @pytest.mark.parametrize("beam_width", [1]) @@ -691,8 +723,6 @@ def verify_output(beam_tokens, test_data, given_input_lengths): verify_output(tokens, test_data, given_input_lengths) -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") @pytest.mark.parametrize("streaming", [False, True]) @pytest.mark.parametrize("beam_width", [1]) def test_finish_reason(streaming: bool, beam_width: int, model_files, @@ -1117,8 +1147,6 @@ def test_spec_dec_fast_logits_info(): assert fast_logits_info.draft_participant_id == 5 -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") def test_result(): result = trtllm.Result() result.is_final = True @@ -1156,8 +1184,6 @@ def test_result(): assert (additional_output.output == torch.ones(1, 4, 100)).all() -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") def test_result_pickle(): result = trtllm.Result() result.is_final = True @@ -1171,6 +1197,9 @@ def test_result_pickle(): result.sequence_index = 1 result.is_sequence_final = True result.decoding_iter = 1 + result.context_phase_params = trtllm.ContextPhaseParams([1, 2], 123, + bytes([0, 1]), + [10, 20, 30]) result.request_perf_metrics = trtllm.RequestPerfMetrics() result.request_perf_metrics.last_iter = 33 result_str = pickle.dumps(result) @@ -1186,6 +1215,10 @@ def test_result_pickle(): assert result.sequence_index == result_copy.sequence_index assert result.is_sequence_final == result_copy.is_sequence_final assert result.decoding_iter == result_copy.decoding_iter + assert result.context_phase_params.req_id == result_copy.context_phase_params.req_id + assert result.context_phase_params.first_gen_tokens == result_copy.context_phase_params.first_gen_tokens + assert result.context_phase_params.draft_tokens == result_copy.context_phase_params.draft_tokens + assert result.context_phase_params.opaque_state == result_copy.context_phase_params.opaque_state assert result.request_perf_metrics.last_iter == result_copy.request_perf_metrics.last_iter @@ -1504,8 +1537,6 @@ def test_eagle_config(): assert getattr(config, k) == v -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") def test_eagle_config_pickle(): config = trtllm.EagleConfig([[0, 0], [0, 1]], False, 0.5) config_copy = pickle.loads(pickle.dumps(config)) @@ -1878,8 +1909,6 @@ def logits_post_processor(req_id: int, logits: torch.Tensor, assert tokens[-max_tokens:] == [42] * max_tokens -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") def test_logits_post_processor_batched(model_files, model_path): # Define the logits post-processor callback @@ -2154,8 +2183,6 @@ def test_request_perf_metrics_kv_cache(model_path): assert kv_cache_metrics.kv_cache_hit_rate == 1.0 -@pytest.mark.skipif(_tb.binding_type == "nanobind", - reason="Test not supported for nanobind yet") @pytest.mark.parametrize("exclude_input_from_output", [False, True]) def test_request_perf_metrics_draft(model_path_draft_tokens_external, exclude_input_from_output: bool):