Skip to content

Commit 5fd2771

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla] Use PjRtFuture::Promise::ToShared instead of std::make_shared
PiperOrigin-RevId: 806062172
1 parent b644c6b commit 5fd2771

File tree

5 files changed

+57
-65
lines changed

5 files changed

+57
-65
lines changed

third_party/xla/xla/pjrt/pjrt_c_api_client.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2656,8 +2656,7 @@ void PjRtCApiBuffer::MakePromiseTrackEvent() {
26562656
PjRtFuture<> PjRtCApiBuffer::GetReadyFuture() {
26572657
if (readiness_promise_ == nullptr) {
26582658
auto [promise, future] = PjRtFuture<>::MakePromise();
2659-
readiness_promise_ =
2660-
std::make_shared<PjRtFuture<>::Promise>(std::move(promise));
2659+
readiness_promise_ = std::move(promise).ToShared();
26612660
readiness_future_ = std::move(future);
26622661
MakePromiseTrackEvent();
26632662
}

third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ GrpcClientSession::GrpcClientSession(
125125
Future<std::shared_ptr<IfrtResponse>> GrpcClientSession::Enqueue(
126126
std::unique_ptr<IfrtRequest> request) {
127127
auto [promise, future] = Future<std::shared_ptr<IfrtResponse>>::MakePromise();
128-
auto shared_promise = std::make_shared<decltype(promise)>(std::move(promise));
128+
auto shared_promise = std::move(promise).ToShared();
129129
absl::Status status = Enqueue(
130130
std::move(request),
131131
[promise = std::move(shared_promise),

third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -100,49 +100,48 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle,
100100
std::unique_ptr<std::string> buffered_data;
101101

102102
auto reservation = ScopedAcquireSemaphore(store_throttler_);
103-
work_queue_->Schedule(
104-
[this, reservation = std::move(reservation), handle,
105-
promise = std::make_shared<Future<>::Promise>(std::move(promise)), data,
106-
flow]() mutable -> void {
107-
auto span = flow.Span<XFlowHelper::kRecv>();
108-
GrpcHostBufferStoreMetadata metadata;
109-
metadata.set_session_id(session_id_);
110-
metadata.set_handle(handle);
111-
metadata.set_buffer_size(data.size());
112-
VLOG(3) << "GrpcClientHostBufferStore::Store start "
113-
<< metadata.ShortDebugString();
114-
115-
::grpc::ClientContext context;
116-
context.AddMetadata("ifrt-proxy-grpc-host-buffer-store-metadata-bin",
117-
metadata.SerializeAsString());
118-
119-
GrpcHostBufferStoreResponse response;
120-
auto writer = stub_->HostBufferStore(&context, &response);
121-
122-
{
123-
tsl::profiler::TraceMe trace_me_send_data([size = data.size()]() {
124-
return tsl::profiler::TraceMeEncode(
125-
"GrpcClientHostBufferStore::StoreAsync_Send", {{"size", size}});
126-
});
127-
for (int64_t offset = 0; offset < data.size(); offset += kChunkSize) {
128-
GrpcHostBufferStoreRequest request;
129-
SetDataFromStringView(request, data.substr(offset, kChunkSize));
130-
writer->Write(request);
131-
}
132-
133-
if (!writer->WritesDone()) {
134-
absl::Status s = xla::FromGrpcStatus(writer->Finish());
135-
promise->Set(absl::InternalError(absl::StrCat(
136-
"Failed to write all host buffer chunks, Finish() returned: ",
137-
s.ToString())));
138-
return;
139-
}
140-
}
141-
142-
VLOG(3) << "GrpcClientHostBufferStore::Store done "
143-
<< metadata.ShortDebugString();
144-
promise->Set(xla::FromGrpcStatus(writer->Finish()));
103+
work_queue_->Schedule([this, reservation = std::move(reservation), handle,
104+
promise = std::move(promise).ToShared(), data,
105+
flow]() mutable -> void {
106+
auto span = flow.Span<XFlowHelper::kRecv>();
107+
GrpcHostBufferStoreMetadata metadata;
108+
metadata.set_session_id(session_id_);
109+
metadata.set_handle(handle);
110+
metadata.set_buffer_size(data.size());
111+
VLOG(3) << "GrpcClientHostBufferStore::Store start "
112+
<< metadata.ShortDebugString();
113+
114+
::grpc::ClientContext context;
115+
context.AddMetadata("ifrt-proxy-grpc-host-buffer-store-metadata-bin",
116+
metadata.SerializeAsString());
117+
118+
GrpcHostBufferStoreResponse response;
119+
auto writer = stub_->HostBufferStore(&context, &response);
120+
121+
{
122+
tsl::profiler::TraceMe trace_me_send_data([size = data.size()]() {
123+
return tsl::profiler::TraceMeEncode(
124+
"GrpcClientHostBufferStore::StoreAsync_Send", {{"size", size}});
145125
});
126+
for (int64_t offset = 0; offset < data.size(); offset += kChunkSize) {
127+
GrpcHostBufferStoreRequest request;
128+
SetDataFromStringView(request, data.substr(offset, kChunkSize));
129+
writer->Write(request);
130+
}
131+
132+
if (!writer->WritesDone()) {
133+
absl::Status s = xla::FromGrpcStatus(writer->Finish());
134+
promise->Set(absl::InternalError(absl::StrCat(
135+
"Failed to write all host buffer chunks, Finish() returned: ",
136+
s.ToString())));
137+
return;
138+
}
139+
}
140+
141+
VLOG(3) << "GrpcClientHostBufferStore::Store done "
142+
<< metadata.ShortDebugString();
143+
promise->Set(xla::FromGrpcStatus(writer->Finish()));
144+
});
146145
return std::move(future);
147146
}
148147

@@ -201,9 +200,7 @@ Future<absl::Cord> GrpcClientHostBufferStore::Lookup(uint64_t handle) {
201200

202201
auto reservation = ScopedAcquireSemaphore(lookup_throttler_);
203202
work_queue_->Schedule([this, reservation = std::move(reservation), handle,
204-
promise =
205-
std::make_shared<Future<absl::Cord>::Promise>(
206-
std::move(promise)),
203+
promise = std::move(promise).ToShared(),
207204
flow]() mutable -> void {
208205
auto span = flow.Span<XFlowHelper::kRecv>();
209206
GrpcHostBufferLookupRequest request;

third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,7 @@ Future<BackendInterface::Response> IfrtBackend::AsyncExecute(
667667
++in_flight_count_;
668668
}
669669
auto [promise, future] = Future<Response>::MakePromise();
670-
auto f = [this,
671-
promise =
672-
std::make_shared<Future<Response>::Promise>(std::move(promise)),
670+
auto f = [this, promise = std::move(promise).ToShared(),
673671
handle_fn = std::move(handle_fn)]() mutable {
674672
promise->Set(handle_fn());
675673
{

third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,7 @@ absl::StatusOr<ArrayRef> AssembleStringArrayFromSingleDeviceStringArrays(
660660
Future<BasicStringArray::Buffers>::MakePromise();
661661

662662
auto buffer_copier = [state = buffer_copying_state,
663-
promise = std::make_shared<decltype(buffers_promise)>(
664-
std::move(buffers_promise))](
663+
promise = std::move(buffers_promise).ToShared()](
665664
absl::StatusOr<BasicStringArray::Buffers> strbuf,
666665
int shard_index) mutable {
667666
absl::MutexLock lock(&state->mu);
@@ -1508,18 +1507,17 @@ absl::Status PjRtClient::CrossHostSendBuffers(
15081507
// keys together to reduce the number of threads used.
15091508
for (int i = 0; i < keys.size(); ++i) {
15101509
auto [promise, descriptor_future] = PjRtFuture<std::string>::MakePromise();
1511-
work_queue_->Schedule([this, k = keys[i],
1512-
promise = std::make_shared<decltype(promise)>(
1513-
std::move(promise))]() mutable {
1514-
std::string key = absl::StrCat(kKeyPrefix, k);
1515-
absl::StatusOr<std::string> descriptor =
1516-
kv_store_->Get(key, cross_host_transfer_timeout_);
1517-
if (!descriptor.ok()) {
1518-
LOG(FATAL) << "Failed to get descriptor for key " << key << ": "
1519-
<< descriptor.status();
1520-
}
1521-
promise->Set(std::move(*descriptor));
1522-
});
1510+
work_queue_->Schedule(
1511+
[this, k = keys[i], promise = std::move(promise).ToShared()]() mutable {
1512+
std::string key = absl::StrCat(kKeyPrefix, k);
1513+
absl::StatusOr<std::string> descriptor =
1514+
kv_store_->Get(key, cross_host_transfer_timeout_);
1515+
if (!descriptor.ok()) {
1516+
LOG(FATAL) << "Failed to get descriptor for key " << key << ": "
1517+
<< descriptor.status();
1518+
}
1519+
promise->Set(std::move(*descriptor));
1520+
});
15231521
auto on_done = [](absl::Status status, bool sends_were_enqueued) {
15241522
CHECK_OK(status);
15251523
};

0 commit comments

Comments
 (0)