Skip to content

Commit 812619a

Browse files
committed
Add checks for KERNEL_FLOAT_FP16_OPS_AVAILABLE in fp16.h
1 parent 023bc75 commit 812619a

File tree

4 files changed

+44
-33
lines changed

4 files changed

+44
-33
lines changed

include/kernel_float/approx.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57)
8686
KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57)
8787
KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57)
8888

89-
#if KERNEL_FLOAT_FP16_AVAILABLE
89+
#if KERNEL_FLOAT_FP16_OPS_AVAILABLE
9090
KERNEL_FLOAT_DEVICE half2_t flipsign(half2_t input, half2_t sign) {
9191
// Flip signbit of input when sign<0
9292
uint32_t result;
@@ -281,9 +281,9 @@ KERNEL_FLOAT_DEVICE half2_t tanh(half2_t x) {
281281
}
282282
}
283283

284-
#endif // KERNEL_FLOAT_FP16_AVAILABLE
284+
#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
285285

286-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
286+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
287287
KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(bfloat16_t x) {
288288
return {x, x};
289289
}
@@ -363,7 +363,7 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
363363
transmute<bfloat16_t>(uint16_t(transmute<uint32_t>(a))),
364364
transmute<bfloat16_t>(uint16_t(transmute<uint32_t>(b)))};
365365
}
366-
#endif
366+
#endif // KERNEL_FLOAT_BF16_OPS_AVAILABLE
367367
} // namespace approx
368368

369369
namespace detail {
@@ -394,7 +394,7 @@ struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
394394
apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2, T, T> {}; \
395395
}
396396

397-
#if KERNEL_FLOAT_FP16_AVAILABLE
397+
#if KERNEL_FLOAT_FP16_OPS_AVAILABLE
398398
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4)
399399
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4)
400400
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1)
@@ -406,7 +406,7 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2)
406406
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2)
407407
#endif
408408

409-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
409+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
410410
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4)
411411
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4)
412412
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1)

include/kernel_float/bf16.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ using bfloat16_t = __hip_bfloat16;
2929
using bfloat16x2_t = __hip_bfloat162;
3030
#endif
3131

32-
3332
template<>
3433
struct preferred_vector_size<bfloat16_t> {
3534
static constexpr size_t value = 2;
@@ -294,6 +293,6 @@ struct promote_type<half_t, bfloat16_t> {
294293
} // namespace kernel_float
295294

296295
#endif // KERNEL_FLOAT_FP16_AVAILABLE
297-
#endif // KERNEL_FLOAT_BF16_AVAILABLE
296+
#endif // KERNEL_FLOAT_BF16_AVAILABLE
298297

299298
#endif //KERNEL_FLOAT_BF16_H

include/kernel_float/fp16.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil)
9393
KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint)
9494
KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc)
9595
KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
96-
#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
96+
#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
9797

9898
#if KERNEL_FLOAT_IS_DEVICE
9999
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
@@ -124,7 +124,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
124124
#if KERNEL_FLOAT_IS_CUDA
125125
KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2)
126126
KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2)
127-
#endif // KERNEL_FLOAT_IS_CUDA
127+
#endif // KERNEL_FLOAT_IS_CUDA
128128

129129
KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2)
130130
KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2)
@@ -137,8 +137,9 @@ KERNEL_FLOAT_FP16_BINARY_FUN(less, __hlt, __hlt2)
137137
KERNEL_FLOAT_FP16_BINARY_FUN(less_equal, __hle, __hle2)
138138
KERNEL_FLOAT_FP16_BINARY_FUN(greater, __hgt, __hgt2)
139139
KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2)
140-
#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
140+
#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
141141

142+
#if KERNEL_FLOAT_FP16_OPS_AVAILABLE
142143
#if KERNEL_FLOAT_IS_DEVICE
143144
namespace ops {
144145
template<>
@@ -175,7 +176,8 @@ struct apply_impl<accurate_policy, ops::fma<half_t>, 2, half_t, half_t, half_t,
175176

176177
KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH)
177178
} // namespace detail
178-
#endif
179+
#endif // KERNEL_FLOAT_IS_DEVICE
180+
#endif //KERNEL_FLOAT_FP16_OPS_AVAILABLE
179181

