Skip to content
35 changes: 34 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length);
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob);

void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &stop_flags,
Expand Down Expand Up @@ -891,6 +893,31 @@ void SaveOutMmsgStatic(const paddle::Tensor& x,
int64_t rank_id,
bool save_each_rank);

void SpeculateGetLogits(const paddle::Tensor &draft_logits,
const paddle::Tensor &batch_token_num,
const paddle::Tensor &cu_batch_token_offset,
const paddle::Tensor &logits,
const paddle::Tensor &first_token_logits,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder);

void SpeculateInsertFirstToken(const paddle::Tensor &token_ids,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &next_tokens,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cu_batch_token_offset,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder);

void SpeculateGetTargetLogits(const paddle::Tensor &target_logits,
const paddle::Tensor &logits,
const paddle::Tensor &cu_batch_token_offset,
const paddle::Tensor &ori_cu_batch_token_offset,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &accept_num);

PYBIND11_MODULE(fastdeploy_ops, m) {

m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
Expand Down Expand Up @@ -1277,4 +1304,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function");

m.def("save_output", &SaveOutMmsgStatic, "save_output function");

m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function");

m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function");

m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function");
}
53 changes: 42 additions & 11 deletions custom_ops/gpu_ops/rebuild_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ __global__ void RebuildPaddingKernel(T *output_data,

template <typename T, int VecSize>
__global__ void RebuildAppendPaddingKernel(T *output_data,
T *first_token_out,
const T *input_data,
const int *cu_seqlens_q,
const int *seq_len_this_time,
Expand All @@ -55,7 +56,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
const int max_input_length,
const int dim_embed,
const int64_t output_elem_nums,
const int bsz) {
const int bsz,
const bool enable_logprob) {
AlignedVector<T, VecSize> src_vec;
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = global_idx * VecSize; i < output_elem_nums;
Expand All @@ -77,6 +79,15 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
&src_vec);
Store<T, VecSize>(src_vec, &output_data[i]);

if (enable_logprob && seq_len_encoder[bi] > 0) {
int first_token_seq_id = seq_len_encoder[bi] - 2;
const int first_token_id =
ori_token_id - cum_offset_bi + first_token_seq_id;
Load<T, VecSize>(&input_data[first_token_id * dim_embed + bias_idx],
&src_vec);
Store<T, VecSize>(src_vec, &first_token_out[i]);
}
}
}

Expand All @@ -89,7 +100,9 @@ std::vector<paddle::Tensor> rebuild_padding(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
Expand Down Expand Up @@ -135,6 +148,10 @@ std::vector<paddle::Tensor> rebuild_padding(
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
first_token_out.is_initialized()
? reinterpret_cast<DataType_ *>(const_cast<data_t *>(
first_token_out.get_ptr()->data<data_t>()))
: nullptr,
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
cu_seqlens_q.data<int>(),
seq_len_this_time.data<int>(),
Expand All @@ -144,7 +161,8 @@ std::vector<paddle::Tensor> rebuild_padding(
max_input_length,
dim_embed,
elem_nums,
bsz);
bsz,
enable_logprob);
} else {
RebuildPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
Expand All @@ -169,7 +187,9 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
switch (tmp_out.type()) {
case paddle::DataType::BFLOAT16: {
return rebuild_padding<paddle::DataType::BFLOAT16>(
Expand All @@ -179,7 +199,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
case paddle::DataType::FLOAT16: {
return rebuild_padding<paddle::DataType::FLOAT16>(
Expand All @@ -189,7 +211,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
case paddle::DataType::FLOAT32: {
return rebuild_padding<paddle::DataType::FLOAT32>(
Expand All @@ -199,7 +223,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
default: {
PD_THROW(
Expand All @@ -217,14 +243,18 @@ std::vector<paddle::Tensor> RebuildPadding(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
return {RebuildPaddingFunc(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)};
first_token_out,
max_input_length,
enable_logprob)};
}

std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
Expand Down Expand Up @@ -260,9 +290,10 @@ PD_BUILD_STATIC_OP(rebuild_padding)
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",
paddle::Optional("output_padding_offset")})
paddle::Optional("output_padding_offset"),
paddle::Optional("first_token_out")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
.Attrs({"max_input_length: int", "enable_logprob: bool"})
.SetKernelFn(PD_KERNEL(RebuildPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

#define MAX_BSZ 512
#define K 20
#define MAX_DRAFT_TOKEN_NUM 6

struct batch_msgdata {
int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)];
float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)];
int ranks[MAX_DRAFT_TOKEN_NUM];
};

struct msgdata {
long mtype;
int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums
batch_msgdata mtext[MAX_BSZ];
};

void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
const paddle::Tensor& output_scores,
const paddle::Tensor& output_ranks,
int real_k,
int64_t rank_id,
bool wait_flag) {
struct msgdata msg_rcv;
int msg_queue_id = 1;

if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);

static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
std::cout << "get_output_key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
#endif

int64_t* output_tokens_data =
const_cast<int64_t*>(output_tokens.data<int64_t>());
float* output_scores_data = const_cast<float*>(output_scores.data<float>());
int64_t* output_ranks_data =
const_cast<int64_t*>(output_ranks.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(
msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, 0);
}
if (ret == -1) {
// read none
output_tokens_data[0] = -2; // stop_flag
output_tokens_data[1] = 0; // message_flag, Target: 3, Draft: 4
output_tokens_data[2] = 0; // bsz
return;
}

int bsz = msg_rcv.meta[1];
output_tokens_data[0] = (int64_t)msg_rcv.meta[0];
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];

int output_tokens_offset = 3 + MAX_BSZ;
for (int i = 0; i < bsz; i++) {
int cur_token_num = msg_rcv.meta[3 + i];
output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums

auto* cur_output_token = output_tokens_data + output_tokens_offset +
i * (MAX_DRAFT_TOKEN_NUM * (K + 1));
auto* cur_output_score =
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1));
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
for (int j = 0; j < cur_token_num; j++) {
for (int k = 0; k < real_k + 1; k++) {
cur_output_token[j * (K + 1) + k] =
(int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k];
cur_output_score[j * (K + 1) + k] =
cur_batch_msg_rcv->scores[j * (K + 1) + k];
}
output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] =
(int64_t)cur_batch_msg_rcv->ranks[j];
}
}
return;
}

PD_BUILD_STATIC_OP(speculate_get_output_topk)
.Inputs({"output_tokens", "output_scores", "output_ranks"})
.Attrs({"real_k: int", "rank_id: int64_t", "wait_flag: bool"})
.Outputs({"output_tokens_out", "output_scores_out", "output_ranks_out"})
.SetInplaceMap({{"output_tokens", "output_tokens_out"},
{"output_scores", "output_scores_out"},
{"output_ranks", "output_ranks_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetOutMmsgTopK));
Loading
Loading