@@ -2194,18 +2194,20 @@ void PjRtStreamExecutorBuffer::CopyToRemoteDevice(
2194
2194
2195
2195
PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture () {
2196
2196
absl::InlinedVector<BufferSequencingEventRef, 2 > definition_events;
2197
- PjRtFuture<>::Promise definition_promise;
2197
+ PjRtFuture<>::MoveOnlyPromise definition_promise;
2198
+ PjRtFuture<> definition_future;
2198
2199
{
2199
2200
absl::MutexLock lock (&mu_);
2200
2201
if (device_buffer () == nullptr ) {
2201
2202
return PjRtFuture<>(InvalidArgument (
2202
2203
" GetReadyFuture() called on deleted or donated buffer" ));
2203
2204
}
2204
- if (!definition_promise_ ) {
2205
+ if (!definition_future_ ) {
2205
2206
definition_events = device_buffer ()->definition_events ();
2206
- definition_promise_ = PjRtFuture<>::CreatePromise ();
2207
+ std::tie (definition_promise, definition_future_) =
2208
+ PjRtFuture<>::MakePromise ();
2207
2209
}
2208
- definition_promise = definition_promise_ ;
2210
+ definition_future = definition_future_ ;
2209
2211
}
2210
2212
2211
2213
if (!definition_events.empty ()) {
@@ -2214,12 +2216,13 @@ PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() {
2214
2216
auto async_wait_for_events =
2215
2217
[definition_events = std::move (definition_events),
2216
2218
local_device_state = std::move (local_device_state),
2217
- definition_promise]() mutable {
2219
+ definition_promise = std::make_shared<PjRtFuture<>::Promise>(
2220
+ std::move (definition_promise))]() mutable {
2218
2221
std::unique_ptr<se::Stream> stream;
2219
2222
absl::Status defined_status =
2220
2223
definition_events[0 ]->GetDefinedStatus ();
2221
2224
if (!defined_status.ok ()) {
2222
- definition_promise. Set (defined_status);
2225
+ definition_promise-> Set (defined_status);
2223
2226
return ;
2224
2227
}
2225
2228
for (auto & event : definition_events) {
@@ -2242,28 +2245,29 @@ PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() {
2242
2245
event_with_status = definition_events[0 ]]() mutable {
2243
2246
local_device_state->ReturnStreamToPool (
2244
2247
std::unique_ptr<se::Stream>(stream_ptr));
2245
- definition_promise.Set (event_with_status->GetDefinedStatus ());
2248
+ definition_promise->Set (
2249
+ event_with_status->GetDefinedStatus ());
2246
2250
});
2247
2251
if (!status.ok ()) {
2248
- definition_promise. Set (status);
2252
+ definition_promise-> Set (status);
2249
2253
return ;
2250
2254
}
2251
2255
} else {
2252
2256
// All events are already complete; set the `definition_promise`
2253
2257
// with the status of the buffer's first definition event which may
2254
2258
// have error status to propagate.
2255
- definition_promise. Set (definition_events[0 ]->GetDefinedStatus ());
2259
+ definition_promise-> Set (definition_events[0 ]->GetDefinedStatus ());
2256
2260
}
2257
2261
};
2258
2262
first_definition_event->ExecuteOrAddToFutureTasks (
2259
2263
absl::StrFormat (" async_wait_for_events_%p" , &async_wait_for_events),
2260
2264
std::move (async_wait_for_events));
2261
2265
}
2262
2266
2263
- return PjRtFuture<> (
2264
- std::move (definition_promise ),
2267
+ return PjRtFutureHelpers::WithProfiling (
2268
+ std::move (definition_future ),
2265
2269
/* on_block_start=*/
2266
- []() {
2270
+ [] {
2267
2271
tsl::profiler::TraceMeProducer traceme (
2268
2272
" PjRtStreamExecutorBuffer::Await" );
2269
2273
VLOG (3 ) << " PjRtStreamExecutorBuffer::Await" ;
0 commit comments