Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void initBindings(nb::module_& m)
}
});

PybindUtils::bindSet<tb::ReqIdsSet>(m, "ReqIdsSet");
NanobindUtils::bindSet<tb::ReqIdsSet>(m, "ReqIdsSet");

nb::enum_<tb::LlmRequestType>(m, "LlmRequestType")
.value("LLMREQUEST_TYPE_CONTEXT_AND_GENERATION", tb::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION)
Expand Down
13 changes: 12 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32;
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
using VecTokens = std::vector<TokenIdType>;
using CudaStreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using CacheBlockIds = std::vector<std::vector<SizeType32>>;

NB_MAKE_OPAQUE(CacheBlockIds);

namespace
{
Expand Down Expand Up @@ -424,7 +427,15 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds)
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents);

nb::bind_vector<std::vector<std::vector<SizeType32>>>(m, "CacheBlockIds");
nb::bind_vector<CacheBlockIds>(m, "CacheBlockIds")
.def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); })
.def("__setstate__",
[](CacheBlockIds& self, nb::tuple const& t)
{
if (t.size() != 1)
throw std::runtime_error("Invalid state!");
new (&self) CacheBlockIds(nb::cast<std::vector<std::vector<SizeType32>>>(t[0]));
});

nb::enum_<tbk::CacheType>(m, "CacheType")
.value("SELF", tbk::CacheType::kSELF)
Expand Down
9 changes: 6 additions & 3 deletions cpp/tensorrt_llm/nanobind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,12 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP,
config.beamWidthArray);
};
auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t) -> tr::SamplingConfig
auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t)
{
assert(t.size() == 19);
if (t.size() != 19)
{
throw std::runtime_error("Invalid SamplingConfig state!");
}

tr::SamplingConfig config;
config.beamWidth = nb::cast<SizeType32>(t[0]);
Expand All @@ -384,7 +387,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
config.minP = nb::cast<OptVec<float>>(t[17]);
config.beamWidthArray = nb::cast<OptVec<std::vector<SizeType32>>>(t[18]);

return config;
new (&self) tr::SamplingConfig(config);
};

nb::class_<tr::SamplingConfig>(m, "SamplingConfig")
Expand Down
39 changes: 3 additions & 36 deletions cpp/tensorrt_llm/nanobind/common/bindTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,11 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>

namespace PybindUtils
namespace NanobindUtils
{

namespace nb = nanobind;

template <typename T>
void bindList(nb::module_& m, std::string const& name)
{
nb::class_<T>(m, name.c_str())
.def(nb::init<>())
.def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); })
.def("pop_back", [](T& lst) { lst.pop_back(); })
.def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); })
.def("pop_front", [](T& lst) { lst.pop_front(); })
.def("__len__", [](T const& lst) { return lst.size(); })
.def(
"__iter__", [](T& lst) { return nb::make_iterator(nb::type<T>(), "iterator", lst.begin(), lst.end()); },
nb::keep_alive<0, 1>())
.def("__getitem__",
[](T const& lst, size_t index)
{
if (index >= lst.size())
throw nb::index_error();
auto it = lst.begin();
std::advance(it, index);
return *it;
})
.def("__setitem__",
[](T& lst, size_t index, const typename T::value_type& value)
{
if (index >= lst.size())
throw nb::index_error();
auto it = lst.begin();
std::advance(it, index);
*it = value;
});
}

template <typename T>
void bindSet(nb::module_& m, std::string const& name)
{
Expand Down Expand Up @@ -93,8 +60,8 @@ void bindSet(nb::module_& m, std::string const& name)
{
s.insert(item);
}
return s;
new (&v) T(s);
});
}

} // namespace PybindUtils
} // namespace NanobindUtils
123 changes: 35 additions & 88 deletions cpp/tensorrt_llm/nanobind/common/customCasters.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <torch/csrc/autograd/variable.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>

// Pybind requires to have a central include in order for type casters to work.
// Opaque bindings add a type caster, so they have the same requirement.
Expand All @@ -48,7 +49,6 @@ NB_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet)
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::batch_manager::SlotDecoderBuffers>)
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::decoder_batch::Request>)
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::SamplingConfig>)
NB_MAKE_OPAQUE(std::vector<std::vector<tensorrt_llm::runtime::SizeType32>>)

