Skip to content

Commit b4c382e

Browse files
apivovarovtensorflower-gardener
authored andcommitted
Reverts 1dd7358
PiperOrigin-RevId: 806063246
1 parent 5fd2771 commit b4c382e

File tree

16 files changed

+282
-40
lines changed

16 files changed

+282
-40
lines changed

third_party/xla/xla/backends/gpu/runtime/BUILD

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ cuda_library(
890890
],
891891
)
892892

893-
cuda_library(
893+
cc_library(
894894
name = "select_k_exec_stub",
895895
srcs = ["select_k_exec_stub.cc"],
896896
hdrs = ["select_k_exec.h"],
@@ -939,7 +939,6 @@ cc_library(
939939
name = "select_k_thunk",
940940
srcs = ["select_k_thunk.cc"],
941941
hdrs = ["select_k_thunk.h"],
942-
tags = ["gpu"],
943942
deps = [
944943
":thunk",
945944
":thunk_proto_cc",
@@ -955,13 +954,32 @@ cc_library(
955954
"@com_google_absl//absl/log",
956955
"@com_google_absl//absl/log:check",
957956
"@com_google_absl//absl/status",
957+
"@com_google_absl//absl/status:statusor",
958958
"@com_google_absl//absl/strings",
959959
] + if_cuda_is_configured(
960960
[":select_k_exec_raft"],
961961
no_cuda = [":select_k_exec_stub"],
962962
),
963963
)
964964

965+
xla_cc_test(
966+
name = "select_k_thunk_test",
967+
srcs = ["select_k_thunk_test.cc"],
968+
deps = [
969+
":select_k_thunk",
970+
":thunk",
971+
":thunk_proto_cc",
972+
"//xla:literal_util",
973+
"//xla:shape_util",
974+
"//xla/codegen/emitters:kernel_arguments",
975+
"//xla/hlo/ir:hlo",
976+
"//xla/service:buffer_assignment",
977+
"//xla/tsl/platform:statusor",
978+
"//xla/tsl/util/proto:proto_matchers",
979+
"@com_google_googletest//:gtest_main",
980+
],
981+
)
982+
965983
cc_library(
966984
name = "memset_thunk",
967985
srcs = ["memset_thunk.cc"],

third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft.cc

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ class OwningScratchAllocator {
8282
return absl::NotFoundError("Pointer not found");
8383
}
8484

85+
se::DeviceMemoryAllocator* get_allocator() const { return allocator_; }
86+
87+
void set_allocator(se::DeviceMemoryAllocator* allocator) {
88+
allocator_ = allocator;
89+
}
90+
8591
private:
8692
int device_ordinal_;
8793
se::DeviceMemoryAllocator* allocator_;
@@ -96,6 +102,14 @@ class XlaDeviceMemoryResource : public rmm::mr::device_memory_resource {
96102
se::DeviceMemoryAllocator* allocator)
97103
: scratch_allocator_(device_ordinal, allocator) {}
98104

105+
se::DeviceMemoryAllocator* get_allocator() const {
106+
return scratch_allocator_.get_allocator();
107+
}
108+
109+
void set_allocator(se::DeviceMemoryAllocator* allocator) {
110+
scratch_allocator_.set_allocator(allocator);
111+
}
112+
99113
protected:
100114
void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override {
101115
auto mem = scratch_allocator_.AllocateBytes(bytes);
@@ -122,6 +136,8 @@ class XlaDeviceMemoryResource : public rmm::mr::device_memory_resource {
122136
// RAII wrapper for RAFT resources bound to a CUDA stream
123137
struct RaftStreamResource : public se::Stream::Resource {
124138
raft::resources res;
139+
std::shared_ptr<XlaDeviceMemoryResource> xla_dev_mem_res;
140+
~RaftStreamResource() override = default;
125141

126142
// Factory to create a RaftStreamResource tied to a CUDA stream.
127143
// Sets up `raft::resources` with a custom XlaDeviceMemoryResource
@@ -138,9 +154,10 @@ struct RaftStreamResource : public se::Stream::Resource {
138154
cudaStream_t cuda_stream) {
139155
// Assign our custom AllocatorForRaft for this device
140156
auto handle = std::make_unique<RaftStreamResource>();
141-
raft::resource::set_workspace_resource(
142-
handle->res,
143-
std::make_shared<XlaDeviceMemoryResource>(device_ordinal, allocator));
157+
handle->xla_dev_mem_res =
158+
std::make_shared<XlaDeviceMemoryResource>(device_ordinal, allocator);
159+
raft::resource::set_workspace_resource(handle->res,
160+
handle->xla_dev_mem_res);
144161
// Set Cuda Stream
145162
raft::resource::set_cuda_stream(handle->res,
146163
rmm::cuda_stream_view{cuda_stream});
@@ -246,6 +263,8 @@ absl::Status select_k_exec(int device_ordinal,
246263
SelectAlgo algo = choose_select_k_algorithm<T>(batch, n, k);
247264
VLOG(3) << "select_k_exec_raft: "
248265
<< "device_ordinal: " << device_ordinal << ", "
266+
<< "allocator: " << allocator << ", "
267+
<< "stream: " << stream << ", "
249268
<< "data_in: " << data_in.opaque() << " (" << data_in.size() << "B)"
250269
<< ", data_out: " << data_out.opaque() << " (" << data_out.size()
251270
<< "B)"
@@ -268,6 +287,13 @@ absl::Status select_k_exec(int device_ordinal,
268287
TF_RET_CHECK(resContainer != nullptr)
269288
<< "Failed to create or retrieve RaftStreamResource";
270289

290+
// resContainer is scoped to a single stream.
291+
// Because a stream does not execute select_k_exec concurrently from multiple
292+
// threads, it is safe to update the allocator without additional locking.
293+
if (allocator != resContainer->xla_dev_mem_res->get_allocator()) {
294+
resContainer->xla_dev_mem_res->set_allocator(allocator);
295+
}
296+
271297
try {
272298
// Wrap raw device pointers in RAFT matrix views
273299
auto input_view =

third_party/xla/xla/backends/gpu/runtime/select_k_thunk.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "absl/log/check.h"
2323
#include "absl/log/log.h"
2424
#include "absl/status/status.h"
25+
#include "absl/status/statusor.h"
2526
#include "absl/strings/str_cat.h"
2627
#include "xla/backends/gpu/runtime/select_k_exec.h"
2728
#include "xla/backends/gpu/runtime/thunk.h"
@@ -99,4 +100,14 @@ absl::Status SelectKThunk::ExecuteOnStream(const ExecuteParams& params) {
99100
primitive_util::LowercasePrimitiveTypeName(dtype_)));
100101
}
101102
}
103+
104+
absl::StatusOr<ThunkProto> SelectKThunk::ToProto() const {
105+
ThunkProto proto;
106+
*proto.mutable_thunk_info() = thunk_info().ToProto();
107+
108+
SelectKThunkProto* select_k_thunk_proto = proto.mutable_select_k_thunk();
109+
(void)select_k_thunk_proto;
110+
// TODO(upwind): Add fields for SelectKThunkProto.
111+
return proto;
112+
}
102113
} // namespace xla::gpu

third_party/xla/xla/backends/gpu/runtime/select_k_thunk.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ limitations under the License.
1616
#ifndef XLA_BACKENDS_GPU_RUNTIME_SELECT_K_THUNK_H_
1717
#define XLA_BACKENDS_GPU_RUNTIME_SELECT_K_THUNK_H_
1818

19-
#include <cstddef>
2019
#include <cstdint>
2120
#include <string>
2221
#include <vector>
2322

2423
#include "absl/status/status.h"
24+
#include "absl/status/statusor.h"
2525
#include "xla/backends/gpu/runtime/thunk.h"
2626
#include "xla/backends/gpu/runtime/thunk.pb.h"
2727
#include "xla/codegen/emitters/kernel_arguments.h"
@@ -61,6 +61,8 @@ class SelectKThunk : public Thunk {
6161
return args_;
6262
}
6363

64+
absl::StatusOr<ThunkProto> ToProto() const override;
65+
6466
private:
6567
std::uint32_t batch_size_;
6668
std::uint32_t num_elements_;
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/backends/gpu/runtime/select_k_thunk.h"
17+
18+
#include <memory>
19+
#include <vector>
20+
21+
#include <gmock/gmock.h>
22+
#include <gtest/gtest.h>
23+
#include "xla/backends/gpu/runtime/thunk.h"
24+
#include "xla/backends/gpu/runtime/thunk.pb.h"
25+
#include "xla/codegen/emitters/kernel_arguments.h"
26+
#include "xla/hlo/ir/hlo_instruction.h"
27+
#include "xla/literal_util.h"
28+
#include "xla/service/buffer_assignment.h"
29+
#include "xla/shape_util.h"
30+
#include "xla/tsl/platform/statusor.h"
31+
#include "xla/tsl/util/proto/proto_matchers.h"
32+
33+
namespace xla::gpu {
34+
namespace {
35+
36+
using ::tsl::proto_testing::EqualsProto;
37+
38+
TEST(SelectKThunkTest, ToProto) {
39+
Thunk::ThunkInfo thunk_info;
40+
thunk_info.profile_annotation = "profile_annotation";
41+
thunk_info.execution_stream_id = 123;
42+
43+
BufferAllocation alloc0(/*index=*/0, /*size=*/20, /*color=*/0);
44+
BufferAllocation::Slice slice0(&alloc0, /*offset=*/0, /*size=*/20);
45+
46+
BufferAllocation alloc1(/*index=*/1, /*size=*/12, /*color=*/0);
47+
BufferAllocation::Slice slice1(&alloc1, /*offset=*/0, /*size=*/12);
48+
49+
BufferAllocation alloc2(/*index=*/2, /*size=*/12, /*color=*/0);
50+
BufferAllocation::Slice slice2(&alloc2, /*offset=*/0, /*size=*/12);
51+
52+
emitters::KernelArgument arg0(ShapeUtil::MakeShape(F32, {1, 5}), slice0);
53+
emitters::KernelArgument arg1(ShapeUtil::MakeShape(F32, {1, 3}), slice1);
54+
emitters::KernelArgument arg2(ShapeUtil::MakeShape(U32, {1, 3}), slice2);
55+
arg0.set_written(false);
56+
arg1.set_written(true);
57+
arg2.set_written(true);
58+
59+
emitters::KernelArguments kernel_arguments({arg0, arg1, arg2});
60+
61+
auto c1 = HloInstruction::CreateConstant(
62+
LiteralUtil::CreateR2<float>({{.125f, 0.875f, .5f, .25f, 0.75f}}));
63+
auto topKInst = HloInstruction::CreateCustomCall(
64+
ShapeUtil::MakeShape(F32, {1, 5}), {c1.get()}, "__gpu$TopK");
65+
66+
SelectKThunk thunk(topKInst.get(), 1, 5, 3, F32, kernel_arguments);
67+
TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk.ToProto());
68+
EXPECT_THAT(proto, EqualsProto(R"pb(
69+
thunk_info { profile_annotation: "custom-call" }
70+
select_k_thunk {}
71+
)pb"));
72+
}
73+
74+
} // namespace
75+
} // namespace xla::gpu

third_party/xla/xla/backends/gpu/runtime/thunk.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ message MemzeroThunkProto {
133133
xla.buffer_assignment.BufferAllocationSliceProto dest_buffer = 1;
134134
}
135135

136+
message SelectKThunkProto {
137+
// TODO(upwind): Add fields for SelectKThunkProto.
138+
}
139+
136140
message ThunkProto {
137141
ThunkInfoProto thunk_info = 1;
138142

@@ -155,6 +159,7 @@ message ThunkProto {
155159
HostExecuteDoneThunkProto host_execute_done_thunk = 17;
156160
DynamicSliceThunkProto dynamic_slice_thunk = 18;
157161
MemzeroThunkProto memzero_thunk = 19;
162+
SelectKThunkProto select_k_thunk = 20;
158163
}
159164
}
160165

third_party/xla/xla/debug_options_flags.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
451451
opts.set_xla_detect_unstable_reductions(
452452
DebugOptions::UNSTABLE_REDUCTION_DETECTION_MODE_NONE);
453453
opts.set_xla_gpu_experimental_scaled_dot_with_triton(false);
454+
opts.set_xla_gpu_experimental_use_raft_select_k(false);
454455
return opts;
455456
}
456457

@@ -2556,6 +2557,12 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
25562557
"that checks for unstable reductions in HLO computations. "
25572558
"Acceptable values are: 'none', 'log', and 'crash'. 'none' is "
25582559
"the default."));
2560+
flag_list->push_back(tsl::Flag(
2561+
"xla_gpu_experimental_use_raft_select_k",
2562+
bool_setter_for(
2563+
&DebugOptions::set_xla_gpu_experimental_use_raft_select_k),
2564+
debug_options->xla_gpu_experimental_use_raft_select_k(),
2565+
"If true, use the raft::matrix::select_k implementation of TopK."));
25592566
} // NOLINT(readability/fn_size)
25602567

25612568
// Allocates flag_values and flag_objects; this function must not be called more

third_party/xla/xla/service/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ cc_library(
407407
"//xla/backends/gpu/runtime:ragged_all_to_all_thunk",
408408
"//xla/backends/gpu/runtime:recv_thunk",
409409
"//xla/backends/gpu/runtime:replica_id_thunk",
410+
"//xla/backends/gpu/runtime:select_k_thunk",
410411
"//xla/backends/gpu/runtime:send_thunk",
411412
"//xla/backends/gpu/runtime:sequential_thunk",
412413
"//xla/backends/gpu/runtime:thunk",

third_party/xla/xla/service/gpu/gpu_compiler.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,8 @@ absl::Status RunOptimizationPasses(
724724
const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts,
725725
absl::string_view platform_name) {
726726
const DebugOptions& debug_options = hlo_module->config().debug_options();
727+
se::GpuComputeCapability gpu_version =
728+
gpu_target_config.device_description.gpu_compute_capability();
727729

728730
HloPassPipeline pipeline("optimization");
729731
AddHloVerifier(&pipeline, !debug_options.xla_ignore_channel_id());
@@ -738,7 +740,7 @@ absl::Status RunOptimizationPasses(
738740
pipeline.AddPass<WindowedEinsumHandler>();
739741
}
740742
pipeline.AddPass<TopKSplitter>();
741-
pipeline.AddPass<TopkSpecializer>();
743+
pipeline.AddPass<TopkSpecializer>(gpu_version);
742744
pipeline.AddPass<TopkDecomposer>();
743745

744746
pipeline.AddPass<DotDimensionSorter>();
@@ -876,9 +878,6 @@ absl::Status RunOptimizationPasses(
876878
// Expand the sort op to support stable sorting if required.
877879
pipeline.AddPass<StableSortExpander>();
878880

879-
se::GpuComputeCapability gpu_version =
880-
gpu_target_config.device_description.gpu_compute_capability();
881-
882881
// Build simplification pipeline. The passes in here are run to a fixed
883882
// point.
884883
[&, &pipeline =

third_party/xla/xla/service/gpu/ir_emitter_unnested.cc

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ limitations under the License.
110110
#include "xla/backends/gpu/runtime/ragged_all_to_all_thunk.h"
111111
#include "xla/backends/gpu/runtime/recv_thunk.h"
112112
#include "xla/backends/gpu/runtime/replica_id_thunk.h"
113+
#include "xla/backends/gpu/runtime/select_k_thunk.h"
113114
#include "xla/backends/gpu/runtime/send_thunk.h"
114115
#include "xla/backends/gpu/runtime/sequential_thunk.h"
115116
#include "xla/backends/gpu/runtime/thunk.h"
@@ -1428,25 +1429,53 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall(
14281429
: std::tuple<size_t, size_t, size_t>{
14291430
1, data_shape.dimensions(0), top_elements_shape.dimensions(0)};
14301431

1431-
auto wavefront_size =
1432-
ir_emitter_context_->gpu_device_info().threads_per_warp();
1433-
1434-
// Load TopK custom kernel.
1435-
TF_ASSIGN_OR_RETURN(
1436-
CustomKernel kernel,
1437-
kernel::topk::GetTopKKernel("topk", data_shape.element_type(), n, k,
1438-
batch_size, platform_name(), wavefront_size));
1439-
14401432
// Prepare kernel arguments.
14411433
TF_ASSIGN_OR_RETURN(auto kernel_arguments,
14421434
emitters::KernelArguments::Create(
14431435
ir_emitter_context_->buffer_assignment(),
14441436
GetDefaultBufferAlignment(), instr));
14451437

1446-
auto thunk = std::make_unique<CustomKernelThunk>(instr, std::move(kernel),
1447-
kernel_arguments);
1448-
AddThunkToThunkSequence(std::move(thunk));
1438+
auto dtype = data_shape.element_type();
1439+
bool is_cuda = std::holds_alternative<stream_executor::CudaComputeCapability>(
1440+
ir_emitter_context_->gpu_compute_capability());
1441+
if (is_cuda && instr->GetModule()
1442+
->config()
1443+
.debug_options()
1444+
.xla_gpu_experimental_use_raft_select_k()) {
1445+
// The heuristic for deciding when to use TopK Custom Kernel versus
1446+
// Raft::matrix::select_k was developed as part of the initial research
1447+
// in b/409009349.
1448+
// CustomCall TopK requires k <= 16 and n >= 1024
1449+
bool use_raft_select_k = false;
1450+
if (dtype == PrimitiveType::F32) {
1451+
use_raft_select_k =
1452+
(n < 1024) || (n == 1024 && k > 12) || (n > 1024 && k >= 8);
1453+
} else if (dtype == PrimitiveType::BF16) {
1454+
use_raft_select_k = n < 1024 || k >= 8;
1455+
}
1456+
1457+
VLOG(3) << "EmitTopKCustomCall: dtype=" << dtype << ", n=" << n
1458+
<< ", k=" << k << ", use_raft_select_k=" << use_raft_select_k;
1459+
1460+
if (use_raft_select_k) {
1461+
AddThunkToThunkSequence(std::make_unique<SelectKThunk>(
1462+
instr, batch_size, n, k, dtype, kernel_arguments));
1463+
return absl::OkStatus();
1464+
}
1465+
}
1466+
1467+
auto wavefront_size =
1468+
ir_emitter_context_->gpu_device_info().threads_per_warp();
1469+
1470+
TF_RET_CHECK(k <= 16) << "CustomCall TopK requires k <= 16";
1471+
// Load TopK custom kernel.
1472+
TF_ASSIGN_OR_RETURN(
1473+
CustomKernel kernel,
1474+
kernel::topk::GetTopKKernel("topk", dtype, n, k, batch_size,
1475+
platform_name(), wavefront_size));
14491476

1477+
AddThunkToThunkSequence(std::make_unique<CustomKernelThunk>(
1478+
instr, std::move(kernel), kernel_arguments));
14501479
return absl::OkStatus();
14511480
}
14521481

0 commit comments

Comments
 (0)