7
7
#include < ATen/cuda/ThrustAllocator.h>
8
8
#include < c10/core/DeviceArray.h>
9
9
10
+ #include < thrust/count.h>
10
11
#include < thrust/device_ptr.h>
11
12
#include < thrust/device_vector.h>
12
13
#include < thrust/execution_policy.h>
13
14
#include < thrust/extrema.h>
15
+ #include < thrust/find.h>
14
16
#include < thrust/inner_product.h>
15
17
#include < thrust/iterator/constant_iterator.h>
16
18
#include < thrust/sequence.h>
19
21
namespace at {
20
22
namespace native {
21
23
24
+ template <typename scalar_t >
25
+ struct ModeImpl {
26
+ std::tuple<scalar_t , int64_t > operator ()(
27
+ scalar_t *iter_begin,
28
+ scalar_t *iter_end) {
29
+ at::cuda::ThrustAllocator thrust_allocator;
30
+ auto stream = at::cuda::getCurrentCUDAStream ();
31
+ auto policy = thrust::cuda::par (thrust_allocator).on (stream);
32
+
33
+ const auto n_element = iter_end - iter_begin;
34
+ auto cuda_allocator = at::cuda::getCUDADeviceAllocator ();
35
+ auto sort_buffer = c10::DeviceArray<int64_t >(*cuda_allocator, n_element);
36
+ auto sort_buffer_ptr = thrust::device_pointer_cast (sort_buffer.get ());
37
+ auto count_from_zero_iter = thrust::make_counting_iterator (int64_t {0 });
38
+ thrust::copy_n (policy, count_from_zero_iter, n_element, sort_buffer_ptr);
39
+
40
+
41
+ // Sort the input data. The original indices of the data are stored in
42
+ // sort_buffer_ptr
43
+ thrust::sort_by_key (policy, iter_begin, iter_end, sort_buffer_ptr);
44
+
45
+ // Count # of unique elements via an inner product between adjacent elements.
46
+ // Add 1 if two neighboring element are not equal.
47
+ int unique = 1 +
48
+ thrust::inner_product (
49
+ policy,
50
+ iter_begin,
51
+ iter_end - 1 ,
52
+ iter_begin + 1 ,
53
+ 0 ,
54
+ thrust::plus<int >(),
55
+ thrust::not_equal_to<scalar_t >());
56
+
57
+ // Count frequency of each element
58
+ auto keys = c10::DeviceArray<scalar_t >(*cuda_allocator, unique);
59
+ auto counts = c10::DeviceArray<int64_t >(*cuda_allocator, unique);
60
+
61
+ auto keys_ptr = thrust::device_pointer_cast (keys.get ());
62
+ auto counts_ptr = thrust::device_pointer_cast (counts.get ());
63
+
64
+ thrust::reduce_by_key (
65
+ policy,
66
+ iter_begin,
67
+ iter_end,
68
+ thrust::constant_iterator<int >(1 ),
69
+ keys_ptr,
70
+ counts_ptr);
71
+
72
+ // Find index of maximum count
73
+ auto it = thrust::max_element (policy, counts_ptr, counts_ptr + unique);
74
+ scalar_t mode = keys_ptr[it - counts_ptr];
75
+
76
+ // Find first index within which it occurs
77
+ auto position_iter = thrust::find (policy, iter_begin, iter_end, mode);
78
+
79
+ // Translate to original non-sorted index
80
+ TORCH_INTERNAL_ASSERT (position_iter != iter_end);
81
+ int64_t index = sort_buffer_ptr[position_iter - iter_begin];
82
+ return {mode, index };
83
+ }
84
+ };
85
+
86
+ struct EqualsMode {
87
+ bool mode;
88
+
89
+ C10_DEVICE bool operator ()(const uint8_t x) {
90
+ return static_cast <bool >(x) == mode;
91
+ }
92
+ };
93
+
94
+ template <>
95
+ struct ModeImpl <bool > {
96
+ std::tuple<bool , int64_t > operator ()(
97
+ const bool *first,
98
+ const bool *last) {
99
+ at::cuda::ThrustAllocator thrust_allocator;
100
+ auto stream = at::cuda::getCurrentCUDAStream ();
101
+ auto policy = thrust::cuda::par (thrust_allocator).on (stream);
102
+
103
+ // For bool, we can skip finding the unique elements since there
104
+ // are only two possible values.
105
+
106
+ // See NOTE [Loading boolean values]
107
+ auto first_bytes = reinterpret_cast <const uint8_t *>(first);
108
+ auto last_bytes = reinterpret_cast <const uint8_t *>(last);
109
+
110
+ const auto numel = last - first;
111
+ const auto num_true = thrust::count_if (
112
+ policy,
113
+ first_bytes,
114
+ last_bytes,
115
+ [] GPU_LAMBDA (uint8_t x) {
116
+ return static_cast <bool >(x);
117
+ }
118
+ );
119
+ const auto num_false = (numel - num_true);
120
+ const auto mode = num_true > num_false;
121
+
122
+ // Find first index within which it occurs
123
+ const auto position_iter = thrust::find_if (
124
+ policy, first_bytes, last_bytes, EqualsMode{mode});
125
+ const int64_t index = position_iter - first_bytes;
126
+ return {mode, index };
127
+ }
128
+ };
129
+
22
130
template <typename scalar_t >
23
131
void calculate_mode (
24
132
const TensorBase& values,
25
133
const TensorBase& indices,
26
134
const TensorBase& self,
27
135
std::vector<int64_t >& position,
28
136
int dim) {
29
- at::cuda::ThrustAllocator thrust_allocator;
30
- auto stream = at::cuda::getCurrentCUDAStream ();
31
- auto policy = thrust::cuda::par (thrust_allocator).on (stream);
32
137
33
138
TORCH_INTERNAL_ASSERT (self.is_contiguous ());
34
139
@@ -47,53 +152,9 @@ void calculate_mode(
47
152
scalar_t * iter_begin = data;
48
153
scalar_t * iter_end = data + n_element;
49
154
50
- auto cuda_allocator = at::cuda::getCUDADeviceAllocator ();
51
- auto sort_buffer = c10::DeviceArray<int64_t >(*cuda_allocator, n_element);
52
- auto sort_buffer_ptr = thrust::device_pointer_cast (sort_buffer.get ());
53
- auto count_from_zero_iter = thrust::make_counting_iterator (int64_t {0 });
54
- thrust::copy_n (policy, count_from_zero_iter, n_element, sort_buffer_ptr);
55
-
56
-
57
- // Sort the input data. The original indices of the data are stored in
58
- // sort_buffer_ptr
59
- thrust::sort_by_key (policy, iter_begin, iter_end, sort_buffer_ptr);
60
-
61
- // Count # of unique elements via an inner product between adjacent elements.
62
- // Add 1 if two neighboring element are not equal.
63
- int unique = 1 +
64
- thrust::inner_product (
65
- policy,
66
- iter_begin,
67
- iter_end - 1 ,
68
- iter_begin + 1 ,
69
- 0 ,
70
- thrust::plus<int >(),
71
- thrust::not_equal_to<scalar_t >());
72
-
73
- // Count frequency of each element
74
- auto keys = c10::DeviceArray<scalar_t >(*cuda_allocator, unique);
75
- auto counts = c10::DeviceArray<int64_t >(*cuda_allocator, unique);
76
-
77
- auto keys_ptr = thrust::device_pointer_cast (keys.get ());
78
- auto counts_ptr = thrust::device_pointer_cast (counts.get ());
79
-
80
- thrust::reduce_by_key (
81
- policy,
82
- iter_begin,
83
- iter_end,
84
- thrust::constant_iterator<int >(1 ),
85
- keys_ptr,
86
- counts_ptr);
87
-
88
- // Find index of maximum count
89
- auto it = thrust::max_element (policy, counts_ptr, counts_ptr + unique);
90
- scalar_t mode = keys_ptr[it - counts_ptr];
91
-
92
- // Find first index within which it occurs
93
- auto position_iter = thrust::find (policy, iter_begin, iter_end, mode);
94
-
95
- TORCH_INTERNAL_ASSERT (position_iter != iter_end);
96
- int64_t index = sort_buffer_ptr[position_iter - iter_begin];
155
+ scalar_t mode;
156
+ int64_t index ;
157
+ std::tie (mode, index ) = ModeImpl<scalar_t >{}(iter_begin, iter_end);
97
158
98
159
// Place mode, index in output
99
160
scalar_t * values_data = values.data_ptr <scalar_t >();
@@ -105,6 +166,7 @@ void calculate_mode(
105
166
indices_data += ensure_nonempty_stride (indices, i) * pos;
106
167
}
107
168
169
+ auto stream = at::cuda::getCurrentCUDAStream ();
108
170
AT_CUDA_CHECK (cudaMemcpyAsync (
109
171
values_data, &mode, sizeof (scalar_t ), cudaMemcpyHostToDevice, stream));
110
172
// memcpy_and_sync will synchronize results
0 commit comments