namespace nb = nanobind;

Expand Down Expand Up @@ -128,70 +128,6 @@ struct type_caster<tensorrt_llm::common::OptionalRef<T>>
}
};

template <typename T>
struct PathCaster
{

private:
static PyObject* unicode_from_fs_native(std::string const& w)
{
return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size()));
}

static PyObject* unicode_from_fs_native(std::wstring const& w)
{
return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size()));
}

public:
static handle from_cpp(T const& path, rv_policy, cleanup_list* cleanup)
{
if (auto py_str = unicode_from_fs_native(path.native()))
{
return module_::import_("pathlib").attr("Path")(steal<object>(py_str), cleanup).release();
}
return nullptr;
}

bool from_python(handle src, uint8_t flags, cleanup_list* cleanup)
{
PyObject* native = nullptr;
if constexpr (std::is_same_v<typename T::value_type, char>)
{
if (PyUnicode_FSConverter(src.ptr(), &native) != 0)
{
if (auto* c_str = PyBytes_AsString(native))
{
// AsString returns a pointer to the internal buffer, which
// must not be free'd.
value = c_str;
}
}
}
else if constexpr (std::is_same_v<typename T::value_type, wchar_t>)
{
if (PyUnicode_FSDecoder(src.ptr(), &native) != 0)
{
if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr))
{
// AsWideCharString returns a new string that must be free'd.
value = c_str; // Copies the string.
PyMem_Free(c_str);
}
}
}
Py_XDECREF(native);
if (PyErr_Occurred())
{
PyErr_Clear();
return false;
}
return true;
}

NB_TYPE_CASTER(T, const_name("os.PathLike"));
};

template <>
class type_caster<tensorrt_llm::executor::StreamPtr>
{
Expand Down Expand Up @@ -311,34 +247,45 @@ struct type_caster<at::Tensor>

bool from_python(nb::handle src, uint8_t, cleanup_list*) noexcept
{
nb::object capsule = nb::getattr(src, "__dlpack__")();
DLManagedTensor* dl_managed = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(capsule.ptr(), "dltensor"));
PyCapsule_SetDestructor(capsule.ptr(), nullptr);
value = at::fromDLPack(dl_managed).alias();
return true;
PyObject* obj = src.ptr();
if (THPVariable_Check(obj))
{
value = THPVariable_Unpack(obj);
return true;
}
return false;
}

static handle from_cpp(at::Tensor tensor, rv_policy, cleanup_list*) noexcept
static handle from_cpp(at::Tensor src, rv_policy, cleanup_list*) noexcept
{
DLManagedTensor* dl_managed = at::toDLPack(tensor);
if (!dl_managed)
return nullptr;

nanobind::object capsule = nb::steal(PyCapsule_New(dl_managed, "dltensor",
[](PyObject* obj)
{
DLManagedTensor* dl = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(obj, "dltensor"));
dl->deleter(dl);
}));
if (!capsule.is_valid())
return THPVariable_Wrap(src);
}
};

template <typename T>
struct type_caster<std::vector<std::reference_wrapper<T const>>>
{
using VectorType = std::vector<std::reference_wrapper<T const>>;

NB_TYPE_CASTER(VectorType, const_name("List[") + make_caster<T>::Name + const_name("]"));

bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept
{
// Not needed for our use case since we only convert C++ to Python
return false;
}

static handle from_cpp(VectorType const& src, rv_policy policy, cleanup_list* cleanup) noexcept
{

std::vector<T> result;
result.reserve(src.size());
for (auto const& ref : src)
{
dl_managed->deleter(dl_managed);
return nullptr;
result.push_back(ref.get());
}
nanobind::module_ torch = nanobind::module_::import_("torch");
nanobind::object result = torch.attr("from_dlpack")(capsule);
capsule.release();
return result.release();

return make_caster<std::vector<T>>::from_cpp(result, policy, cleanup);
}
};
} // namespace detail
Expand Down
64 changes: 24 additions & 40 deletions cpp/tensorrt_llm/nanobind/executor/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,58 +52,37 @@ struct dtype_traits<half>

