Skip to content

Commit 02550bc

Browse files
peterbell10pytorchmergebot
authored andcommitted
Support non-standard bools in CUDA mode (pytorch#79393)
Closes pytorch#54789 For the `fused_mode` kernel, this just uses `c10::load` but the `apply_mode` function is a bit harder because it uses `thrust`. Instead, I've added a second dedicated path for bool which also only uses 2 thrust calls instead of the normal 6, by exploiting the fact that bools only have two possible values. In the following `timeit` benchmark which calls the `apply_mode` version, I see execution time drop from 16.9 ms to 2.2 ms (which is still terrible, but my main goal is fixing the bool handling). ```python import torch a = torch.randint( 0, 2, size=(100, 4096), device='cuda', dtype=torch.bool) %timeit a.mode(1) ``` Pull Request resolved: pytorch#79393 Approved by: https://github.com/ngimel
1 parent 5880a66 commit 02550bc

File tree

3 files changed

+116
-56
lines changed

3 files changed

+116
-56
lines changed

aten/src/ATen/native/cuda/TensorModeKernel.cu

Lines changed: 112 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
#include <ATen/cuda/ThrustAllocator.h>
88
#include <c10/core/DeviceArray.h>
99

10+
#include <thrust/count.h>
1011
#include <thrust/device_ptr.h>
1112
#include <thrust/device_vector.h>
1213
#include <thrust/execution_policy.h>
1314
#include <thrust/extrema.h>
15+
#include <thrust/find.h>
1416
#include <thrust/inner_product.h>
1517
#include <thrust/iterator/constant_iterator.h>
1618
#include <thrust/sequence.h>
@@ -19,16 +21,119 @@
1921
namespace at {
2022
namespace native {
2123

24+
template <typename scalar_t>
25+
struct ModeImpl {
26+
std::tuple<scalar_t, int64_t> operator()(
27+
scalar_t *iter_begin,
28+
scalar_t *iter_end) {
29+
at::cuda::ThrustAllocator thrust_allocator;
30+
auto stream = at::cuda::getCurrentCUDAStream();
31+
auto policy = thrust::cuda::par(thrust_allocator).on(stream);
32+
33+
const auto n_element = iter_end - iter_begin;
34+
auto cuda_allocator = at::cuda::getCUDADeviceAllocator();
35+
auto sort_buffer = c10::DeviceArray<int64_t>(*cuda_allocator, n_element);
36+
auto sort_buffer_ptr = thrust::device_pointer_cast(sort_buffer.get());
37+
auto count_from_zero_iter = thrust::make_counting_iterator(int64_t{0});
38+
thrust::copy_n(policy, count_from_zero_iter, n_element, sort_buffer_ptr);
39+
40+
41+
// Sort the input data. The original indices of the data are stored in
42+
// sort_buffer_ptr
43+
thrust::sort_by_key(policy, iter_begin, iter_end, sort_buffer_ptr);
44+
45+
// Count # of unique elements via an inner product between adjacent elements.
46+
// Add 1 if two neighboring element are not equal.
47+
int unique = 1 +
48+
thrust::inner_product(
49+
policy,
50+
iter_begin,
51+
iter_end - 1,
52+
iter_begin + 1,
53+
0,
54+
thrust::plus<int>(),
55+
thrust::not_equal_to<scalar_t>());
56+
57+
// Count frequency of each element
58+
auto keys = c10::DeviceArray<scalar_t>(*cuda_allocator, unique);
59+
auto counts = c10::DeviceArray<int64_t>(*cuda_allocator, unique);
60+
61+
auto keys_ptr = thrust::device_pointer_cast(keys.get());
62+
auto counts_ptr = thrust::device_pointer_cast(counts.get());
63+
64+
thrust::reduce_by_key(
65+
policy,
66+
iter_begin,
67+
iter_end,
68+
thrust::constant_iterator<int>(1),
69+
keys_ptr,
70+
counts_ptr);
71+
72+
// Find index of maximum count
73+
auto it = thrust::max_element(policy, counts_ptr, counts_ptr + unique);
74+
scalar_t mode = keys_ptr[it - counts_ptr];
75+
76+
// Find first index within which it occurs
77+
auto position_iter = thrust::find(policy, iter_begin, iter_end, mode);
78+
79+
// Translate to original non-sorted index
80+
TORCH_INTERNAL_ASSERT(position_iter != iter_end);
81+
int64_t index = sort_buffer_ptr[position_iter - iter_begin];
82+
return {mode, index};
83+
}
84+
};
85+
86+
struct EqualsMode {
87+
bool mode;
88+
89+
C10_DEVICE bool operator()(const uint8_t x) {
90+
return static_cast<bool>(x) == mode;
91+
}
92+
};
93+
94+
template <>
95+
struct ModeImpl<bool> {
96+
std::tuple<bool, int64_t> operator()(
97+
const bool *first,
98+
const bool *last) {
99+
at::cuda::ThrustAllocator thrust_allocator;
100+
auto stream = at::cuda::getCurrentCUDAStream();
101+
auto policy = thrust::cuda::par(thrust_allocator).on(stream);
102+
103+
// For bool, we can skip finding the unique elements since there
104+
// are only two possible values.
105+
106+
// See NOTE [Loading boolean values]
107+
auto first_bytes = reinterpret_cast<const uint8_t*>(first);
108+
auto last_bytes = reinterpret_cast<const uint8_t*>(last);
109+
110+
const auto numel = last - first;
111+
const auto num_true = thrust::count_if(
112+
policy,
113+
first_bytes,
114+
last_bytes,
115+
[] GPU_LAMBDA (uint8_t x) {
116+
return static_cast<bool>(x);
117+
}
118+
);
119+
const auto num_false = (numel - num_true);
120+
const auto mode = num_true > num_false;
121+
122+
// Find first index within which it occurs
123+
const auto position_iter = thrust::find_if(
124+
policy, first_bytes, last_bytes, EqualsMode{mode});
125+
const int64_t index = position_iter - first_bytes;
126+
return {mode, index};
127+
}
128+
};
129+
22130
template <typename scalar_t>
23131
void calculate_mode(
24132
const TensorBase& values,
25133
const TensorBase& indices,
26134
const TensorBase& self,
27135
std::vector<int64_t>& position,
28136
int dim) {
29-
at::cuda::ThrustAllocator thrust_allocator;
30-
auto stream = at::cuda::getCurrentCUDAStream();
31-
auto policy = thrust::cuda::par(thrust_allocator).on(stream);
32137

33138
TORCH_INTERNAL_ASSERT(self.is_contiguous());
34139

@@ -47,53 +152,9 @@ void calculate_mode(
47152
scalar_t* iter_begin = data;
48153
scalar_t* iter_end = data + n_element;
49154

50-
auto cuda_allocator = at::cuda::getCUDADeviceAllocator();
51-
auto sort_buffer = c10::DeviceArray<int64_t>(*cuda_allocator, n_element);
52-
auto sort_buffer_ptr = thrust::device_pointer_cast(sort_buffer.get());
53-
auto count_from_zero_iter = thrust::make_counting_iterator(int64_t{0});
54-
thrust::copy_n(policy, count_from_zero_iter, n_element, sort_buffer_ptr);
55-
56-
57-
// Sort the input data. The original indices of the data are stored in
58-
// sort_buffer_ptr
59-
thrust::sort_by_key(policy, iter_begin, iter_end, sort_buffer_ptr);
60-
61-
// Count # of unique elements via an inner product between adjacent elements.
62-
// Add 1 if two neighboring element are not equal.
63-
int unique = 1 +
64-
thrust::inner_product(
65-
policy,
66-
iter_begin,
67-
iter_end - 1,
68-
iter_begin + 1,
69-
0,
70-
thrust::plus<int>(),
71-
thrust::not_equal_to<scalar_t>());
72-
73-
// Count frequency of each element
74-
auto keys = c10::DeviceArray<scalar_t>(*cuda_allocator, unique);
75-
auto counts = c10::DeviceArray<int64_t>(*cuda_allocator, unique);
76-
77-
auto keys_ptr = thrust::device_pointer_cast(keys.get());
78-
auto counts_ptr = thrust::device_pointer_cast(counts.get());
79-
80-
thrust::reduce_by_key(
81-
policy,
82-
iter_begin,
83-
iter_end,
84-
thrust::constant_iterator<int>(1),
85-
keys_ptr,
86-
counts_ptr);
87-
88-
// Find index of maximum count
89-
auto it = thrust::max_element(policy, counts_ptr, counts_ptr + unique);
90-
scalar_t mode = keys_ptr[it - counts_ptr];
91-
92-
// Find first index within which it occurs
93-
auto position_iter = thrust::find(policy, iter_begin, iter_end, mode);
94-
95-
TORCH_INTERNAL_ASSERT(position_iter != iter_end);
96-
int64_t index = sort_buffer_ptr[position_iter - iter_begin];
155+
scalar_t mode;
156+
int64_t index;
157+
std::tie(mode, index) = ModeImpl<scalar_t>{}(iter_begin, iter_end);
97158

98159
// Place mode, index in output
99160
scalar_t* values_data = values.data_ptr<scalar_t>();
@@ -105,6 +166,7 @@ void calculate_mode(
105166
indices_data += ensure_nonempty_stride(indices, i) * pos;
106167
}
107168

169+
auto stream = at::cuda::getCurrentCUDAStream();
108170
AT_CUDA_CHECK(cudaMemcpyAsync(
109171
values_data, &mode, sizeof(scalar_t), cudaMemcpyHostToDevice, stream));
110172
//memcpy_and_sync will synchronize results

aten/src/ATen/native/cuda/TensorModeKernel.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,10 @@ __global__ void compute_mode(
232232

233233
// Each thread loads up to two elements from the Tensor into shared memory
234234
if (tidx < sliceSize) {
235-
smem[tidx] = input[linearOffset + tidx];
235+
smem[tidx] = c10::load(&input[linearOffset + tidx]);
236236
}
237237
if (stidx < sliceSize) {
238-
smem[stidx] = input[linearOffset + stidx];
238+
smem[stidx] = c10::load(&input[linearOffset + stidx]);
239239
}
240240

241241
// Next, we initialize a boolean region of the buffer, offset by the loaded
@@ -396,11 +396,11 @@ __global__ void compute_mode(
396396
unsigned mode_index[2] = {0u, 0u};
397397
if (tidx * 2 < sliceSize) {
398398
const unsigned idx = tidx * 2;
399-
mode_index[0] = input[linearOffset + idx] == mode ? idx : 0u;
399+
mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
400400
}
401401
if (tidx * 2 + 1 < sliceSize) {
402402
const unsigned idx = tidx * 2 + 1;
403-
mode_index[1] = input[linearOffset + idx] == mode ? idx : 0u;
403+
mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
404404
}
405405

406406
struct MaxIndexOp {

torch/testing/_internal/common_methods_invocations.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15317,8 +15317,6 @@ def error_inputs_mean(op_info, device, **kwargs):
1531715317
skips=(
1531815318
# Resized a non-empty tensor but did not warn about it
1531915319
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
15320-
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
15321-
device_type='cuda'),
1532215320
),
1532315321
sample_inputs_func=sample_inputs_mode,),
1532415322
MvlGammaInfo(variant_test_name='mvlgamma_p_1',

0 commit comments

Comments
 (0)