Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 2 additions & 41 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1932,45 +1932,6 @@ void reduCGFuncMulti(handler &CGH, KernelType KernelFunc,
Rest(createReduOutAccs<false>(NWorkGroups, CGH, ReduTuple, ReduIndices));
}

namespace reduction {
namespace main_krn {
template <class KernelName> struct NDRangeAtomic64;
} // namespace main_krn
} // namespace reduction

// Specialization for devices with the atomic64 aspect, which guarantees 64 bit
// floating point support for atomic reduction operation.
template <typename KernelName, typename KernelType, int Dims,
typename PropertiesT, class Reduction>
void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc,
const nd_range<Dims> &Range, PropertiesT Properties,
Reduction &Redu) {
auto Out = Redu.getReadWriteAccessorToInitializedMem(CGH);
static_assert(
Reduction::has_float64_atomics,
"Only suitable for reductions that have FP64 atomic operations.");
size_t NElements = Reduction::num_elements;
using Name =
__sycl_reduction_kernel<reduction::main_krn::NDRangeAtomic64, KernelName>;
CGH.parallel_for<Name>(Range, Properties, [=](nd_item<Dims> NDIt) {
// Call user's function. Reducer.MValue gets initialized there.
typename Reduction::reducer_type Reducer;
KernelFunc(NDIt, Reducer);

// If there are multiple values, reduce each separately
// reduce_over_group is only defined for each T, not for span<T, ...>
for (int E = 0; E < NElements; ++E) {
typename Reduction::binary_operation BOp;
Reducer.getElement(E) =
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
}

if (NDIt.get_local_linear_id() == 0) {
Reducer.atomic_combine(&Out[0]);
}
});
}

template <typename... Reductions, size_t... Is>
void associateReduAccsWithHandler(handler &CGH,
std::tuple<Reductions...> &ReduTuple,
Expand Down Expand Up @@ -2386,8 +2347,8 @@ void reduction_parallel_for(handler &CGH,
device D = detail::getDeviceFromHandler(CGH);

if (D.has(aspect::atomic64)) {
reduCGFuncAtomic64<KernelName>(CGH, KernelFunc, Range, Properties,
Redu);
reduCGFuncForNDRangeBothFastReduceAndAtomics<KernelName>(
CGH, KernelFunc, Range, Properties, Redu);
} else {
// Resort to basic implementation as well.
reduction_parallel_for_basic_impl<KernelName>(
Expand Down