namespace
{
// todo: Properly support FP8 and BF16 and verify functionality
tle::Tensor numpyToTensor(nb::ndarray<nb::numpy> const& array)
tle::Tensor numpyToTensor(nb::object const& object)
{
auto npDtype = array.dtype();
char kind = '\0';
switch (npDtype.code)
{
case static_cast<uint8_t>(nb::dlpack::dtype_code::Int):
kind = 'i'; // signed integer
break;
case static_cast<uint8_t>(nb::dlpack::dtype_code::UInt):
kind = 'u'; // unsigned integer
break;
case static_cast<uint8_t>(nb::dlpack::dtype_code::Float):
kind = 'f'; // floating point
break;
case static_cast<uint8_t>(nb::dlpack::dtype_code::Bfloat):
kind = 'f'; // brain floating point (treat as float kind)
break;
case static_cast<uint8_t>(nb::dlpack::dtype_code::Complex):
kind = 'c'; // complex
break;
default:
kind = 'V'; // void/other
break;
}
std::string dtype_name = nb::cast<std::string>(object.attr("dtype").attr("name"));
nb::object metadata = object.attr("dtype").attr("metadata");

tle::DataType dtype;
if (npDtype == nb::dtype<half>())
if (dtype_name == "float16")
{
dtype = tle::DataType::kFP16;
}
else if (npDtype == nb::dtype<float>())
else if (dtype_name == "float32")
{
dtype = tle::DataType::kFP32;
}
else if (npDtype == nb::dtype<int8_t>())
else if (dtype_name == "int8")
{
dtype = tle::DataType::kINT8;
}
else if (npDtype == nb::dtype<int32_t>())
else if (dtype_name == "int32")
{
dtype = tle::DataType::kINT32;
}
else if (npDtype == nb::dtype<int64_t>())
else if (dtype_name == "int64")
{
dtype = tle::DataType::kINT64;
}
else if (kind == 'V' && array.itemsize() == 1)
else if (dtype_name == "void8" && !metadata.is_none() && nb::cast<std::string>(metadata["dtype"]) == "float8")
{
dtype = tle::DataType::kFP8;
}
else if (kind == 'V' && array.itemsize() == 2)
else if (dtype_name == "void16" && !metadata.is_none() && nb::cast<std::string>(metadata["dtype"]) == "bfloat16")
{
dtype = tle::DataType::kBF16;
}
Expand All @@ -112,16 +91,21 @@ tle::Tensor numpyToTensor(nb::ndarray<nb::numpy> const& array)
TLLM_THROW("Unsupported numpy dtype.");
}

// todo: improve the following code
nb::object array_interface = object.attr("__array_interface__");
nb::object shape_obj = array_interface["shape"];
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i)
dims.reserve(nb::len(shape_obj));

for (size_t i = 0; i < nb::len(shape_obj); ++i)
{
dims.push_back(static_cast<int64_t>(array.shape(i)));
dims.push_back(nb::cast<int64_t>(shape_obj[i]));
}
tle::Shape shape(dims.data(), dims.size());

return tle::Tensor::of(dtype, const_cast<void*>(array.data()), shape);
nb::object data_obj = array_interface["data"];
uintptr_t addr = nb::cast<uintptr_t>(data_obj[0]);
void* data_ptr = reinterpret_cast<void*>(addr);
tle::Shape shape(dims.data(), dims.size());
return tle::Tensor::of(dtype, data_ptr, shape);
}

} // namespace
Expand Down Expand Up @@ -153,8 +137,8 @@ Executor::Executor(nb::bytes const& engineBuffer, std::string const& jsonConfigS
for (auto const& [rawName, rawArray] : managedWeights.value())
{
std::string name = nb::cast<std::string>(rawName);
nb::ndarray<nb::numpy> array = nb::cast<nb::ndarray<nb::numpy>>(rawArray);
managedWeightsMap->emplace(name, numpyToTensor(array));
nb::object array_obj = nb::cast<nb::object>(rawArray);
managedWeightsMap->emplace(name, numpyToTensor(array_obj));
}
}
mExecutor = std::make_unique<tle::Executor>(
Expand Down
Loading