180182
#define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \
181183
namespace ops { \
@@ -240,6 +242,6 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, half_t)
240242

241243
} // namespace kernel_float
242244

243-
#endif // KERNEL_FLOAT_FP16_AVAILABLE
245+
#endif // KERNEL_FLOAT_FP16_AVAILABLE
244246

245247
#endif //KERNEL_FLOAT_FP16_H

single_include/kernel_float.h

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2025-08-21 10:13:04.148230
20-
// git hash: 4d0d49cad7962d3f9ba4f2a0abfa2faea3ec7efa
19+
// date: 2025-09-02 18:31:16.281730
20+
// git hash: 023bc75e8ec67145cdcb447c5fd9aa7d7f180cc6
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -59,10 +59,18 @@
5959
#define KERNEL_FLOAT_FP16_AVAILABLE (1)
6060
#endif // KERNEL_FLOAT_FP16_AVAILABLE
6161

62+
#ifndef KERNEL_FLOAT_FP16_OPS_AVAILABLE
63+
#define KERNEL_FLOAT_FP16_OPS_AVAILABLE ((KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 530) || KERNEL_FLOAT_IS_HIP)
64+
#endif
65+
6266
#ifndef KERNEL_FLOAT_BF16_AVAILABLE
6367
#define KERNEL_FLOAT_BF16_AVAILABLE (1)
6468
#endif // KERNEL_FLOAT_BF16_AVAILABLE
6569

70+
#ifndef KERNEL_FLOAT_BF16_OPS_AVAILABLE
71+
#define KERNEL_FLOAT_BF16_OPS_AVAILABLE ((KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800) || KERNEL_FLOAT_IS_HIP)
72+
#endif
73+
6674
#ifndef KERNEL_FLOAT_FP8_AVAILABLE
6775
#ifdef __CUDACC_VER_MAJOR__
6876
#define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12)
@@ -4171,6 +4179,7 @@ struct allow_float_fallback<half_t> {
41714179
#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2)
41724180
#endif
41734181

4182+
#if KERNEL_FLOAT_FP16_OPS_AVAILABLE
41744183
KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin)
41754184
KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos)
41764185

@@ -4191,6 +4200,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil)
41914200
KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint)
41924201
KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc)
41934202
KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
4203+
#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
41944204

41954205
#if KERNEL_FLOAT_IS_DEVICE
41964206
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
@@ -4217,10 +4227,11 @@ KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
42174227
#endif
42184228

42194229
// There are not available in HIP
4230+
#if KERNEL_FLOAT_FP16_OPS_AVAILABLE
42204231
#if KERNEL_FLOAT_IS_CUDA
42214232
KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2)
42224233
KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2)
4223-
#endif
4234+
#endif // KERNEL_FLOAT_IS_CUDA
42244235

42254236
KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2)
42264237
KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2)
@@ -4233,7 +4244,9 @@ KERNEL_FLOAT_FP16_BINARY_FUN(less, __hlt, __hlt2)
42334244
KERNEL_FLOAT_FP16_BINARY_FUN(less_equal, __hle, __hle2)
42344245
KERNEL_FLOAT_FP16_BINARY_FUN(greater, __hgt, __hgt2)
42354246
KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2)
4247+
#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
42364248

4249+
#if KERNEL_FLOAT_FP16_OPS_AVAILABLE
42374250
#if KERNEL_FLOAT_IS_DEVICE
42384251
namespace ops {
42394252
template<>
@@ -4270,7 +4283,8 @@ struct apply_impl<accurate_policy, ops::fma<half_t>, 2, half_t, half_t, half_t,
42704283

42714284
KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH)
42724285
} // namespace detail
4273-
#endif
4286+
#endif // KERNEL_FLOAT_IS_DEVICE
4287+
#endif //KERNEL_FLOAT_FP16_OPS_AVAILABLE
42744288

