Skip to content

Commit b910f4f

Browse files
committed
Formalize is_function_v w/a, extend invoke_simd test.
Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent 16e6ec7 commit b910f4f

File tree

2 files changed

+72
-10
lines changed

2 files changed

+72
-10
lines changed

sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -226,24 +226,42 @@ simd_call_helper(const void *obj_ptr,
226226
return f(simd_args...);
227227
};
228228

229+
#ifdef _GLIBCXX_RELEASE
230+
#if _GLIBCXX_RELEASE < 10
231+
#define __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
232+
#endif // _GLIBCXX_RELEASE < 10
233+
#endif // _GLIBCXX_RELEASE
234+
235+
#ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
229236
// TODO This is a workaround for libstdc++ version 9 buggy behavior which
230237
// returns false in the code below. Version 10 works fine. Once required
231238
// minimum libstdc++ version is bumped to 10, this w/a should be removed.
232239
// template <class F> bool foo(F &&f) {
233240
// return std::is_function_v<std::remove_reference_t<F>>;
234241
// }
235242
// where F is a function type with __regcall.
236-
template <class F> struct is_regcall_function : std::false_type {};
243+
template <class F> struct is_regcall_function_ptr_or_ref_v : std::false_type {};
237244

238245
template <class Ret, class... Args>
239-
struct is_regcall_function<Ret(__regcall *)(Args...)> : std::true_type {};
246+
struct is_regcall_function_ptr_or_ref_v<Ret(__regcall &)(Args...)>
247+
: std::true_type {};
240248

241249
template <class Ret, class... Args>
242-
struct is_regcall_function<Ret(__regcall &)(Args...)> : std::true_type {};
250+
struct is_regcall_function_ptr_or_ref_v<Ret(__regcall *)(Args...)>
251+
: std::true_type {};
243252

244253
template <class F>
245-
static constexpr bool is_regcall_function_v = is_regcall_function<F>::value;
246-
254+
static constexpr bool is_regcall_function_ptr_or_ref_v =
255+
is_regcall_function_ptr_or_ref_v<F>::value;
256+
#endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
257+
258+
template <class Callable>
259+
static constexpr bool is_function_ptr_or_ref_v =
260+
std::is_function_v<std::remove_pointer_t<std::remove_reference_t<Callable>>>
261+
#ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
262+
|| is_regcall_function_ptr_or_ref_v<Callable>
263+
#endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
264+
;
247265
} // namespace detail
248266

249267
// --- The main API
@@ -274,11 +292,7 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
274292
// is fine in this case.
275293
constexpr int N = detail::get_sg_size<Callable, T...>();
276294
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
277-
278-
using CallableNoRef = std::remove_reference_t<Callable>;
279-
using CallableNoRefNoPtr = std::remove_pointer_t<CallableNoRef>;
280-
constexpr bool is_function = std::is_function_v<CallableNoRefNoPtr> ||
281-
detail::is_regcall_function_v<Callable>;
295+
constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
282296

283297
if constexpr (is_function) {
284298
return __builtin_invoke_simd<is_function, RetSpmd>(

sycl/test/invoke_simd/invoke_simd.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,51 @@ SYCL_EXTERNAL auto barx(sub_group sg, float a, char ch,
260260
auto x = invoke_simd(sg, f, 1.f, uniform{a});
261261
static_assert(std::is_same_v<decltype(x), uniform<char>>);
262262
}
263+
264+
// Internal is_function_ref_v meta-API checks {
265+
template <class F> void assert_is_func(F &&f) {
266+
static_assert(
267+
sycl::ext::oneapi::experimental::detail::is_function_ptr_or_ref_v<F>);
268+
}
269+
270+
template <class F> void assert_is_not_func(F &&f) {
271+
static_assert(
272+
!sycl::ext::oneapi::experimental::detail::is_function_ptr_or_ref_v<F>);
273+
}
274+
275+
void ordinary_func();
276+
277+
// clang-format off
278+
void check_f(
279+
int(*func_ptr)(float*), int(__regcall* func_ptr_regcall)(float*),
280+
int(&func_ref)(float*), int(__regcall& func_ref_regcall)(float*),
281+
int(func)(float*), int(__regcall func_regcall)(float*)) {
282+
283+
assert_is_func(SIMD_CALLEE);
284+
assert_is_func(ordinary_func);
285+
286+
assert_is_func(func_ptr);
287+
assert_is_func(func_ptr_regcall);
288+
289+
assert_is_func(func_ref);
290+
assert_is_func(func_ref_regcall);
291+
292+
assert_is_func(func);
293+
assert_is_func(func_regcall);
294+
}
295+
// clang-format on
296+
297+
void check_not_f(char ch) {
298+
assert_is_not_func(SIMD_FUNCTOR{10});
299+
const auto capt_lambda = [=] [[gnu::regcall]] (simd<float, 16>, float) {
300+
// capturing lambda
301+
return ch;
302+
};
303+
const auto non_capt_lambda = [](simd<float, 16>, float) {
304+
// non-capturing lambda
305+
return 10;
306+
};
307+
assert_is_not_func(capt_lambda);
308+
assert_is_not_func(non_capt_lambda);
309+
}
310+
// }

0 commit comments

Comments
 (0)