Skip to content

Commit 113f6fb

Browse files
authored
Fix: missing clientId when serialize and deserialize response (#5231)
Signed-off-by: Kaiyu Xie <[email protected]>
1 parent 7246fd7 commit 113f6fb

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

cpp/tensorrt_llm/executor/serialization.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,22 +946,26 @@ Response Serialization::deserializeResponse(std::istream& is)
946946
{
947947
auto requestId = su::deserialize<IdType>(is);
948948
auto errOrResult = su::deserialize<std::variant<std::string, Result>>(is);
949+
auto clientId = su::deserialize<std::optional<IdType>>(is);
949950

950-
return std::holds_alternative<std::string>(errOrResult) ? Response{requestId, std::get<std::string>(errOrResult)}
951-
: Response{requestId, std::get<Result>(errOrResult)};
951+
return std::holds_alternative<std::string>(errOrResult)
952+
? Response{requestId, std::get<std::string>(errOrResult), clientId}
953+
: Response{requestId, std::get<Result>(errOrResult), clientId};
952954
}
953955

954956
void Serialization::serialize(Response const& response, std::ostream& os)
955957
{
956958
su::serialize(response.mImpl->mRequestId, os);
957959
su::serialize(response.mImpl->mErrOrResult, os);
960+
su::serialize(response.mImpl->mClientId, os);
958961
}
959962

960963
size_t Serialization::serializedSize(Response const& response)
961964
{
962965
size_t totalSize = 0;
963966
totalSize += su::serializedSize(response.mImpl->mRequestId);
964967
totalSize += su::serializedSize(response.mImpl->mErrOrResult);
968+
totalSize += su::serializedSize(response.mImpl->mClientId);
965969
return totalSize;
966970
}
967971

cpp/tests/unit_tests/executor/serializeUtilsTest.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ void compareResponse(texec::Response res, texec::Response res2)
160160
{
161161
compareResult(res.getResult(), res2.getResult());
162162
}
163+
EXPECT_EQ(res.getClientId(), res2.getClientId());
163164
}
164165

165166
template <typename T>
@@ -428,11 +429,15 @@ TEST(SerializeUtilsTest, ResultResponse)
428429
auto val = texec::Response(1, "my error msg");
429430
testSerializeDeserialize(val);
430431
}
432+
{
433+
auto val = texec::Response(1, "my error msg", 2);
434+
testSerializeDeserialize(val);
435+
}
431436
}
432437

433438
TEST(SerializeUtilsTest, VectorResponses)
434439
{
435-
int numResponses = 10;
440+
int numResponses = 15;
436441
std::vector<texec::Response> responsesIn;
437442
for (int i = 0; i < numResponses; ++i)
438443
{
@@ -443,11 +448,16 @@ TEST(SerializeUtilsTest, VectorResponses)
443448
std::nullopt, std::vector<texec::FinishReason>{texec::FinishReason::kEND_ID}};
444449
responsesIn.emplace_back(i, res);
445450
}
446-
else
451+
else if (i < 10)
447452
{
448453
std::string errMsg = "my_err_msg" + std::to_string(i);
449454
responsesIn.emplace_back(i, errMsg);
450455
}
456+
else
457+
{
458+
std::string errMsg = "my_err_msg" + std::to_string(i);
459+
responsesIn.emplace_back(i, errMsg, i + 1);
460+
}
451461
}
452462

453463
auto buffer = texec::Serialization::serialize(responsesIn);

0 commit comments

Comments
 (0)