42754289
#define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \
42764290
namespace ops { \
@@ -4335,7 +4349,7 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, half_t)
43354349

43364350
} // namespace kernel_float
43374351

4338-
#endif
4352+
#endif // KERNEL_FLOAT_FP16_AVAILABLE
43394353

43404354
#endif //KERNEL_FLOAT_FP16_H
43414355
#ifndef KERNEL_FLOAT_BF16_H
@@ -4369,10 +4383,6 @@ using bfloat16_t = __hip_bfloat16;
43694383
using bfloat16x2_t = __hip_bfloat162;
43704384
#endif
43714385

4372-
#if KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800
4373-
#define KERNEL_FLOAT_BF16_OPS_SUPPORTED 1
4374-
#endif
4375-
43764386
template<>
43774387
struct preferred_vector_size<bfloat16_t> {
43784388
static constexpr size_t value = 2;
@@ -4420,7 +4430,7 @@ struct allow_float_fallback<bfloat16_t> {
44204430
}; \
44214431
}
44224432

4423-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4433+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
44244434
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
44254435
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)
44264436

@@ -4496,7 +4506,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2)
44964506
}; \
44974507
}
44984508

4499-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4509+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
45004510
KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2)
45014511
KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2)
45024512
KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)
@@ -4512,7 +4522,7 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)
45124522
KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2)
45134523
#endif
45144524

4515-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4525+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
45164526
namespace ops {
45174527
template<>
45184528
struct fma<bfloat16_t> {
@@ -4583,7 +4593,7 @@ KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH)
45834593
KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input))
45844594
KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input))
45854595

4586-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4596+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
45874597
// clang-format off
45884598
// there are no official char casts. Instead, cast to int and then to char
45894599
KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input));
@@ -4637,7 +4647,7 @@ struct promote_type<half_t, bfloat16_t> {
46374647
} // namespace kernel_float
46384648

46394649
#endif // KERNEL_FLOAT_FP16_AVAILABLE
4640-
#endif
4650+
#endif // KERNEL_FLOAT_BF16_AVAILABLE
46414651

46424652
#endif //KERNEL_FLOAT_BF16_H
46434653
#pragma once
@@ -4728,7 +4738,7 @@ KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57)
47284738
KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57)
47294739
KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57)
47304740

4731-
#if KERNEL_FLOAT_FP16_AVAILABLE
4741+
#if KERNEL_FLOAT_FP16_OPS_AVAILABLE
47324742
KERNEL_FLOAT_DEVICE half2_t flipsign(half2_t input, half2_t sign) {
47334743
// Flip signbit of input when sign<0
47344744
uint32_t result;
@@ -4923,9 +4933,9 @@ KERNEL_FLOAT_DEVICE half2_t tanh(half2_t x) {
49234933
}
49244934
}
49254935

4926-
#endif // KERNEL_FLOAT_FP16_AVAILABLE
4936+
#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
49274937

4928-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4938+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
49294939
KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(bfloat16_t x) {
49304940
return {x, x};
49314941
}
@@ -5005,7 +5015,7 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
50055015
transmute<bfloat16_t>(uint16_t(transmute<uint32_t>(a))),
50065016
transmute<bfloat16_t>(uint16_t(transmute<uint32_t>(b)))};
50075017
}
5008-
#endif
5018+
#endif // KERNEL_FLOAT_BF16_OPS_AVAILABLE
50095019
} // namespace approx
50105020

50115021
namespace detail {
@@ -5036,7 +5046,7 @@ struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
50365046
apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2, T, T> {}; \
50375047
}
50385048

5039-
#if KERNEL_FLOAT_FP16_AVAILABLE
5049+
#if KERNEL_FLOAT_FP16_OPS_AVAILABLE
50405050
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4)
50415051
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4)
50425052
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1)
@@ -5048,7 +5058,7 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2)
50485058
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2)
50495059
#endif
50505060

5051-
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
5061+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
50525062
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4)
50535063
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4)
50545064
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1)

0 commit comments

Comments
 (0)