Skip to content

Commit b644c6b

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla] Migrate to PjRtFuture<>::MakePromise() API
PiperOrigin-RevId: 806057470
1 parent 7fa9f35 commit b644c6b

File tree

5 files changed

+33
-21
lines changed

5 files changed

+33
-21
lines changed

third_party/xla/xla/pjrt/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ cc_library(
109109
"@com_google_absl//absl/log:check",
110110
"@com_google_absl//absl/status",
111111
"@com_google_absl//absl/status:statusor",
112+
"@com_google_absl//absl/strings",
112113
"@com_google_absl//absl/synchronization",
113114
"@local_tsl//tsl/profiler/lib:traceme",
114115
],

third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@ limitations under the License.
1818

1919
#include <array>
2020
#include <memory>
21+
#include <vector>
2122

2223
#include "absl/base/thread_annotations.h"
24+
#include "absl/functional/any_invocable.h"
2325
#include "absl/log/check.h"
2426
#include "absl/status/status.h"
2527
#include "absl/status/statusor.h"
28+
#include "absl/strings/str_cat.h"
2629
#include "absl/synchronization/mutex.h"
30+
#include "xla/pjrt/device_event.h"
2731
#include "xla/pjrt/pjrt_client.h"
2832
#include "xla/pjrt/pjrt_future.h"
2933
#include "xla/pjrt/raw_buffer.h"
34+
#include "xla/tsl/concurrency/async_value.h"
35+
#include "xla/tsl/concurrency/ref_count.h"
3036

3137
namespace xla {
3238

@@ -224,9 +230,8 @@ class CommonPjRtBuffer : public PjRtBuffer {
224230

225231
absl::Status AcquireScopedRawBuffer(
226232
absl::AnyInvocable<absl::StatusOr<tsl::RCReference<PjRtDeviceEvent>>(
227-
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
228-
std::vector<tsl::RCReference<tsl::AsyncValue>>
229-
definition_events) &&>
233+
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
234+
std::vector<tsl::RCReference<tsl::AsyncValue>> definition_events) &&>
230235
scoped_acquire,
231236
const char* caller_name = "AcquireScopedRawBuffer");
232237

@@ -294,7 +299,7 @@ class CommonPjRtBuffer : public PjRtBuffer {
294299
}
295300

296301
mutable absl::Mutex mu_;
297-
PjRtFuture<>::Promise definition_promise_ ABSL_GUARDED_BY(mu_);
302+
PjRtFuture<> definition_future_ ABSL_GUARDED_BY(mu_);
298303
PjRtMemorySpace* const memory_space_;
299304

300305
private:

third_party/xla/xla/pjrt/common_pjrt_client.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,12 +1194,12 @@ PjRtFuture<> CommonPjRtBufferImpl::GetReadyFuture() {
11941194
return PjRtFuture<>(InvalidArgument(
11951195
"GetReadyFuture() called on deleted or donated buffer"));
11961196
}
1197-
if (!definition_promise_) {
1198-
definition_promise_ =
1199-
device_buffer()->GetReadyFuturePromise(memory_space());
1197+
if (!definition_future_) {
1198+
auto promise = device_buffer()->GetReadyFuturePromise(memory_space());
1199+
definition_future_ = client()->CreateFutureFromUserPromise(
1200+
memory_space(), "CommonPjRtBuffer", "Await", std::move(promise));
12001201
}
1201-
return client()->CreateFutureFromUserPromise(
1202-
memory_space(), "CommonPjRtBuffer", "Await", definition_promise_);
1202+
return definition_future_;
12031203
}
12041204

12051205
} // namespace xla

third_party/xla/xla/pjrt/pjrt_future.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ class PjRtFutureBase : public PjRtFutureMoveControl<is_move_only> {
217217
}
218218
}
219219

220+
explicit operator bool() const { return static_cast<bool>(promise_); }
221+
220222
protected:
221223
static constexpr bool IsMoveOnly() { return is_move_only; }
222224

third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,18 +2194,20 @@ void PjRtStreamExecutorBuffer::CopyToRemoteDevice(
21942194

21952195
PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() {
21962196
absl::InlinedVector<BufferSequencingEventRef, 2> definition_events;
2197-
PjRtFuture<>::Promise definition_promise;
2197+
PjRtFuture<>::MoveOnlyPromise definition_promise;
2198+
PjRtFuture<> definition_future;
21982199
{
21992200
absl::MutexLock lock(&mu_);
22002201
if (device_buffer() == nullptr) {
22012202
return PjRtFuture<>(InvalidArgument(
22022203
"GetReadyFuture() called on deleted or donated buffer"));
22032204
}
2204-
if (!definition_promise_) {
2205+
if (!definition_future_) {
22052206
definition_events = device_buffer()->definition_events();
2206-
definition_promise_ = PjRtFuture<>::CreatePromise();
2207+
std::tie(definition_promise, definition_future_) =
2208+
PjRtFuture<>::MakePromise();
22072209
}
2208-
definition_promise = definition_promise_;
2210+
definition_future = definition_future_;
22092211
}
22102212

22112213
if (!definition_events.empty()) {
@@ -2214,12 +2216,13 @@ PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() {
22142216
auto async_wait_for_events =
22152217
[definition_events = std::move(definition_events),
22162218
local_device_state = std::move(local_device_state),
2217-
definition_promise]() mutable {
2219+
definition_promise = std::make_shared<PjRtFuture<>::Promise>(
2220+
std::move(definition_promise))]() mutable {
22182221
std::unique_ptr<se::Stream> stream;
22192222
absl::Status defined_status =
22202223
definition_events[0]->GetDefinedStatus();
22212224
if (!defined_status.ok()) {
2222-
definition_promise.Set(defined_status);
2225+
definition_promise->Set(defined_status);
22232226
return;
22242227
}
22252228
for (auto& event : definition_events) {
@@ -2242,28 +2245,29 @@ PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() {
22422245
event_with_status = definition_events[0]]() mutable {
22432246
local_device_state->ReturnStreamToPool(
22442247
std::unique_ptr<se::Stream>(stream_ptr));
2245-
definition_promise.Set(event_with_status->GetDefinedStatus());
2248+
definition_promise->Set(
2249+
event_with_status->GetDefinedStatus());
22462250
});
22472251
if (!status.ok()) {
2248-
definition_promise.Set(status);
2252+
definition_promise->Set(status);
22492253
return;
22502254
}
22512255
} else {
22522256
// All events are already complete; set the `definition_promise`
22532257
// with the status of the buffer's first definition event which may
22542258
// have error status to propagate.
2255-
definition_promise.Set(definition_events[0]->GetDefinedStatus());
2259+
definition_promise->Set(definition_events[0]->GetDefinedStatus());
22562260
}
22572261
};
22582262
first_definition_event->ExecuteOrAddToFutureTasks(
22592263
absl::StrFormat("async_wait_for_events_%p", &async_wait_for_events),
22602264
std::move(async_wait_for_events));
22612265
}
22622266

2263-
return PjRtFuture<>(
2264-
std::move(definition_promise),
2267+
return PjRtFutureHelpers::WithProfiling(
2268+
std::move(definition_future),
22652269
/*on_block_start=*/
2266-
[]() {
2270+
[] {
22672271
tsl::profiler::TraceMeProducer traceme(
22682272
"PjRtStreamExecutorBuffer::Await");
22692273
VLOG(3) << "PjRtStreamExecutorBuffer::Await";

0 commit comments

Comments
 (0)